Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 52 additions & 3 deletions crates/rmcp-macros/src/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ pub struct ToolAttribute {
pub icons: Option<Expr>,
/// Optional metadata for the tool
pub meta: Option<Expr>,
/// Whether the generated future should be `Send`. Defaults to `true`.
/// Set to `false` for tools that hold non-Send state (e.g., `Rc`, `RefCell`).
/// Note: tools with `send = false` are incompatible with the built-in tool router.
pub send: Option<bool>,
}

#[derive(FromMeta, Debug, Default)]
Expand Down Expand Up @@ -330,9 +334,10 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
};
let tool_attr_fn = resolved_tool_attr.into_fn(tool_attr_fn_ident)?;
// modify the the input function
let is_send = attribute.send.unwrap_or(true);
if fn_item.sig.asyncness.is_some() {
// 1. remove asyncness from sig
// 2. make return type: `std::pin::Pin<Box<dyn std::future::Future<Output = #ReturnType> + Send + '_>>`
// 2. make return type: `std::pin::Pin<Box<dyn std::future::Future<Output = #ReturnType> (+ Send)? + '_>>`
// 3. make body: { Box::pin(async move { #body }) }
let new_output = syn::parse2::<ReturnType>({
let mut lt = quote! { 'static };
Expand All @@ -345,12 +350,17 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
}
}
}
let send_bound = if is_send {
quote! { Send + }
} else {
quote! {}
};
match &fn_item.sig.output {
syn::ReturnType::Default => {
quote! { -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = ()> + Send + #lt>> }
quote! { -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = ()> + #send_bound #lt>> }
}
syn::ReturnType::Type(_, ty) => {
quote! { -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = #ty> + Send + #lt>> }
quote! { -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = #ty> + #send_bound #lt>> }
}
}
})?;
Expand Down Expand Up @@ -443,4 +453,43 @@ mod test {
assert!(result_str.contains("include_str"));
Ok(())
}

#[test]
fn async_tool_future_includes_send_bound_by_default() -> syn::Result<()> {
let attr = quote! {};
let input = quote! {
async fn my_tool(&self) -> String {
"hello".to_string()
}
};
let result = tool(attr, input)?;
assert!(result.to_string().contains("Send"));
Ok(())
}

#[test]
fn async_tool_future_omits_send_bound_when_send_is_false() -> syn::Result<()> {
let attr = quote! { send = false };
let input = quote! {
async fn my_tool(&self) -> String {
"hello".to_string()
}
};
let result = tool(attr, input)?;
assert!(!result.to_string().contains("Send"));
Ok(())
}

#[test]
fn async_tool_future_includes_send_bound_when_send_is_true() -> syn::Result<()> {
let attr = quote! { send = true };
let input = quote! {
async fn my_tool(&self) -> String {
"hello".to_string()
}
};
let result = tool(attr, input)?;
assert!(result.to_string().contains("Send"));
Ok(())
}
}