Skip to main content

Grove/WASM/
ModuleLoader.rs

1//! WASM Module Loader
2//!
3//! Handles loading, compiling, and instantiating WebAssembly modules.
4//! Provides utilities for working with WASM modules from various sources.
5
6use std::{
7	fs,
8	path::{Path, PathBuf},
9	sync::Arc,
10};
11
12use anyhow::{Context, Result};
13use serde::{Deserialize, Serialize};
14use tokio::sync::RwLock;
15use tracing::{debug, info, instrument};
16use wasmtime::{Instance, Linker, Module, Store, StoreLimits};
17
18use crate::WASM::Runtime::{WASMConfig, WASMRuntime};
19
20/// WASM module wrapper with metadata
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct WASMModule {
23	/// Unique module identifier
24	pub id:String,
25	/// Module name (if available from name section)
26	pub name:Option<String>,
27	/// Path to the module file (if loaded from disk)
28	pub path:Option<PathBuf>,
29	/// Module source type
30	pub source_type:ModuleSourceType,
31	/// Module size in bytes
32	pub size:usize,
33	/// Exported functions
34	pub exported_functions:Vec<String>,
35	/// Exported memories
36	pub exported_memories:Vec<String>,
37	/// Exported tables
38	pub exported_tables:Vec<String>,
39	/// Import declarations
40	pub imports:Vec<ImportDeclaration>,
41	/// Compilation timestamp
42	pub compiled_at:u64,
43	/// Module hash (for caching)
44	pub hash:Option<String>,
45}
46
47/// Source type of a WASM module
48#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
49pub enum ModuleSourceType {
50	/// Module loaded from a file
51	File,
52	/// Module loaded from in-memory bytes
53	Memory,
54	/// Module loaded from a network URL
55	Url,
56	/// Module generated dynamically
57	Generated,
58}
59
60/// Import declaration for a WASM module
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct ImportDeclaration {
63	/// Module name being imported from
64	pub module:String,
65	/// Name of the imported item
66	pub name:String,
67	/// Kind of import
68	pub kind:ImportKind,
69}
70
71/// Kind of import
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
73pub enum ImportKind {
74	/// Function import
75	Function,
76	/// Table import
77	Table,
78	/// Memory import
79	Memory,
80	/// Global import
81	Global,
82	/// Tag import
83	Tag,
84}
85
86/// Module loading options
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct ModuleLoadOptions {
89	/// Enable lazy compilation
90	pub lazy_compilation:bool,
91	/// Enable module caching
92	pub enable_cache:bool,
93	/// Cache directory path
94	pub cache_dir:Option<PathBuf>,
95	/// Custom linker configuration
96	pub custom_linker:bool,
97	/// Validate module before loading
98	pub validate:bool,
99	/// Optimized compilation
100	pub optimized:bool,
101}
102
103impl Default for ModuleLoadOptions {
104	fn default() -> Self {
105		Self {
106			lazy_compilation:false,
107			enable_cache:true,
108			cache_dir:None,
109			custom_linker:false,
110			validate:true,
111			optimized:true,
112		}
113	}
114}
115
116/// Module instance with store
117pub struct WASMInstance {
118	/// The WASM instance
119	pub instance:Instance,
120	/// The associated store
121	pub store:Store<StoreLimits>,
122	/// Instance ID
123	pub id:String,
124	/// Module reference
125	pub module:Arc<Module>,
126}
127
128/// WASM Module Loader
129pub struct ModuleLoaderImpl {
130	runtime:Arc<WASMRuntime>,
131	#[allow(dead_code)]
132	config:WASMConfig,
133	#[allow(dead_code)]
134	linkers:Arc<RwLock<Vec<Linker<()>>>>,
135	loaded_modules:Arc<RwLock<Vec<WASMModule>>>,
136}
137
138impl ModuleLoaderImpl {
139	/// Create a new module loader
140	pub fn new(runtime:Arc<WASMRuntime>, config:WASMConfig) -> Self {
141		Self {
142			runtime,
143			config,
144			linkers:Arc::new(RwLock::new(Vec::new())),
145			loaded_modules:Arc::new(RwLock::new(Vec::new())),
146		}
147	}
148
149	/// Load a WASM module from a file
150	#[instrument(skip(self, path))]
151	pub async fn load_from_file(&self, path:&Path) -> Result<WASMModule> {
152		info!("Loading WASM module from file: {:?}", path);
153
154		let wasm_bytes = fs::read(path).context(format!("Failed to read WASM file: {:?}", path))?;
155
156		self.load_from_memory(&wasm_bytes, ModuleSourceType::File)
157			.await
158			.map(|mut module| {
159				module.path = Some(path.to_path_buf());
160				module
161			})
162	}
163
164	/// Load a WASM module from memory
165	#[instrument(skip(self, wasm_bytes))]
166	pub async fn load_from_memory(&self, wasm_bytes:&[u8], source_type:ModuleSourceType) -> Result<WASMModule> {
167		info!("Loading WASM module from memory ({} bytes)", wasm_bytes.len());
168
169		// Validate if option is set
170		if ModuleLoadOptions::default().validate {
171			if !self.runtime.validate_module(wasm_bytes)? {
172				return Err(anyhow::anyhow!("WASM module validation failed"));
173			}
174		}
175
176		// Compile the module
177		let module = self.runtime.compile_module(wasm_bytes)?;
178
179		// Extract module information
180		let module_info = self.extract_module_info(&module);
181
182		// Create module wrapper
183		let wasm_module = WASMModule {
184			id:generate_module_id(&module_info.name),
185			name:module_info.name,
186			path:None,
187			source_type,
188			size:wasm_bytes.len(),
189			exported_functions:module_info.exports.functions,
190			exported_memories:module_info.exports.memories,
191			exported_tables:module_info.exports.tables,
192			imports:module_info.imports,
193			compiled_at:std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs(),
194			hash:self.compute_hash(wasm_bytes),
195		};
196
197		// Store the module
198		let mut loaded = self.loaded_modules.write().await;
199		loaded.push(wasm_module.clone());
200
201		debug!("WASM module loaded successfully: {}", wasm_module.id);
202
203		Ok(wasm_module)
204	}
205
206	/// Load a WASM module from a URL
207	#[instrument(skip(self, url))]
208	pub async fn load_from_url(&self, url:&str) -> Result<WASMModule> {
209		info!("Loading WASM module from URL: {}", url);
210
211		// Fetch the module
212		let response = reqwest::get(url)
213			.await
214			.context(format!("Failed to fetch WASM module from: {}", url))?;
215
216		if !response.status().is_success() {
217			return Err(anyhow::anyhow!("Failed to fetch WASM module: HTTP {}", response.status()));
218		}
219
220		let wasm_bytes = response.bytes().await?;
221
222		self.load_from_memory(&wasm_bytes, ModuleSourceType::Url).await
223	}
224
225	/// Instantiate a loaded module
226	#[instrument(skip(self, module))]
227	pub async fn instantiate(&self, module:&Module, mut store:Store<StoreLimits>) -> Result<WASMInstance> {
228		debug!("Instantiating WASM module");
229
230		// Create linker with StoreLimits type
231		let linker = self.runtime.create_linker::<StoreLimits>(true)?;
232
233		// Instantiate
234		let instance = linker
235			.instantiate(&mut store, module)
236			.map_err(|e| anyhow::anyhow!("Failed to instantiate WASM module: {}", e))?;
237
238		let instance_id = generate_instance_id();
239
240		debug!("WASM module instantiated: {}", instance_id);
241
242		Ok(WASMInstance { instance, store, id:instance_id, module:Arc::new(module.clone()) })
243	}
244
245	/// Get all loaded modules
246	pub async fn get_loaded_modules(&self) -> Vec<WASMModule> { self.loaded_modules.read().await.clone() }
247
248	/// Get a loaded module by ID
249	pub async fn get_module_by_id(&self, id:&str) -> Option<WASMModule> {
250		let loaded = self.loaded_modules.read().await;
251		loaded.iter().find(|m| m.id == id).cloned()
252	}
253
254	/// Unload a module
255	pub async fn unload_module(&self, id:&str) -> Result<bool> {
256		let mut loaded = self.loaded_modules.write().await;
257		let pos = loaded.iter().position(|m| m.id == id);
258
259		if let Some(pos) = pos {
260			loaded.remove(pos);
261			info!("WASM module unloaded: {}", id);
262			Ok(true)
263		} else {
264			Ok(false)
265		}
266	}
267
268	/// Extract module information from a compiled module
269	fn extract_module_info(&self, module:&Module) -> ModuleInfo {
270		let mut exports = Exports { functions:Vec::new(), memories:Vec::new(), tables:Vec::new(), globals:Vec::new() };
271
272		let mut imports = Vec::new();
273
274		for export in module.exports() {
275			match export.ty() {
276				wasmtime::ExternType::Func(_) => exports.functions.push(export.name().to_string()),
277				wasmtime::ExternType::Memory(_) => exports.memories.push(export.name().to_string()),
278				wasmtime::ExternType::Table(_) => exports.tables.push(export.name().to_string()),
279				wasmtime::ExternType::Global(_) => exports.globals.push(export.name().to_string()),
280				_ => {},
281			}
282		}
283
284		for import in module.imports() {
285			let kind = match import.ty() {
286				wasmtime::ExternType::Func(_) => ImportKind::Function,
287				wasmtime::ExternType::Memory(_) => ImportKind::Memory,
288				wasmtime::ExternType::Table(_) => ImportKind::Table,
289				wasmtime::ExternType::Global(_) => ImportKind::Global,
290				_ => ImportKind::Tag,
291			};
292			imports.push(ImportDeclaration {
293				module:import.module().to_string(),
294				name:import.name().to_string(),
295				kind,
296			});
297		}
298
299		ModuleInfo {
300			name:None, // Would need to parse name section
301			exports,
302			imports,
303		}
304	}
305
306	/// Compute a hash of the WASM bytes for caching
307	fn compute_hash(&self, wasm_bytes:&[u8]) -> Option<String> {
308		use std::{
309			collections::hash_map::DefaultHasher,
310			hash::{Hash, Hasher},
311		};
312
313		let mut hasher = DefaultHasher::new();
314		wasm_bytes.hash(&mut hasher);
315		Some(format!("{:x}", hasher.finish()))
316	}
317}
318
319// Helper structures and functions
320
321struct ModuleInfo {
322	name:Option<String>,
323	exports:Exports,
324	imports:Vec<ImportDeclaration>,
325}
326
327struct Exports {
328	functions:Vec<String>,
329	memories:Vec<String>,
330	tables:Vec<String>,
331	globals:Vec<String>,
332}
333
334fn generate_module_id(name:&Option<String>) -> String {
335	match name {
336		Some(n) => format!("module-{}", n.to_lowercase().replace(' ', "-")),
337		None => format!("module-{}", uuid::Uuid::new_v4()),
338	}
339}
340
341fn generate_instance_id() -> String { format!("instance-{}", uuid::Uuid::new_v4()) }
342
343#[cfg(test)]
344mod tests {
345	use super::*;
346
347	#[tokio::test]
348	async fn test_module_loader_creation() {
349		let runtime = Arc::new(WASMRuntime::new(WASMConfig::default()).await.unwrap());
350		let config = WASMConfig::default();
351		let loader = ModuleLoaderImpl::new(runtime, config);
352
353		// Just test creation
354		assert_eq!(loader.get_loaded_modules().await.len(), 0);
355	}
356
357	#[test]
358	fn test_module_load_options_default() {
359		let options = ModuleLoadOptions::default();
360		assert_eq!(options.validate, true);
361		assert_eq!(options.enable_cache, true);
362	}
363
364	#[test]
365	fn test_generate_module_id() {
366		let id1 = generate_module_id(&Some("Test Module".to_string()));
367		let id2 = generate_module_id(&None);
368
369		assert!(id1.starts_with("module-"));
370		assert!(id2.starts_with("module-"));
371		assert_ne!(id1, id2);
372	}
373}
374
375// Add uuid dependency to Cargo.toml if needed
376// uuid = { version = "1.6", features = ["v4"] }