From 3cc16ef594c118f98f44341826f22529ac0c3966 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 13 Mar 2026 16:25:53 +0000 Subject: [PATCH] refactor: connect-first stream lifecycle for sse and streamable_http MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Apply the websocket_client pattern from #2266 to the other two transports: establish the network connection first, create memory streams only after it succeeds, then own all four stream ends plus the task group in a single merged async with as the innermost scope. This eliminates the try/finally + four explicit aclose() calls. If the connection fails, no streams were ever created — nothing to clean up. The multi-CM async with unwinds in reverse order on exit, so tg.__aexit__ waits for cancelled tasks to finish before any stream end closes. streamable_http has one outer async with (the AsyncExitStack for the conditional httpx client), which is clean on all Python versions. sse has two unavoidable outer layers (httpx_client_factory feeds into aconnect_sse — data dependency, can't merge). On 3.14, coverage.py's static analysis sees a phantom branch on the innermost multi-CM line: each __aexit__ gets a POP_JUMP_IF_TRUE for 'did it suppress the exception?', which memory streams never do. One targeted pragma on the line we own, documented inline. Behavior change: sse_client's ConnectError is no longer wrapped in an ExceptionGroup, since the task group is never entered when the connection fails. Updated the regression test to match. --- src/mcp/client/sse.py | 207 +++++++++--------- src/mcp/client/streamable_http.py | 71 +++--- tests/client/test_transport_stream_cleanup.py | 16 +- 3 files changed, 144 insertions(+), 150 deletions(-) diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 972efce58..a9b4bfd27 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -57,108 +57,107 @@ async def sse_client( write_stream: MemoryObjectSendStream[SessionMessage] write_stream_reader: MemoryObjectReceiveStream[SessionMessage] - read_stream_writer, read_stream = anyio.create_memory_object_stream(0) - write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - - async with anyio.create_task_group() as tg: - try: - logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") - async with httpx_client_factory( - headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout) - ) as client: - async with aconnect_sse( - client, - "GET", - url, - ) as event_source: - event_source.response.raise_for_status() - logger.debug("SSE connection established") - - async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): - try: - async for sse in event_source.aiter_sse(): # pragma: no branch - logger.debug(f"Received SSE event: {sse.event}") - match sse.event: - case "endpoint": - endpoint_url = urljoin(url, sse.data) - logger.debug(f"Received endpoint URL: {endpoint_url}") - - url_parsed = urlparse(url) - endpoint_parsed = urlparse(endpoint_url) - if ( # pragma: no cover - url_parsed.netloc != endpoint_parsed.netloc - or url_parsed.scheme != endpoint_parsed.scheme - ): - error_msg = ( # pragma: no cover - f"Endpoint origin does not match connection origin: {endpoint_url}" - ) - logger.error(error_msg) # pragma: no cover - raise ValueError(error_msg) # pragma: no cover - - if on_session_created: - session_id = _extract_session_id_from_endpoint(endpoint_url) - if session_id: - on_session_created(session_id) - - task_status.started(endpoint_url) - - case "message": - # Skip empty data (keep-alive pings) - if not sse.data: - continue - try: - message = types.jsonrpc_message_adapter.validate_json( - sse.data, by_name=False - ) - logger.debug(f"Received server message: {message}") - except Exception as exc: # pragma: no cover - logger.exception("Error parsing server message") # pragma: no cover - await read_stream_writer.send(exc) # pragma: no cover - continue # pragma: no cover - - session_message = SessionMessage(message) - await read_stream_writer.send(session_message) - case _: # pragma: no cover - logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover - except SSEError as sse_exc: # pragma: lax no cover - logger.exception("Encountered SSE exception") - raise sse_exc - except Exception as exc: # pragma: lax no cover - logger.exception("Error in sse_reader") - await read_stream_writer.send(exc) - finally: - await read_stream_writer.aclose() - - async def post_writer(endpoint_url: str): - try: - async with write_stream_reader: - async for session_message in write_stream_reader: - logger.debug(f"Sending client message: {session_message}") - response = await client.post( - endpoint_url, - json=session_message.message.model_dump( - by_alias=True, - mode="json", - exclude_unset=True, - ), + logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}") + async with httpx_client_factory( + headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout) + ) as client: + async with aconnect_sse( + client, + "GET", + url, + ) as event_source: + event_source.response.raise_for_status() + logger.debug("SSE connection established") + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + + async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): + try: + async for sse in event_source.aiter_sse(): # pragma: no branch + logger.debug(f"Received SSE event: {sse.event}") + match sse.event: + case "endpoint": + endpoint_url = urljoin(url, sse.data) + logger.debug(f"Received endpoint URL: {endpoint_url}") + + url_parsed = urlparse(url) + endpoint_parsed = urlparse(endpoint_url) + if ( # pragma: no cover + url_parsed.netloc != endpoint_parsed.netloc + or url_parsed.scheme != endpoint_parsed.scheme + ): + error_msg = ( # pragma: no cover + f"Endpoint origin does not match connection origin: {endpoint_url}" ) - response.raise_for_status() - logger.debug(f"Client message sent successfully: {response.status_code}") - except Exception: # pragma: lax no cover - logger.exception("Error in post_writer") - finally: - await write_stream.aclose() - - endpoint_url = await tg.start(sse_reader) - logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}") - tg.start_soon(post_writer, endpoint_url) - - try: - yield read_stream, write_stream - finally: - tg.cancel_scope.cancel() - finally: - await read_stream_writer.aclose() - await write_stream.aclose() - await read_stream.aclose() - await write_stream_reader.aclose() + logger.error(error_msg) # pragma: no cover + raise ValueError(error_msg) # pragma: no cover + + if on_session_created: + session_id = _extract_session_id_from_endpoint(endpoint_url) + if session_id: + on_session_created(session_id) + + task_status.started(endpoint_url) + + case "message": + # Skip empty data (keep-alive pings) + if not sse.data: + continue + try: + message = types.jsonrpc_message_adapter.validate_json(sse.data, by_name=False) + logger.debug(f"Received server message: {message}") + except Exception as exc: # pragma: no cover + logger.exception("Error parsing server message") # pragma: no cover + await read_stream_writer.send(exc) # pragma: no cover + continue # pragma: no cover + + session_message = SessionMessage(message) + await read_stream_writer.send(session_message) + case _: # pragma: no cover + logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover + except SSEError as sse_exc: # pragma: lax no cover + logger.exception("Encountered SSE exception") + raise sse_exc + except Exception as exc: # pragma: lax no cover + logger.exception("Error in sse_reader") + await read_stream_writer.send(exc) + finally: + await read_stream_writer.aclose() + + async def post_writer(endpoint_url: str): + try: + async with write_stream_reader: + async for session_message in write_stream_reader: + logger.debug(f"Sending client message: {session_message}") + response = await client.post( + endpoint_url, + json=session_message.message.model_dump( + by_alias=True, + mode="json", + exclude_unset=True, + ), + ) + response.raise_for_status() + logger.debug(f"Client message sent successfully: {response.status_code}") + except Exception: # pragma: lax no cover + logger.exception("Error in post_writer") + finally: + await write_stream.aclose() + + # On Python 3.14, coverage.py reports a phantom branch arc on this + # line (->yield) when nested two async-with levels deep. The branch + # is the unreachable "did __aexit__ suppress?" arm for memory streams. + async with ( # pragma: no branch + read_stream_writer, + read_stream, + write_stream, + write_stream_reader, + anyio.create_task_group() as tg, + ): + endpoint_url = await tg.start(sse_reader) + logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}") + tg.start_soon(post_writer, endpoint_url) + + yield read_stream, write_stream + tg.cancel_scope.cancel() diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 3416bbc81..4e873364a 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -533,9 +533,6 @@ async def streamable_http_client( Example: See examples/snippets/clients/ for usage patterns. """ - read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) - write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) - # Determine if we need to create and manage the client client_provided = http_client is not None client = http_client @@ -546,36 +543,40 @@ async def streamable_http_client( transport = StreamableHTTPTransport(url) - async with anyio.create_task_group() as tg: - try: - logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") - - async with contextlib.AsyncExitStack() as stack: - # Only manage client lifecycle if we created it - if not client_provided: - await stack.enter_async_context(client) - - def start_get_stream() -> None: - tg.start_soon(transport.handle_get_stream, client, read_stream_writer) - - tg.start_soon( - transport.post_writer, - client, - write_stream_reader, - read_stream_writer, - write_stream, - start_get_stream, - tg, - ) + logger.debug(f"Connecting to StreamableHTTP endpoint: {url}") + + async with contextlib.AsyncExitStack() as stack: + # Only manage client lifecycle if we created it + if not client_provided: + await stack.enter_async_context(client) + + read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0) + + async with ( + read_stream_writer, + read_stream, + write_stream, + write_stream_reader, + anyio.create_task_group() as tg, + ): + + def start_get_stream() -> None: + tg.start_soon(transport.handle_get_stream, client, read_stream_writer) + + tg.start_soon( + transport.post_writer, + client, + write_stream_reader, + read_stream_writer, + write_stream, + start_get_stream, + tg, + ) - try: - yield read_stream, write_stream - finally: - if transport.session_id and terminate_on_close: - await transport.terminate_session(client) - tg.cancel_scope.cancel() - finally: - await read_stream_writer.aclose() - await write_stream.aclose() - await read_stream.aclose() - await write_stream_reader.aclose() + try: + yield read_stream, write_stream + finally: + if transport.session_id and terminate_on_close: + await transport.terminate_session(client) + tg.cancel_scope.cancel() diff --git a/tests/client/test_transport_stream_cleanup.py b/tests/client/test_transport_stream_cleanup.py index 631b0fff2..f80232c9d 100644 --- a/tests/client/test_transport_stream_cleanup.py +++ b/tests/client/test_transport_stream_cleanup.py @@ -58,23 +58,17 @@ def hook(args: "sys.UnraisableHookArgs") -> None: # pragma: no cover @pytest.mark.anyio async def test_sse_client_closes_all_streams_on_connection_error(free_tcp_port: int) -> None: - """sse_client must close all 4 stream ends when the connection fails. + """sse_client creates streams only after the SSE connection succeeds, so a + ConnectError propagates directly with nothing to leak. - Before the fix, only read_stream_writer and write_stream were closed in - the finally block. read_stream and write_stream_reader were leaked. + Before the fix, streams were created before connecting and only 2 of 4 were + closed in the finally block. """ with _assert_no_memory_stream_leak(): - # sse_client enters a task group BEFORE connecting, so anyio wraps the - # ConnectError from aconnect_sse in an ExceptionGroup. - with pytest.raises(Exception) as exc_info: # noqa: B017 + with pytest.raises(httpx.ConnectError): async with sse_client(f"http://127.0.0.1:{free_tcp_port}/sse"): pytest.fail("should not reach here") # pragma: no cover - assert exc_info.group_contains(httpx.ConnectError) - # exc_info holds the traceback → holds frame locals → keeps leaked - # streams alive. Must drop it before gc.collect() can detect a leak. - del exc_info - @pytest.mark.anyio async def test_streamable_http_client_closes_all_streams_on_exit() -> None: