diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 6ac4b02c0..f01ca58df 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -62,6 +62,11 @@ uuid = { version = "1", features = ["v4"], optional = true } http-body = { version = "1", optional = true } http-body-util = { version = "0.1", optional = true } bytes = { version = "1", optional = true } + +# for unix socket transport +hyper = { version = "1", features = ["client", "http1"], optional = true } +hyper-util = { version = "0.1", features = ["tokio"], optional = true } + # macro rmcp-macros = { workspace = true, optional = true } [target.'cfg(not(all(target_family = "wasm", target_os = "unknown")))'.dependencies] @@ -111,6 +116,15 @@ client-side-sse = ["dep:sse-stream", "dep:http"] # Streamable HTTP client transport-streamable-http-client = ["client-side-sse", "transport-worker"] transport-streamable-http-client-reqwest = ["transport-streamable-http-client", "__reqwest"] +transport-streamable-http-client-unix-socket = [ + "transport-streamable-http-client", + "dep:hyper", + "dep:hyper-util", + "dep:http-body-util", + "dep:http", + "dep:bytes", + "tokio/net", +] transport-async-rw = ["tokio/io-util", "tokio-util/codec"] transport-io = ["transport-async-rw", "tokio/io-std"] @@ -259,3 +273,12 @@ path = "tests/test_sse_concurrent_streams.rs" name = "test_client_credentials" required-features = ["auth"] path = "tests/test_client_credentials.rs" + +[[test]] +name = "test_unix_socket_transport" +required-features = [ + "client", + "server", + "transport-streamable-http-client-unix-socket", +] +path = "tests/test_unix_socket_transport.rs" diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index 683f6880f..e9b076fd2 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -112,6 +112,8 @@ pub use streamable_http_server::tower::{StreamableHttpServerConfig, StreamableHt #[cfg(feature = "transport-streamable-http-client")] pub mod streamable_http_client; +#[cfg(all(unix, feature = "transport-streamable-http-client-unix-socket"))] +pub use common::unix_socket::UnixSocketHttpClient; #[cfg(feature = "transport-streamable-http-client")] pub use streamable_http_client::StreamableHttpClientTransport; diff --git a/crates/rmcp/src/transport/common.rs b/crates/rmcp/src/transport/common.rs index 615b0e273..3691602b1 100644 --- a/crates/rmcp/src/transport/common.rs +++ b/crates/rmcp/src/transport/common.rs @@ -14,3 +14,6 @@ pub mod client_side_sse; #[cfg(feature = "auth")] pub mod auth; + +#[cfg(all(unix, feature = "transport-streamable-http-client-unix-socket"))] +pub mod unix_socket; diff --git a/crates/rmcp/src/transport/common/http_header.rs b/crates/rmcp/src/transport/common/http_header.rs index 441753260..196d96fff 100644 --- a/crates/rmcp/src/transport/common/http_header.rs +++ b/crates/rmcp/src/transport/common/http_header.rs @@ -3,3 +3,122 @@ pub const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id"; pub const HEADER_MCP_PROTOCOL_VERSION: &str = "MCP-Protocol-Version"; pub const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream"; pub const JSON_MIME_TYPE: &str = "application/json"; + +/// Reserved headers that must not be overridden by user-supplied custom headers. +/// `MCP-Protocol-Version` is in this list but is allowed through because the worker +/// injects it after initialization. +pub(crate) const RESERVED_HEADERS: &[&str] = &[ + "accept", + HEADER_SESSION_ID, + HEADER_MCP_PROTOCOL_VERSION, // allowed through by validate_custom_header; worker injects it post-init + HEADER_LAST_EVENT_ID, +]; + +/// Checks whether a custom header name is allowed. +/// Returns `Ok(())` if allowed, `Err(name)` if rejected as reserved. +/// `MCP-Protocol-Version` is reserved but allowed through (the worker injects it post-init). +#[cfg(feature = "client-side-sse")] +pub(crate) fn validate_custom_header(name: &http::HeaderName) -> Result<(), String> { + if RESERVED_HEADERS + .iter() + .any(|&r| name.as_str().eq_ignore_ascii_case(r)) + { + if name + .as_str() + .eq_ignore_ascii_case(HEADER_MCP_PROTOCOL_VERSION) + { + return Ok(()); + } + return Err(name.to_string()); + } + Ok(()) +} + +/// Extracts the `scope=` parameter from a `WWW-Authenticate` header value. +/// Handles both quoted (`scope="files:read files:write"`) and unquoted (`scope=read:data`) forms. +pub(crate) fn extract_scope_from_header(header: &str) -> Option { + let header_lowercase = header.to_ascii_lowercase(); + let scope_key = "scope="; + + if let Some(pos) = header_lowercase.find(scope_key) { + let start = pos + scope_key.len(); + let value_slice = &header[start..]; + + if let Some(stripped) = value_slice.strip_prefix('"') { + if let Some(end_quote) = stripped.find('"') { + return Some(stripped[..end_quote].to_string()); + } + } else { + let end = value_slice + .find(|c: char| c == ',' || c == ';' || c.is_whitespace()) + .unwrap_or(value_slice.len()); + if end > 0 { + return Some(value_slice[..end].to_string()); + } + } + } + + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn extract_scope_quoted() { + let header = r#"Bearer error="insufficient_scope", scope="files:read files:write""#; + assert_eq!( + extract_scope_from_header(header), + Some("files:read files:write".to_string()) + ); + } + + #[test] + fn extract_scope_unquoted() { + let header = r#"Bearer scope=read:data, error="insufficient_scope""#; + assert_eq!( + extract_scope_from_header(header), + Some("read:data".to_string()) + ); + } + + #[test] + fn extract_scope_missing() { + let header = r#"Bearer error="invalid_token""#; + assert_eq!(extract_scope_from_header(header), None); + } + + #[test] + fn extract_scope_empty_header() { + assert_eq!(extract_scope_from_header("Bearer"), None); + } + + #[cfg(feature = "client-side-sse")] + #[test] + fn validate_rejects_reserved_accept() { + let name = http::HeaderName::from_static("accept"); + assert!(validate_custom_header(&name).is_err()); + } + + #[cfg(feature = "client-side-sse")] + #[test] + fn validate_rejects_reserved_session_id() { + let name = http::HeaderName::from_static("mcp-session-id"); + assert!(validate_custom_header(&name).is_err()); + } + + #[cfg(feature = "client-side-sse")] + #[test] + fn validate_allows_mcp_protocol_version() { + let name = http::HeaderName::from_static("mcp-protocol-version"); + assert!(validate_custom_header(&name).is_ok()); + } + + #[cfg(feature = "client-side-sse")] + #[test] + fn validate_allows_custom_header() { + let name = http::HeaderName::from_static("x-custom"); + assert!(validate_custom_header(&name).is_ok()); + } +} diff --git a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs index 8fca86fbc..bcc2a69c7 100644 --- a/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs +++ b/crates/rmcp/src/transport/common/reqwest/streamable_http_client.rs @@ -9,8 +9,8 @@ use crate::{ model::{ClientJsonRpcMessage, ServerJsonRpcMessage}, transport::{ common::http_header::{ - EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_MCP_PROTOCOL_VERSION, - HEADER_SESSION_ID, JSON_MIME_TYPE, + EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, JSON_MIME_TYPE, + extract_scope_from_header, validate_custom_header, }, streamable_http_client::*, }, @@ -22,38 +22,13 @@ impl From for StreamableHttpError { } } -/// Reserved headers that must not be overridden by user-supplied custom headers. -/// `MCP-Protocol-Version` is in this list but is allowed through because the worker -/// injects it after initialization. -const RESERVED_HEADERS: &[&str] = &[ - "accept", - HEADER_SESSION_ID, - HEADER_MCP_PROTOCOL_VERSION, - HEADER_LAST_EVENT_ID, -]; - -/// Applies custom headers to a request builder, rejecting reserved headers -/// except `MCP-Protocol-Version` (which the worker injects after init). +/// Applies custom headers to a request builder, rejecting reserved headers. fn apply_custom_headers( mut builder: reqwest::RequestBuilder, custom_headers: HashMap, ) -> Result> { for (name, value) in custom_headers { - if RESERVED_HEADERS - .iter() - .any(|&r| name.as_str().eq_ignore_ascii_case(r)) - { - if name - .as_str() - .eq_ignore_ascii_case(HEADER_MCP_PROTOCOL_VERSION) - { - builder = builder.header(name, value); - continue; - } - return Err(StreamableHttpError::ReservedHeaderConflict( - name.to_string(), - )); - } + validate_custom_header(&name).map_err(StreamableHttpError::ReservedHeaderConflict)?; builder = builder.header(name, value); } Ok(builder) @@ -280,66 +255,10 @@ impl StreamableHttpClientTransport { } } -/// extract scope parameter from WWW-Authenticate header -fn extract_scope_from_header(header: &str) -> Option { - let header_lowercase = header.to_ascii_lowercase(); - let scope_key = "scope="; - - if let Some(pos) = header_lowercase.find(scope_key) { - let start = pos + scope_key.len(); - let value_slice = &header[start..]; - - if let Some(stripped) = value_slice.strip_prefix('"') { - if let Some(end_quote) = stripped.find('"') { - return Some(stripped[..end_quote].to_string()); - } - } else { - let end = value_slice - .find(|c: char| c == ',' || c == ';' || c.is_whitespace()) - .unwrap_or(value_slice.len()); - if end > 0 { - return Some(value_slice[..end].to_string()); - } - } - } - - None -} - #[cfg(test)] mod tests { - use super::extract_scope_from_header; use crate::transport::streamable_http_client::InsufficientScopeError; - #[test] - fn extract_scope_quoted() { - let header = r#"Bearer error="insufficient_scope", scope="files:read files:write""#; - assert_eq!( - extract_scope_from_header(header), - Some("files:read files:write".to_string()) - ); - } - - #[test] - fn extract_scope_unquoted() { - let header = r#"Bearer scope=read:data, error="insufficient_scope""#; - assert_eq!( - extract_scope_from_header(header), - Some("read:data".to_string()) - ); - } - - #[test] - fn extract_scope_missing() { - let header = r#"Bearer error="invalid_token""#; - assert_eq!(extract_scope_from_header(header), None); - } - - #[test] - fn extract_scope_empty_header() { - assert_eq!(extract_scope_from_header("Bearer"), None); - } - #[test] fn insufficient_scope_error_can_upgrade() { let with_scope = InsufficientScopeError { diff --git a/crates/rmcp/src/transport/common/unix_socket.rs b/crates/rmcp/src/transport/common/unix_socket.rs new file mode 100644 index 000000000..3af987973 --- /dev/null +++ b/crates/rmcp/src/transport/common/unix_socket.rs @@ -0,0 +1,545 @@ +use std::{borrow::Cow, collections::HashMap, sync::Arc}; + +use bytes::Bytes; +use futures::{StreamExt, stream::BoxStream}; +use http::{HeaderName, HeaderValue, Method, Request, StatusCode, header::WWW_AUTHENTICATE}; +use http_body_util::{BodyExt, Full}; +use hyper::body::Incoming; +use hyper_util::rt::TokioIo; +use sse_stream::{Sse, SseStream}; +use tokio::net::UnixStream; + +use crate::{ + model::{ClientJsonRpcMessage, ServerJsonRpcMessage}, + transport::{ + common::http_header::{ + EVENT_STREAM_MIME_TYPE, HEADER_LAST_EVENT_ID, HEADER_SESSION_ID, JSON_MIME_TYPE, + extract_scope_from_header, validate_custom_header, + }, + streamable_http_client::*, + }, +}; + +#[derive(Debug, thiserror::Error)] +pub enum UnixSocketError { + #[error("hyper error: {0}")] + Hyper(#[from] hyper::Error), + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + #[error("HTTP error: {0}")] + Http(#[from] http::Error), + #[error("JSON error: {0}")] + Json(#[from] serde_json::Error), +} + +impl From for StreamableHttpError { + fn from(e: UnixSocketError) -> Self { + StreamableHttpError::Client(e) + } +} + +/// HTTP client that routes requests through a Unix domain socket. +/// +/// Implements [`StreamableHttpClient`] using `hyper` over `tokio::net::UnixStream`, +/// enabling MCP hosts in Kubernetes environments to connect through Envoy sidecars +/// or other Unix socket-based proxies. +/// +/// Each request opens a new Unix socket connection (no connection pooling). +/// This is appropriate when connecting through a sidecar proxy that manages +/// its own upstream connection pool. +/// +/// # Example +/// +/// ```rust,no_run +/// use rmcp::transport::{StreamableHttpClientTransport, UnixSocketHttpClient}; +/// use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; +/// +/// let client = UnixSocketHttpClient::new("/var/run/envoy.sock", "http://mcp-server.internal/mcp"); +/// let config = StreamableHttpClientTransportConfig::with_uri("http://mcp-server.internal/mcp"); +/// let transport = StreamableHttpClientTransport::with_client(client, config); +/// ``` +#[derive(Clone, Debug)] +pub struct UnixSocketHttpClient { + socket_path: Arc, + host_header: HeaderValue, +} + +impl UnixSocketHttpClient { + /// Creates a new Unix socket HTTP client. + /// + /// # Arguments + /// + /// * `socket_path` - Path to the Unix domain socket. Use `@name` syntax for Linux + /// abstract sockets (e.g., `@egress.sock` becomes `\0egress.sock`). + /// * `uri` - The MCP server URI. The authority (host:port) is extracted for the + /// HTTP `Host` header, since hyper does not auto-set it for Unix socket connections. + /// + /// # Panics + /// + /// Panics if `socket_path` is empty or is `@` with no name (empty abstract socket). + pub fn new(socket_path: &str, uri: &str) -> Self { + assert!( + !socket_path.is_empty() && socket_path != "@", + "socket_path must not be empty or a bare '@' (empty abstract socket name)" + ); + + let host_header = uri + .parse::() + .ok() + .and_then(|u| u.authority().cloned()) + .and_then(|a| HeaderValue::from_str(a.as_str()).ok()) + .unwrap_or_else(|| HeaderValue::from_static("localhost")); + + Self { + socket_path: resolve_socket_path(socket_path).into(), + host_header, + } + } +} + +/// Converts the `@`-prefixed abstract socket notation to the null-byte prefix +/// expected by the Linux kernel. Filesystem socket paths are returned unchanged. +fn resolve_socket_path(raw: &str) -> String { + if let Some(name) = raw.strip_prefix('@') { + format!("\0{name}") + } else { + raw.to_string() + } +} + +async fn connect_unix(socket_path: &str) -> Result { + #[cfg(target_os = "linux")] + if let Some(abstract_name) = socket_path.strip_prefix('\0') { + let abstract_name = abstract_name.to_string(); + let std_stream = tokio::task::spawn_blocking(move || { + use std::os::linux::net::SocketAddrExt; + let addr = std::os::unix::net::SocketAddr::from_abstract_name(&abstract_name)?; + let stream = std::os::unix::net::UnixStream::connect_addr(&addr)?; + stream.set_nonblocking(true)?; + Ok::<_, std::io::Error>(stream) + }) + .await + .map_err(std::io::Error::other)??; + return UnixStream::from_std(std_stream); + } + + UnixStream::connect(socket_path).await +} + +/// Opens a new Unix socket connection and sends the HTTP request. +/// One connection per request — the sidecar proxy handles connection pooling. +async fn send_http_request( + socket_path: &str, + request: Request>, +) -> Result, UnixSocketError> { + let stream = connect_unix(socket_path).await?; + let io = TokioIo::new(stream); + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; + + tokio::spawn(async move { + if let Err(e) = conn.await { + tracing::warn!("unix socket HTTP/1.1 connection error: {e}"); + } + }); + + Ok(sender.send_request(request).await?) +} + +/// Applies custom headers to a request builder, rejecting reserved headers. +fn apply_custom_headers( + mut builder: http::request::Builder, + custom_headers: HashMap, +) -> Result> { + for (name, value) in custom_headers { + validate_custom_header(&name).map_err(StreamableHttpError::ReservedHeaderConflict)?; + builder = builder.header(name, value); + } + Ok(builder) +} + +impl StreamableHttpClient for UnixSocketHttpClient { + type Error = UnixSocketError; + + async fn post_message( + &self, + uri: Arc, + message: ClientJsonRpcMessage, + session_id: Option>, + auth_token: Option, + custom_headers: HashMap, + ) -> Result> { + let json_body = serde_json::to_string(&message) + .map_err(|e| StreamableHttpError::Client(UnixSocketError::Json(e)))?; + + let mut builder = Request::builder() + .method(Method::POST) + .uri(uri.as_ref()) + .header(http::header::HOST, self.host_header.clone()) + .header(http::header::CONTENT_TYPE, JSON_MIME_TYPE) + .header( + http::header::ACCEPT, + format!("{EVENT_STREAM_MIME_TYPE}, {JSON_MIME_TYPE}"), + ); + + if let Some(auth) = auth_token { + builder = builder.header(http::header::AUTHORIZATION, format!("Bearer {auth}")); + } + + builder = apply_custom_headers(builder, custom_headers)?; + + let session_was_attached = session_id.is_some(); + if let Some(sid) = session_id { + builder = builder.header(HEADER_SESSION_ID, sid.as_ref()); + } + + let request = builder + .body(Full::new(Bytes::from(json_body))) + .map_err(|e| StreamableHttpError::Client(UnixSocketError::Http(e)))?; + + let response = send_http_request(&self.socket_path, request) + .await + .map_err(StreamableHttpError::Client)?; + + let status = response.status(); + + if status == StatusCode::UNAUTHORIZED { + if let Some(header) = response.headers().get(WWW_AUTHENTICATE) { + let www_authenticate_header = header + .to_str() + .map_err(|_| { + StreamableHttpError::UnexpectedServerResponse(Cow::from( + "invalid www-authenticate header value", + )) + })? + .to_string(); + return Err(StreamableHttpError::AuthRequired(AuthRequiredError { + www_authenticate_header, + })); + } + } + + if status == StatusCode::FORBIDDEN { + if let Some(header) = response.headers().get(WWW_AUTHENTICATE) { + let header_str = header.to_str().map_err(|_| { + StreamableHttpError::UnexpectedServerResponse(Cow::from( + "invalid www-authenticate header value", + )) + })?; + let scope = extract_scope_from_header(header_str); + return Err(StreamableHttpError::InsufficientScope( + InsufficientScopeError { + www_authenticate_header: header_str.to_string(), + required_scope: scope, + }, + )); + } + } + + if matches!(status, StatusCode::ACCEPTED | StatusCode::NO_CONTENT) { + return Ok(StreamableHttpPostResponse::Accepted); + } + + if status == StatusCode::NOT_FOUND && session_was_attached { + return Err(StreamableHttpError::SessionExpired); + } + + if !status.is_success() { + let body = response + .into_body() + .collect() + .await + .map(|c| String::from_utf8_lossy(&c.to_bytes()).into_owned()) + .unwrap_or_else(|_| "".to_owned()); + return Err(StreamableHttpError::UnexpectedServerResponse(Cow::Owned( + format!("HTTP {status}: {body}"), + ))); + } + + let content_type = response.headers().get(http::header::CONTENT_TYPE).cloned(); + let session_id = response + .headers() + .get(HEADER_SESSION_ID) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + + match content_type { + Some(ref ct) if ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) => { + let sse_stream = SseStream::new(response.into_body()).boxed(); + Ok(StreamableHttpPostResponse::Sse(sse_stream, session_id)) + } + Some(ref ct) if ct.as_bytes().starts_with(JSON_MIME_TYPE.as_bytes()) => { + let body = response + .into_body() + .collect() + .await + .map_err(|e| StreamableHttpError::Client(UnixSocketError::Hyper(e)))? + .to_bytes(); + match serde_json::from_slice::(&body) { + Ok(message) => Ok(StreamableHttpPostResponse::Json(message, session_id)), + Err(e) => { + tracing::warn!( + "could not parse JSON response as ServerJsonRpcMessage, treating as accepted: {e}" + ); + Ok(StreamableHttpPostResponse::Accepted) + } + } + } + _ => Err(StreamableHttpError::UnexpectedContentType( + content_type.map(|ct| String::from_utf8_lossy(ct.as_bytes()).into_owned()), + )), + } + } + + async fn delete_session( + &self, + uri: Arc, + session_id: Arc, + auth_token: Option, + custom_headers: HashMap, + ) -> Result<(), StreamableHttpError> { + let mut builder = Request::builder() + .method(Method::DELETE) + .uri(uri.as_ref()) + .header(http::header::HOST, self.host_header.clone()) + .header(HEADER_SESSION_ID, session_id.as_ref()); + + if let Some(auth) = auth_token { + builder = builder.header(http::header::AUTHORIZATION, format!("Bearer {auth}")); + } + + builder = apply_custom_headers(builder, custom_headers)?; + + let request = builder + .body(Full::new(Bytes::new())) + .map_err(|e| StreamableHttpError::Client(UnixSocketError::Http(e)))?; + + let response = send_http_request(&self.socket_path, request) + .await + .map_err(StreamableHttpError::Client)?; + + if response.status() == StatusCode::METHOD_NOT_ALLOWED { + tracing::debug!("this server doesn't support deleting session"); + return Ok(()); + } + + if !response.status().is_success() { + return Err(StreamableHttpError::UnexpectedServerResponse(Cow::Owned( + format!("delete_session returned {}", response.status()), + ))); + } + + Ok(()) + } + + async fn get_stream( + &self, + uri: Arc, + session_id: Arc, + last_event_id: Option, + auth_token: Option, + custom_headers: HashMap, + ) -> Result>, StreamableHttpError> + { + let mut builder = Request::builder() + .method(Method::GET) + .uri(uri.as_ref()) + .header(http::header::HOST, self.host_header.clone()) + .header( + http::header::ACCEPT, + format!("{EVENT_STREAM_MIME_TYPE}, {JSON_MIME_TYPE}"), + ) + .header(HEADER_SESSION_ID, session_id.as_ref()); + + if let Some(last_id) = last_event_id { + builder = builder.header(HEADER_LAST_EVENT_ID, last_id); + } + + if let Some(auth) = auth_token { + builder = builder.header(http::header::AUTHORIZATION, format!("Bearer {auth}")); + } + + builder = apply_custom_headers(builder, custom_headers)?; + + let request = builder + .body(Full::new(Bytes::new())) + .map_err(|e| StreamableHttpError::Client(UnixSocketError::Http(e)))?; + + let response = send_http_request(&self.socket_path, request) + .await + .map_err(StreamableHttpError::Client)?; + + if response.status() == StatusCode::METHOD_NOT_ALLOWED { + return Err(StreamableHttpError::ServerDoesNotSupportSse); + } + + if response.status() == StatusCode::UNAUTHORIZED { + if let Some(header) = response.headers().get(WWW_AUTHENTICATE) { + let www_authenticate_header = header + .to_str() + .map_err(|_| { + StreamableHttpError::UnexpectedServerResponse(Cow::from( + "invalid www-authenticate header value", + )) + })? + .to_string(); + return Err(StreamableHttpError::AuthRequired(AuthRequiredError { + www_authenticate_header, + })); + } + } + + if response.status() == StatusCode::FORBIDDEN { + if let Some(header) = response.headers().get(WWW_AUTHENTICATE) { + let header_str = header.to_str().map_err(|_| { + StreamableHttpError::UnexpectedServerResponse(Cow::from( + "invalid www-authenticate header value", + )) + })?; + let scope = extract_scope_from_header(header_str); + return Err(StreamableHttpError::InsufficientScope( + InsufficientScopeError { + www_authenticate_header: header_str.to_string(), + required_scope: scope, + }, + )); + } + } + + if !response.status().is_success() { + return Err(StreamableHttpError::UnexpectedServerResponse(Cow::Owned( + format!("get_stream returned {}", response.status()), + ))); + } + + match response.headers().get(http::header::CONTENT_TYPE) { + Some(ct) => { + if !ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) + && !ct.as_bytes().starts_with(JSON_MIME_TYPE.as_bytes()) + { + return Err(StreamableHttpError::UnexpectedContentType(Some( + String::from_utf8_lossy(ct.as_bytes()).to_string(), + ))); + } + } + None => { + return Err(StreamableHttpError::UnexpectedContentType(None)); + } + } + + Ok(SseStream::new(response.into_body()).boxed()) + } +} + +impl StreamableHttpClientTransport { + /// Creates a new transport connecting through a Unix domain socket. + /// + /// # Arguments + /// + /// * `socket_path` - Path to the Unix domain socket. Use `@name` for Linux abstract sockets. + /// * `uri` - The MCP server URI (used for HTTP Host header and request path). + pub fn from_unix_socket(socket_path: &str, uri: impl Into>) -> Self { + let uri: Arc = uri.into(); + let client = UnixSocketHttpClient::new(socket_path, &uri); + let config = StreamableHttpClientTransportConfig { + uri, + ..Default::default() + }; + StreamableHttpClientTransport::with_client(client, config) + } + + /// Creates a new transport connecting through a Unix domain socket with custom config. + /// + /// # Arguments + /// + /// * `socket_path` - Path to the Unix domain socket. Use `@name` for Linux abstract sockets. + /// * `config` - Transport configuration (URI, retry policy, custom headers, etc.). + pub fn from_unix_socket_with_config( + socket_path: &str, + config: StreamableHttpClientTransportConfig, + ) -> Self { + let client = UnixSocketHttpClient::new(socket_path, &config.uri); + StreamableHttpClientTransport::with_client(client, config) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn resolve_abstract_socket() { + assert_eq!(resolve_socket_path("@egress.sock"), "\0egress.sock"); + } + + #[test] + fn resolve_filesystem_socket() { + assert_eq!( + resolve_socket_path("/var/run/envoy.sock"), + "/var/run/envoy.sock" + ); + } + + #[test] + fn resolve_empty_abstract() { + assert_eq!(resolve_socket_path("@"), "\0"); + } + + #[test] + #[should_panic(expected = "socket_path must not be empty")] + fn rejects_bare_at_symbol() { + UnixSocketHttpClient::new("@", "http://localhost/mcp"); + } + + #[test] + #[should_panic(expected = "socket_path must not be empty")] + fn rejects_empty_path() { + UnixSocketHttpClient::new("", "http://localhost/mcp"); + } + + #[test] + fn host_header_auto_derived() { + let client = + UnixSocketHttpClient::new("/var/run/envoy.sock", "http://mcp-server.internal/mcp"); + assert_eq!(client.host_header, "mcp-server.internal"); + } + + #[test] + fn host_header_with_port() { + let client = + UnixSocketHttpClient::new("/var/run/envoy.sock", "http://mcp-server.internal:8080/mcp"); + assert_eq!(client.host_header, "mcp-server.internal:8080"); + } + + #[test] + fn host_header_fallback_on_path_only_uri() { + let client = UnixSocketHttpClient::new("/var/run/envoy.sock", "/mcp"); + assert_eq!(client.host_header, "localhost"); + } + + #[test] + fn reserved_header_rejected() { + let mut headers = HashMap::new(); + headers.insert( + HeaderName::from_static("accept"), + HeaderValue::from_static("text/plain"), + ); + let builder = Request::builder(); + let result = apply_custom_headers(builder, headers); + assert!(matches!( + result, + Err(StreamableHttpError::ReservedHeaderConflict(_)) + )); + } + + #[test] + fn mcp_protocol_version_allowed_through() { + let mut headers = HashMap::new(); + headers.insert( + HeaderName::from_static("mcp-protocol-version"), + HeaderValue::from_static("2025-03-26"), + ); + let builder = Request::builder().uri("http://localhost/mcp").method("GET"); + let result = apply_custom_headers(builder, headers); + assert!(result.is_ok()); + } +} diff --git a/crates/rmcp/tests/test_unix_socket_transport.rs b/crates/rmcp/tests/test_unix_socket_transport.rs new file mode 100644 index 000000000..9a9eff3ea --- /dev/null +++ b/crates/rmcp/tests/test_unix_socket_transport.rs @@ -0,0 +1,272 @@ +#![cfg(all(unix, feature = "transport-streamable-http-client-unix-socket"))] + +use std::{collections::HashMap, sync::Arc}; + +use axum::{ + Router, body::Bytes, extract::State, http::StatusCode, response::IntoResponse, routing::post, +}; +use http::{HeaderName, HeaderValue}; +use rmcp::{ + ServiceExt, + transport::{ + StreamableHttpClientTransport, UnixSocketHttpClient, + streamable_http_client::StreamableHttpClientTransportConfig, + }, +}; +use serde_json::json; +use tokio::sync::Mutex; + +#[derive(Clone)] +struct ServerState { + received_headers: Arc>>, + initialize_called: Arc, +} + +async fn mcp_handler( + State(state): State, + headers: http::HeaderMap, + body: Bytes, +) -> impl IntoResponse { + let mut headers_map = HashMap::new(); + for (name, value) in headers.iter() { + let name_str = name.as_str(); + if name_str.starts_with("x-") || name_str == "host" { + if let Ok(v) = value.to_str() { + headers_map.insert(name_str.to_string(), v.to_string()); + } + } + } + + let mut stored = state.received_headers.lock().await; + stored.extend(headers_map); + drop(stored); + + if let Ok(json_body) = serde_json::from_slice::(&body) { + if let Some(method) = json_body.get("method").and_then(|m| m.as_str()) { + if method == "initialize" { + state.initialize_called.notify_one(); + let response = json!({ + "jsonrpc": "2.0", + "id": json_body.get("id"), + "result": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "serverInfo": { + "name": "test-unix-server", + "version": "1.0.0" + } + } + }); + return ( + StatusCode::OK, + [ + (http::header::CONTENT_TYPE, "application/json"), + ( + http::HeaderName::from_static("mcp-session-id"), + "unix-test-session", + ), + ], + response.to_string(), + ); + } else if method == "notifications/initialized" { + return ( + StatusCode::ACCEPTED, + [ + (http::header::CONTENT_TYPE, "application/json"), + ( + http::HeaderName::from_static("mcp-session-id"), + "unix-test-session", + ), + ], + String::new(), + ); + } + } + } + + let request_id = serde_json::from_slice::(&body) + .ok() + .and_then(|j| j.get("id").cloned()) + .unwrap_or(serde_json::Value::Null); + let response = json!({ + "jsonrpc": "2.0", + "id": request_id, + "result": {} + }); + ( + StatusCode::OK, + [ + (http::header::CONTENT_TYPE, "application/json"), + ( + http::HeaderName::from_static("mcp-session-id"), + "unix-test-session", + ), + ], + response.to_string(), + ) +} + +/// Integration test: MCP client connects and completes handshake over a Unix domain socket. +#[tokio::test(flavor = "current_thread")] +async fn test_unix_socket_mcp_handshake() -> anyhow::Result<()> { + let dir = std::env::temp_dir().join(format!("rmcp-test-{}", std::process::id())); + std::fs::create_dir_all(&dir)?; + let socket_path = dir.join("mcp.sock"); + + // Clean up any leftover socket from a previous run + let _ = std::fs::remove_file(&socket_path); + + let state = ServerState { + received_headers: Arc::new(Mutex::new(HashMap::new())), + initialize_called: Arc::new(tokio::sync::Notify::new()), + }; + + let app = Router::new() + .route("/mcp", post(mcp_handler)) + .with_state(state.clone()); + + let listener = tokio::net::UnixListener::bind(&socket_path)?; + let server_handle = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + let socket_str = socket_path.to_str().unwrap(); + let uri = "http://mcp-server.internal/mcp"; + let client = UnixSocketHttpClient::new(socket_str, uri); + let config = StreamableHttpClientTransportConfig::with_uri(uri); + let transport = StreamableHttpClientTransport::with_client(client, config); + + let mcp_client = ().serve(transport).await.expect("MCP handshake should succeed"); + + tokio::time::timeout( + std::time::Duration::from_secs(5), + state.initialize_called.notified(), + ) + .await + .expect("Initialize request should be received"); + + // Verify Host header was set correctly + let headers = state.received_headers.lock().await; + assert_eq!( + headers.get("host"), + Some(&"mcp-server.internal".to_string()), + "Host header should be derived from URI" + ); + + drop(mcp_client); + server_handle.abort(); + let _ = std::fs::remove_file(&socket_path); + let _ = std::fs::remove_dir(&dir); + + Ok(()) +} + +/// Integration test: Custom headers are sent through the Unix socket transport. +#[tokio::test(flavor = "current_thread")] +async fn test_unix_socket_custom_headers() -> anyhow::Result<()> { + let dir = std::env::temp_dir().join(format!("rmcp-test-headers-{}", std::process::id())); + std::fs::create_dir_all(&dir)?; + let socket_path = dir.join("mcp.sock"); + let _ = std::fs::remove_file(&socket_path); + + let state = ServerState { + received_headers: Arc::new(Mutex::new(HashMap::new())), + initialize_called: Arc::new(tokio::sync::Notify::new()), + }; + + let app = Router::new() + .route("/mcp", post(mcp_handler)) + .with_state(state.clone()); + + let listener = tokio::net::UnixListener::bind(&socket_path)?; + let server_handle = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + let mut custom_headers = HashMap::new(); + custom_headers.insert( + HeaderName::from_static("x-test-header"), + HeaderValue::from_static("test-value-123"), + ); + custom_headers.insert( + HeaderName::from_static("x-client-id"), + HeaderValue::from_static("unix-test-client"), + ); + + let socket_str = socket_path.to_str().unwrap(); + let uri = "http://mcp-server.internal/mcp"; + let client = UnixSocketHttpClient::new(socket_str, uri); + let config = StreamableHttpClientTransportConfig::with_uri(uri).custom_headers(custom_headers); + let transport = StreamableHttpClientTransport::with_client(client, config); + + let mcp_client = ().serve(transport).await.expect("MCP handshake should succeed"); + + tokio::time::timeout( + std::time::Duration::from_secs(5), + state.initialize_called.notified(), + ) + .await + .expect("Initialize request should be received"); + + let headers = state.received_headers.lock().await; + assert_eq!( + headers.get("x-test-header"), + Some(&"test-value-123".to_string()), + "Custom header x-test-header should be received" + ); + assert_eq!( + headers.get("x-client-id"), + Some(&"unix-test-client".to_string()), + "Custom header x-client-id should be received" + ); + + drop(mcp_client); + server_handle.abort(); + let _ = std::fs::remove_file(&socket_path); + let _ = std::fs::remove_dir(&dir); + + Ok(()) +} + +/// Integration test: Convenience constructor `from_unix_socket` works end-to-end. +#[tokio::test(flavor = "current_thread")] +async fn test_unix_socket_convenience_constructor() -> anyhow::Result<()> { + let dir = std::env::temp_dir().join(format!("rmcp-test-conv-{}", std::process::id())); + std::fs::create_dir_all(&dir)?; + let socket_path = dir.join("mcp.sock"); + let _ = std::fs::remove_file(&socket_path); + + let state = ServerState { + received_headers: Arc::new(Mutex::new(HashMap::new())), + initialize_called: Arc::new(tokio::sync::Notify::new()), + }; + + let app = Router::new() + .route("/mcp", post(mcp_handler)) + .with_state(state.clone()); + + let listener = tokio::net::UnixListener::bind(&socket_path)?; + let server_handle = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + let socket_str = socket_path.to_str().unwrap(); + let transport = + StreamableHttpClientTransport::from_unix_socket(socket_str, "http://localhost/mcp"); + + let mcp_client = ().serve(transport).await.expect("MCP handshake should succeed"); + + tokio::time::timeout( + std::time::Duration::from_secs(5), + state.initialize_called.notified(), + ) + .await + .expect("Initialize request should be received"); + + drop(mcp_client); + server_handle.abort(); + let _ = std::fs::remove_file(&socket_path); + let _ = std::fs::remove_dir(&dir); + + Ok(()) +}