diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 759d2131a..f58ac6ef7 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -222,7 +222,7 @@ async def send_log_message( related_request_id, ) - async def send_resource_updated(self, uri: str | AnyUrl) -> None: # pragma: no cover + async def send_resource_updated(self, uri: str | AnyUrl) -> None: """Send a resource updated notification.""" await self.send_notification( types.ResourceUpdatedNotification( diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index aa99e7c88..011deb1f0 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -177,7 +177,7 @@ def is_terminated(self) -> bool: """Check if this transport has been explicitly terminated.""" return self._terminated - def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover + def close_sse_stream(self, request_id: RequestId) -> None: """Close SSE connection for a specific request without terminating the stream. This method closes the HTTP connection for the specified request, triggering @@ -200,12 +200,12 @@ def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover writer.close() # Also close and remove request streams - if request_id in self._request_streams: + if request_id in self._request_streams: # pragma: no branch send_stream, receive_stream = self._request_streams.pop(request_id) send_stream.close() receive_stream.close() - def close_standalone_sse_stream(self) -> None: # pragma: no cover + def close_standalone_sse_stream(self) -> None: """Close the standalone GET SSE stream, triggering client reconnection. This method closes the HTTP connection for the standalone GET stream used @@ -240,10 +240,10 @@ def _create_session_message( # Only provide close callbacks when client supports resumability if self._event_store and protocol_version >= "2025-11-25": - async def close_stream_callback() -> None: # pragma: no cover + async def close_stream_callback() -> None: self.close_sse_stream(request_id) - async def close_standalone_stream_callback() -> None: # pragma: no cover + async def close_standalone_stream_callback() -> None: self.close_standalone_sse_stream() metadata = ServerMessageMetadata( @@ -291,7 +291,7 @@ def _create_error_response( ) -> Response: """Create an error response with a simple string message.""" response_headers = {"Content-Type": CONTENT_TYPE_JSON} - if headers: # pragma: no cover + if headers: response_headers.update(headers) if self.mcp_session_id: @@ -342,7 +342,7 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: } # If an event ID was provided, include it - if event_message.event_id: # pragma: no cover + if event_message.event_id: event_data["id"] = event_message.event_id return event_data @@ -372,7 +372,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await error_response(scope, receive, send) return - if self._terminated: # pragma: no cover + if self._terminated: # If the session has been terminated, return 404 Not Found response = self._create_error_response( "Not Found: Session has been terminated", @@ -387,7 +387,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No await self._handle_get_request(request, send) elif request.method == "DELETE": await self._handle_delete_request(request, send) - else: # pragma: no cover + else: await self._handle_unsupported_request(request, send) def _check_accept_headers(self, request: Request) -> tuple[bool, bool]: @@ -467,7 +467,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re try: message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False) - except ValidationError as e: # pragma: no cover + except ValidationError as e: response = self._create_error_response( f"Validation error: {str(e)}", HTTPStatus.BAD_REQUEST, @@ -493,7 +493,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re ) await response(scope, receive, send) return - elif not await self._validate_request_headers(request, send): # pragma: no cover + elif not await self._validate_request_headers(request, send): return # For notifications and responses only, return 202 Accepted @@ -633,7 +633,7 @@ async def sse_writer(): # pragma: lax no cover finally: await sse_stream_reader.aclose() - except Exception as err: # pragma: no cover + except Exception as err: # pragma: lax no cover logger.exception("Error handling POST request") response = self._create_error_response( f"Error handling POST request: {err}", @@ -659,7 +659,7 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: # Validate Accept header - must include text/event-stream _, has_sse = self._check_accept_headers(request) - if not has_sse: # pragma: no cover + if not has_sse: response = self._create_error_response( "Not Acceptable: Client must accept text/event-stream", HTTPStatus.NOT_ACCEPTABLE, @@ -667,11 +667,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: await response(request.scope, request.receive, send) return - if not await self._validate_request_headers(request, send): # pragma: no cover - return + if not await self._validate_request_headers(request, send): + return # pragma: no cover # Handle resumability: check for Last-Event-ID header - if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): # pragma: no cover + if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): await self._replay_events(last_event_id, request, send) return @@ -681,11 +681,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: "Content-Type": CONTENT_TYPE_SSE, } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Check if we already have an active GET stream - if GET_STREAM_KEY in self._request_streams: # pragma: no cover + if GET_STREAM_KEY in self._request_streams: response = self._create_error_response( "Conflict: Only one SSE stream is allowed per session", HTTPStatus.CONFLICT, @@ -714,7 +714,7 @@ async def standalone_sse_writer(): # Send the message via SSE event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) - except Exception: # pragma: no cover + except Exception: # pragma: lax no cover logger.exception("Error in standalone SSE writer") finally: logger.debug("Closing standalone SSE writer") @@ -791,13 +791,13 @@ async def terminate(self) -> None: # During cleanup, we catch all exceptions since streams might be in various states logger.debug(f"Error closing streams: {e}") - async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: no cover + async def _handle_unsupported_request(self, request: Request, send: Send) -> None: """Handle unsupported HTTP methods.""" headers = { "Content-Type": CONTENT_TYPE_JSON, "Allow": "GET, POST, DELETE", } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id response = self._create_error_response( @@ -824,7 +824,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool: request_session_id = self._get_session_id(request) # If no session ID provided but required, return error - if not request_session_id: # pragma: no cover + if not request_session_id: response = self._create_error_response( "Bad Request: Missing session ID", HTTPStatus.BAD_REQUEST, @@ -849,11 +849,11 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) # If no protocol version provided, assume default version - if protocol_version is None: # pragma: no cover + if protocol_version is None: protocol_version = DEFAULT_NEGOTIATED_VERSION # Check if the protocol version is supported - if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: # pragma: no cover + if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS) response = self._create_error_response( f"Bad Request: Unsupported protocol version: {protocol_version}. " @@ -865,13 +865,13 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool return True - async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: no cover + async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: """Replays events that would have been sent after the specified event ID. Only used when resumability is enabled. """ event_store = self._event_store - if not event_store: + if not event_store: # pragma: no cover return try: @@ -881,7 +881,7 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send) "Content-Type": CONTENT_TYPE_SSE, } - if self.mcp_session_id: + if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id # Get protocol version from header (already validated in _validate_protocol_version) @@ -902,7 +902,7 @@ async def send_event(event_message: EventMessage) -> None: stream_id = await event_store.replay_events_after(last_event_id, send_event) # If stream ID not in mapping, create it - if stream_id and stream_id not in self._request_streams: + if stream_id and stream_id not in self._request_streams: # pragma: no branch # Register SSE writer so close_sse_stream() can close it self._sse_stream_writers[stream_id] = sse_stream_writer @@ -919,10 +919,10 @@ async def send_event(event_message: EventMessage) -> None: event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: lax no cover # Expected when close_sse_stream() is called logger.debug("Replay SSE stream closed by close_sse_stream()") - except Exception: + except Exception: # pragma: lax no cover logger.exception("Error in replay sender") # Create and start EventSourceResponse @@ -934,13 +934,13 @@ async def send_event(event_message: EventMessage) -> None: try: await response(request.scope, request.receive, send) - except Exception: + except Exception: # pragma: no cover logger.exception("Error in replay response") finally: await sse_stream_writer.aclose() await sse_stream_reader.aclose() - except Exception: + except Exception: # pragma: no cover logger.exception("Error replaying events") response = self._create_error_response( "Error replaying events", @@ -991,7 +991,7 @@ async def message_router(): if isinstance(message, JSONRPCResponse | JSONRPCError) and message.id is not None: target_request_id = str(message.id) # Extract related_request_id from meta if it exists - elif ( # pragma: no cover + elif ( session_message.metadata is not None and isinstance( session_message.metadata, @@ -1015,10 +1015,10 @@ async def message_router(): try: # Send both the message and the event ID await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id)) - except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover + except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: lax no cover # Stream might be closed, remove from registry self._request_streams.pop(request_stream_id, None) - else: # pragma: no cover + else: logger.debug( f"""Request stream {request_stream_id} not found for message. Still processing message as the client diff --git a/src/mcp/server/transport_security.py b/src/mcp/server/transport_security.py index 1ed9842c0..8a0a6c15d 100644 --- a/src/mcp/server/transport_security.py +++ b/src/mcp/server/transport_security.py @@ -40,7 +40,7 @@ def __init__(self, settings: TransportSecuritySettings | None = None): # If not specified, disable DNS rebinding protection by default for backwards compatibility self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False) - def _validate_host(self, host: str | None) -> bool: # pragma: no cover + def _validate_host(self, host: str | None) -> bool: # pragma: lax no cover """Validate the Host header against allowed values.""" if not host: logger.warning("Missing Host header in request") @@ -62,7 +62,7 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover logger.warning(f"Invalid Host header: {host}") return False - def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover + def _validate_origin(self, origin: str | None) -> bool: # pragma: lax no cover """Validate the Origin header against allowed values.""" # Origin can be absent for same-origin requests if not origin: @@ -103,14 +103,14 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res if not self.settings.enable_dns_rebinding_protection: return None - # Validate Host header # pragma: no cover - host = request.headers.get("host") # pragma: no cover - if not self._validate_host(host): # pragma: no cover - return Response("Invalid Host header", status_code=421) # pragma: no cover + # Validate Host header + host = request.headers.get("host") # pragma: lax no cover + if not self._validate_host(host): # pragma: lax no cover + return Response("Invalid Host header", status_code=421) - # Validate Origin header # pragma: no cover - origin = request.headers.get("origin") # pragma: no cover - if not self._validate_origin(origin): # pragma: no cover - return Response("Invalid Origin header", status_code=403) # pragma: no cover + # Validate Origin header + origin = request.headers.get("origin") # pragma: lax no cover + if not self._validate_origin(origin): # pragma: lax no cover + return Response("Invalid Origin header", status_code=403) - return None # pragma: no cover + return None # pragma: lax no cover diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index f8ca30441..9119b9d14 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -6,10 +6,7 @@ from __future__ import annotations as _annotations import json -import multiprocessing -import socket import time -import traceback from collections.abc import AsyncIterator, Generator from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -21,7 +18,6 @@ import httpx import pytest import requests -import uvicorn from httpx_sse import ServerSentEvent from starlette.applications import Starlette from starlette.requests import Request @@ -65,7 +61,7 @@ TextResourceContents, Tool, ) -from tests.test_helpers import wait_for_server +from tests.test_helpers import run_uvicorn_in_thread # Test constants SERVER_NAME = "test_streamable_http_server" @@ -108,7 +104,7 @@ async def store_event(self, stream_id: StreamId, message: types.JSONRPCMessage | self._events.append((stream_id, event_id, message)) return event_id - async def replay_events_after( # pragma: no cover + async def replay_events_after( self, last_event_id: EventId, send_callback: EventCallback, @@ -117,11 +113,11 @@ async def replay_events_after( # pragma: no cover # Find the stream ID of the last event target_stream_id = None for stream_id, event_id, _ in self._events: - if event_id == last_event_id: + if event_id == last_event_id: # pragma: no branch target_stream_id = stream_id break - if target_stream_id is None: + if target_stream_id is None: # pragma: no cover # If event ID not found, return None return None @@ -132,7 +128,7 @@ async def replay_events_after( # pragma: no cover for stream_id, event_id, message in self._events: if stream_id == target_stream_id and int(event_id) > last_event_id_int: # Skip priming events (None message) - if message is not None: + if message is not None: # pragma: no branch await send_callback(EventMessage(message, event_id)) return target_stream_id @@ -144,18 +140,18 @@ class ServerState: @asynccontextmanager -async def _server_lifespan(_server: Server[ServerState]) -> AsyncIterator[ServerState]: # pragma: no cover +async def _server_lifespan(_server: Server[ServerState]) -> AsyncIterator[ServerState]: yield ServerState() -async def _handle_read_resource( # pragma: no cover +async def _handle_read_resource( ctx: ServerRequestContext[ServerState], params: ReadResourceRequestParams ) -> ReadResourceResult: uri = str(params.uri) parsed = urlparse(uri) if parsed.scheme == "foobar": text = f"Read {parsed.netloc}" - elif parsed.scheme == "slow": + elif parsed.scheme == "slow": # pragma: no cover await anyio.sleep(2.0) text = f"Slow response from {parsed.netloc}" else: @@ -163,7 +159,7 @@ async def _handle_read_resource( # pragma: no cover return ReadResourceResult(contents=[TextResourceContents(uri=uri, text=text, mime_type="text/plain")]) -async def _handle_list_tools( # pragma: no cover +async def _handle_list_tools( ctx: ServerRequestContext[ServerState], params: PaginatedRequestParams | None ) -> ListToolsResult: return ListToolsResult( @@ -228,9 +224,7 @@ async def _handle_list_tools( # pragma: no cover ) -async def _handle_call_tool( # pragma: no cover - ctx: ServerRequestContext[ServerState], params: CallToolRequestParams -) -> CallToolResult: +async def _handle_call_tool(ctx: ServerRequestContext[ServerState], params: CallToolRequestParams) -> CallToolResult: name = params.name args = params.arguments or {} @@ -239,7 +233,7 @@ async def _handle_call_tool( # pragma: no cover await ctx.session.send_resource_updated(uri="http://test_resource") return CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) - elif name == "long_running_with_checkpoints": + elif name == "long_running_with_checkpoints": # pragma: no cover await ctx.session.send_log_message( level="info", data="Tool started", @@ -272,7 +266,7 @@ async def _handle_call_tool( # pragma: no cover if sampling_result.content.type == "text": response = sampling_result.content.text - else: + else: # pragma: no cover response = str(sampling_result.content) return CallToolResult( content=[ @@ -360,7 +354,7 @@ async def _handle_call_tool( # pragma: no cover related_request_id=ctx.request_id, ) - if ctx.close_sse_stream: + if ctx.close_sse_stream: # pragma: no branch await ctx.close_sse_stream() await anyio.sleep(sleep_time) @@ -371,7 +365,7 @@ async def _handle_call_tool( # pragma: no cover await ctx.session.send_resource_updated(uri="http://notification_1") await anyio.sleep(0.1) - if ctx.close_standalone_sse_stream: + if ctx.close_standalone_sse_stream: # pragma: no branch await ctx.close_standalone_sse_stream() await anyio.sleep(1.5) @@ -382,7 +376,7 @@ async def _handle_call_tool( # pragma: no cover return CallToolResult(content=[TextContent(type="text", text=f"Called {name}")]) -def _create_server() -> Server[ServerState]: # pragma: no cover +def _create_server() -> Server[ServerState]: return Server( SERVER_NAME, lifespan=_server_lifespan, @@ -396,7 +390,7 @@ def create_app( is_json_response_enabled: bool = False, event_store: EventStore | None = None, retry_interval: int | None = None, -) -> Starlette: # pragma: no cover +) -> Starlette: """Create a Starlette application for testing using the session manager. Args: @@ -431,74 +425,11 @@ def create_app( return app -def run_server( - port: int, - is_json_response_enabled: bool = False, - event_store: EventStore | None = None, - retry_interval: int | None = None, -) -> None: # pragma: no cover - """Run the test server. - - Args: - port: Port to listen on. - is_json_response_enabled: If True, use JSON responses instead of SSE streams. - event_store: Optional event store for testing resumability. - retry_interval: Retry interval in milliseconds for SSE polling. - """ - - app = create_app(is_json_response_enabled, event_store, retry_interval) - # Configure server - config = uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="info", - limit_concurrency=10, - timeout_keep_alive=5, - access_log=False, - ) - - # Start the server - server = uvicorn.Server(config=config) - - # This is important to catch exceptions and prevent test hangs - try: - server.run() - except Exception: - traceback.print_exc() - - -# Test fixtures - using same approach as SSE tests -@pytest.fixture -def basic_server_port() -> int: - """Find an available port for the basic server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - @pytest.fixture -def json_server_port() -> int: - """Find an available port for the JSON response server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def basic_server(basic_server_port: int) -> Generator[None, None, None]: - """Start a basic server.""" - proc = multiprocessing.Process(target=run_server, kwargs={"port": basic_server_port}, daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - # Clean up - proc.kill() - proc.join(timeout=2) +def basic_server_url() -> Generator[str, None, None]: + """Start a basic server in a background thread. Yields its base URL.""" + with run_uvicorn_in_thread(create_app(), limit_concurrency=10, timeout_keep_alive=5, access_log=False) as url: + yield url @pytest.fixture @@ -508,69 +439,28 @@ def event_store() -> SimpleEventStore: @pytest.fixture -def event_server_port() -> int: - """Find an available port for the event store server.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def event_server( - event_server_port: int, event_store: SimpleEventStore -) -> Generator[tuple[SimpleEventStore, str], None, None]: - """Start a server with event store and retry_interval enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": event_server_port, "event_store": event_store, "retry_interval": 500}, - daemon=True, - ) - proc.start() - - # Wait for server to be running - wait_for_server(event_server_port) - - yield event_store, f"http://127.0.0.1:{event_server_port}" - - # Clean up - proc.kill() - proc.join(timeout=2) - - -@pytest.fixture -def json_response_server(json_server_port: int) -> Generator[None, None, None]: - """Start a server with JSON response enabled.""" - proc = multiprocessing.Process( - target=run_server, - kwargs={"port": json_server_port, "is_json_response_enabled": True}, - daemon=True, - ) - proc.start() - - # Wait for server to be running - wait_for_server(json_server_port) - - yield - - # Clean up - proc.kill() - proc.join(timeout=2) - - -@pytest.fixture -def basic_server_url(basic_server_port: int) -> str: - """Get the URL for the basic test server.""" - return f"http://127.0.0.1:{basic_server_port}" +def event_server(event_store: SimpleEventStore) -> Generator[tuple[SimpleEventStore, str], None, None]: + """Start a server with event store and retry_interval enabled in a background thread.""" + with run_uvicorn_in_thread( + create_app(event_store=event_store, retry_interval=500), + limit_concurrency=10, + timeout_keep_alive=5, + access_log=False, + ) as url: + yield event_store, url @pytest.fixture -def json_server_url(json_server_port: int) -> str: - """Get the URL for the JSON response test server.""" - return f"http://127.0.0.1:{json_server_port}" +def json_server_url() -> Generator[str, None, None]: + """Start a server with JSON response enabled in a background thread. Yields its base URL.""" + with run_uvicorn_in_thread( + create_app(is_json_response_enabled=True), limit_concurrency=10, timeout_keep_alive=5, access_log=False + ) as url: + yield url # Basic request validation tests -def test_accept_header_validation(basic_server: None, basic_server_url: str): +def test_accept_header_validation(basic_server_url: str): """Test that Accept header is properly validated.""" # Test without Accept header (suppress requests library default Accept: */*) session = requests.Session() @@ -595,7 +485,7 @@ def test_accept_header_validation(basic_server: None, basic_server_url: str): "application/*;q=0.9, text/*;q=0.8", ], ) -def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accept_header: str): +def test_accept_header_wildcard(basic_server_url: str, accept_header: str): """Test that wildcard Accept headers are accepted per RFC 7231.""" response = requests.post( f"{basic_server_url}/mcp", @@ -616,7 +506,7 @@ def test_accept_header_wildcard(basic_server: None, basic_server_url: str, accep "text/*", ], ) -def test_accept_header_incompatible(basic_server: None, basic_server_url: str, accept_header: str): +def test_accept_header_incompatible(basic_server_url: str, accept_header: str): """Test that incompatible Accept headers are rejected for SSE mode.""" response = requests.post( f"{basic_server_url}/mcp", @@ -630,7 +520,7 @@ def test_accept_header_incompatible(basic_server: None, basic_server_url: str, a assert "Not Acceptable" in response.text -def test_content_type_validation(basic_server: None, basic_server_url: str): +def test_content_type_validation(basic_server_url: str): """Test that Content-Type header is properly validated.""" # Test with incorrect Content-Type response = requests.post( @@ -646,7 +536,7 @@ def test_content_type_validation(basic_server: None, basic_server_url: str): assert "Invalid Content-Type" in response.text -def test_json_validation(basic_server: None, basic_server_url: str): +def test_json_validation(basic_server_url: str): """Test that JSON content is properly validated.""" # Test with invalid JSON response = requests.post( @@ -661,7 +551,7 @@ def test_json_validation(basic_server: None, basic_server_url: str): assert "Parse error" in response.text -def test_json_parsing(basic_server: None, basic_server_url: str): +def test_json_parsing(basic_server_url: str): """Test that JSON content is properly parse.""" # Test with valid JSON but invalid JSON-RPC response = requests.post( @@ -676,7 +566,7 @@ def test_json_parsing(basic_server: None, basic_server_url: str): assert "Validation error" in response.text -def test_method_not_allowed(basic_server: None, basic_server_url: str): +def test_method_not_allowed(basic_server_url: str): """Test that unsupported HTTP methods are rejected.""" # Test with unsupported method (PUT) response = requests.put( @@ -691,7 +581,7 @@ def test_method_not_allowed(basic_server: None, basic_server_url: str): assert "Method Not Allowed" in response.text -def test_session_validation(basic_server: None, basic_server_url: str): +def test_session_validation(basic_server_url: str): """Test session ID validation.""" # session_id not used directly in this test @@ -766,7 +656,7 @@ def test_streamable_http_transport_init_validation(): StreamableHTTPServerTransport(mcp_session_id="test\n") -def test_session_termination(basic_server: None, basic_server_url: str): +def test_session_termination(basic_server_url: str): """Test session termination via DELETE and subsequent request handling.""" response = requests.post( f"{basic_server_url}/mcp", @@ -806,7 +696,7 @@ def test_session_termination(basic_server: None, basic_server_url: str): assert "Session has been terminated" in response.text -def test_response(basic_server: None, basic_server_url: str): +def test_response(basic_server_url: str): """Test response handling for a valid request.""" mcp_url = f"{basic_server_url}/mcp" response = requests.post( @@ -841,7 +731,7 @@ def test_response(basic_server: None, basic_server_url: str): assert tools_response.headers.get("Content-Type") == "text/event-stream" -def test_json_response(json_response_server: None, json_server_url: str): +def test_json_response(json_server_url: str): """Test response handling when is_json_response_enabled is True.""" mcp_url = f"{json_server_url}/mcp" response = requests.post( @@ -856,7 +746,7 @@ def test_json_response(json_response_server: None, json_server_url: str): assert response.headers.get("Content-Type") == "application/json" -def test_json_response_accept_json_only(json_response_server: None, json_server_url: str): +def test_json_response_accept_json_only(json_server_url: str): """Test that json_response servers only require application/json in Accept header.""" mcp_url = f"{json_server_url}/mcp" response = requests.post( @@ -871,7 +761,7 @@ def test_json_response_accept_json_only(json_response_server: None, json_server_ assert response.headers.get("Content-Type") == "application/json" -def test_json_response_missing_accept_header(json_response_server: None, json_server_url: str): +def test_json_response_missing_accept_header(json_server_url: str): """Test that json_response servers reject requests without Accept header.""" mcp_url = f"{json_server_url}/mcp" # Suppress requests library default Accept: */* header @@ -888,7 +778,7 @@ def test_json_response_missing_accept_header(json_response_server: None, json_se assert "Not Acceptable" in response.text -def test_json_response_incorrect_accept_header(json_response_server: None, json_server_url: str): +def test_json_response_incorrect_accept_header(json_server_url: str): """Test that json_response servers reject requests with incorrect Accept header.""" mcp_url = f"{json_server_url}/mcp" # Test with only text/event-stream (wrong for JSON server) @@ -912,7 +802,7 @@ def test_json_response_incorrect_accept_header(json_response_server: None, json_ "application/*;q=0.9", ], ) -def test_json_response_wildcard_accept_header(json_response_server: None, json_server_url: str, accept_header: str): +def test_json_response_wildcard_accept_header(json_server_url: str, accept_header: str): """Test that json_response servers accept wildcard Accept headers per RFC 7231.""" mcp_url = f"{json_server_url}/mcp" response = requests.post( @@ -927,7 +817,7 @@ def test_json_response_wildcard_accept_header(json_response_server: None, json_s assert response.headers.get("Content-Type") == "application/json" -def test_get_sse_stream(basic_server: None, basic_server_url: str): +def test_get_sse_stream(basic_server_url: str): """Test establishing an SSE stream via GET request.""" # First, we need to initialize a session mcp_url = f"{basic_server_url}/mcp" @@ -987,7 +877,7 @@ def test_get_sse_stream(basic_server: None, basic_server_url: str): assert second_get.status_code == 409 -def test_get_validation(basic_server: None, basic_server_url: str): +def test_get_validation(basic_server_url: str): """Test validation for GET requests.""" # First, we need to initialize a session mcp_url = f"{basic_server_url}/mcp" @@ -1044,14 +934,14 @@ def test_get_validation(basic_server: None, basic_server_url: str): # Client-specific fixtures @pytest.fixture -async def http_client(basic_server: None, basic_server_url: str): # pragma: no cover +async def http_client(basic_server_url: str): # pragma: no cover """Create test client matching the SSE test pattern.""" async with httpx.AsyncClient(base_url=basic_server_url) as client: yield client @pytest.fixture -async def initialized_client_session(basic_server: None, basic_server_url: str): +async def initialized_client_session(basic_server_url: str): """Create initialized StreamableHTTP client session.""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1060,7 +950,7 @@ async def initialized_client_session(basic_server: None, basic_server_url: str): @pytest.mark.anyio -async def test_streamable_http_client_basic_connection(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_basic_connection(basic_server_url: str): """Test basic client connection with initialization.""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1105,7 +995,7 @@ async def test_streamable_http_client_error_handling(initialized_client_session: @pytest.mark.anyio -async def test_streamable_http_client_session_persistence(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_session_persistence(basic_server_url: str): """Test that session ID persists across requests.""" async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1126,7 +1016,7 @@ async def test_streamable_http_client_session_persistence(basic_server: None, ba @pytest.mark.anyio -async def test_streamable_http_client_json_response(json_response_server: None, json_server_url: str): +async def test_streamable_http_client_json_response(json_server_url: str): """Test client with JSON response mode.""" async with streamable_http_client(f"{json_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: @@ -1147,7 +1037,7 @@ async def test_streamable_http_client_json_response(json_response_server: None, @pytest.mark.anyio -async def test_streamable_http_client_get_stream(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_get_stream(basic_server_url: str): """Test GET stream functionality for server-initiated messages.""" notifications_received: list[types.ServerNotification] = [] @@ -1198,7 +1088,7 @@ async def capture_session_id(response: httpx.Response) -> None: @pytest.mark.anyio -async def test_streamable_http_client_session_termination(basic_server: None, basic_server_url: str): +async def test_streamable_http_client_session_termination(basic_server_url: str): """Test client session termination functionality.""" # Use httpx client with event hooks to capture session ID httpx_client, captured_ids = create_session_id_capturing_client() @@ -1233,9 +1123,7 @@ async def test_streamable_http_client_session_termination(basic_server: None, ba @pytest.mark.anyio -async def test_streamable_http_client_session_termination_204( - basic_server: None, basic_server_url: str, monkeypatch: pytest.MonkeyPatch -): +async def test_streamable_http_client_session_termination_204(basic_server_url: str, monkeypatch: pytest.MonkeyPatch): """Test client session termination functionality with a 204 response. This test patches the httpx client to return a 204 response for DELETEs. @@ -1412,7 +1300,7 @@ async def run_tool(): @pytest.mark.anyio -async def test_streamablehttp_server_sampling(basic_server: None, basic_server_url: str): +async def test_streamablehttp_server_sampling(basic_server_url: str): """Test server-initiated sampling request through streamable HTTP transport.""" # Variable to track if sampling callback was invoked sampling_callback_invoked = False @@ -1462,7 +1350,7 @@ async def sampling_callback( # Context-aware server implementation for testing request context propagation -async def _handle_context_list_tools( # pragma: no cover +async def _handle_context_list_tools( ctx: ServerRequestContext, params: PaginatedRequestParams | None ) -> ListToolsResult: return ListToolsResult( @@ -1487,15 +1375,13 @@ async def _handle_context_list_tools( # pragma: no cover ) -async def _handle_context_call_tool( # pragma: no cover - ctx: ServerRequestContext, params: CallToolRequestParams -) -> CallToolResult: +async def _handle_context_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult: name = params.name args = params.arguments or {} if name == "echo_headers": headers_info: dict[str, Any] = {} - if ctx.request and isinstance(ctx.request, Request): + if ctx.request and isinstance(ctx.request, Request): # pragma: no branch headers_info = dict(ctx.request.headers) return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))]) @@ -1506,19 +1392,18 @@ async def _handle_context_call_tool( # pragma: no cover "method": None, "path": None, } - if ctx.request and isinstance(ctx.request, Request): + if ctx.request and isinstance(ctx.request, Request): # pragma: no branch request = ctx.request context_data["headers"] = dict(request.headers) context_data["method"] = request.method context_data["path"] = request.url.path return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))]) - return CallToolResult(content=[TextContent(type="text", text=f"Unknown tool: {name}")]) + return CallToolResult(content=[TextContent(type="text", text=f"Unknown tool: {name}")]) # pragma: no cover -# Server runner for context-aware testing -def run_context_aware_server(port: int): # pragma: no cover - """Run the context-aware test server.""" +def create_context_aware_app() -> Starlette: + """Build the context-aware test app (echoes request headers via tools).""" server = Server( "ContextAwareServer", on_list_tools=_handle_context_list_tools, @@ -1531,7 +1416,7 @@ def run_context_aware_server(port: int): # pragma: no cover json_response=False, ) - app = Starlette( + return Starlette( debug=True, routes=[ Mount("/mcp", app=session_manager.handle_request), @@ -1539,36 +1424,16 @@ def run_context_aware_server(port: int): # pragma: no cover lifespan=lambda app: session_manager.run(), ) - server_instance = uvicorn.Server( - config=uvicorn.Config( - app=app, - host="127.0.0.1", - port=port, - log_level="error", - ) - ) - server_instance.run() - @pytest.fixture -def context_aware_server(basic_server_port: int) -> Generator[None, None, None]: - """Start the context-aware server in a separate process.""" - proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True) - proc.start() - - # Wait for server to be running - wait_for_server(basic_server_port) - - yield - - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("Context-aware server process failed to terminate") +def context_server_url() -> Generator[str, None, None]: + """Start the context-aware server in a background thread. Yields its base URL.""" + with run_uvicorn_in_thread(create_context_aware_app()) as url: + yield url @pytest.mark.anyio -async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None: +async def test_streamablehttp_request_context_propagation(context_server_url: str) -> None: """Test that request context is properly propagated through StreamableHTTP.""" custom_headers = { "Authorization": "Bearer test-token", @@ -1577,7 +1442,7 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: } async with create_mcp_http_client(headers=custom_headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with streamable_http_client(f"{context_server_url}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1601,7 +1466,7 @@ async def test_streamablehttp_request_context_propagation(context_aware_server: @pytest.mark.anyio -async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None: +async def test_streamablehttp_request_context_isolation(context_server_url: str) -> None: """Test that request contexts are isolated between StreamableHTTP clients.""" contexts: list[dict[str, Any]] = [] @@ -1614,7 +1479,7 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No } async with create_mcp_http_client(headers=headers) as httpx_client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=httpx_client) as ( + async with streamable_http_client(f"{context_server_url}/mcp", http_client=httpx_client) as ( read_stream, write_stream, ): @@ -1639,9 +1504,9 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No @pytest.mark.anyio -async def test_client_includes_protocol_version_header_after_init(context_aware_server: None, basic_server_url: str): +async def test_client_includes_protocol_version_header_after_init(context_server_url: str): """Test that client includes mcp-protocol-version header after initialization.""" - async with streamable_http_client(f"{basic_server_url}/mcp") as (read_stream, write_stream): + async with streamable_http_client(f"{context_server_url}/mcp") as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: # Initialize and get the negotiated version init_result = await session.initialize() @@ -1659,7 +1524,7 @@ async def test_client_includes_protocol_version_header_after_init(context_aware_ assert headers_data[MCP_PROTOCOL_VERSION_HEADER] == negotiated_version -def test_server_validates_protocol_version_header(basic_server: None, basic_server_url: str): +def test_server_validates_protocol_version_header(basic_server_url: str): """Test that server returns 400 Bad Request version if header unsupported or invalid.""" # First initialize a session to get a valid session ID init_response = requests.post( @@ -1717,7 +1582,7 @@ def test_server_validates_protocol_version_header(basic_server: None, basic_serv assert response.status_code == 200 -def test_server_backwards_compatibility_no_protocol_version(basic_server: None, basic_server_url: str): +def test_server_backwards_compatibility_no_protocol_version(basic_server_url: str): """Test server accepts requests without protocol version header.""" # First initialize a session to get a valid session ID init_response = requests.post( @@ -1747,7 +1612,7 @@ def test_server_backwards_compatibility_no_protocol_version(basic_server: None, @pytest.mark.anyio -async def test_client_crash_handled(basic_server: None, basic_server_url: str): +async def test_client_crash_handled(basic_server_url: str): """Test that cases where the client crashes are handled gracefully.""" # Simulate bad client that crashes after init @@ -2219,9 +2084,7 @@ async def message_handler( @pytest.mark.anyio -async def test_streamable_http_client_does_not_mutate_provided_client( - basic_server: None, basic_server_url: str -) -> None: +async def test_streamable_http_client_does_not_mutate_provided_client(basic_server_url: str) -> None: """Test that streamable_http_client does not mutate the provided httpx client's headers.""" # Create a client with custom headers original_headers = { @@ -2252,9 +2115,7 @@ async def test_streamable_http_client_does_not_mutate_provided_client( @pytest.mark.anyio -async def test_streamable_http_client_mcp_headers_override_defaults( - context_aware_server: None, basic_server_url: str -) -> None: +async def test_streamable_http_client_mcp_headers_override_defaults(context_server_url: str) -> None: """Test that MCP protocol headers override httpx.AsyncClient default headers.""" # httpx.AsyncClient has default "accept: */*" header # We need to verify that our MCP accept header overrides it in actual requests @@ -2263,7 +2124,10 @@ async def test_streamable_http_client_mcp_headers_override_defaults( # Verify client has default accept header assert client.headers.get("accept") == "*/*" - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as (read_stream, write_stream): + async with streamable_http_client(f"{context_server_url}/mcp", http_client=client) as ( + read_stream, + write_stream, + ): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() @@ -2283,9 +2147,7 @@ async def test_streamable_http_client_mcp_headers_override_defaults( @pytest.mark.anyio -async def test_streamable_http_client_preserves_custom_with_mcp_headers( - context_aware_server: None, basic_server_url: str -) -> None: +async def test_streamable_http_client_preserves_custom_with_mcp_headers(context_server_url: str) -> None: """Test that both custom headers and MCP protocol headers are sent in requests.""" custom_headers = { "X-Custom-Header": "custom-value", @@ -2294,7 +2156,10 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers( } async with httpx.AsyncClient(headers=custom_headers, follow_redirects=True) as client: - async with streamable_http_client(f"{basic_server_url}/mcp", http_client=client) as (read_stream, write_stream): + async with streamable_http_client(f"{context_server_url}/mcp", http_client=client) as ( + read_stream, + write_stream, + ): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch await session.initialize() diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 5c04c269f..7a51c028c 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,18 +1,28 @@ """Common test utilities for MCP server tests.""" import socket +import threading import time +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any + +import uvicorn def wait_for_server(port: int, timeout: float = 20.0) -> None: """Wait for server to be ready to accept connections. Polls the server port until it accepts connections or timeout is reached. - This eliminates race conditions without arbitrary sleeps. + + .. deprecated:: + This has a race: the port may be bound by a different server (another + pytest-xdist worker). Prefer :func:`run_uvicorn_in_thread` which holds + the port atomically from bind until shutdown. Args: port: The port number to check - timeout: Maximum time to wait in seconds (default 5.0) + timeout: Maximum time to wait in seconds Raises: TimeoutError: If server doesn't start within the timeout period @@ -23,9 +33,54 @@ def wait_for_server(port: int, timeout: float = 20.0) -> None: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.settimeout(0.1) s.connect(("127.0.0.1", port)) - # Server is ready return except (ConnectionRefusedError, OSError): - # Server not ready yet, retry quickly time.sleep(0.01) raise TimeoutError(f"Server on port {port} did not start within {timeout} seconds") # pragma: no cover + + +@contextmanager +def run_uvicorn_in_thread(app: Any, **config_kwargs: Any) -> Generator[str, None, None]: + """Run a uvicorn server in a background thread with an ephemeral port. + + This eliminates the TOCTOU race that occurs when a test picks a free port + with ``socket.bind((host, 0))``, releases it, then starts a server hoping + to rebind the same port — between release and rebind, another pytest-xdist + worker may claim it, causing connection errors or cross-test contamination. + + We bind the listening socket here with ``port=0`` and pass it to uvicorn + via ``server.run(sockets=[sock])`` — the OS assigns the port atomically at + bind time and we hold it until shutdown. No polling; the port is known + before the server thread even starts, and the kernel's listen queue buffers + any connections that arrive during startup. + + Args: + app: ASGI application to serve. + **config_kwargs: Additional keyword arguments for :class:`uvicorn.Config` + (e.g. ``log_level``, ``limit_concurrency``). ``host`` defaults to + ``127.0.0.1``. + + Yields: + The base URL of the running server, e.g. ``http://127.0.0.1:54321``. + """ + host = config_kwargs.setdefault("host", "127.0.0.1") + config_kwargs.setdefault("log_level", "error") + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((host, 0)) + sock.listen() + port = sock.getsockname()[1] + + config = uvicorn.Config(app=app, **config_kwargs) + server = uvicorn.Server(config=config) + thread = threading.Thread(target=server.run, args=([sock],), daemon=True) + thread.start() + + try: + yield f"http://{host}:{port}" + finally: + server.should_exit = True + server.force_exit = True + thread.join(timeout=5) + sock.close()