1use std::{collections::HashMap, sync::Arc};
7
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use tokio::sync::RwLock;
11use tracing::{debug, instrument, warn};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct APICallRequest {
16 pub extension_id:String,
18 pub api_method:String,
20 pub arguments:Vec<serde_json::Value>,
22 pub correlation_id:Option<String>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct APICallResponse {
29 pub success:bool,
31 pub data:Option<serde_json::Value>,
33 pub error:Option<String>,
35 pub correlation_id:Option<String>,
37}
38
39#[allow(dead_code)]
41pub struct APICall {
42 extension_id:String,
44 api_method:String,
46 arguments:Vec<serde_json::Value>,
48 timestamp:u64,
50}
51
52#[allow(dead_code)]
54type APIMethodHandler = fn(&str, Vec<serde_json::Value>) -> Result<serde_json::Value>;
55
56#[allow(dead_code)]
58type AsyncAPIMethodHandler =
59 fn(&str, Vec<serde_json::Value>) -> Box<dyn std::future::Future<Output = Result<serde_json::Value>> + Send + Unpin>;
60
61#[derive(Clone)]
63pub struct APIMethodInfo {
64 #[allow(dead_code)]
66 name:String,
67 #[allow(dead_code)]
69 description:String,
70 #[allow(dead_code)]
72 parameters:Option<serde_json::Value>,
73 #[allow(dead_code)]
75 returns:Option<serde_json::Value>,
76 #[allow(dead_code)]
78 is_async:bool,
79 call_count:u64,
81 total_time_us:u64,
83}
84
85pub struct APIBridgeImpl {
87 api_methods:Arc<RwLock<HashMap<String, APIMethodInfo>>>,
89 stats:Arc<RwLock<APIStats>>,
91 contexts:Arc<RwLock<HashMap<String, APIContext>>>,
93}
94
95#[derive(Debug, Clone, Default, Serialize, Deserialize)]
97pub struct APIStats {
98 pub total_calls:u64,
100 pub successful_calls:u64,
102 pub failed_calls:u64,
104 pub avg_latency_us:u64,
106 pub active_contexts:usize,
108}
109
110#[derive(Debug, Clone)]
112pub struct APIContext {
113 pub extension_id:String,
115 pub context_id:String,
117 pub workspace_folder:Option<String>,
119 pub active_editor:Option<String>,
121 pub selections:Vec<Selection>,
123 pub created_at:u64,
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct Selection {
130 pub start_line:u32,
132 pub start_character:u32,
134 pub end_line:u32,
136 pub end_character:u32,
138}
139
140impl Default for Selection {
141 fn default() -> Self { Self { start_line:0, start_character:0, end_line:0, end_character:0 } }
142}
143
144impl APIBridgeImpl {
145 pub fn new() -> Self {
147 let bridge = Self {
148 api_methods:Arc::new(RwLock::new(HashMap::new())),
149 stats:Arc::new(RwLock::new(APIStats::default())),
150 contexts:Arc::new(RwLock::new(HashMap::new())),
151 };
152
153 bridge.register_builtin_methods();
154
155 bridge
156 }
157
158 fn register_builtin_methods(&self) {
160 debug!("Registered built-in VS Code API methods");
169 }
170
171 pub async fn register_method(
173 &self,
174 name:&str,
175 description:&str,
176 parameters:Option<serde_json::Value>,
177 returns:Option<serde_json::Value>,
178 is_async:bool,
179 ) -> Result<()> {
180 let mut methods = self.api_methods.write().await;
181
182 if methods.contains_key(name) {
183 warn!("API method already registered: {}", name);
184 }
185
186 methods.insert(
187 name.to_string(),
188 APIMethodInfo {
189 name:name.to_string(),
190 description:description.to_string(),
191 parameters,
192 returns,
193 is_async,
194 call_count:0,
195 total_time_us:0,
196 },
197 );
198
199 debug!("Registered API method: {}", name);
200
201 Ok(())
202 }
203
204 #[instrument(skip(self))]
206 pub async fn create_context(&self, extension_id:&str) -> Result<APIContext> {
207 let context_id = format!("{}-{}", extension_id, uuid::Uuid::new_v4());
208
209 let context = APIContext {
210 extension_id:extension_id.to_string(),
211 context_id:context_id.clone(),
212 workspace_folder:None,
213 active_editor:None,
214 selections:Vec::new(),
215 created_at:std::time::SystemTime::now()
216 .duration_since(std::time::UNIX_EPOCH)
217 .map(|d| d.as_secs())
218 .unwrap_or(0),
219 };
220
221 let mut contexts = self.contexts.write().await;
222 contexts.insert(context_id.clone(), context.clone());
223
224 let mut stats = self.stats.write().await;
226 stats.active_contexts = contexts.len();
227
228 debug!("Created API context for extension: {}", extension_id);
229
230 Ok(context)
231 }
232
233 pub async fn get_context(&self, context_id:&str) -> Option<APIContext> {
235 self.contexts.read().await.get(context_id).cloned()
236 }
237
238 pub async fn update_context(&self, context:APIContext) -> Result<()> {
240 let mut contexts = self.contexts.write().await;
241 contexts.insert(context.context_id.clone(), context);
242 Ok(())
243 }
244
245 pub async fn remove_context(&self, context_id:&str) -> Result<bool> {
247 let mut contexts = self.contexts.write().await;
248 let removed = contexts.remove(context_id).is_some();
249
250 if removed {
251 let mut stats = self.stats.write().await;
252 stats.active_contexts = contexts.len();
253 }
254
255 Ok(removed)
256 }
257
258 #[instrument(skip(self, request))]
260 pub async fn handle_call(&self, request:APICallRequest) -> Result<APICallResponse> {
261 let start = std::time::Instant::now();
262
263 debug!(
264 "Handling API call: {} from extension {}",
265 request.api_method, request.extension_id
266 );
267
268 let exists = {
270 let methods = self.api_methods.read().await;
271 methods.contains_key(&request.api_method)
272 };
273
274 if !exists {
275 return Ok(APICallResponse {
276 success:false,
277 data:None,
278 error:Some(format!("API method not found: {}", request.api_method)),
279 correlation_id:request.correlation_id,
280 });
281 }
282
283 let result = self
286 .execute_method(&request.extension_id, &request.api_method, &request.arguments)
287 .await;
288
289 let elapsed_us = start.elapsed().as_micros() as u64;
290
291 let mut stats = self.stats.write().await;
293 stats.total_calls += 1;
294 stats.total_calls += 1;
295 if exists {
296 stats.successful_calls += 1;
297 stats.avg_latency_us =
299 (stats.avg_latency_us * (stats.successful_calls - 1) + elapsed_us) / stats.successful_calls;
300 }
301
302 {
304 let mut methods = self.api_methods.write().await;
305 if let Some(method) = methods.get_mut(&request.api_method) {
306 method.call_count += 1;
307 method.total_time_us += elapsed_us;
308 }
309 }
310
311 debug!("API call {} completed in {}µs", request.api_method, elapsed_us);
312
313 match result {
314 Ok(data) => {
315 Ok(
316 APICallResponse {
317 success:true,
318 data:Some(data),
319 error:None,
320 correlation_id:request.correlation_id,
321 },
322 )
323 },
324 Err(e) => {
325 Ok(APICallResponse {
326 success:false,
327 data:None,
328 error:Some(e.to_string()),
329 correlation_id:request.correlation_id,
330 })
331 },
332 }
333 }
334
335 async fn execute_method(
337 &self,
338 _extension_id:&str,
339 _method_name:&str,
340 _arguments:&[serde_json::Value],
341 ) -> Result<serde_json::Value> {
342 Ok(serde_json::Value::Null)
351 }
352
353 pub async fn stats(&self) -> APIStats { self.stats.read().await.clone() }
355
356 pub async fn get_methods(&self) -> Vec<APIMethodInfo> { self.api_methods.read().await.values().cloned().collect() }
358
359 pub async fn unregister_method(&self, name:&str) -> Result<bool> {
361 let mut methods = self.api_methods.write().await;
362 let removed = methods.remove(name).is_some();
363
364 if removed {
365 debug!("Unregistered API method: {}", name);
366 }
367
368 Ok(removed)
369 }
370}
371
372impl Default for APIBridgeImpl {
373 fn default() -> Self { Self::new() }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379
380 #[tokio::test]
381 async fn test_api_bridge_creation() {
382 let bridge = APIBridgeImpl::new();
383 let stats = bridge.stats().await;
384 assert_eq!(stats.total_calls, 0);
385 assert_eq!(stats.successful_calls, 0);
386 }
387
388 #[tokio::test]
389 async fn test_context_creation() {
390 let bridge = APIBridgeImpl::new();
391 let context = bridge.create_context("test.ext").await.unwrap();
392 assert_eq!(context.extension_id, "test.ext");
393 assert!(!context.context_id.is_empty());
394 }
395
396 #[tokio::test]
397 async fn test_method_registration() {
398 let bridge = APIBridgeImpl::new();
399 let result: Result<()> = bridge.register_method("test.method", "Test method", None, None, false).await;
400 assert!(result.is_ok());
401
402 let methods: Vec<APIMethodInfo> = bridge.get_methods().await;
403 assert!(methods.iter().any(|m| m.name == "test.method"));
404 }
405
406 #[tokio::test]
407 async fn test_api_call_request() {
408 let request = APICallRequest {
409 extension_id:"test.ext".to_string(),
410 api_method:"test.method".to_string(),
411 arguments:vec![serde_json::json!("arg1")],
412 correlation_id:Some("test-id".to_string()),
413 };
414
415 assert_eq!(request.extension_id, "test.ext");
416 assert_eq!(request.api_method, "test.method");
417 assert_eq!(request.arguments.len(), 1);
418 }
419
420 #[test]
421 fn test_selection_default() {
422 let selection = Selection::default();
423 assert_eq!(selection.start_line, 0);
424 assert_eq!(selection.end_line, 0);
425 }
426}