From f41dc7babc0d3d818557cde3d2c5f8c71983f638 Mon Sep 17 00:00:00 2001 From: Bartek Wolowiec Date: Mon, 16 Mar 2026 15:12:21 +0000 Subject: [PATCH 1/2] Subscribe post. --- src/a2a/client/transports/rest.py | 3 +- src/a2a/compat/v0_3/rest_adapter.py | 4 + src/a2a/compat/v0_3/rest_transport.py | 58 ++++++- src/a2a/server/apps/rest/rest_adapter.py | 4 + .../server/request_handlers/rest_handler.py | 10 +- tests/client/transports/test_rest_client.py | 8 +- tests/compat/v0_3/test_rest_handler.py | 38 +++++ tests/compat/v0_3/test_rest_transport.py | 145 +++++++++++++++++- .../server/apps/rest/test_rest_fastapi_app.py | 75 ++++++++- 9 files changed, 327 insertions(+), 18 deletions(-) diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 82e963142..e02290c0e 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -258,10 +258,11 @@ async def subscribe( ) -> AsyncGenerator[StreamResponse]: """Reconnects to get task updates.""" async for event in self._send_stream_request( - 'GET', + 'POST', f'/tasks/{request.id}:subscribe', request.tenant, context=context, + json=MessageToDict(request), ): yield event diff --git a/src/a2a/compat/v0_3/rest_adapter.py b/src/a2a/compat/v0_3/rest_adapter.py index fc7d67455..b0296e402 100644 --- a/src/a2a/compat/v0_3/rest_adapter.py +++ b/src/a2a/compat/v0_3/rest_adapter.py @@ -163,6 +163,10 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: self._handle_streaming_request, self.handler.on_subscribe_to_task, ), + ('/v1/tasks/{id}:subscribe', 'POST'): functools.partial( + self._handle_streaming_request, + self.handler.on_subscribe_to_task, + ), ('/v1/tasks/{id}', 'GET'): functools.partial( self._handle_request, self.handler.on_get_task ), diff --git a/src/a2a/compat/v0_3/rest_transport.py b/src/a2a/compat/v0_3/rest_transport.py index 7b04f9d70..9f6d4c19e 100644 --- a/src/a2a/compat/v0_3/rest_transport.py +++ b/src/a2a/compat/v0_3/rest_transport.py @@ -1,3 +1,4 @@ +import contextlib import json import logging @@ -68,6 +69,8 @@ def __init__( self.url = url.removesuffix('/') self.httpx_client = httpx_client self.agent_card = agent_card + self._subscribe_method = 'POST' + self._subscribe_retry_attempted = False async def send_message( self, @@ -273,13 +276,45 @@ async def subscribe( *, context: ClientCallContext | None = None, ) -> AsyncGenerator[StreamResponse]: - """Reconnects to get task updates.""" - async for event in self._send_stream_request( - 'GET', - f'/v1/tasks/{request.id}:subscribe', - context=context, - ): - yield event + """Reconnects to get task updates. + + This method implements backward compatibility logic for the subscribe + endpoint. It first attempts to use POST, which is the official method + for A2A subscribe endpoint. If the server returns 405 Method Not Allowed, + it falls back to GET and remembers this preference for future calls + on this transport instance. If both fail with 405, it will default back + to POST for next calls but will not retry again. + """ + if self._subscribe_method == 'POST': + json_body = MessageToDict(request, preserving_proto_field_name=True) + else: + json_body = None + + try: + async for event in self._send_stream_request( + self._subscribe_method, + f'/v1/tasks/{request.id}:subscribe', + context=context, + json=json_body, + ): + yield event + except A2AClientError as e: + # Check for 405 Method Not Allowed in the cause (httpx.HTTPStatusError) + cause = e.__cause__ + if ( + isinstance(cause, httpx.HTTPStatusError) + and cause.response.status_code == httpx.codes.METHOD_NOT_ALLOWED + ): + if self._subscribe_retry_attempted: + self._subscribe_method = 'POST' + raise + else: + self._subscribe_method = 'GET' + self._subscribe_retry_attempted = True + async for event in self.subscribe(request, context=context): + yield event + else: + raise async def get_extended_agent_card( self, @@ -311,7 +346,14 @@ async def close(self) -> None: def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn: """Handles HTTP status errors and raises the appropriate A2AError.""" try: - error_data = e.response.json() + with contextlib.suppress(httpx.StreamClosed): + e.response.read() + + try: + error_data = e.response.json() + except (json.JSONDecodeError, ValueError, httpx.ResponseNotRead): + error_data = {} + error_type = error_data.get('type') message = error_data.get('message', str(e)) diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py index 154409923..0ef56c149 100644 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ b/src/a2a/server/apps/rest/rest_adapter.py @@ -237,6 +237,10 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: self._handle_streaming_request, self.handler.on_subscribe_to_task, ), + ('/tasks/{id}:subscribe', 'POST'): functools.partial( + self._handle_streaming_request, + self.handler.on_subscribe_to_task, + ), ('/tasks/{id}', 'GET'): functools.partial( self._handle_request, self.handler.on_get_task ), diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py index b809dcb5b..4de704cdf 100644 --- a/src/a2a/server/request_handlers/rest_handler.py +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -159,9 +159,15 @@ async def on_subscribe_to_task( Yields: JSON serialized objects containing streaming events """ - task_id = request.path_params['id'] + params = SubscribeToTaskRequest() + if request.method == 'POST': + body = await request.body() + if body: + Parse(body, params) + + params.id = request.path_params['id'] async for event in self.request_handler.on_subscribe_to_task( - SubscribeToTaskRequest(id=task_id), context + params, context ): yield MessageToDict(proto_utils.to_stream_response(event)) diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index 57b197040..8416c6e1f 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -730,8 +730,12 @@ async def empty_aiter(): async for _ in method(request=request_obj): pass - # 4. Verify the URL + # 4. Verify the URL and method mock_aconnect_sse.assert_called_once() - args, _ = mock_aconnect_sse.call_args + args, kwargs = mock_aconnect_sse.call_args + # method is 2nd positional argument + assert args[1] == 'POST' + assert kwargs.get('json') == json_format.MessageToDict(request_obj) + # url is 3rd positional argument in aconnect_sse(client, method, url, ...) assert args[2] == f'http://agent.example.com/api{expected_path}' diff --git a/tests/compat/v0_3/test_rest_handler.py b/tests/compat/v0_3/test_rest_handler.py index 24e2b24fe..f0aa4e759 100644 --- a/tests/compat/v0_3/test_rest_handler.py +++ b/tests/compat/v0_3/test_rest_handler.py @@ -186,6 +186,44 @@ async def mock_stream(*args, **kwargs): ] +@pytest.mark.anyio +async def test_on_subscribe_to_task_post( + rest_handler, mock_request, mock_context +): + mock_request.path_params = {'id': 'task-1'} + mock_request.method = 'POST' + request_body = {'name': 'tasks/task-1'} + mock_request.body = AsyncMock( + return_value=json.dumps(request_body).encode('utf-8') + ) + + async def mock_stream(*args, **kwargs): + yield types_v03.SendStreamingMessageSuccessResponse( + id='req-1', + result=types_v03.Message( + message_id='msg-2', + role='agent', + parts=[types_v03.TextPart(text='Update')], + ), + ) + + rest_handler.handler03.on_subscribe_to_task = MagicMock( + side_effect=mock_stream + ) + + results = [ + chunk + async for chunk in rest_handler.on_subscribe_to_task( + mock_request, mock_context + ) + ] + + assert len(results) == 1 + rest_handler.handler03.on_subscribe_to_task.assert_called_once() + called_req = rest_handler.handler03.on_subscribe_to_task.call_args[0][0] + assert called_req.params.id == 'task-1' + + @pytest.mark.anyio async def test_get_push_notification(rest_handler, mock_request, mock_context): mock_request.path_params = {'id': 'task-1', 'push_id': 'push-1'} diff --git a/tests/compat/v0_3/test_rest_transport.py b/tests/compat/v0_3/test_rest_transport.py index 9bcf3dba3..695eb913d 100644 --- a/tests/compat/v0_3/test_rest_transport.py +++ b/tests/compat/v0_3/test_rest_transport.py @@ -1,4 +1,5 @@ import json + from unittest.mock import AsyncMock, MagicMock, patch import httpx @@ -232,14 +233,49 @@ async def mock_send_stream_request(*args, **kwargs): assert events[1] == StreamResponse(message=Message(message_id='msg-123')) +def create_405_error(): + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 405 + mock_response.json.return_value = { + 'type': 'MethodNotAllowed', + 'message': 'Method Not Allowed', + } + mock_request = MagicMock(spec=httpx.Request) + mock_request.url = 'http://example.com/v1/tasks/task-123:subscribe' + + status_error = httpx.HTTPStatusError( + '405 Method Not Allowed', request=mock_request, response=mock_response + ) + raise A2AClientError('HTTP Error 405') from status_error + + +def create_500_error(): + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 500 + mock_response.json.return_value = { + 'type': 'InternalError', + 'message': 'Internal Error', + } + mock_request = MagicMock(spec=httpx.Request) + + status_error = httpx.HTTPStatusError( + '500 Internal Error', request=mock_request, response=mock_response + ) + raise A2AClientError('HTTP Error 500') from status_error + + @pytest.mark.asyncio -async def test_compat_rest_transport_subscribe(transport): - async def mock_send_stream_request(*args, **kwargs): +async def test_compat_rest_transport_subscribe_post_works_no_retry(transport): + """Scenario: POST works, no retry.""" + + async def mock_stream(method, path, context=None, json=None): + assert method == 'POST' + assert json == {'id': 'task-123'} task = Task(id='task-123') task.status.message.role = Role.ROLE_AGENT yield StreamResponse(task=task) - transport._send_stream_request = mock_send_stream_request + transport._send_stream_request = mock_stream req = SubscribeToTaskRequest(id='task-123') events = [event async for event in transport.subscribe(req)] @@ -248,6 +284,109 @@ async def mock_send_stream_request(*args, **kwargs): expected_task = Task(id='task-123') expected_task.status.message.role = Role.ROLE_AGENT assert events[0] == StreamResponse(task=expected_task) + assert transport._subscribe_method == 'POST' + assert transport._subscribe_retry_attempted is False + + +@pytest.mark.asyncio +async def test_compat_rest_transport_subscribe_post_405_retry_get_success( + transport, +): + """Scenario: POST returns 405, automatic retry GET. Second call uses GET directly.""" + call_count = 0 + + async def mock_stream(method, path, context=None, json=None): + nonlocal call_count + call_count += 1 + if method == 'POST': + assert json == {'id': 'task-123'} + create_405_error() + if method == 'GET': + assert json is None + task = Task(id='task-123') + task.status.message.role = Role.ROLE_AGENT + yield StreamResponse(task=task) + + transport._send_stream_request = mock_stream + + req = SubscribeToTaskRequest(id='task-123') + events = [event async for event in transport.subscribe(req)] + + assert len(events) == 1 + assert call_count == 2 + assert transport._subscribe_method == 'GET' + assert transport._subscribe_retry_attempted is True + + # Second call should use GET directly + call_count = 0 + events = [event async for event in transport.subscribe(req)] + assert len(events) == 1 + assert call_count == 1 # Only GET called + assert transport._subscribe_method == 'GET' + + +@pytest.mark.asyncio +async def test_compat_rest_transport_subscribe_post_405_get_405_fails( + transport, +): + """Scenario: POST return 405, retry GET, return 405 - error. Second call is just POST.""" + call_count = 0 + + async def mock_stream(method, path, context=None, json=None): + nonlocal call_count + call_count += 1 + if method == 'POST': + assert json == {'id': 'task-123'} + elif method == 'GET': + assert json is None + # To make it an async generator even when it raises + if False: + yield + create_405_error() + + transport._send_stream_request = mock_stream + + req = SubscribeToTaskRequest(id='task-123') + with pytest.raises(A2AClientError) as exc_info: + [event async for event in transport.subscribe(req)] + + assert '405' in str(exc_info.value) + assert call_count == 2 # Tried POST then GET + assert transport._subscribe_method == 'POST' + assert transport._subscribe_retry_attempted is True + + # Second call should try POST directly and fail without retry + call_count = 0 + with pytest.raises(A2AClientError): + [event async for event in transport.subscribe(req)] + assert call_count == 1 + assert transport._subscribe_method == 'POST' + + +@pytest.mark.asyncio +async def test_compat_rest_transport_subscribe_post_500_no_retry(transport): + """Scenario: POST return 500, no automatic retry.""" + call_count = 0 + + async def mock_stream(method, path, context=None, json=None): + nonlocal call_count + call_count += 1 + assert method == 'POST' + assert json == {'id': 'task-123'} + if False: + yield + create_500_error() + + transport._send_stream_request = mock_stream + + req = SubscribeToTaskRequest(id='task-123') + with pytest.raises(A2AClientError) as exc_info: + [event async for event in transport.subscribe(req)] + + assert '500' in str(exc_info.value) + assert call_count == 1 # No retry on 500 + assert transport._subscribe_method == 'POST' + assert transport._subscribe_retry_attempted is False def test_compat_rest_transport_handle_http_error(transport): diff --git a/tests/server/apps/rest/test_rest_fastapi_app.py b/tests/server/apps/rest/test_rest_fastapi_app.py index 382ebea13..76179bd5c 100644 --- a/tests/server/apps/rest/test_rest_fastapi_app.py +++ b/tests/server/apps/rest/test_rest_fastapi_app.py @@ -37,9 +37,9 @@ async def agent_card() -> AgentCard: mock_agent_card = MagicMock(spec=AgentCard) mock_agent_card.url = 'http://mockurl.com' - # Mock the capabilities object with streaming disabled + # Mock the capabilities object with streaming enabled mock_capabilities = MagicMock() - mock_capabilities.streaming = False + mock_capabilities.streaming = True mock_capabilities.push_notifications = True mock_capabilities.extended_agent_card = True mock_agent_card.capabilities = mock_capabilities @@ -405,6 +405,67 @@ async def mock_stream_response(): assert data_lines == expected_data_lines +@pytest.mark.anyio +async def test_subscribe_to_task_get( + streaming_client: AsyncClient, request_handler: MagicMock +) -> None: + """Test that GET /tasks/{id}:subscribe works.""" + + async def mock_stream_response(): + yield Task( + id='task-1', + context_id='ctx-1', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + + request_handler.on_subscribe_to_task.return_value = mock_stream_response() + + response = await streaming_client.get( + '/tasks/task-1:subscribe', + headers={'Accept': 'text/event-stream'}, + ) + + response.raise_for_status() + assert response.status_code == 200 + + # Verify handler call + request_handler.on_subscribe_to_task.assert_called_once() + args, _ = request_handler.on_subscribe_to_task.call_args + assert args[0].id == 'task-1' + + +@pytest.mark.anyio +async def test_subscribe_to_task_post( + streaming_client: AsyncClient, request_handler: MagicMock +) -> None: + """Test that POST /tasks/{id}:subscribe works and parses body.""" + + async def mock_stream_response(): + yield Task( + id='task-1', + context_id='ctx-1', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + + request_handler.on_subscribe_to_task.return_value = mock_stream_response() + + request = a2a_pb2.SubscribeToTaskRequest(id='task-1') + + response = await streaming_client.post( + '/tasks/task-1:subscribe', + json=json_format.MessageToDict(request), + headers={'Accept': 'text/event-stream'}, + ) + + response.raise_for_status() + assert response.status_code == 200 + + # Verify handler call + request_handler.on_subscribe_to_task.assert_called_once() + args, _ = request_handler.on_subscribe_to_task.call_args + assert args[0].id == 'task-1' + + @pytest.mark.anyio async def test_streaming_endpoint_with_invalid_content_type( streaming_client: AsyncClient, request_handler: MagicMock @@ -493,6 +554,14 @@ class TestTenantExtraction: @pytest.fixture(autouse=True) def configure_mocks(self, request_handler: MagicMock) -> None: # Setup default return values for all handlers + async def mock_stream(*args, **kwargs): + if False: + yield + + request_handler.on_subscribe_to_task.side_effect = ( + lambda *args, **kwargs: mock_stream() + ) + request_handler.on_message_send.return_value = Message( message_id='test', role=Role.ROLE_AGENT, @@ -525,6 +594,8 @@ def extended_card_modifier(self) -> MagicMock: [ ('/message:send', 'POST', 'on_message_send', {'message': {}}), ('/tasks/1:cancel', 'POST', 'on_cancel_task', None), + ('/tasks/1:subscribe', 'GET', 'on_subscribe_to_task', None), + ('/tasks/1:subscribe', 'POST', 'on_subscribe_to_task', {'id': '1'}), ('/tasks/1', 'GET', 'on_get_task', None), ('/tasks', 'GET', 'on_list_tasks', None), ( From c842e3ef76a2ba262786c88cd52067c3b04abd1e Mon Sep 17 00:00:00 2001 From: Bartek Wolowiec Date: Tue, 17 Mar 2026 09:23:22 +0000 Subject: [PATCH 2/2] Subscribe post fix v2. --- src/a2a/client/transports/rest.py | 1 - src/a2a/compat/v0_3/rest_transport.py | 23 ++-- .../server/request_handlers/rest_handler.py | 10 +- tests/client/transports/test_rest_client.py | 5 +- tests/compat/v0_3/test_rest_transport.py | 101 ++++++++++++++---- .../server/apps/rest/test_rest_fastapi_app.py | 7 +- 6 files changed, 99 insertions(+), 48 deletions(-) diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index e02290c0e..ed40d31c7 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -262,7 +262,6 @@ async def subscribe( f'/tasks/{request.id}:subscribe', request.tenant, context=context, - json=MessageToDict(request), ): yield event diff --git a/src/a2a/compat/v0_3/rest_transport.py b/src/a2a/compat/v0_3/rest_transport.py index 9f6d4c19e..0ba38538d 100644 --- a/src/a2a/compat/v0_3/rest_transport.py +++ b/src/a2a/compat/v0_3/rest_transport.py @@ -64,13 +64,14 @@ def __init__( httpx_client: httpx.AsyncClient, agent_card: AgentCard | None, url: str, + subscribe_method_override: str | None = None, ): """Initializes the CompatRestTransport.""" self.url = url.removesuffix('/') self.httpx_client = httpx_client self.agent_card = agent_card - self._subscribe_method = 'POST' - self._subscribe_retry_attempted = False + self._subscribe_method_override = subscribe_method_override + self._subscribe_auto_method_override = subscribe_method_override is None async def send_message( self, @@ -285,17 +286,12 @@ async def subscribe( on this transport instance. If both fail with 405, it will default back to POST for next calls but will not retry again. """ - if self._subscribe_method == 'POST': - json_body = MessageToDict(request, preserving_proto_field_name=True) - else: - json_body = None - + subscribe_method = self._subscribe_method_override or 'POST' try: async for event in self._send_stream_request( - self._subscribe_method, + subscribe_method, f'/v1/tasks/{request.id}:subscribe', context=context, - json=json_body, ): yield event except A2AClientError as e: @@ -305,12 +301,13 @@ async def subscribe( isinstance(cause, httpx.HTTPStatusError) and cause.response.status_code == httpx.codes.METHOD_NOT_ALLOWED ): - if self._subscribe_retry_attempted: - self._subscribe_method = 'POST' + if self._subscribe_method_override: + if self._subscribe_auto_method_override: + self._subscribe_auto_method_override = False + self._subscribe_method_override = 'POST' raise else: - self._subscribe_method = 'GET' - self._subscribe_retry_attempted = True + self._subscribe_method_override = 'GET' async for event in self.subscribe(request, context=context): yield event else: diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py index 4de704cdf..b809dcb5b 100644 --- a/src/a2a/server/request_handlers/rest_handler.py +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -159,15 +159,9 @@ async def on_subscribe_to_task( Yields: JSON serialized objects containing streaming events """ - params = SubscribeToTaskRequest() - if request.method == 'POST': - body = await request.body() - if body: - Parse(body, params) - - params.id = request.path_params['id'] + task_id = request.path_params['id'] async for event in self.request_handler.on_subscribe_to_task( - params, context + SubscribeToTaskRequest(id=task_id), context ): yield MessageToDict(proto_utils.to_stream_response(event)) diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index 8416c6e1f..7ed8522fb 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -735,7 +735,10 @@ async def empty_aiter(): args, kwargs = mock_aconnect_sse.call_args # method is 2nd positional argument assert args[1] == 'POST' - assert kwargs.get('json') == json_format.MessageToDict(request_obj) + if method_name == 'subscribe': + assert kwargs.get('json') is None + else: + assert kwargs.get('json') == json_format.MessageToDict(request_obj) # url is 3rd positional argument in aconnect_sse(client, method, url, ...) assert args[2] == f'http://agent.example.com/api{expected_path}' diff --git a/tests/compat/v0_3/test_rest_transport.py b/tests/compat/v0_3/test_rest_transport.py index 695eb913d..4be7cd425 100644 --- a/tests/compat/v0_3/test_rest_transport.py +++ b/tests/compat/v0_3/test_rest_transport.py @@ -270,7 +270,7 @@ async def test_compat_rest_transport_subscribe_post_works_no_retry(transport): async def mock_stream(method, path, context=None, json=None): assert method == 'POST' - assert json == {'id': 'task-123'} + assert json is None task = Task(id='task-123') task.status.message.role = Role.ROLE_AGENT yield StreamResponse(task=task) @@ -284,8 +284,7 @@ async def mock_stream(method, path, context=None, json=None): expected_task = Task(id='task-123') expected_task.status.message.role = Role.ROLE_AGENT assert events[0] == StreamResponse(task=expected_task) - assert transport._subscribe_method == 'POST' - assert transport._subscribe_retry_attempted is False + assert transport._subscribe_method_override is None @pytest.mark.asyncio @@ -299,7 +298,7 @@ async def mock_stream(method, path, context=None, json=None): nonlocal call_count call_count += 1 if method == 'POST': - assert json == {'id': 'task-123'} + assert json is None create_405_error() if method == 'GET': assert json is None @@ -314,15 +313,14 @@ async def mock_stream(method, path, context=None, json=None): assert len(events) == 1 assert call_count == 2 - assert transport._subscribe_method == 'GET' - assert transport._subscribe_retry_attempted is True + assert transport._subscribe_method_override == 'GET' # Second call should use GET directly call_count = 0 events = [event async for event in transport.subscribe(req)] assert len(events) == 1 assert call_count == 1 # Only GET called - assert transport._subscribe_method == 'GET' + assert transport._subscribe_method_override == 'GET' @pytest.mark.asyncio @@ -330,13 +328,13 @@ async def test_compat_rest_transport_subscribe_post_405_get_405_fails( transport, ): """Scenario: POST return 405, retry GET, return 405 - error. Second call is just POST.""" - call_count = 0 + + method_count = {} async def mock_stream(method, path, context=None, json=None): - nonlocal call_count - call_count += 1 + method_count[method] = method_count.get(method, 0) + 1 if method == 'POST': - assert json == {'id': 'task-123'} + assert json is None elif method == 'GET': assert json is None # To make it an async generator even when it raises @@ -351,16 +349,16 @@ async def mock_stream(method, path, context=None, json=None): [event async for event in transport.subscribe(req)] assert '405' in str(exc_info.value) - assert call_count == 2 # Tried POST then GET - assert transport._subscribe_method == 'POST' - assert transport._subscribe_retry_attempted is True + assert transport._subscribe_method_override == 'POST' + assert method_count == {'POST': 1, 'GET': 1} + assert transport._subscribe_auto_method_override is False # Second call should try POST directly and fail without retry - call_count = 0 with pytest.raises(A2AClientError): [event async for event in transport.subscribe(req)] - assert call_count == 1 - assert transport._subscribe_method == 'POST' + assert transport._subscribe_auto_method_override is False + assert transport._subscribe_method_override == 'POST' + assert method_count == {'POST': 2, 'GET': 1} @pytest.mark.asyncio @@ -372,7 +370,7 @@ async def mock_stream(method, path, context=None, json=None): nonlocal call_count call_count += 1 assert method == 'POST' - assert json == {'id': 'task-123'} + assert json is None if False: yield create_500_error() @@ -385,8 +383,71 @@ async def mock_stream(method, path, context=None, json=None): assert '500' in str(exc_info.value) assert call_count == 1 # No retry on 500 - assert transport._subscribe_method == 'POST' - assert transport._subscribe_retry_attempted is False + assert transport._subscribe_method_override is None + + +@pytest.mark.asyncio +async def test_compat_rest_transport_subscribe_method_override_avoids_retry_get( + mock_httpx_client, agent_card +): + """Scenario: Init with GET override, server returns 405, no automatic retry.""" + transport = CompatRestTransport( + httpx_client=mock_httpx_client, + agent_card=agent_card, + url='http://example.com', + subscribe_method_override='GET', + ) + call_count = 0 + + async def mock_stream(method, path, context=None, json=None): + nonlocal call_count + call_count += 1 + assert method == 'GET' + assert json is None + if False: + yield + create_405_error() + + transport._send_stream_request = mock_stream + + req = SubscribeToTaskRequest(id='task-123') + with pytest.raises(A2AClientError) as exc_info: + [event async for event in transport.subscribe(req)] + + assert '405' in str(exc_info.value) + assert call_count == 1 + + +@pytest.mark.asyncio +async def test_compat_rest_transport_subscribe_method_override_avoids_retry_post( + mock_httpx_client, agent_card +): + """Scenario: Init with POST override, server returns 405, no automatic retry.""" + transport = CompatRestTransport( + httpx_client=mock_httpx_client, + agent_card=agent_card, + url='http://example.com', + subscribe_method_override='POST', + ) + call_count = 0 + + async def mock_stream(method, path, context=None, json=None): + nonlocal call_count + call_count += 1 + assert method == 'POST' + assert json is None + if False: + yield + create_405_error() + + transport._send_stream_request = mock_stream + + req = SubscribeToTaskRequest(id='task-123') + with pytest.raises(A2AClientError) as exc_info: + [event async for event in transport.subscribe(req)] + + assert '405' in str(exc_info.value) + assert call_count == 1 def test_compat_rest_transport_handle_http_error(transport): diff --git a/tests/server/apps/rest/test_rest_fastapi_app.py b/tests/server/apps/rest/test_rest_fastapi_app.py index 76179bd5c..c8510023a 100644 --- a/tests/server/apps/rest/test_rest_fastapi_app.py +++ b/tests/server/apps/rest/test_rest_fastapi_app.py @@ -438,7 +438,7 @@ async def mock_stream_response(): async def test_subscribe_to_task_post( streaming_client: AsyncClient, request_handler: MagicMock ) -> None: - """Test that POST /tasks/{id}:subscribe works and parses body.""" + """Test that POST /tasks/{id}:subscribe works.""" async def mock_stream_response(): yield Task( @@ -449,11 +449,8 @@ async def mock_stream_response(): request_handler.on_subscribe_to_task.return_value = mock_stream_response() - request = a2a_pb2.SubscribeToTaskRequest(id='task-1') - response = await streaming_client.post( '/tasks/task-1:subscribe', - json=json_format.MessageToDict(request), headers={'Accept': 'text/event-stream'}, ) @@ -595,7 +592,7 @@ def extended_card_modifier(self) -> MagicMock: ('/message:send', 'POST', 'on_message_send', {'message': {}}), ('/tasks/1:cancel', 'POST', 'on_cancel_task', None), ('/tasks/1:subscribe', 'GET', 'on_subscribe_to_task', None), - ('/tasks/1:subscribe', 'POST', 'on_subscribe_to_task', {'id': '1'}), + ('/tasks/1:subscribe', 'POST', 'on_subscribe_to_task', None), ('/tasks/1', 'GET', 'on_get_task', None), ('/tasks', 'GET', 'on_list_tasks', None), (