Skip to content
62 changes: 61 additions & 1 deletion aws_lambda_powertools/event_handler/middlewares/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@

import asyncio
import inspect
import logging
import threading
from typing import TYPE_CHECKING, Any

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from collections.abc import Callable

from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, Response
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, BedrockResponse, Response


def wrap_middleware_async(middleware: Callable, next_handler: Callable) -> Callable:
Expand Down Expand Up @@ -105,3 +108,60 @@ def run_middleware() -> None:
raise middleware_error_holder[0]

return middleware_result_holder[0]


async def _registered_api_adapter_async(
app: ApiGatewayResolver,
next_middleware: Callable[..., Any],
) -> dict | tuple | Response | BedrockResponse:
"""
Async version of _registered_api_adapter.

Detects if the route handler is a coroutine and awaits it.
_to_response() stays sync (CPU-bound — no async benefit).

IMPORTANT: This is an internal building block only.
Nothing calls it in the resolve chain yet. It will be used
by resolve_async() (see issue #8137).

Parameters
----------
app: ApiGatewayResolver
The API Gateway resolver
next_middleware: Callable[..., Any]
The function to handle the API

Returns
-------
Response
The API Response Object
"""
route_args: dict = app.context.get("_route_args", {})
logger.debug(f"Calling API Route Handler: {route_args}")

route = app.context.get("_route")
if route is not None:
if not route.request_param_name_checked:
from aws_lambda_powertools.event_handler.api_gateway import _find_request_param_name

route.request_param_name = _find_request_param_name(next_middleware)
route.request_param_name_checked = True
if route.request_param_name:
route_args = {**route_args, route.request_param_name: app.request}

if route.has_dependencies:
from aws_lambda_powertools.event_handler.depends import build_dependency_tree, solve_dependencies

dep_values = solve_dependencies(
dependant=build_dependency_tree(route.func),
request=app.request,
dependency_overrides=app.dependency_overrides or None,
)
route_args.update(dep_values)

# Call handler — detect if result is a coroutine and await it
result = next_middleware(**route_args)
if inspect.iscoroutine(result):
result = await result

return app._to_response(result)
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
from __future__ import annotations

import asyncio
from typing import cast

import pytest

from aws_lambda_powertools.event_handler import content_types
from aws_lambda_powertools.event_handler.api_gateway import (
APIGatewayHttpResolver,
ApiGatewayResolver,
APIGatewayRestResolver,
BaseRouter,
ProxyEventType,
Response,
)
from aws_lambda_powertools.event_handler.middlewares.async_utils import _registered_api_adapter_async
from tests.functional.utils import load_event

API_REST_EVENT = load_event("apiGatewayProxyEvent.json")
API_RESTV2_EVENT = load_event("apiGatewayProxyV2Event_GET.json")


def _setup_resolver_context(app: ApiGatewayResolver, event: dict) -> None:
"""Populate the resolver context the same way resolve() does, without calling the full chain."""
BaseRouter.current_event = app._to_proxy_event(cast(dict, event))
BaseRouter.lambda_context = {}


@pytest.mark.parametrize(
"app, event",
[
(ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT),
(APIGatewayRestResolver(), API_REST_EVENT),
(APIGatewayHttpResolver(), API_RESTV2_EVENT),
],
)
def test_sync_handler_returns_response(app: ApiGatewayResolver, event):
# GIVEN a sync route handler
@app.get("/my/path")
def get_lambda():
return Response(200, content_types.TEXT_HTML, "sync response")

# WHEN resolving the event through the normal chain
result = app(event, {})

# THEN the sync handler is called and returns correctly
assert result["statusCode"] == 200
assert result["body"] == "sync response"


@pytest.mark.parametrize(
"app, event",
[
(ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT),
(APIGatewayRestResolver(), API_REST_EVENT),
(APIGatewayHttpResolver(), API_RESTV2_EVENT),
],
)
def test_async_handler_is_awaited(app: ApiGatewayResolver, event):
# GIVEN an async route handler registered on the resolver
@app.get("/my/path")
async def get_lambda():
return Response(200, content_types.TEXT_HTML, "async response")

# WHEN populating context and calling the async adapter directly
_setup_resolver_context(app, event)
app.append_context(_route_args={})

result = asyncio.run(
_registered_api_adapter_async(app, get_lambda),
)

# THEN the async handler is awaited and returns correctly
assert result.status_code == 200
assert result.body == "async response"


@pytest.mark.parametrize(
"app, event",
[
(ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT),
(APIGatewayRestResolver(), API_REST_EVENT),
(APIGatewayHttpResolver(), API_RESTV2_EVENT),
],
)
def test_sync_handler_through_adapter(app: ApiGatewayResolver, event):
# GIVEN a sync route handler
@app.get("/my/path")
def get_lambda():
return Response(200, content_types.TEXT_HTML, "sync via adapter")

# WHEN calling _registered_api_adapter_async with a sync handler
_setup_resolver_context(app, event)
app.append_context(_route_args={})

result = asyncio.run(
_registered_api_adapter_async(app, get_lambda),
)

# THEN sync handler works through the async adapter without issue
assert result.status_code == 200
assert result.body == "sync via adapter"


@pytest.mark.parametrize(
"app, event",
[
(ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT),
(APIGatewayRestResolver(), API_REST_EVENT),
(APIGatewayHttpResolver(), API_RESTV2_EVENT),
],
)
def test_adapter_passes_route_args_to_async_handler(app: ApiGatewayResolver, event):
# GIVEN an async handler that expects route arguments
async def get_lambda(name: str):

Check warning on line 116 in tests/functional/event_handler/required_dependencies/test_registered_api_adapter_async.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Use asynchronous features in this function or remove the `async` keyword.

See more on https://sonarcloud.io/project/issues?id=aws-powertools_powertools-lambda-python&issues=AZ2lAXhqWySueEBar__d&open=AZ2lAXhqWySueEBar__d&pullRequest=8157
return Response(200, content_types.TEXT_HTML, name)

# WHEN route_args are set in the context
_setup_resolver_context(app, event)
app.append_context(_route_args={"name": "powertools"})

result = asyncio.run(
_registered_api_adapter_async(app, get_lambda),
)

# THEN the route args are passed to the handler
assert result.status_code == 200
assert result.body == "powertools"


@pytest.mark.parametrize(
"app, event",
[
(ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent), API_REST_EVENT),
(APIGatewayRestResolver(), API_REST_EVENT),
(APIGatewayHttpResolver(), API_RESTV2_EVENT),
],
)
def test_adapter_passes_route_args_to_sync_handler(app: ApiGatewayResolver, event):
# GIVEN a sync handler that expects route arguments
def get_lambda(name: str):
return Response(200, content_types.TEXT_HTML, name)

# WHEN route_args are set in the context
_setup_resolver_context(app, event)
app.append_context(_route_args={"name": "powertools"})

result = asyncio.run(
_registered_api_adapter_async(app, get_lambda),
)

# THEN the route args are passed to the sync handler
assert result.status_code == 200
assert result.body == "powertools"


def test_adapter_converts_dict_response_from_async_handler():
# GIVEN an async handler that returns a dict (not a Response object)
app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent)

async def get_lambda():

Check warning on line 162 in tests/functional/event_handler/required_dependencies/test_registered_api_adapter_async.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Use asynchronous features in this function or remove the `async` keyword.

See more on https://sonarcloud.io/project/issues?id=aws-powertools_powertools-lambda-python&issues=AZ2lAXhqWySueEBar__e&open=AZ2lAXhqWySueEBar__e&pullRequest=8157
return {"message": "hello"}

# WHEN calling through the async adapter
_setup_resolver_context(app, API_REST_EVENT)
app.append_context(_route_args={})

result = asyncio.run(
_registered_api_adapter_async(app, get_lambda),
)

# THEN _to_response normalizes the dict into a Response object
assert result.status_code == 200
assert result.body is not None


def test_adapter_converts_tuple_response_from_async_handler():
# GIVEN an async handler that returns a (dict, status_code) tuple
app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent)

async def get_lambda():

Check warning on line 182 in tests/functional/event_handler/required_dependencies/test_registered_api_adapter_async.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Use asynchronous features in this function or remove the `async` keyword.

See more on https://sonarcloud.io/project/issues?id=aws-powertools_powertools-lambda-python&issues=AZ2lAXhqWySueEBar__f&open=AZ2lAXhqWySueEBar__f&pullRequest=8157
return {"created": True}, 201

# WHEN calling through the async adapter
_setup_resolver_context(app, API_REST_EVENT)
app.append_context(_route_args={})

result = asyncio.run(
_registered_api_adapter_async(app, get_lambda),
)

# THEN _to_response normalizes the tuple into a Response object
assert result.status_code == 201


def test_adapter_with_no_route_in_context():
# GIVEN a handler and no _route in context
app = ApiGatewayResolver(proxy_type=ProxyEventType.APIGatewayProxyEvent)

async def get_lambda():

Check warning on line 201 in tests/functional/event_handler/required_dependencies/test_registered_api_adapter_async.py

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Use asynchronous features in this function or remove the `async` keyword.

See more on https://sonarcloud.io/project/issues?id=aws-powertools_powertools-lambda-python&issues=AZ2lAXhqWySueEBar__g&open=AZ2lAXhqWySueEBar__g&pullRequest=8157
return Response(200, content_types.TEXT_HTML, "no route")

# WHEN _route is None in context (default)
_setup_resolver_context(app, API_REST_EVENT)
app.append_context(_route_args={})

result = asyncio.run(
_registered_api_adapter_async(app, get_lambda),
)

# THEN the adapter skips request injection and dependency resolution
assert result.status_code == 200
assert result.body == "no route"
Loading