diff --git a/reflex/app_mixins/lifespan.py b/reflex/app_mixins/lifespan.py index a62195469c2..bac00517bb0 100644 --- a/reflex/app_mixins/lifespan.py +++ b/reflex/app_mixins/lifespan.py @@ -9,7 +9,7 @@ import inspect import time from collections.abc import Callable, Coroutine -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar, overload from reflex_base.utils import console from reflex_base.utils.exceptions import InvalidLifespanTaskTypeError @@ -20,6 +20,8 @@ if TYPE_CHECKING: from typing_extensions import deprecated +_LifespanTaskT = TypeVar("_LifespanTaskT", bound="Callable | asyncio.Task") + def _get_task_name(task: asyncio.Task | Callable) -> str: """Get a display name for a lifespan task. @@ -137,17 +139,43 @@ async def _run_lifespan_tasks(self, app: Starlette): else: await state_manager.close() - def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs): + @overload + def register_lifespan_task( + self, task: _LifespanTaskT, **task_kwargs + ) -> _LifespanTaskT: ... + + @overload + def register_lifespan_task( + self, task: None = None, **task_kwargs + ) -> Callable[[_LifespanTaskT], _LifespanTaskT]: ... + + def register_lifespan_task( + self, + task: Callable | asyncio.Task | None = None, + **task_kwargs, + ): """Register a task to run during the lifespan of the app. + Supports three call shapes: + - `app.register_lifespan_task(fn, **kwargs)` — direct call. + - `@app.register_lifespan_task` — bare decorator. + - `@app.register_lifespan_task(**kwargs)` — parameterized decorator. + Args: - task: The task to register. + task: The task to register, or None to return a decorator. **task_kwargs: The kwargs of the task. + Returns: + The original task when called with a task, or a decorator when + called without one. + Raises: InvalidLifespanTaskTypeError: If the task is a generator function. RuntimeError: If lifespan tasks are already running. """ + if task is None: + return functools.partial(self.register_lifespan_task, **task_kwargs) + if self._lifespan_tasks_started: msg = ( f"Cannot register lifespan task {_get_task_name(task)!r} after " @@ -159,9 +187,13 @@ def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs): raise InvalidLifespanTaskTypeError(msg) task_name = _get_task_name(task) + registered_task = task if task_kwargs: - original_task = task - task = functools.partial(task, **task_kwargs) # pyright: ignore [reportArgumentType] - functools.update_wrapper(task, original_task) # pyright: ignore [reportArgumentType] - self._lifespan_tasks[task] = None + if isinstance(task, asyncio.Task): + msg = f"Task {task_name!r} of type asyncio.Task cannot be registered with kwargs." + raise InvalidLifespanTaskTypeError(msg) + registered_task = functools.partial(task, **task_kwargs) + functools.update_wrapper(registered_task, task) + self._lifespan_tasks[registered_task] = None console.debug(f"Registered lifespan task: {task_name}") + return task diff --git a/tests/units/app_mixins/test_lifespan.py b/tests/units/app_mixins/test_lifespan.py new file mode 100644 index 00000000000..d1d2f38bdd4 --- /dev/null +++ b/tests/units/app_mixins/test_lifespan.py @@ -0,0 +1,55 @@ +"""Unit tests for lifespan app mixin behavior.""" + +from __future__ import annotations + +import asyncio +import contextlib + +import pytest +from reflex_base.utils.exceptions import InvalidLifespanTaskTypeError + +from reflex.app_mixins.lifespan import LifespanMixin + + +def test_register_lifespan_task_can_be_used_as_decorator(): + """Bare decorator registers the task and preserves the name binding.""" + mixin = LifespanMixin() + + @mixin.register_lifespan_task + def polling_task() -> str: + return "ok" + + assert polling_task() == "ok" + assert polling_task in mixin.get_lifespan_tasks() + + +def test_register_lifespan_task_with_kwargs_can_be_used_as_decorator(): + """Decorator-with-kwargs registers a partial that applies the kwargs.""" + mixin = LifespanMixin() + + @mixin.register_lifespan_task(timeout=10) + def check_for_updates(timeout: int) -> int: + return timeout + + assert check_for_updates(timeout=4) == 4 + + (registered_task,) = mixin.get_lifespan_tasks() + assert not isinstance(registered_task, asyncio.Task) + assert registered_task() == 10 + + +async def test_register_lifespan_task_rejects_kwargs_for_asyncio_task(): + """Registering kwargs against an asyncio.Task raises a clear error.""" + mixin = LifespanMixin() + task = asyncio.create_task(asyncio.sleep(0), name="scheduled-lifespan-task") + + try: + with pytest.raises( + InvalidLifespanTaskTypeError, + match=r"of type asyncio\.Task cannot be registered with kwargs", + ): + mixin.register_lifespan_task(task, timeout=10) + finally: + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task