Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ async def subscribe(
) -> AsyncGenerator[StreamResponse]:
"""Reconnects to get task updates."""
async for event in self._send_stream_request(
'GET',
'POST',
f'/tasks/{request.id}:subscribe',
request.tenant,
context=context,
Expand Down
4 changes: 4 additions & 0 deletions src/a2a/compat/v0_3/rest_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
self._handle_streaming_request,
self.handler.on_subscribe_to_task,
),
('/v1/tasks/{id}:subscribe', 'POST'): functools.partial(
self._handle_streaming_request,
self.handler.on_subscribe_to_task,
),
('/v1/tasks/{id}', 'GET'): functools.partial(
self._handle_request, self.handler.on_get_task
),
Expand Down
55 changes: 47 additions & 8 deletions src/a2a/compat/v0_3/rest_transport.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import json
import logging

Expand Down Expand Up @@ -63,11 +64,14 @@ def __init__(
httpx_client: httpx.AsyncClient,
agent_card: AgentCard | None,
url: str,
subscribe_method_override: str | None = None,
):
"""Initializes the CompatRestTransport."""
self.url = url.removesuffix('/')
self.httpx_client = httpx_client
self.agent_card = agent_card
self._subscribe_method_override = subscribe_method_override
self._subscribe_auto_method_override = subscribe_method_override is None

async def send_message(
self,
Expand Down Expand Up @@ -273,13 +277,41 @@ async def subscribe(
*,
context: ClientCallContext | None = None,
) -> AsyncGenerator[StreamResponse]:
"""Reconnects to get task updates."""
async for event in self._send_stream_request(
'GET',
f'/v1/tasks/{request.id}:subscribe',
context=context,
):
yield event
"""Reconnects to get task updates.

This method implements backward compatibility logic for the subscribe
endpoint. It first attempts to use POST, which is the official method
for A2A subscribe endpoint. If the server returns 405 Method Not Allowed,
it falls back to GET and remembers this preference for future calls
on this transport instance. If both fail with 405, it will default back
to POST for next calls but will not retry again.
"""
subscribe_method = self._subscribe_method_override or 'POST'
try:
async for event in self._send_stream_request(
subscribe_method,
f'/v1/tasks/{request.id}:subscribe',
context=context,
):
yield event
except A2AClientError as e:
# Check for 405 Method Not Allowed in the cause (httpx.HTTPStatusError)
cause = e.__cause__
if (
isinstance(cause, httpx.HTTPStatusError)
and cause.response.status_code == httpx.codes.METHOD_NOT_ALLOWED
):
if self._subscribe_method_override:
if self._subscribe_auto_method_override:
self._subscribe_auto_method_override = False
self._subscribe_method_override = 'POST'
raise
else:
self._subscribe_method_override = 'GET'
async for event in self.subscribe(request, context=context):
yield event
else:
raise

async def get_extended_agent_card(
self,
Expand Down Expand Up @@ -311,7 +343,14 @@ async def close(self) -> None:
def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn:
"""Handles HTTP status errors and raises the appropriate A2AError."""
try:
error_data = e.response.json()
with contextlib.suppress(httpx.StreamClosed):
e.response.read()

try:
error_data = e.response.json()
except (json.JSONDecodeError, ValueError, httpx.ResponseNotRead):
error_data = {}

error_type = error_data.get('type')
message = error_data.get('message', str(e))

Expand Down
4 changes: 4 additions & 0 deletions src/a2a/server/apps/rest/rest_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ def routes(self) -> dict[tuple[str, str], Callable[[Request], Any]]:
self._handle_streaming_request,
self.handler.on_subscribe_to_task,
),
('/tasks/{id}:subscribe', 'POST'): functools.partial(
self._handle_streaming_request,
self.handler.on_subscribe_to_task,
),
('/tasks/{id}', 'GET'): functools.partial(
self._handle_request, self.handler.on_get_task
),
Expand Down
11 changes: 9 additions & 2 deletions tests/client/transports/test_rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,8 +730,15 @@ async def empty_aiter():
async for _ in method(request=request_obj):
pass

# 4. Verify the URL
# 4. Verify the URL and method
mock_aconnect_sse.assert_called_once()
args, _ = mock_aconnect_sse.call_args
args, kwargs = mock_aconnect_sse.call_args
# method is 2nd positional argument
assert args[1] == 'POST'
if method_name == 'subscribe':
assert kwargs.get('json') is None
else:
assert kwargs.get('json') == json_format.MessageToDict(request_obj)

# url is 3rd positional argument in aconnect_sse(client, method, url, ...)
assert args[2] == f'http://agent.example.com/api{expected_path}'
38 changes: 38 additions & 0 deletions tests/compat/v0_3/test_rest_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,44 @@ async def mock_stream(*args, **kwargs):
]


@pytest.mark.anyio
async def test_on_subscribe_to_task_post(
rest_handler, mock_request, mock_context
):
mock_request.path_params = {'id': 'task-1'}
mock_request.method = 'POST'
request_body = {'name': 'tasks/task-1'}
mock_request.body = AsyncMock(
return_value=json.dumps(request_body).encode('utf-8')
)

async def mock_stream(*args, **kwargs):
yield types_v03.SendStreamingMessageSuccessResponse(
id='req-1',
result=types_v03.Message(
message_id='msg-2',
role='agent',
parts=[types_v03.TextPart(text='Update')],
),
)

rest_handler.handler03.on_subscribe_to_task = MagicMock(
side_effect=mock_stream
)

results = [
chunk
async for chunk in rest_handler.on_subscribe_to_task(
mock_request, mock_context
)
]

assert len(results) == 1
rest_handler.handler03.on_subscribe_to_task.assert_called_once()
called_req = rest_handler.handler03.on_subscribe_to_task.call_args[0][0]
assert called_req.params.id == 'task-1'


@pytest.mark.anyio
async def test_get_push_notification(rest_handler, mock_request, mock_context):
mock_request.path_params = {'id': 'task-1', 'push_id': 'push-1'}
Expand Down
Loading
Loading