Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 103 additions & 104 deletions src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,108 +57,107 @@ async def sse_client(
write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]

read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

async with anyio.create_task_group() as tg:
try:
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with httpx_client_factory(
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
) as client:
async with aconnect_sse(
client,
"GET",
url,
) as event_source:
event_source.response.raise_for_status()
logger.debug("SSE connection established")

async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
try:
async for sse in event_source.aiter_sse(): # pragma: no branch
logger.debug(f"Received SSE event: {sse.event}")
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.debug(f"Received endpoint URL: {endpoint_url}")

url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if ( # pragma: no cover
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = ( # pragma: no cover
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
logger.error(error_msg) # pragma: no cover
raise ValueError(error_msg) # pragma: no cover

if on_session_created:
session_id = _extract_session_id_from_endpoint(endpoint_url)
if session_id:
on_session_created(session_id)

task_status.started(endpoint_url)

case "message":
# Skip empty data (keep-alive pings)
if not sse.data:
continue
try:
message = types.jsonrpc_message_adapter.validate_json(
sse.data, by_name=False
)
logger.debug(f"Received server message: {message}")
except Exception as exc: # pragma: no cover
logger.exception("Error parsing server message") # pragma: no cover
await read_stream_writer.send(exc) # pragma: no cover
continue # pragma: no cover

session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
case _: # pragma: no cover
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
except SSEError as sse_exc: # pragma: lax no cover
logger.exception("Encountered SSE exception")
raise sse_exc
except Exception as exc: # pragma: lax no cover
logger.exception("Error in sse_reader")
await read_stream_writer.send(exc)
finally:
await read_stream_writer.aclose()

async def post_writer(endpoint_url: str):
try:
async with write_stream_reader:
async for session_message in write_stream_reader:
logger.debug(f"Sending client message: {session_message}")
response = await client.post(
endpoint_url,
json=session_message.message.model_dump(
by_alias=True,
mode="json",
exclude_unset=True,
),
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with httpx_client_factory(
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
) as client:
async with aconnect_sse(
client,
"GET",
url,
) as event_source:
Comment on lines +64 to +68
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
async with aconnect_sse(
client,
"GET",
url,
) as event_source:
async with aconnect_sse(client, "GET", url) as event_source:

event_source.response.raise_for_status()
logger.debug("SSE connection established")

read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
try:
async for sse in event_source.aiter_sse(): # pragma: no branch
logger.debug(f"Received SSE event: {sse.event}")
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.debug(f"Received endpoint URL: {endpoint_url}")

url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if ( # pragma: no cover
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = ( # pragma: no cover
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
except Exception: # pragma: lax no cover
logger.exception("Error in post_writer")
finally:
await write_stream.aclose()

endpoint_url = await tg.start(sse_reader)
logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}")
tg.start_soon(post_writer, endpoint_url)

try:
yield read_stream, write_stream
finally:
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
await read_stream.aclose()
await write_stream_reader.aclose()
logger.error(error_msg) # pragma: no cover
raise ValueError(error_msg) # pragma: no cover

if on_session_created:
session_id = _extract_session_id_from_endpoint(endpoint_url)
if session_id:
on_session_created(session_id)

task_status.started(endpoint_url)

case "message":
# Skip empty data (keep-alive pings)
if not sse.data:
continue
try:
message = types.jsonrpc_message_adapter.validate_json(sse.data, by_name=False)
logger.debug(f"Received server message: {message}")
except Exception as exc: # pragma: no cover
logger.exception("Error parsing server message") # pragma: no cover
await read_stream_writer.send(exc) # pragma: no cover
continue # pragma: no cover

session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
case _: # pragma: no cover
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
except SSEError as sse_exc: # pragma: lax no cover
logger.exception("Encountered SSE exception")
raise sse_exc
except Exception as exc: # pragma: lax no cover
logger.exception("Error in sse_reader")
await read_stream_writer.send(exc)
finally:
await read_stream_writer.aclose()
Comment on lines +125 to +126
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still neecessary?


async def post_writer(endpoint_url: str):
try:
async with write_stream_reader:
async for session_message in write_stream_reader:
logger.debug(f"Sending client message: {session_message}")
response = await client.post(
endpoint_url,
json=session_message.message.model_dump(
by_alias=True,
mode="json",
exclude_unset=True,
),
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
except Exception: # pragma: lax no cover
logger.exception("Error in post_writer")
finally:
await write_stream.aclose()

# On Python 3.14, coverage.py reports a phantom branch arc on this
# line (->yield) when nested two async-with levels deep. The branch
# is the unreachable "did __aexit__ suppress?" arm for memory streams.
async with ( # pragma: no branch
read_stream_writer,
read_stream,
write_stream,
write_stream_reader,
anyio.create_task_group() as tg,
):
endpoint_url = await tg.start(sse_reader)
logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}")
tg.start_soon(post_writer, endpoint_url)

yield read_stream, write_stream
tg.cancel_scope.cancel()
71 changes: 36 additions & 35 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,9 +533,6 @@ async def streamable_http_client(
Example:
See examples/snippets/clients/ for usage patterns.
"""
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)

# Determine if we need to create and manage the client
client_provided = http_client is not None
client = http_client
Expand All @@ -546,36 +543,40 @@ async def streamable_http_client(

transport = StreamableHTTPTransport(url)

async with anyio.create_task_group() as tg:
try:
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")

async with contextlib.AsyncExitStack() as stack:
# Only manage client lifecycle if we created it
if not client_provided:
await stack.enter_async_context(client)

def start_get_stream() -> None:
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)

tg.start_soon(
transport.post_writer,
client,
write_stream_reader,
read_stream_writer,
write_stream,
start_get_stream,
tg,
)
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")

async with contextlib.AsyncExitStack() as stack:
# Only manage client lifecycle if we created it
if not client_provided:
await stack.enter_async_context(client)

read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)

async with (
read_stream_writer,
read_stream,
write_stream,
write_stream_reader,
anyio.create_task_group() as tg,
):

def start_get_stream() -> None:
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)

tg.start_soon(
transport.post_writer,
client,
write_stream_reader,
read_stream_writer,
write_stream,
start_get_stream,
tg,
)

try:
yield read_stream, write_stream
finally:
if transport.session_id and terminate_on_close:
await transport.terminate_session(client)
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
await read_stream.aclose()
await write_stream_reader.aclose()
try:
yield read_stream, write_stream
finally:
if transport.session_id and terminate_on_close:
await transport.terminate_session(client)
tg.cancel_scope.cancel()
16 changes: 5 additions & 11 deletions tests/client/test_transport_stream_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,23 +58,17 @@ def hook(args: "sys.UnraisableHookArgs") -> None: # pragma: no cover

@pytest.mark.anyio
async def test_sse_client_closes_all_streams_on_connection_error(free_tcp_port: int) -> None:
"""sse_client must close all 4 stream ends when the connection fails.
"""sse_client creates streams only after the SSE connection succeeds, so a
ConnectError propagates directly with nothing to leak.

Before the fix, only read_stream_writer and write_stream were closed in
the finally block. read_stream and write_stream_reader were leaked.
Before the fix, streams were created before connecting and only 2 of 4 were
closed in the finally block.
"""
with _assert_no_memory_stream_leak():
# sse_client enters a task group BEFORE connecting, so anyio wraps the
# ConnectError from aconnect_sse in an ExceptionGroup.
with pytest.raises(Exception) as exc_info: # noqa: B017
with pytest.raises(httpx.ConnectError):
async with sse_client(f"http://127.0.0.1:{free_tcp_port}/sse"):
pytest.fail("should not reach here") # pragma: no cover

assert exc_info.group_contains(httpx.ConnectError)
# exc_info holds the traceback → holds frame locals → keeps leaked
# streams alive. Must drop it before gc.collect() can detect a leak.
del exc_info


@pytest.mark.anyio
async def test_streamable_http_client_closes_all_streams_on_exit() -> None:
Expand Down
Loading