Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ async def send_log_message(
related_request_id,
)

async def send_resource_updated(self, uri: str | AnyUrl) -> None: # pragma: no cover
async def send_resource_updated(self, uri: str | AnyUrl) -> None:
"""Send a resource updated notification."""
await self.send_notification(
types.ResourceUpdatedNotification(
Expand Down
70 changes: 35 additions & 35 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def is_terminated(self) -> bool:
"""Check if this transport has been explicitly terminated."""
return self._terminated

def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover
def close_sse_stream(self, request_id: RequestId) -> None:
"""Close SSE connection for a specific request without terminating the stream.

This method closes the HTTP connection for the specified request, triggering
Expand All @@ -200,12 +200,12 @@ def close_sse_stream(self, request_id: RequestId) -> None: # pragma: no cover
writer.close()

# Also close and remove request streams
if request_id in self._request_streams:
if request_id in self._request_streams: # pragma: no branch
send_stream, receive_stream = self._request_streams.pop(request_id)
send_stream.close()
receive_stream.close()

def close_standalone_sse_stream(self) -> None: # pragma: no cover
def close_standalone_sse_stream(self) -> None:
"""Close the standalone GET SSE stream, triggering client reconnection.

This method closes the HTTP connection for the standalone GET stream used
Expand Down Expand Up @@ -240,10 +240,10 @@ def _create_session_message(
# Only provide close callbacks when client supports resumability
if self._event_store and protocol_version >= "2025-11-25":

async def close_stream_callback() -> None: # pragma: no cover
async def close_stream_callback() -> None:
self.close_sse_stream(request_id)

async def close_standalone_stream_callback() -> None: # pragma: no cover
async def close_standalone_stream_callback() -> None:
self.close_standalone_sse_stream()

metadata = ServerMessageMetadata(
Expand Down Expand Up @@ -291,7 +291,7 @@ def _create_error_response(
) -> Response:
"""Create an error response with a simple string message."""
response_headers = {"Content-Type": CONTENT_TYPE_JSON}
if headers: # pragma: no cover
if headers:
response_headers.update(headers)

if self.mcp_session_id:
Expand Down Expand Up @@ -342,7 +342,7 @@ def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
}

# If an event ID was provided, include it
if event_message.event_id: # pragma: no cover
if event_message.event_id:
event_data["id"] = event_message.event_id

return event_data
Expand Down Expand Up @@ -372,7 +372,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No
await error_response(scope, receive, send)
return

if self._terminated: # pragma: no cover
if self._terminated:
# If the session has been terminated, return 404 Not Found
response = self._create_error_response(
"Not Found: Session has been terminated",
Expand All @@ -387,7 +387,7 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No
await self._handle_get_request(request, send)
elif request.method == "DELETE":
await self._handle_delete_request(request, send)
else: # pragma: no cover
else:
await self._handle_unsupported_request(request, send)

def _check_accept_headers(self, request: Request) -> tuple[bool, bool]:
Expand Down Expand Up @@ -467,7 +467,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re

try:
message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False)
except ValidationError as e: # pragma: no cover
except ValidationError as e:
response = self._create_error_response(
f"Validation error: {str(e)}",
HTTPStatus.BAD_REQUEST,
Expand All @@ -493,7 +493,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
)
await response(scope, receive, send)
return
elif not await self._validate_request_headers(request, send): # pragma: no cover
elif not await self._validate_request_headers(request, send):
return

# For notifications and responses only, return 202 Accepted
Expand Down Expand Up @@ -633,7 +633,7 @@ async def sse_writer(): # pragma: lax no cover
finally:
await sse_stream_reader.aclose()

except Exception as err: # pragma: no cover
except Exception as err: # pragma: lax no cover
logger.exception("Error handling POST request")
response = self._create_error_response(
f"Error handling POST request: {err}",
Expand All @@ -659,19 +659,19 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
# Validate Accept header - must include text/event-stream
_, has_sse = self._check_accept_headers(request)

if not has_sse: # pragma: no cover
if not has_sse:
response = self._create_error_response(
"Not Acceptable: Client must accept text/event-stream",
HTTPStatus.NOT_ACCEPTABLE,
)
await response(request.scope, request.receive, send)
return

if not await self._validate_request_headers(request, send): # pragma: no cover
return
if not await self._validate_request_headers(request, send):
return # pragma: no cover

# Handle resumability: check for Last-Event-ID header
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER): # pragma: no cover
if last_event_id := request.headers.get(LAST_EVENT_ID_HEADER):
await self._replay_events(last_event_id, request, send)
return

Expand All @@ -681,11 +681,11 @@ async def _handle_get_request(self, request: Request, send: Send) -> None:
"Content-Type": CONTENT_TYPE_SSE,
}

if self.mcp_session_id:
if self.mcp_session_id: # pragma: no branch
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id

# Check if we already have an active GET stream
if GET_STREAM_KEY in self._request_streams: # pragma: no cover
if GET_STREAM_KEY in self._request_streams:
response = self._create_error_response(
"Conflict: Only one SSE stream is allowed per session",
HTTPStatus.CONFLICT,
Expand Down Expand Up @@ -714,7 +714,7 @@ async def standalone_sse_writer():
# Send the message via SSE
event_data = self._create_event_data(event_message)
await sse_stream_writer.send(event_data)
except Exception: # pragma: no cover
except Exception: # pragma: lax no cover
logger.exception("Error in standalone SSE writer")
finally:
logger.debug("Closing standalone SSE writer")
Expand Down Expand Up @@ -791,13 +791,13 @@ async def terminate(self) -> None:
# During cleanup, we catch all exceptions since streams might be in various states
logger.debug(f"Error closing streams: {e}")

async def _handle_unsupported_request(self, request: Request, send: Send) -> None: # pragma: no cover
async def _handle_unsupported_request(self, request: Request, send: Send) -> None:
"""Handle unsupported HTTP methods."""
headers = {
"Content-Type": CONTENT_TYPE_JSON,
"Allow": "GET, POST, DELETE",
}
if self.mcp_session_id:
if self.mcp_session_id: # pragma: no branch
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id

response = self._create_error_response(
Expand All @@ -824,7 +824,7 @@ async def _validate_session(self, request: Request, send: Send) -> bool:
request_session_id = self._get_session_id(request)

# If no session ID provided but required, return error
if not request_session_id: # pragma: no cover
if not request_session_id:
response = self._create_error_response(
"Bad Request: Missing session ID",
HTTPStatus.BAD_REQUEST,
Expand All @@ -849,11 +849,11 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool
protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER)

# If no protocol version provided, assume default version
if protocol_version is None: # pragma: no cover
if protocol_version is None:
protocol_version = DEFAULT_NEGOTIATED_VERSION

# Check if the protocol version is supported
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: # pragma: no cover
if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS:
supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS)
response = self._create_error_response(
f"Bad Request: Unsupported protocol version: {protocol_version}. "
Expand All @@ -865,13 +865,13 @@ async def _validate_protocol_version(self, request: Request, send: Send) -> bool

return True

async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: # pragma: no cover
async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None:
"""Replays events that would have been sent after the specified event ID.

Only used when resumability is enabled.
"""
event_store = self._event_store
if not event_store:
if not event_store: # pragma: no cover
return

try:
Expand All @@ -881,7 +881,7 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send)
"Content-Type": CONTENT_TYPE_SSE,
}

if self.mcp_session_id:
if self.mcp_session_id: # pragma: no branch
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id

# Get protocol version from header (already validated in _validate_protocol_version)
Expand All @@ -902,7 +902,7 @@ async def send_event(event_message: EventMessage) -> None:
stream_id = await event_store.replay_events_after(last_event_id, send_event)

# If stream ID not in mapping, create it
if stream_id and stream_id not in self._request_streams:
if stream_id and stream_id not in self._request_streams: # pragma: no branch
# Register SSE writer so close_sse_stream() can close it
self._sse_stream_writers[stream_id] = sse_stream_writer

Expand All @@ -919,10 +919,10 @@ async def send_event(event_message: EventMessage) -> None:
event_data = self._create_event_data(event_message)

await sse_stream_writer.send(event_data)
except anyio.ClosedResourceError:
except anyio.ClosedResourceError: # pragma: lax no cover
# Expected when close_sse_stream() is called
logger.debug("Replay SSE stream closed by close_sse_stream()")
except Exception:
except Exception: # pragma: lax no cover
logger.exception("Error in replay sender")

# Create and start EventSourceResponse
Expand All @@ -934,13 +934,13 @@ async def send_event(event_message: EventMessage) -> None:

try:
await response(request.scope, request.receive, send)
except Exception:
except Exception: # pragma: no cover
logger.exception("Error in replay response")
finally:
await sse_stream_writer.aclose()
await sse_stream_reader.aclose()

except Exception:
except Exception: # pragma: no cover
logger.exception("Error replaying events")
response = self._create_error_response(
"Error replaying events",
Expand Down Expand Up @@ -991,7 +991,7 @@ async def message_router():
if isinstance(message, JSONRPCResponse | JSONRPCError) and message.id is not None:
target_request_id = str(message.id)
# Extract related_request_id from meta if it exists
elif ( # pragma: no cover
elif (
session_message.metadata is not None
and isinstance(
session_message.metadata,
Expand All @@ -1015,10 +1015,10 @@ async def message_router():
try:
# Send both the message and the event ID
await self._request_streams[request_stream_id][0].send(EventMessage(message, event_id))
except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: no cover
except (anyio.BrokenResourceError, anyio.ClosedResourceError): # pragma: lax no cover
# Stream might be closed, remove from registry
self._request_streams.pop(request_stream_id, None)
else: # pragma: no cover
else:
logger.debug(
f"""Request stream {request_stream_id} not found
for message. Still processing message as the client
Expand Down
22 changes: 11 additions & 11 deletions src/mcp/server/transport_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, settings: TransportSecuritySettings | None = None):
# If not specified, disable DNS rebinding protection by default for backwards compatibility
self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False)

def _validate_host(self, host: str | None) -> bool: # pragma: no cover
def _validate_host(self, host: str | None) -> bool: # pragma: lax no cover
"""Validate the Host header against allowed values."""
if not host:
logger.warning("Missing Host header in request")
Expand All @@ -62,7 +62,7 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover
logger.warning(f"Invalid Host header: {host}")
return False

def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover
def _validate_origin(self, origin: str | None) -> bool: # pragma: lax no cover
"""Validate the Origin header against allowed values."""
# Origin can be absent for same-origin requests
if not origin:
Expand Down Expand Up @@ -103,14 +103,14 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res
if not self.settings.enable_dns_rebinding_protection:
return None

# Validate Host header # pragma: no cover
host = request.headers.get("host") # pragma: no cover
if not self._validate_host(host): # pragma: no cover
return Response("Invalid Host header", status_code=421) # pragma: no cover
# Validate Host header
host = request.headers.get("host") # pragma: lax no cover
if not self._validate_host(host): # pragma: lax no cover
return Response("Invalid Host header", status_code=421)

# Validate Origin header # pragma: no cover
origin = request.headers.get("origin") # pragma: no cover
if not self._validate_origin(origin): # pragma: no cover
return Response("Invalid Origin header", status_code=403) # pragma: no cover
# Validate Origin header
origin = request.headers.get("origin") # pragma: lax no cover
if not self._validate_origin(origin): # pragma: lax no cover
return Response("Invalid Origin header", status_code=403)

return None # pragma: no cover
return None # pragma: lax no cover
Loading
Loading