diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 579c2ff15..8afe0ca65 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -14,6 +14,7 @@ agentic AGrpc aio aiomysql +AIP alg amannn aproject diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index 27c0b6a0a..82e963142 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -34,16 +34,12 @@ Task, TaskPushNotificationConfig, ) -from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP, MethodNotFoundError +from a2a.utils.errors import A2A_REASON_TO_ERROR, MethodNotFoundError from a2a.utils.telemetry import SpanKind, trace_class logger = logging.getLogger(__name__) -_A2A_ERROR_NAME_TO_CLS = { - error_type.__name__: error_type for error_type in JSON_RPC_ERROR_CODE_MAP -} - @trace_class(kind=SpanKind.CLIENT) class RestTransport(ClientTransport): @@ -297,15 +293,36 @@ def _get_path(self, base_path: str, tenant: str) -> str: def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn: """Handles HTTP status errors and raises the appropriate A2AError.""" try: - error_data = e.response.json() - error_type = error_data.get('type') - message = error_data.get('message', str(e)) + error_payload = e.response.json() + error_data = error_payload.get('error', {}) - if isinstance(error_type, str): - # TODO(#723): Resolving imports by name is temporary until proper error handling structure is added in #723. - exception_cls = _A2A_ERROR_NAME_TO_CLS.get(error_type) + message = error_data.get('message', str(e)) + details = error_data.get('details', []) + if not isinstance(details, list): + details = [] + + # The `details` array can contain multiple different error objects. + # We extract the first `ErrorInfo` object because it contains the + # specific `reason` code needed to map this back to a Python A2AError. + error_info = {} + for d in details: + if ( + isinstance(d, dict) + and d.get('@type') + == 'type.googleapis.com/google.rpc.ErrorInfo' + ): + error_info = d + break + reason = error_info.get('reason') + metadata = error_info.get('metadata') or {} + + if isinstance(reason, str): + exception_cls = A2A_REASON_TO_ERROR.get(reason) if exception_cls: - raise exception_cls(message) from e + exc = exception_cls(message) + if metadata: + exc.data = metadata + raise exc from e except (json.JSONDecodeError, ValueError): pass diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py index c828610a3..ea9a501b9 100644 --- a/src/a2a/server/apps/rest/fastapi_app.py +++ b/src/a2a/server/apps/rest/fastapi_app.py @@ -7,12 +7,14 @@ if TYPE_CHECKING: from fastapi import APIRouter, FastAPI, Request, Response from fastapi.responses import JSONResponse + from starlette.exceptions import HTTPException as StarletteHTTPException _package_fastapi_installed = True else: try: from fastapi import APIRouter, FastAPI, Request, Response from fastapi.responses import JSONResponse + from starlette.exceptions import HTTPException as StarletteHTTPException _package_fastapi_installed = True except ImportError: @@ -20,6 +22,7 @@ FastAPI = Any Request = Any Response = Any + StarletteHTTPException = Any _package_fastapi_installed = False @@ -36,6 +39,23 @@ logger = logging.getLogger(__name__) +_HTTP_TO_GRPC_STATUS_MAP = { + 400: 'INVALID_ARGUMENT', + 401: 'UNAUTHENTICATED', + 403: 'PERMISSION_DENIED', + 404: 'NOT_FOUND', + 405: 'UNIMPLEMENTED', + 409: 'ALREADY_EXISTS', + 415: 'INVALID_ARGUMENT', + 422: 'INVALID_ARGUMENT', + 500: 'INTERNAL', + 501: 'UNIMPLEMENTED', + 502: 'INTERNAL', + 503: 'UNAVAILABLE', + 504: 'DEADLINE_EXCEEDED', +} + + class A2ARESTFastAPIApplication: """A FastAPI application implementing the A2A protocol server REST endpoints. @@ -121,6 +141,34 @@ def build( A configured FastAPI application instance. """ app = FastAPI(**kwargs) + + @app.exception_handler(StarletteHTTPException) + async def http_exception_handler( + request: Request, exc: StarletteHTTPException + ) -> Response: + """Catches framework-level HTTP exceptions. + + For example, 404 Not Found for bad routes, 422 Unprocessable Entity + for schema validation, and formats them into the A2A standard + google.rpc.Status JSON format (AIP-193). + """ + grpc_status = _HTTP_TO_GRPC_STATUS_MAP.get( + exc.status_code, 'UNKNOWN' + ) + return JSONResponse( + status_code=exc.status_code, + content={ + 'error': { + 'code': exc.status_code, + 'status': grpc_status, + 'message': str(exc.detail) + if hasattr(exc, 'detail') + else 'HTTP Exception', + } + }, + media_type='application/json', + ) + if self.enable_v0_3_compat and self._v03_adapter: v03_adapter = self._v03_adapter v03_router = APIRouter() diff --git a/src/a2a/utils/error_handlers.py b/src/a2a/utils/error_handlers.py index 00843fcf6..30916b6f0 100644 --- a/src/a2a/utils/error_handlers.py +++ b/src/a2a/utils/error_handlers.py @@ -2,7 +2,7 @@ import logging from collections.abc import Awaitable, Callable, Coroutine -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: @@ -17,70 +17,40 @@ from google.protobuf.json_format import ParseError -from a2a.server.jsonrpc_models import ( - InternalError as JSONRPCInternalError, -) -from a2a.server.jsonrpc_models import ( - JSONParseError, - JSONRPCError, -) from a2a.utils.errors import ( + A2A_REST_ERROR_MAPPING, A2AError, - ContentTypeNotSupportedError, - ExtendedAgentCardNotConfiguredError, - ExtensionSupportRequiredError, InternalError, - InvalidAgentResponseError, - InvalidParamsError, - InvalidRequestError, - MethodNotFoundError, - PushNotificationNotSupportedError, - TaskNotCancelableError, - TaskNotFoundError, - UnsupportedOperationError, - VersionNotSupportedError, + RestErrorMap, ) logger = logging.getLogger(__name__) -_A2AErrorType = ( - type[JSONRPCError] - | type[JSONParseError] - | type[InvalidRequestError] - | type[MethodNotFoundError] - | type[InvalidParamsError] - | type[InternalError] - | type[JSONRPCInternalError] - | type[TaskNotFoundError] - | type[TaskNotCancelableError] - | type[PushNotificationNotSupportedError] - | type[UnsupportedOperationError] - | type[ContentTypeNotSupportedError] - | type[InvalidAgentResponseError] - | type[ExtendedAgentCardNotConfiguredError] - | type[ExtensionSupportRequiredError] - | type[VersionNotSupportedError] -) -A2AErrorToHttpStatus: dict[_A2AErrorType, int] = { - JSONRPCError: 500, - JSONParseError: 400, - InvalidRequestError: 400, - MethodNotFoundError: 404, - InvalidParamsError: 422, - InternalError: 500, - JSONRPCInternalError: 500, - TaskNotFoundError: 404, - TaskNotCancelableError: 409, - PushNotificationNotSupportedError: 501, - UnsupportedOperationError: 501, - ContentTypeNotSupportedError: 415, - InvalidAgentResponseError: 502, - ExtendedAgentCardNotConfiguredError: 400, - ExtensionSupportRequiredError: 400, - VersionNotSupportedError: 400, -} +def _build_error_payload( + code: int, + status: str, + message: str, + reason: str | None = None, + metadata: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Helper function to build the JSON error payload.""" + payload: dict[str, Any] = { + 'code': code, + 'status': status, + 'message': message, + } + if reason: + payload['details'] = [ + { + '@type': 'type.googleapis.com/google.rpc.ErrorInfo', + 'reason': reason, + 'domain': 'a2a-protocol.org', + 'metadata': metadata if metadata is not None else {}, + } + ] + return {'error': payload} def rest_error_handler( @@ -93,9 +63,12 @@ async def wrapper(*args: Any, **kwargs: Any) -> Response: try: return await func(*args, **kwargs) except A2AError as error: - http_code = A2AErrorToHttpStatus.get( - cast('_A2AErrorType', type(error)), 500 + mapping = A2A_REST_ERROR_MAPPING.get( + type(error), RestErrorMap(500, 'INTERNAL', 'INTERNAL_ERROR') ) + http_code = mapping.http_code + grpc_status = mapping.grpc_status + reason = mapping.reason log_level = ( logging.ERROR @@ -107,32 +80,46 @@ async def wrapper(*args: Any, **kwargs: Any) -> Response: "Request error: Code=%s, Message='%s'%s", getattr(error, 'code', 'N/A'), getattr(error, 'message', str(error)), - ', Data=' + str(getattr(error, 'data', '')) - if getattr(error, 'data', None) - else '', + f', Data={error.data}' if error.data else '', ) - # TODO(#722): Standardize error response format. + + # SECURITY WARNING: Data attached to A2AError.data is serialized unaltered and exposed publicly to the client in the REST API response. + metadata = getattr(error, 'data', None) or {} + return JSONResponse( - content={ - 'message': getattr(error, 'message', str(error)), - 'type': type(error).__name__, - }, + content=_build_error_payload( + code=http_code, + status=grpc_status, + message=getattr(error, 'message', str(error)), + reason=reason, + metadata=metadata, + ), status_code=http_code, + media_type='application/json', ) except ParseError as error: logger.warning('Parse error: %s', str(error)) return JSONResponse( - content={ - 'message': str(error), - 'type': 'ParseError', - }, + content=_build_error_payload( + code=400, + status='INVALID_ARGUMENT', + message=str(error), + reason='INVALID_REQUEST', + metadata={}, + ), status_code=400, + media_type='application/json', ) except Exception: logger.exception('Unknown error occurred') return JSONResponse( - content={'message': 'unknown exception', 'type': 'Exception'}, + content=_build_error_payload( + code=500, + status='INTERNAL', + message='unknown exception', + ), status_code=500, + media_type='application/json', ) return wrapper @@ -158,9 +145,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: "Request error: Code=%s, Message='%s'%s", getattr(error, 'code', 'N/A'), getattr(error, 'message', str(error)), - ', Data=' + str(getattr(error, 'data', '')) - if getattr(error, 'data', None) - else '', + f', Data={error.data}' if error.data else '', ) # Since the stream has started, we can't return a JSONResponse. # Instead, we run the error handling logic (provides logging) diff --git a/src/a2a/utils/errors.py b/src/a2a/utils/errors.py index ac4da027a..a16542d97 100644 --- a/src/a2a/utils/errors.py +++ b/src/a2a/utils/errors.py @@ -4,11 +4,22 @@ as well as server exception classes. """ +from typing import NamedTuple + + +class RestErrorMap(NamedTuple): + """Named tuple mapping HTTP status, gRPC status, and reason strings.""" + + http_code: int + grpc_status: str + reason: str + class A2AError(Exception): """Base exception for A2A errors.""" message: str = 'A2A Error' + data: dict | None = None def __init__(self, message: str | None = None): if message: @@ -100,6 +111,7 @@ class VersionNotSupportedError(A2AError): __all__ = [ 'A2A_ERROR_REASONS', 'A2A_REASON_TO_ERROR', + 'A2A_REST_ERROR_MAPPING', 'JSON_RPC_ERROR_CODE_MAP', 'ExtensionSupportRequiredError', 'InternalError', @@ -108,6 +120,7 @@ class VersionNotSupportedError(A2AError): 'InvalidRequestError', 'MethodNotFoundError', 'PushNotificationNotSupportedError', + 'RestErrorMap', 'TaskNotCancelableError', 'TaskNotFoundError', 'UnsupportedOperationError', @@ -132,16 +145,53 @@ class VersionNotSupportedError(A2AError): } +A2A_REST_ERROR_MAPPING: dict[type[A2AError], RestErrorMap] = { + TaskNotFoundError: RestErrorMap(404, 'NOT_FOUND', 'TASK_NOT_FOUND'), + TaskNotCancelableError: RestErrorMap( + 409, 'FAILED_PRECONDITION', 'TASK_NOT_CANCELABLE' + ), + PushNotificationNotSupportedError: RestErrorMap( + 400, + 'UNIMPLEMENTED', + 'PUSH_NOTIFICATION_NOT_SUPPORTED', + ), + UnsupportedOperationError: RestErrorMap( + 400, 'UNIMPLEMENTED', 'UNSUPPORTED_OPERATION' + ), + ContentTypeNotSupportedError: RestErrorMap( + 415, + 'INVALID_ARGUMENT', + 'CONTENT_TYPE_NOT_SUPPORTED', + ), + InvalidAgentResponseError: RestErrorMap( + 502, 'INTERNAL', 'INVALID_AGENT_RESPONSE' + ), + ExtendedAgentCardNotConfiguredError: RestErrorMap( + 400, + 'FAILED_PRECONDITION', + 'EXTENDED_AGENT_CARD_NOT_CONFIGURED', + ), + ExtensionSupportRequiredError: RestErrorMap( + 400, + 'FAILED_PRECONDITION', + 'EXTENSION_SUPPORT_REQUIRED', + ), + VersionNotSupportedError: RestErrorMap( + 400, 'UNIMPLEMENTED', 'VERSION_NOT_SUPPORTED' + ), + InvalidParamsError: RestErrorMap(400, 'INVALID_ARGUMENT', 'INVALID_PARAMS'), + InvalidRequestError: RestErrorMap( + 400, 'INVALID_ARGUMENT', 'INVALID_REQUEST' + ), + MethodNotFoundError: RestErrorMap(404, 'NOT_FOUND', 'METHOD_NOT_FOUND'), + InternalError: RestErrorMap(500, 'INTERNAL', 'INTERNAL_ERROR'), +} + + A2A_ERROR_REASONS = { - TaskNotFoundError: 'TASK_NOT_FOUND', - TaskNotCancelableError: 'TASK_NOT_CANCELABLE', - PushNotificationNotSupportedError: 'PUSH_NOTIFICATION_NOT_SUPPORTED', - UnsupportedOperationError: 'UNSUPPORTED_OPERATION', - ContentTypeNotSupportedError: 'CONTENT_TYPE_NOT_SUPPORTED', - InvalidAgentResponseError: 'INVALID_AGENT_RESPONSE', - ExtendedAgentCardNotConfiguredError: 'EXTENDED_AGENT_CARD_NOT_CONFIGURED', - ExtensionSupportRequiredError: 'EXTENSION_SUPPORT_REQUIRED', - VersionNotSupportedError: 'VERSION_NOT_SUPPORTED', + cls: mapping.reason for cls, mapping in A2A_REST_ERROR_MAPPING.items() } -A2A_REASON_TO_ERROR = {reason: cls for cls, reason in A2A_ERROR_REASONS.items()} +A2A_REASON_TO_ERROR = { + mapping.reason: cls for cls, mapping in A2A_REST_ERROR_MAPPING.items() +} diff --git a/tests/client/transports/test_rest_client.py b/tests/client/transports/test_rest_client.py index d76873918..57b197040 100644 --- a/tests/client/transports/test_rest_client.py +++ b/tests/client/transports/test_rest_client.py @@ -29,7 +29,7 @@ TaskState, ) from a2a.utils.constants import TransportProtocol -from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP +from a2a.utils.errors import A2A_REST_ERROR_MAPPING @pytest.fixture @@ -102,7 +102,7 @@ async def test_send_message_streaming_timeout( assert 'Client Request timed out' in str(exc_info.value) - @pytest.mark.parametrize('error_cls', list(JSON_RPC_ERROR_CODE_MAP.keys())) + @pytest.mark.parametrize('error_cls', list(A2A_REST_ERROR_MAPPING.keys())) @pytest.mark.asyncio async def test_rest_mapped_errors( self, @@ -127,9 +127,23 @@ async def test_rest_mapped_errors( mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 500 + + reason = A2A_REST_ERROR_MAPPING[error_cls][2] + mock_response.json.return_value = { - 'type': error_cls.__name__, - 'message': 'Mapped Error', + 'error': { + 'code': 500, + 'status': 'UNKNOWN', + 'message': 'Mapped Error', + 'details': [ + { + '@type': 'type.googleapis.com/google.rpc.ErrorInfo', + 'reason': reason, + 'domain': 'a2a-protocol.org', + 'metadata': {}, + } + ], + } } error = httpx.HTTPStatusError( diff --git a/tests/server/apps/rest/test_rest_fastapi_app.py b/tests/server/apps/rest/test_rest_fastapi_app.py index 0731f0e76..382ebea13 100644 --- a/tests/server/apps/rest/test_rest_fastapi_app.py +++ b/tests/server/apps/rest/test_rest_fastapi_app.py @@ -624,5 +624,33 @@ async def test_tenant_extraction_extended_agent_card( assert context.tenant == '' +@pytest.mark.anyio +async def test_global_http_exception_handler_returns_rpc_status( + client: AsyncClient, +) -> None: + """Test that a standard FastAPI 404 is transformed into the A2A google.rpc.Status format.""" + + # Send a request to an endpoint that does not exist + response = await client.get('/non-existent-route') + + # Verify it returns a 404 with standard application/json + assert response.status_code == 404 + assert response.headers.get('content-type') == 'application/json' + + data = response.json() + + # Assert the payload is wrapped in the "error" envelope + assert 'error' in data + error_payload = data['error'] + + # Assert it has the correct AIP-193 format + assert error_payload['code'] == 404 + assert error_payload['status'] == 'NOT_FOUND' + assert 'Not Found' in error_payload['message'] + + # Standard HTTP errors shouldn't leak details + assert 'details' not in error_payload + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/utils/test_error_handlers.py b/tests/utils/test_error_handlers.py index e20c402a1..3fd189eb9 100644 --- a/tests/utils/test_error_handlers.py +++ b/tests/utils/test_error_handlers.py @@ -13,16 +13,16 @@ MethodNotFoundError, ) from a2a.utils.error_handlers import ( - A2AErrorToHttpStatus, rest_error_handler, rest_stream_error_handler, ) class MockJSONResponse: - def __init__(self, content, status_code): + def __init__(self, content, status_code, media_type=None): self.content = content self.status_code = status_code + self.media_type = media_type @pytest.mark.asyncio @@ -39,9 +39,21 @@ async def failing_func(): assert isinstance(result, MockJSONResponse) assert result.status_code == 400 + assert result.media_type == 'application/json' assert result.content == { - 'message': 'Bad request', - 'type': 'InvalidRequestError', + 'error': { + 'code': 400, + 'status': 'INVALID_ARGUMENT', + 'message': 'Bad request', + 'details': [ + { + '@type': 'type.googleapis.com/google.rpc.ErrorInfo', + 'reason': 'INVALID_REQUEST', + 'domain': 'a2a-protocol.org', + 'metadata': {}, + } + ], + } } @@ -58,9 +70,13 @@ async def failing_func(): assert isinstance(result, MockJSONResponse) assert result.status_code == 500 + assert result.media_type == 'application/json' assert result.content == { - 'message': 'unknown exception', - 'type': 'Exception', + 'error': { + 'code': 500, + 'status': 'INTERNAL', + 'message': 'unknown exception', + } } @@ -89,11 +105,3 @@ async def failing_stream(): with pytest.raises(RuntimeError, match='Stream failed'): await failing_stream() - - -def test_a2a_error_to_http_status_mapping(): - """Test A2AErrorToHttpStatus mapping.""" - assert A2AErrorToHttpStatus[InvalidRequestError] == 400 - assert A2AErrorToHttpStatus[MethodNotFoundError] == 404 - assert A2AErrorToHttpStatus[TaskNotFoundError] == 404 - assert A2AErrorToHttpStatus[InternalError] == 500