diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 82e963142..ed40d31c7 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -258,7 +258,7 @@ async def subscribe( ) -> AsyncGenerator[StreamResponse]: """Reconnects to get task updates.""" async for event in self._send_stream_request( - 'GET', + 'POST', f'/tasks/{request.id}:subscribe', request.tenant, context=context, diff --git a/src/a2a/compat/v0_3/rest_adapter.py b/src/a2a/compat/v0_3/rest_adapter.py index fc7d67455..b0296e402 100644 --- a/src/a2a/compat/v0_3/rest_adapter.py +++ b/src/a2a/compat/v0_3/rest_adapter.py @@ -163,6 +163,10 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: self._handle_streaming_request, self.handler.on_subscribe_to_task, ), + ('/v1/tasks/{id}:subscribe', 'POST'): functools.partial( + self._handle_streaming_request, + self.handler.on_subscribe_to_task, + ), ('/v1/tasks/{id}', 'GET'): functools.partial( self._handle_request, self.handler.on_get_task ), diff --git a/src/a2a/compat/v0_3/rest_transport.py b/src/a2a/compat/v0_3/rest_transport.py index 7b04f9d70..0ba38538d 100644 --- a/src/a2a/compat/v0_3/rest_transport.py +++ b/src/a2a/compat/v0_3/rest_transport.py @@ -1,3 +1,4 @@ +import contextlib import json import logging @@ -63,11 +64,14 @@ def __init__( httpx_client: httpx.AsyncClient, agent_card: AgentCard | None, url: str, + subscribe_method_override: str | None = None, ): """Initializes the CompatRestTransport.""" self.url = url.removesuffix('/') self.httpx_client = httpx_client self.agent_card = agent_card + self._subscribe_method_override = subscribe_method_override + self._subscribe_auto_method_override = subscribe_method_override is None async def send_message( self, @@ -273,13 +277,41 @@ async def subscribe( *, context: ClientCallContext | None = None, ) -> AsyncGenerator[StreamResponse]: - """Reconnects to get task updates.""" - async for event in self._send_stream_request( - 'GET', - f'/v1/tasks/{request.id}:subscribe', - context=context, - ): - yield event + """Reconnects to get task updates. + + This method implements backward compatibility logic for the subscribe + endpoint. It first attempts to use POST, which is the official method + for A2A subscribe endpoint. If the server returns 405 Method Not Allowed, + it falls back to GET and remembers this preference for future calls + on this transport instance. If both fail with 405, it will default back + to POST for next calls but will not retry again. + """ + subscribe_method = self._subscribe_method_override or 'POST' + try: + async for event in self._send_stream_request( + subscribe_method, + f'/v1/tasks/{request.id}:subscribe', + context=context, + ): + yield event + except A2AClientError as e: + # Check for 405 Method Not Allowed in the cause (httpx.HTTPStatusError) + cause = e.__cause__ + if ( + isinstance(cause, httpx.HTTPStatusError) + and cause.response.status_code == httpx.codes.METHOD_NOT_ALLOWED + ): + if self._subscribe_method_override: + if self._subscribe_auto_method_override: + self._subscribe_auto_method_override = False + self._subscribe_method_override = 'POST' + raise + else: + self._subscribe_method_override = 'GET' + async for event in self.subscribe(request, context=context): + yield event + else: + raise async def get_extended_agent_card( self, @@ -311,7 +343,14 @@ async def close(self) -> None: def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn: """Handles HTTP status errors and raises the appropriate A2AError.""" try: - error_data = e.response.json() + with contextlib.suppress(httpx.StreamClosed): + e.response.read() + + try: + error_data = e.response.json() + except (json.JSONDecodeError, ValueError, httpx.ResponseNotRead): + error_data = {} + error_type = error_data.get('type') message = error_data.get('message', str(e)) diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py index 154409923..0ef56c149 100644 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ b/src/a2a/server/apps/rest/rest_adapter.py @@ -237,6 +237,10 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]: self._handle_streaming_request, self.handler.on_subscribe_to_task, ), + ('/tasks/{id}:subscribe', 'POST'): functools.partial( + self._handle_streaming_request, + self.handler.on_subscribe_to_task, + ), ('/tasks/{id}', 'GET'): functools.partial( self._handle_request, self.handler.on_get_task ), diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index 57b197040..7ed8522fb 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -730,8 +730,15 @@ async def empty_aiter(): async for _ in method(request=request_obj): pass - # 4. Verify the URL + # 4. Verify the URL and method mock_aconnect_sse.assert_called_once() - args, _ = mock_aconnect_sse.call_args + args, kwargs = mock_aconnect_sse.call_args + # method is 2nd positional argument + assert args[1] == 'POST' + if method_name == 'subscribe': + assert kwargs.get('json') is None + else: + assert kwargs.get('json') == json_format.MessageToDict(request_obj) + # url is 3rd positional argument in aconnect_sse(client, method, url, ...) assert args[2] == f'http://agent.example.com/api{expected_path}' diff --git a/tests/compat/v0_3/test_rest_handler.py b/tests/compat/v0_3/test_rest_handler.py index 24e2b24fe..f0aa4e759 100644 --- a/tests/compat/v0_3/test_rest_handler.py +++ b/tests/compat/v0_3/test_rest_handler.py @@ -186,6 +186,44 @@ async def mock_stream(*args, **kwargs): ] +@pytest.mark.anyio +async def test_on_subscribe_to_task_post( + rest_handler, mock_request, mock_context +): + mock_request.path_params = {'id': 'task-1'} + mock_request.method = 'POST' + request_body = {'name': 'tasks/task-1'} + mock_request.body = AsyncMock( + return_value=json.dumps(request_body).encode('utf-8') + ) + + async def mock_stream(*args, **kwargs): + yield types_v03.SendStreamingMessageSuccessResponse( + id='req-1', + result=types_v03.Message( + message_id='msg-2', + role='agent', + parts=[types_v03.TextPart(text='Update')], + ), + ) + + rest_handler.handler03.on_subscribe_to_task = MagicMock( + side_effect=mock_stream + ) + + results = [ + chunk + async for chunk in rest_handler.on_subscribe_to_task( + mock_request, mock_context + ) + ] + + assert len(results) == 1 + rest_handler.handler03.on_subscribe_to_task.assert_called_once() + called_req = rest_handler.handler03.on_subscribe_to_task.call_args[0][0] + assert called_req.params.id == 'task-1' + + @pytest.mark.anyio async def test_get_push_notification(rest_handler, mock_request, mock_context): mock_request.path_params = {'id': 'task-1', 'push_id': 'push-1'} diff --git a/tests/compat/v0_3/test_rest_transport.py b/tests/compat/v0_3/test_rest_transport.py index 9bcf3dba3..4be7cd425 100644 --- a/tests/compat/v0_3/test_rest_transport.py +++ b/tests/compat/v0_3/test_rest_transport.py @@ -1,4 +1,5 @@ import json + from unittest.mock import AsyncMock, MagicMock, patch import httpx @@ -232,14 +233,49 @@ async def mock_send_stream_request(*args, **kwargs): assert events[1] == StreamResponse(message=Message(message_id='msg-123')) +def create_405_error(): + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 405 + mock_response.json.return_value = { + 'type': 'MethodNotAllowed', + 'message': 'Method Not Allowed', + } + mock_request = MagicMock(spec=httpx.Request) + mock_request.url = 'http://example.com/v1/tasks/task-123:subscribe' + + status_error = httpx.HTTPStatusError( + '405 Method Not Allowed', request=mock_request, response=mock_response + ) + raise A2AClientError('HTTP Error 405') from status_error + + +def create_500_error(): + mock_response = MagicMock(spec=httpx.Response) + mock_response.status_code = 500 + mock_response.json.return_value = { + 'type': 'InternalError', + 'message': 'Internal Error', + } + mock_request = MagicMock(spec=httpx.Request) + + status_error = httpx.HTTPStatusError( + '500 Internal Error', request=mock_request, response=mock_response + ) + raise A2AClientError('HTTP Error 500') from status_error + + @pytest.mark.asyncio -async def test_compat_rest_transport_subscribe(transport): - async def mock_send_stream_request(*args, **kwargs): +async def test_compat_rest_transport_subscribe_post_works_no_retry(transport): + """Scenario: POST works, no retry.""" + + async def mock_stream(method, path, context=None, json=None): + assert method == 'POST' + assert json is None task = Task(id='task-123') task.status.message.role = Role.ROLE_AGENT yield StreamResponse(task=task) - transport._send_stream_request = mock_send_stream_request + transport._send_stream_request = mock_stream req = SubscribeToTaskRequest(id='task-123') events = [event async for event in transport.subscribe(req)] @@ -248,6 +284,170 @@ async def mock_send_stream_request(*args, **kwargs): expected_task = Task(id='task-123') expected_task.status.message.role = Role.ROLE_AGENT assert events[0] == StreamResponse(task=expected_task) + assert transport._subscribe_method_override is None + + +@pytest.mark.asyncio +async def test_compat_rest_transport_subscribe_post_405_retry_get_success( + transport, +): + """Scenario: POST returns 405, automatic retry GET. Second call uses GET directly.""" + call_count = 0 + + async def mock_stream(method, path, context=None, json=None): + nonlocal call_count + call_count += 1 + if method == 'POST': + assert json is None + create_405_error() + if method == 'GET': + assert json is None + task = Task(id='task-123') + task.status.message.role = Role.ROLE_AGENT + yield StreamResponse(task=task) + + transport._send_stream_request = mock_stream + + req = SubscribeToTaskRequest(id='task-123') + events = [event async for event in transport.subscribe(req)] + + assert len(events) == 1 + assert call_count == 2 + assert transport._subscribe_method_override == 'GET' + + # Second call should use GET directly + call_count = 0 + events = [event async for event in transport.subscribe(req)] + assert len(events) == 1 + assert call_count == 1 # Only GET called + assert transport._subscribe_method_override == 'GET' + + +@pytest.mark.asyncio +async def test_compat_rest_transport_subscribe_post_405_get_405_fails( + transport, +): + """Scenario: POST return 405, retry GET, return 405 - error. Second call is just POST.""" + + method_count = {} + + async def mock_stream(method, path, context=None, json=None): + method_count[method] = method_count.get(method, 0) + 1 + if method == 'POST': + assert json is None + elif method == 'GET': + assert json is None + # To make it an async generator even when it raises + if False: + yield + create_405_error() + + transport._send_stream_request = mock_stream + + req = SubscribeToTaskRequest(id='task-123') + with pytest.raises(A2AClientError) as exc_info: + [event async for event in transport.subscribe(req)] + + assert '405' in str(exc_info.value) + assert transport._subscribe_method_override == 'POST' + assert method_count == {'POST': 1, 'GET': 1} + assert transport._subscribe_auto_method_override is False + + # Second call should try POST directly and fail without retry + with pytest.raises(A2AClientError): + [event async for event in transport.subscribe(req)] + assert transport._subscribe_auto_method_override is False + assert transport._subscribe_method_override == 'POST' + assert method_count == {'POST': 2, 'GET': 1} + + +@pytest.mark.asyncio +async def test_compat_rest_transport_subscribe_post_500_no_retry(transport): + """Scenario: POST return 500, no automatic retry.""" + call_count = 0 + + async def mock_stream(method, path, context=None, json=None): + nonlocal call_count + call_count += 1 + assert method == 'POST' + assert json is None + if False: + yield + create_500_error() + + transport._send_stream_request = mock_stream + + req = SubscribeToTaskRequest(id='task-123') + with pytest.raises(A2AClientError) as exc_info: + [event async for event in transport.subscribe(req)] + + assert '500' in str(exc_info.value) + assert call_count == 1 # No retry on 500 + assert transport._subscribe_method_override is None + + +@pytest.mark.asyncio +async def test_compat_rest_transport_subscribe_method_override_avoids_retry_get( + mock_httpx_client, agent_card +): + """Scenario: Init with GET override, server returns 405, no automatic retry.""" + transport = CompatRestTransport( + httpx_client=mock_httpx_client, + agent_card=agent_card, + url='http://example.com', + subscribe_method_override='GET', + ) + call_count = 0 + + async def mock_stream(method, path, context=None, json=None): + nonlocal call_count + call_count += 1 + assert method == 'GET' + assert json is None + if False: + yield + create_405_error() + + transport._send_stream_request = mock_stream + + req = SubscribeToTaskRequest(id='task-123') + with pytest.raises(A2AClientError) as exc_info: + [event async for event in transport.subscribe(req)] + + assert '405' in str(exc_info.value) + assert call_count == 1 + + +@pytest.mark.asyncio +async def test_compat_rest_transport_subscribe_method_override_avoids_retry_post( + mock_httpx_client, agent_card +): + """Scenario: Init with POST override, server returns 405, no automatic retry.""" + transport = CompatRestTransport( + httpx_client=mock_httpx_client, + agent_card=agent_card, + url='http://example.com', + subscribe_method_override='POST', + ) + call_count = 0 + + async def mock_stream(method, path, context=None, json=None): + nonlocal call_count + call_count += 1 + assert method == 'POST' + assert json is None + if False: + yield + create_405_error() + + transport._send_stream_request = mock_stream + + req = SubscribeToTaskRequest(id='task-123') + with pytest.raises(A2AClientError) as exc_info: + [event async for event in transport.subscribe(req)] + + assert '405' in str(exc_info.value) + assert call_count == 1 def test_compat_rest_transport_handle_http_error(transport): diff --git a/tests/server/apps/rest/test_rest_fastapi_app.py b/tests/server/apps/rest/test_rest_fastapi_app.py index 382ebea13..c8510023a 100644 --- a/tests/server/apps/rest/test_rest_fastapi_app.py +++ b/tests/server/apps/rest/test_rest_fastapi_app.py @@ -37,9 +37,9 @@ async def agent_card() -> AgentCard: mock_agent_card = MagicMock(spec=AgentCard) mock_agent_card.url = 'http://mockurl.com' - # Mock the capabilities object with streaming disabled + # Mock the capabilities object with streaming enabled mock_capabilities = MagicMock() - mock_capabilities.streaming = False + mock_capabilities.streaming = True mock_capabilities.push_notifications = True mock_capabilities.extended_agent_card = True mock_agent_card.capabilities = mock_capabilities @@ -405,6 +405,64 @@ async def mock_stream_response(): assert data_lines == expected_data_lines +@pytest.mark.anyio +async def test_subscribe_to_task_get( + streaming_client: AsyncClient, request_handler: MagicMock +) -> None: + """Test that GET /tasks/{id}:subscribe works.""" + + async def mock_stream_response(): + yield Task( + id='task-1', + context_id='ctx-1', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + + request_handler.on_subscribe_to_task.return_value = mock_stream_response() + + response = await streaming_client.get( + '/tasks/task-1:subscribe', + headers={'Accept': 'text/event-stream'}, + ) + + response.raise_for_status() + assert response.status_code == 200 + + # Verify handler call + request_handler.on_subscribe_to_task.assert_called_once() + args, _ = request_handler.on_subscribe_to_task.call_args + assert args[0].id == 'task-1' + + +@pytest.mark.anyio +async def test_subscribe_to_task_post( + streaming_client: AsyncClient, request_handler: MagicMock +) -> None: + """Test that POST /tasks/{id}:subscribe works.""" + + async def mock_stream_response(): + yield Task( + id='task-1', + context_id='ctx-1', + status=TaskStatus(state=TaskState.TASK_STATE_WORKING), + ) + + request_handler.on_subscribe_to_task.return_value = mock_stream_response() + + response = await streaming_client.post( + '/tasks/task-1:subscribe', + headers={'Accept': 'text/event-stream'}, + ) + + response.raise_for_status() + assert response.status_code == 200 + + # Verify handler call + request_handler.on_subscribe_to_task.assert_called_once() + args, _ = request_handler.on_subscribe_to_task.call_args + assert args[0].id == 'task-1' + + @pytest.mark.anyio async def test_streaming_endpoint_with_invalid_content_type( streaming_client: AsyncClient, request_handler: MagicMock @@ -493,6 +551,14 @@ class TestTenantExtraction: @pytest.fixture(autouse=True) def configure_mocks(self, request_handler: MagicMock) -> None: # Setup default return values for all handlers + async def mock_stream(*args, **kwargs): + if False: + yield + + request_handler.on_subscribe_to_task.side_effect = ( + lambda *args, **kwargs: mock_stream() + ) + request_handler.on_message_send.return_value = Message( message_id='test', role=Role.ROLE_AGENT, @@ -525,6 +591,8 @@ def extended_card_modifier(self) -> MagicMock: [ ('/message:send', 'POST', 'on_message_send', {'message': {}}), ('/tasks/1:cancel', 'POST', 'on_cancel_task', None), + ('/tasks/1:subscribe', 'GET', 'on_subscribe_to_task', None), + ('/tasks/1:subscribe', 'POST', 'on_subscribe_to_task', None), ('/tasks/1', 'GET', 'on_get_task', None), ('/tasks', 'GET', 'on_list_tasks', None), (