Skip to main content

Grove/WASM/
FunctionExport.rs

1//! Function Export Module
2//!
3//! Handles exporting host functions to WASM modules.
4//! Provides registration and management of functions that WASM can call.
5
6use std::{collections::HashMap, sync::Arc};
7
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use tokio::sync::RwLock;
11use tracing::{debug, info, instrument, warn};
12use wasmtime::{Caller, Linker};
13
14use crate::WASM::HostBridge::{FunctionSignature, HostBridgeImpl as HostBridge, HostFunctionCallback};
15
16/// Host function registry for WASM exports
17pub struct HostFunctionRegistry {
18	/// Registered host functions
19	functions:Arc<RwLock<HashMap<String, RegisteredHostFunction>>>,
20	/// Associated host bridge
21	#[allow(dead_code)]
22	bridge:Arc<HostBridge>,
23}
24
25/// Registered host function with metadata
26#[derive(Debug, Clone)]
27struct RegisteredHostFunction {
28	/// Function name
29	#[allow(dead_code)]
30	name:String,
31	/// Function signature
32	#[allow(dead_code)]
33	signature:FunctionSignature,
34	/// Synchronous callback
35	callback:Option<HostFunctionCallback>,
36	/// Registration timestamp
37	#[allow(dead_code)]
38	registered_at:u64,
39	/// Call statistics
40	stats:FunctionStats,
41}
42
43/// Function statistics
44#[derive(Debug, Clone, Default)]
45pub struct FunctionStats {
46	/// Number of times called
47	pub call_count:u64,
48	/// Total execution time in nanoseconds
49	pub total_execution_ns:u64,
50	/// Last call timestamp
51	pub last_call_at:Option<u64>,
52	/// Number of errors
53	pub error_count:u64,
54}
55
56/// Export configuration for WASM functions
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct ExportConfig {
59	/// Enable function export by default
60	pub auto_export:bool,
61	/// Enable timing statistics
62	pub enable_stats:bool,
63	/// Maximum number of functions that can be exported
64	pub max_functions:usize,
65	/// Function name prefix for exports
66	pub name_prefix:Option<String>,
67}
68
69impl Default for ExportConfig {
70	fn default() -> Self {
71		Self {
72			auto_export:true,
73			enable_stats:true,
74			max_functions:1000,
75			name_prefix:Some("host_".to_string()),
76		}
77	}
78}
79
80/// Function export for WASM
81pub struct FunctionExportImpl {
82	registry:Arc<HostFunctionRegistry>,
83	config:ExportConfig,
84}
85
86impl FunctionExportImpl {
87	/// Create a new function export manager
88	pub fn new(bridge:Arc<HostBridge>) -> Self {
89		Self {
90			registry:Arc::new(HostFunctionRegistry { functions:Arc::new(RwLock::new(HashMap::new())), bridge }),
91			config:ExportConfig::default(),
92		}
93	}
94
95	/// Create with custom configuration
96	pub fn with_config(bridge:Arc<HostBridge>, config:ExportConfig) -> Self {
97		Self {
98			registry:Arc::new(HostFunctionRegistry { functions:Arc::new(RwLock::new(HashMap::new())), bridge }),
99			config,
100		}
101	}
102
103	/// Register a host function for export to WASM
104	#[instrument(skip(self, callback))]
105	pub async fn register_function(
106		&self,
107		name:&str,
108		signature:FunctionSignature,
109		callback:HostFunctionCallback,
110	) -> Result<()> {
111		info!("Registering host function for export: {}", name);
112
113		let functions = self.registry.functions.read().await;
114
115		// Check max function limit
116		if functions.len() >= self.config.max_functions {
117			return Err(anyhow::anyhow!(
118				"Maximum number of exported functions reached: {}",
119				self.config.max_functions
120			));
121		}
122
123		drop(functions);
124
125		let mut functions = self.registry.functions.write().await;
126
127		let registered_at = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs();
128
129		functions.insert(
130			name.to_string(),
131			RegisteredHostFunction {
132				name:name.to_string(),
133				signature,
134				callback:Some(callback),
135				registered_at,
136				stats:FunctionStats::default(),
137			},
138		);
139
140		debug!("Host function registered for WASM export: {}", name);
141		Ok(())
142	}
143
144	/// Register multiple host functions
145	#[instrument(skip(self, callbacks))]
146	pub async fn register_functions(
147		&self,
148		signatures:Vec<FunctionSignature>,
149		callbacks:Vec<HostFunctionCallback>,
150	) -> Result<()> {
151		if signatures.len() != callbacks.len() {
152			return Err(anyhow::anyhow!("Number of signatures must match number of callbacks"));
153		}
154
155		for (sig, callback) in signatures.into_iter().zip(callbacks) {
156			let name = sig.name.clone();
157			self.register_function(&name, sig, callback).await?;
158		}
159
160		Ok(())
161	}
162
163	/// Export all registered functions to a WASMtime linker
164	#[instrument(skip(self, linker))]
165	pub async fn export_to_linker<T>(&self, linker:&mut Linker<T>) -> Result<()>
166	where
167		T: Send + 'static, {
168		info!(
169			"Exporting {} host functions to linker",
170			self.registry.functions.read().await.len()
171		);
172
173		let functions = self.registry.functions.read().await;
174
175		for (name, func) in functions.iter() {
176			self.export_single_function(linker, name, func)?;
177		}
178
179		info!("All host functions exported to linker");
180		Ok(())
181	}
182
183	/// Export a single function to the linker
184	fn export_single_function<T>(&self, linker:&mut Linker<T>, name:&str, func:&RegisteredHostFunction) -> Result<()>
185	where
186		T: Send + 'static, {
187		debug!("Exporting function: {}", name);
188
189		let callback = func
190			.callback
191			.ok_or_else(|| anyhow::anyhow!("No callback available for function: {}", name))?;
192
193		let func_name = if let Some(prefix) = &self.config.name_prefix {
194			format!("{}{}", prefix, name)
195		} else {
196			name.to_string()
197		};
198
199		let func_name_for_debug = func_name.clone();
200		let func_name_inner = func_name.clone();
201
202		// Create a wrapper function that handles stats and error handling
203		let _wrapped_callback =
204			move |_caller:Caller<'_, T>, args:&[wasmtime::Val]| -> Result<Vec<wasmtime::Val>, wasmtime::Trap> {
205				let _start = std::time::Instant::now();
206
207				// Convert args to bytes
208				let args_bytes:Result<Vec<bytes::Bytes>, _> = args
209					.iter()
210					.map(|arg| {
211						match arg {
212							wasmtime::Val::I32(i) => {
213								serde_json::to_vec(i)
214									.map(bytes::Bytes::from)
215									.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
216							},
217							wasmtime::Val::I64(i) => {
218								serde_json::to_vec(i)
219									.map(bytes::Bytes::from)
220									.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
221							},
222							wasmtime::Val::F32(f) => {
223								serde_json::to_vec(f)
224									.map(bytes::Bytes::from)
225									.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
226							},
227							wasmtime::Val::F64(f) => {
228								serde_json::to_vec(f)
229									.map(bytes::Bytes::from)
230									.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
231							},
232							_ => Err(anyhow::anyhow!("Unsupported argument type")),
233						}
234					})
235					.collect();
236
237				let args_bytes = args_bytes.map_err(|_| {
238					warn!("Error converting arguments for function '{}'", func_name_inner);
239					wasmtime::Trap::StackOverflow
240				})?;
241
242				// Call the callback
243				let result = callback(args_bytes);
244
245				match result {
246					Ok(response_bytes) => {
247						// Deserialize response
248						let result_val:serde_json::Value = serde_json::from_slice(&response_bytes).map_err(|_| {
249							warn!("Error deserializing response for function '{}'", func_name_inner);
250							wasmtime::Trap::StackOverflow
251						})?;
252
253						let ret_val = match result_val {
254							serde_json::Value::Number(n) => {
255								if let Some(i) = n.as_i64() {
256									wasmtime::Val::I32(i as i32)
257								} else if let Some(f) = n.as_f64() {
258									wasmtime::Val::I64(f as i64)
259								} else {
260									warn!("Invalid number format for function '{}'", func_name_inner);
261									return Err(wasmtime::Trap::StackOverflow);
262								}
263							},
264							_ => {
265								warn!("Unsupported response type for function '{}'", func_name_inner);
266								return Err(wasmtime::Trap::StackOverflow);
267							},
268						};
269
270						Ok(vec![ret_val])
271					},
272					Err(e) => {
273						// Error handling
274						debug!("Host function '{}' returned error: {}", func_name_inner, e);
275						Err(wasmtime::Trap::StackOverflow)
276					},
277				}
278			};
279
280		// Define the function signature for WASMtime
281		let _wasmparser_signature = wasmparser::FuncType::new([wasmparser::ValType::I32], [wasmparser::ValType::I32]);
282
283		// Register host function with the linker using simple i32->i32 signature
284		// In Wasmtime 20, func_wrap expects parameters to be inferred from the closure
285		// signature
286		let func_name_for_logging = func_name.clone();
287		linker
288			.func_wrap(
289				"_host", // Module name for host functions
290				&func_name,
291				move |_caller:wasmtime::Caller<'_, T>, input_param:i32| -> i32 {
292					// Track function call for metrics
293					let start = std::time::Instant::now();
294
295					// Convert input parameter to bytes for callback
296					let args_bytes = match serde_json::to_vec(&input_param).map(bytes::Bytes::from) {
297						Ok(b) => b,
298						Err(e) => {
299							warn!("Serialization error for function '{}': {}", func_name_for_logging, e);
300							return -1i32;
301						},
302					};
303
304					// Call the registered callback
305					let result = callback(vec![args_bytes]);
306
307					match result {
308						Ok(response_bytes) => {
309							// Deserialize response
310							let result_val:serde_json::Value = match serde_json::from_slice(&response_bytes) {
311								Ok(v) => v,
312								Err(_) => {
313									warn!("Error deserializing response for function '{}'", func_name_for_logging);
314									return -1i32;
315								},
316							};
317
318							// Extract result value
319							let ret_val = match result_val {
320								serde_json::Value::Number(n) => {
321									if let Some(i) = n.as_i64() {
322										i as i32
323									} else if let Some(f) = n.as_f64() {
324										f as i32
325									} else {
326										warn!("Invalid number format for function '{}'", func_name_for_logging);
327										-1i32
328									}
329								},
330								serde_json::Value::Bool(b) => {
331									if b {
332										1i32
333									} else {
334										0i32
335									}
336								},
337								_ => {
338									warn!(
339										"Unsupported response type for function '{}', expected number or bool",
340										func_name_for_logging
341									);
342									-1i32
343								},
344							};
345
346							// Log successful call
347							let duration = start.elapsed();
348							debug!(
349								"[FunctionExport] Host function '{}' executed successfully in {}µs",
350								func_name_for_logging,
351								duration.as_micros()
352							);
353
354							ret_val
355						},
356						Err(e) => {
357							// Error handling - return error code to WASM caller
358							debug!(
359								"[FunctionExport] Host function '{}' returned error: {}",
360								func_name_for_logging, e
361							);
362							// Return -1 to indicate error to WASM caller
363							-1i32
364						},
365					}
366				},
367			)
368			.map_err(|e| {
369				warn!("[FunctionExport] Failed to wrap host function '{}': {}", func_name_for_debug, e);
370				e
371			})?;
372
373		debug!(
374			"[FunctionExport] Host function '{}' registered successfully",
375			func_name_for_debug
376		);
377
378		Ok(())
379	}
380
381	/// Convert our signature to WASMtime signature type
382	#[allow(dead_code)]
383	fn wasmtime_signature_from_signature(&self, _sig:&FunctionSignature) -> Result<wasmparser::FuncType> {
384		// This is a placeholder - actual implementation depends on the exact types
385		// In production, this would map ParamType and ReturnType to WASMtime types
386		Ok(wasmparser::FuncType::new([], []))
387	}
388
389	/// Get all registered function names
390	pub async fn get_function_names(&self) -> Vec<String> {
391		self.registry.functions.read().await.keys().cloned().collect()
392	}
393
394	/// Get function statistics
395	pub async fn get_function_stats(&self, name:&str) -> Option<FunctionStats> {
396		self.registry.functions.read().await.get(name).map(|f| f.stats.clone())
397	}
398
399	/// Unregister a function
400	#[instrument(skip(self))]
401	pub async fn unregister_function(&self, name:&str) -> Result<bool> {
402		let mut functions = self.registry.functions.write().await;
403		let removed = functions.remove(name).is_some();
404
405		if removed {
406			info!("Unregistered host function: {}", name);
407		} else {
408			warn!("Attempted to unregister non-existent function: {}", name);
409		}
410
411		Ok(removed)
412	}
413
414	/// Clear all registered functions
415	pub async fn clear(&self) {
416		info!("Clearing all registered host functions");
417		self.registry.functions.write().await.clear();
418	}
419}
420
421#[cfg(test)]
422mod tests {
423	use super::*;
424
425	#[tokio::test]
426	async fn test_function_export_creation() {
427		let bridge = Arc::new(HostBridgeImpl::new());
428		let export = FunctionExportImpl::new(bridge);
429
430		assert_eq!(export.get_function_names().await.len(), 0);
431	}
432
433	#[tokio::test]
434	async fn test_register_function() {
435		let bridge = Arc::new(HostBridgeImpl::new());
436		let export = FunctionExportImpl::new(bridge);
437
438		let signature = FunctionSignature {
439			name:"echo".to_string(),
440			param_types:vec![ParamType::I32],
441			return_type:Some(ReturnType::I32),
442			is_async:false,
443		};
444
445		let callback = |args:Vec<bytes::Bytes>| Ok(args.get(0).cloned().unwrap_or(bytes::Bytes::new()));
446
447		let result:anyhow::Result<()> = export.register_function("echo", signature, callback).await;
448		assert!(result.is_ok());
449		assert_eq!(export.get_function_names().await.len(), 1);
450	}
451
452	#[tokio::test]
453	async fn test_unregister_function() {
454		let bridge = Arc::new(HostBridgeImpl::new());
455		let export = FunctionExportImpl::new(bridge);
456
457		let signature = FunctionSignature {
458			name:"test".to_string(),
459			param_types:vec![ParamType::I32],
460			return_type:Some(ReturnType::I32),
461			is_async:false,
462		};
463
464		let callback = |_:Vec<bytes::Bytes>| Ok(bytes::Bytes::new());
465		let _:anyhow::Result<()> = export.register_function("test", signature, callback).await;
466
467		let result:bool = export.unregister_function("test").await.unwrap();
468		assert!(result);
469		assert_eq!(export.get_function_names().await.len(), 0);
470	}
471
472	#[test]
473	fn test_export_config_default() {
474		let config = ExportConfig::default();
475		assert_eq!(config.auto_export, true);
476		assert_eq!(config.max_functions, 1000);
477	}
478
479	#[test]
480	fn test_function_stats_default() {
481		let stats = FunctionStats::default();
482		assert_eq!(stats.call_count, 0);
483		assert_eq!(stats.error_count, 0);
484	}
485}