Skip to main content

Grove/WASM/
HostBridge.rs

1//! Host Bridge
2//!
3//! Provides bidirectional communication between the host (Grove) and WASM
4//! modules. Handles function calls, data transfer, and marshalling between the
5//! two environments.
6
7use std::{collections::HashMap, sync::Arc};
8
9use anyhow::Result;
10use bytes::Bytes;
11use serde::{Serialize, de::DeserializeOwned};
12use tokio::sync::{RwLock, mpsc, oneshot};
13use tracing::{debug, instrument, warn};
14#[allow(unused_imports)]
15use wasmtime::{Caller, Extern, Func, Linker, Store};
16
17/// Host bridge error types
18#[derive(Debug, thiserror::Error)]
19pub enum BridgeError {
20	/// Function not found error
21	#[error("Function not found: {0}")]
22	FunctionNotFound(String),
23
24	/// Invalid function signature error
25	#[error("Invalid function signature: {0}")]
26	InvalidSignature(String),
27
28	/// Serialization failed error
29	#[error("Serialization failed: {0}")]
30	SerializationError(String),
31
32	/// Deserialization failed error
33	#[error("Deserialization failed: {0}")]
34	DeserializationError(String),
35
36	/// Host function error
37	#[error("Host function error: {0}")]
38	HostFunctionError(String),
39
40	/// Communication timeout error
41	#[error("Communication timeout")]
42	Timeout,
43
44	/// Bridge closed error
45	#[error("Bridge closed")]
46	BridgeClosed,
47}
48
49/// Type-safe result for operations
50pub type BridgeResult<T> = Result<T, BridgeError>;
51
52/// Function signature information
53#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
54pub struct FunctionSignature {
55	/// Function name
56	pub name:String,
57	/// Parameter types
58	pub param_types:Vec<ParamType>,
59	/// Return type
60	pub return_type:Option<ReturnType>,
61	/// Whether this is an async function
62	pub is_async:bool,
63}
64
65/// Parameter types for WASM functions
66#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
67pub enum ParamType {
68	/// 32-bit signed integer parameter
69	I32,
70	/// 64-bit signed integer parameter
71	I64,
72	/// 32-bit floating point parameter
73	F32,
74	/// 64-bit floating point parameter
75	F64,
76	/// Pointer to memory
77	Ptr,
78	/// Length parameter following a pointer
79	Len,
80}
81
82/// Return types for WASM functions
83#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
84pub enum ReturnType {
85	/// 32-bit signed integer return type
86	I32,
87	/// 64-bit signed integer return type
88	I64,
89	/// 32-bit floating point return type
90	F32,
91	/// 64-bit floating point return type
92	F64,
93	/// No return value (void)
94	Void,
95}
96
97/// Message sent from WASM to host
98#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
99pub struct HostMessage {
100	/// Message ID for correlation
101	pub message_id:String,
102	/// Function name to call
103	pub function:String,
104	/// Serialized arguments
105	pub args:Vec<Bytes>,
106	/// Callback token for async responses
107	pub callback_token:Option<u64>,
108}
109
110/// Response sent from host to WASM
111#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
112pub struct HostResponse {
113	/// Correlating message ID
114	pub message_id:String,
115	/// Success flag
116	pub success:bool,
117	/// Response data
118	pub data:Option<Bytes>,
119	/// Error message if failed
120	pub error:Option<String>,
121}
122
123/// Callback for async function responses
124#[derive(Clone)]
125pub struct AsyncCallback {
126	/// Sender for transmitting the response
127	sender:Arc<tokio::sync::Mutex<Option<tokio::sync::oneshot::Sender<HostResponse>>>>,
128	/// Message ID for correlation
129	message_id:String,
130}
131
132impl std::fmt::Debug for AsyncCallback {
133	fn fmt(&self, f:&mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134		f.debug_struct("AsyncCallback").field("message_id", &self.message_id).finish()
135	}
136}
137
138impl AsyncCallback {
139	/// Send response through the callback
140	pub async fn send(self, response:HostResponse) -> Result<()> {
141		let mut sender_opt = self.sender.lock().await;
142		if let Some(sender) = sender_opt.take() {
143			sender.send(response).map_err(|_| BridgeError::BridgeClosed)?;
144			Ok(())
145		} else {
146			Err(BridgeError::BridgeClosed.into())
147		}
148	}
149}
150
151/// Message from host to WASM
152#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
153pub struct WASMMessage {
154	/// Target function in WASM
155	pub function:String,
156	/// Arguments
157	pub args:Vec<Bytes>,
158}
159
160/// Host function callback type
161pub type HostFunctionCallback = fn(Vec<Bytes>) -> Result<Bytes>;
162
163/// Async host function callback type
164pub type AsyncHostFunctionCallback =
165	fn(Vec<Bytes>) -> Box<dyn std::future::Future<Output = Result<Bytes>> + Send + Unpin>;
166
167/// Host function definition
168#[derive(Debug)]
169pub struct HostFunction {
170	/// Function name
171	pub name:String,
172	/// Function signature
173	pub signature:FunctionSignature,
174	/// Synchronous callback - not serializable (skip serde derive)
175	#[allow(dead_code)]
176	pub callback:Option<HostFunctionCallback>,
177	/// Async callback - not serializable (skip serde derive)
178	#[allow(dead_code)]
179	pub async_callback:Option<AsyncHostFunctionCallback>,
180}
181
182/// Host Bridge for WASM communication
183#[derive(Debug)]
184pub struct HostBridgeImpl {
185	/// Registry of host functions exported to WASM
186	host_functions:Arc<RwLock<HashMap<String, HostFunction>>>,
187	/// Channel for receiving messages from WASM
188	wasm_to_host_rx:mpsc::UnboundedReceiver<WASMMessage>,
189	/// Channel for sending messages to WASM
190	host_to_wasm_tx:mpsc::UnboundedSender<WASMMessage>,
191	/// Active async callbacks
192	async_callbacks:Arc<RwLock<HashMap<u64, AsyncCallback>>>,
193	/// Next callback token
194	next_callback_token:Arc<std::sync::atomic::AtomicU64>,
195}
196
197impl HostBridgeImpl {
198	/// Create a new host bridge
199	pub fn new() -> Self {
200		let (_wasm_to_host_tx, wasm_to_host_rx) = mpsc::unbounded_channel();
201		let (host_to_wasm_tx, host_to_wasm_rx) = mpsc::unbounded_channel();
202
203		// In a real implementation, we'd need to wire these up properly
204		// For now, we drop the receiver to avoid unused warnings
205		drop(host_to_wasm_rx);
206
207		Self {
208			host_functions:Arc::new(RwLock::new(HashMap::new())),
209			wasm_to_host_rx,
210			host_to_wasm_tx,
211			async_callbacks:Arc::new(RwLock::new(HashMap::new())),
212			next_callback_token:Arc::new(std::sync::atomic::AtomicU64::new(0)),
213		}
214	}
215
216	/// Register a host function to be exported to WASM
217	#[instrument(skip(self, callback))]
218	pub async fn register_host_function(
219		&self,
220		name:&str,
221		signature:FunctionSignature,
222		callback:HostFunctionCallback,
223	) -> BridgeResult<()> {
224		debug!("Registering host function: {}", name);
225
226		let mut functions = self.host_functions.write().await;
227
228		if functions.contains_key(name) {
229			warn!("Host function already registered: {}", name);
230		}
231
232		functions.insert(
233			name.to_string(),
234			HostFunction { name:name.to_string(), signature, callback:Some(callback), async_callback:None },
235		);
236
237		debug!("Host function registered successfully: {}", name);
238		Ok(())
239	}
240
241	/// Register an async host function
242	#[instrument(skip(self, callback))]
243	pub async fn register_async_host_function(
244		&self,
245		name:&str,
246		signature:FunctionSignature,
247		callback:AsyncHostFunctionCallback,
248	) -> BridgeResult<()> {
249		debug!("Registering async host function: {}", name);
250
251		let mut functions = self.host_functions.write().await;
252
253		functions.insert(
254			name.to_string(),
255			HostFunction { name:name.to_string(), signature, callback:None, async_callback:Some(callback) },
256		);
257
258		debug!("Async host function registered successfully: {}", name);
259		Ok(())
260	}
261
262	/// Call a host function from WASM
263	#[instrument(skip(self, args))]
264	pub async fn call_host_function(&self, function_name:&str, args:Vec<Bytes>) -> BridgeResult<Bytes> {
265		debug!("Calling host function: {}", function_name);
266
267		let functions = self.host_functions.read().await;
268		let func = functions
269			.get(function_name)
270			.ok_or_else(|| BridgeError::FunctionNotFound(function_name.to_string()))?;
271
272		if let Some(callback) = func.callback {
273			// Synchronous call
274			let result =
275				callback(args).map_err(|e| BridgeError::HostFunctionError(format!("{}: {}", function_name, e)))?;
276			debug!("Host function call completed: {}", function_name);
277			Ok(result)
278		} else if let Some(async_callback) = func.async_callback {
279			// Async call
280			let future = async_callback(args);
281			let result = future
282				.await
283				.map_err(|e| BridgeError::HostFunctionError(format!("{}: {}", function_name, e)))?;
284			debug!("Async host function call completed: {}", function_name);
285			Ok(result)
286		} else {
287			Err(BridgeError::FunctionNotFound(format!(
288				"No callback for function: {}",
289				function_name
290			)))
291		}
292	}
293
294	/// Send a message to WASM
295	#[instrument(skip(self, message))]
296	pub async fn send_to_wasm(&self, message:WASMMessage) -> BridgeResult<()> {
297		let function_name = message.function.clone();
298		self.host_to_wasm_tx.send(message).map_err(|_| BridgeError::BridgeClosed)?;
299		debug!("Message sent to WASM: {}", function_name);
300		Ok(())
301	}
302
303	/// Receive a message from WASM (blocking)
304	pub async fn receive_from_wasm(&mut self) -> Option<WASMMessage> { self.wasm_to_host_rx.recv().await }
305
306	/// Create async callback
307	#[instrument(skip(self))]
308	pub async fn create_async_callback(&self, message_id:String) -> (AsyncCallback, u64) {
309		let token = self.next_callback_token.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
310		let (tx, _rx) = oneshot::channel();
311
312		// Create callback with Arc-wrapped sender
313		let callback = AsyncCallback {
314			sender:Arc::new(tokio::sync::Mutex::new(Some(tx))),
315			message_id:message_id.clone(),
316		};
317
318		self.async_callbacks.write().await.insert(token, callback.clone());
319
320		(callback, token)
321	}
322
323	/// Get callback by token
324	#[instrument(skip(self))]
325	pub async fn get_callback(&self, token:u64) -> Option<AsyncCallback> {
326		self.async_callbacks.write().await.remove(&token)
327	}
328
329	/// Get all registered host functions
330	pub async fn get_host_functions(&self) -> Vec<String> { self.host_functions.read().await.keys().cloned().collect() }
331
332	/// Unregister a host function
333	#[instrument(skip(self))]
334	pub async fn unregister_host_function(&self, name:&str) -> bool {
335		let mut functions = self.host_functions.write().await;
336		let removed = functions.remove(name).is_some();
337		if removed {
338			debug!("Host function unregistered: {}", name);
339		}
340		removed
341	}
342
343	/// Clear all registered functions
344	pub async fn clear(&self) {
345		debug!("Clearing all registered host functions");
346		self.host_functions.write().await.clear();
347		self.async_callbacks.write().await.clear();
348	}
349}
350
351impl Default for HostBridgeImpl {
352	fn default() -> Self { Self::new() }
353}
354
355/// Utility function to serialize data to Bytes
356pub fn serialize_to_bytes<T:Serialize>(data:&T) -> Result<Bytes> {
357	serde_json::to_vec(data)
358		.map(Bytes::from)
359		.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
360}
361
362/// Utility function to deserialize Bytes to data
363pub fn deserialize_from_bytes<T:DeserializeOwned>(bytes:&Bytes) -> Result<T> {
364	serde_json::from_slice(bytes).map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))
365}
366
367/// Marshal arguments for WASM function call
368pub fn marshal_args(args:Vec<Bytes>) -> Result<Vec<wasmtime::Val>> {
369	args.iter()
370		.map(|bytes| {
371			let value:serde_json::Value = serde_json::from_slice(bytes)?;
372			match value {
373				serde_json::Value::Number(n) => {
374					if let Some(i) = n.as_i64() {
375						Ok(wasmtime::Val::I32(i as i32))
376					} else if let Some(f) = n.as_f64() {
377						Ok(wasmtime::Val::F64(f.to_bits()))
378					} else {
379						Err(anyhow::anyhow!("Invalid number value"))
380					}
381				},
382				_ => Err(anyhow::anyhow!("Unsupported argument type")),
383			}
384		})
385		.collect()
386}
387
388/// Unmarshal return values from WASM function call
389pub fn unmarshal_return(val:wasmtime::Val) -> Result<Bytes> {
390	match val {
391		wasmtime::Val::I32(i) => {
392			let json = serde_json::to_string(&i)?;
393			Ok(Bytes::from(json))
394		},
395		wasmtime::Val::I64(i) => {
396			let json = serde_json::to_string(&i)?;
397			Ok(Bytes::from(json))
398		},
399		wasmtime::Val::F32(f) => {
400			let json = serde_json::to_string(&f)?;
401			Ok(Bytes::from(json))
402		},
403		wasmtime::Val::F64(f) => {
404			let json = serde_json::to_string(&f)?;
405			Ok(Bytes::from(json))
406		},
407		_ => Err(anyhow::anyhow!("Unsupported return type")),
408	}
409}
410
411#[cfg(test)]
412mod tests {
413	use super::*;
414
415	#[test]
416	fn test_function_signature_creation() {
417		let signature = FunctionSignature {
418			name:"test_func".to_string(),
419			param_types:vec![ParamType::I32, ParamType::Ptr],
420			return_type:Some(ReturnType::I32),
421			is_async:false,
422		};
423
424		assert_eq!(signature.name, "test_func");
425		assert_eq!(signature.param_types.len(), 2);
426	}
427
428	#[tokio::test]
429	async fn test_host_bridge_creation() {
430		let bridge = HostBridgeImpl::new();
431		assert_eq!(bridge.get_host_functions().await.len(), 0);
432	}
433
434	#[tokio::test]
435	async fn test_register_host_function() {
436		let bridge = HostBridgeImpl::new();
437
438		let signature = FunctionSignature {
439			name:"echo".to_string(),
440			param_types:vec![ParamType::I32],
441			return_type:Some(ReturnType::I32),
442			is_async:false,
443		};
444
445		let result = bridge
446			.register_host_function("echo", signature, |args| Ok(args[0].clone()))
447			.await;
448
449		assert!(result.is_ok());
450		assert_eq!(bridge.get_host_functions().await.len(), 1);
451	}
452
453	#[test]
454	fn test_serialize_deserialize() {
455		let data = vec![1, 2, 3, 4, 5];
456		let bytes = serialize_to_bytes(&data).unwrap();
457		let recovered:Vec<i32> = deserialize_from_bytes(&bytes).unwrap();
458		assert_eq!(data, recovered);
459	}
460
461	#[test]
462	fn test_marshal_unmarshal() {
463		let args = vec![serialize_to_bytes(&42i32).unwrap(), serialize_to_bytes(&3.14f64).unwrap()];
464
465		// Test that marshaling works (we don't assert on exact type conversion)
466		let marshaled = marshal_args(args);
467		assert!(marshaled.is_ok());
468	}
469}