diff --git a/aws_lambda_powertools/event_handler/middlewares/async_utils.py b/aws_lambda_powertools/event_handler/middlewares/async_utils.py index b04db33f1e8..469ed1e96b1 100644 --- a/aws_lambda_powertools/event_handler/middlewares/async_utils.py +++ b/aws_lambda_powertools/event_handler/middlewares/async_utils.py @@ -4,13 +4,16 @@ import asyncio import inspect +import logging import threading from typing import TYPE_CHECKING, Any +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from collections.abc import Callable - from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, Response + from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, BedrockResponse, Response def wrap_middleware_async(middleware: Callable, next_handler: Callable) -> Callable: @@ -105,3 +108,60 @@ def run_middleware() -> None: raise middleware_error_holder[0] return middleware_result_holder[0] + + +async def _registered_api_adapter_async( + app: ApiGatewayResolver, + next_middleware: Callable[..., Any], +) -> dict | tuple | Response | BedrockResponse: + """ + Async version of _registered_api_adapter. + + Detects if the route handler is a coroutine and awaits it. + _to_response() stays sync (CPU-bound — no async benefit). + + IMPORTANT: This is an internal building block only. + Nothing calls it in the resolve chain yet. It will be used + by resolve_async() (see issue #8137). + + Parameters + ---------- + app: ApiGatewayResolver + The API Gateway resolver + next_middleware: Callable[..., Any] + The function to handle the API + + Returns + ------- + Response + The API Response Object + """ + route_args: dict = app.context.get("_route_args", {}) + logger.debug(f"Calling API Route Handler: {route_args}") + + route = app.context.get("_route") + if route is not None: + if not route.request_param_name_checked: + from aws_lambda_powertools.event_handler.api_gateway import _find_request_param_name + + route.request_param_name = _find_request_param_name(next_middleware) + route.request_param_name_checked = True + if route.request_param_name: + route_args = {**route_args, route.request_param_name: app.request} + + if route.has_dependencies: + from aws_lambda_powertools.event_handler.depends import build_dependency_tree, solve_dependencies + + dep_values = solve_dependencies( + dependant=build_dependency_tree(route.func), + request=app.request, + dependency_overrides=app.dependency_overrides or None, + ) + route_args.update(dep_values) + + # Call handler — detect if result is a coroutine and await it + result = next_middleware(**route_args) + if inspect.iscoroutine(result): + result = await result + + return app._to_response(result) diff --git a/tests/functional/event_handler/required_dependencies/test_registered_api_adapter_async.py b/tests/functional/event_handler/required_dependencies/test_registered_api_adapter_async.py new file mode 100644 index 00000000000..10d5b4602f0 --- /dev/null +++ b/tests/functional/event_handler/required_dependencies/test_registered_api_adapter_async.py @@ -0,0 +1,335 @@ +import asyncio +import re +from typing import cast + +import pytest +from typing_extensions import Annotated + +from aws_lambda_powertools.event_handler import content_types +from aws_lambda_powertools.event_handler.api_gateway import ( + APIGatewayHttpResolver, + ApiGatewayResolver, + APIGatewayRestResolver, + BaseRouter, + ProxyEventType, + Response, + Route, +) +from aws_lambda_powertools.event_handler.depends import Depends +from aws_lambda_powertools.event_handler.middlewares.async_utils import _registered_api_adapter_async +from aws_lambda_powertools.event_handler.request import Request +from tests.functional.utils import load_event + +API_REST_EVENT = load_event("apiGatewayProxyEvent.json") +API_RESTV2_EVENT = load_event("apiGatewayProxyV2Event_GET.json") + + +def _setup_resolver_context(app: ApiGatewayResolver, event: dict) -> None: + """Populate the resolver context the same way resolve() does, without calling the full chain.""" + BaseRouter.current_event = app._to_proxy_event(cast(dict, event)) + BaseRouter.lambda_context = {} + + +@pytest.mark.parametrize( + "app, event", + [ + (ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT), + (APIGatewayRestResolver(), API_REST_EVENT), + (APIGatewayHttpResolver(), API_RESTV2_EVENT), + ], +) +def test_sync_handler_returns_response(app: ApiGatewayResolver, event): + # GIVEN a sync route handler + @app.get("/my/path") + def get_lambda(): + return Response(200, content_types.TEXT_HTML, "sync response") + + # WHEN resolving the event through the normal chain + result = app(event, {}) + + # THEN the sync handler is called and returns correctly + assert result["statusCode"] == 200 + assert result["body"] == "sync response" + + +@pytest.mark.parametrize( + "app, event", + [ + (ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT), + (APIGatewayRestResolver(), API_REST_EVENT), + (APIGatewayHttpResolver(), API_RESTV2_EVENT), + ], +) +def test_async_handler_is_awaited(app: ApiGatewayResolver, event): + # GIVEN an async route handler registered on the resolver + @app.get("/my/path") + async def get_lambda(): + return Response(200, content_types.TEXT_HTML, "async response") + + # WHEN populating context and calling the async adapter directly + _setup_resolver_context(app, event) + app.append_context(_route_args={}) + + result = asyncio.run( + _registered_api_adapter_async(app, get_lambda), + ) + + # THEN the async handler is awaited and returns correctly + assert result.status_code == 200 + assert result.body == "async response" + + +@pytest.mark.parametrize( + "app, event", + [ + (ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT), + (APIGatewayRestResolver(), API_REST_EVENT), + (APIGatewayHttpResolver(), API_RESTV2_EVENT), + ], +) +def test_sync_handler_through_adapter(app: ApiGatewayResolver, event): + # GIVEN a sync route handler + @app.get("/my/path") + def get_lambda(): + return Response(200, content_types.TEXT_HTML, "sync via adapter") + + # WHEN calling _registered_api_adapter_async with a sync handler + _setup_resolver_context(app, event) + app.append_context(_route_args={}) + + result = asyncio.run( + _registered_api_adapter_async(app, get_lambda), + ) + + # THEN sync handler works through the async adapter without issue + assert result.status_code == 200 + assert result.body == "sync via adapter" + + +@pytest.mark.parametrize( + "app, event", + [ + (ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT), + (APIGatewayRestResolver(), API_REST_EVENT), + (APIGatewayHttpResolver(), API_RESTV2_EVENT), + ], +) +def test_adapter_passes_route_args_to_async_handler(app: ApiGatewayResolver, event): + # GIVEN an async handler that expects route arguments + async def get_lambda(name: str): + return Response(200, content_types.TEXT_HTML, name) + + # WHEN route_args are set in the context + _setup_resolver_context(app, event) + app.append_context(_route_args={"name": "powertools"}) + + result = asyncio.run( + _registered_api_adapter_async(app, get_lambda), + ) + + # THEN the route args are passed to the handler + assert result.status_code == 200 + assert result.body == "powertools" + + +@pytest.mark.parametrize( + "app, event", + [ + (ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT), + (APIGatewayRestResolver(), API_REST_EVENT), + (APIGatewayHttpResolver(), API_RESTV2_EVENT), + ], +) +def test_adapter_passes_route_args_to_sync_handler(app: ApiGatewayResolver, event): + # GIVEN a sync handler that expects route arguments + def get_lambda(name: str): + return Response(200, content_types.TEXT_HTML, name) + + # WHEN route_args are set in the context + _setup_resolver_context(app, event) + app.append_context(_route_args={"name": "powertools"}) + + result = asyncio.run( + _registered_api_adapter_async(app, get_lambda), + ) + + # THEN the route args are passed to the sync handler + assert result.status_code == 200 + assert result.body == "powertools" + + +def test_adapter_converts_dict_response_from_async_handler(): + # GIVEN an async handler that returns a dict (not a Response object) + app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent) + + async def get_lambda(): + return {"message": "hello"} + + # WHEN calling through the async adapter + _setup_resolver_context(app, API_REST_EVENT) + app.append_context(_route_args={}) + + result = asyncio.run( + _registered_api_adapter_async(app, get_lambda), + ) + + # THEN _to_response normalizes the dict into a Response object + assert result.status_code == 200 + assert result.body is not None + + +def test_adapter_converts_tuple_response_from_async_handler(): + # GIVEN an async handler that returns a (dict, status_code) tuple + app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent) + + async def get_lambda(): + return {"created": True}, 201 + + # WHEN calling through the async adapter + _setup_resolver_context(app, API_REST_EVENT) + app.append_context(_route_args={}) + + result = asyncio.run( + _registered_api_adapter_async(app, get_lambda), + ) + + # THEN _to_response normalizes the tuple into a Response object + assert result.status_code == 201 + + +def test_adapter_with_no_route_in_context(): + # GIVEN a handler and no _route in context + app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent) + + async def get_lambda(): + return Response(200, content_types.TEXT_HTML, "no route") + + # WHEN _route is None in context (default) + _setup_resolver_context(app, API_REST_EVENT) + app.append_context(_route_args={}) + + result = asyncio.run( + _registered_api_adapter_async(app, get_lambda), + ) + + # THEN the adapter skips request injection and dependency resolution + assert result.status_code == 200 + assert result.body == "no route" + + +def test_adapter_injects_request_param(): + # GIVEN an async handler that declares a Request parameter + app = APIGatewayHttpResolver() + + async def get_lambda(request: Request): + return Response(200, content_types.TEXT_HTML, request.method) + + # WHEN a Route is present in context with request_param_name not yet checked + _setup_resolver_context(app, API_RESTV2_EVENT) + route = Route( + method="GET", + path="/my/path", + rule=re.compile(r"^/my/path$"), + func=get_lambda, + cors=False, + compress=False, + ) + app.append_context(_route=route, _route_args={}) + + result = asyncio.run( + _registered_api_adapter_async(app, get_lambda), + ) + + # THEN the Request object is injected and request_param_name is cached + assert result.status_code == 200 + assert route.request_param_name_checked is True + assert route.request_param_name == "request" + + +def test_adapter_uses_cached_request_param_name(): + # GIVEN a Route where request_param_name was already resolved + app = APIGatewayHttpResolver() + + async def get_lambda(req: Request): + return Response(200, content_types.TEXT_HTML, req.method) + + _setup_resolver_context(app, API_RESTV2_EVENT) + route = Route( + method="GET", + path="/my/path", + rule=re.compile(r"^/my/path$"), + func=get_lambda, + cors=False, + compress=False, + ) + route.request_param_name = "req" + route.request_param_name_checked = True + app.append_context(_route=route, _route_args={}) + + # WHEN calling the adapter a second time (cache hit) + result = asyncio.run( + _registered_api_adapter_async(app, get_lambda), + ) + + # THEN it still injects the Request using the cached param name + assert result.status_code == 200 + + +def test_adapter_resolves_dependencies(): + # GIVEN an async handler with Depends() parameters + app = APIGatewayHttpResolver() + + def get_greeting() -> str: + return "hello" + + async def get_lambda(greeting: Annotated[str, Depends(get_greeting)]): + return {"greeting": greeting} + + _setup_resolver_context(app, API_RESTV2_EVENT) + route = Route( + method="GET", + path="/my/path", + rule=re.compile(r"^/my/path$"), + func=get_lambda, + cors=False, + compress=False, + ) + app.append_context(_route=route, _route_args={}) + + # WHEN calling the adapter + result = asyncio.run( + _registered_api_adapter_async(app, get_lambda), + ) + + # THEN dependencies are resolved and injected + assert result.status_code == 200 + + +def test_adapter_resolves_dependencies_with_sync_handler(): + # GIVEN a sync handler with Depends() parameters + app = APIGatewayHttpResolver() + + def get_greeting() -> str: + return "hello" + + def get_lambda(greeting: Annotated[str, Depends(get_greeting)]): + return {"greeting": greeting} + + _setup_resolver_context(app, API_RESTV2_EVENT) + route = Route( + method="GET", + path="/my/path", + rule=re.compile(r"^/my/path$"), + func=get_lambda, + cors=False, + compress=False, + ) + app.append_context(_route=route, _route_args={}) + + # WHEN calling the adapter with a sync handler that has dependencies + result = asyncio.run( + _registered_api_adapter_async(app, get_lambda), + ) + + # THEN dependencies are resolved and injected for sync handler too + assert result.status_code == 200