From 04f31d56054daca7ab4be6735be50023c91eff28 Mon Sep 17 00:00:00 2001 From: Benjamin Barrera-Altuna Date: Sun, 19 Apr 2026 18:16:15 -0400 Subject: [PATCH 1/2] Support register_lifespan_task decorator usage --- reflex/app_mixins/lifespan.py | 30 ++++++++++++--- tests/units/app_mixins/test_lifespan.py | 50 +++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 5 deletions(-) create mode 100644 tests/units/app_mixins/test_lifespan.py diff --git a/reflex/app_mixins/lifespan.py b/reflex/app_mixins/lifespan.py index a62195469c2..0bd9cb17642 100644 --- a/reflex/app_mixins/lifespan.py +++ b/reflex/app_mixins/lifespan.py @@ -137,17 +137,32 @@ async def _run_lifespan_tasks(self, app: Starlette): else: await state_manager.close() - def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs): + def register_lifespan_task( + self, + task: Callable | asyncio.Task | None = None, + **task_kwargs, + ) -> ( + Callable + | asyncio.Task + | Callable[[Callable | asyncio.Task], Callable | asyncio.Task] + ): """Register a task to run during the lifespan of the app. Args: task: The task to register. **task_kwargs: The kwargs of the task. + Returns: + The original task when called directly, or a decorator when called + with kwargs only. + Raises: InvalidLifespanTaskTypeError: If the task is a generator function. RuntimeError: If lifespan tasks are already running. """ + if task is None: + return functools.partial(self.register_lifespan_task, **task_kwargs) + if self._lifespan_tasks_started: msg = ( f"Cannot register lifespan task {_get_task_name(task)!r} after " @@ -159,9 +174,14 @@ def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs): raise InvalidLifespanTaskTypeError(msg) task_name = _get_task_name(task) + registered_task = task if task_kwargs: - original_task = task - task = functools.partial(task, **task_kwargs) # pyright: ignore [reportArgumentType] - functools.update_wrapper(task, original_task) # pyright: ignore [reportArgumentType] - self._lifespan_tasks[task] = None + registered_task = functools.partial( + task, **task_kwargs + ) # pyright: ignore [reportArgumentType] + functools.update_wrapper( + registered_task, task + ) # pyright: ignore [reportArgumentType] + self._lifespan_tasks[registered_task] = None console.debug(f"Registered lifespan task: {task_name}") + return task diff --git a/tests/units/app_mixins/test_lifespan.py b/tests/units/app_mixins/test_lifespan.py new file mode 100644 index 00000000000..1be2c8bb0ad --- /dev/null +++ b/tests/units/app_mixins/test_lifespan.py @@ -0,0 +1,50 @@ +"""Unit tests for lifespan app mixin behavior.""" + +from __future__ import annotations + +import functools + +from reflex.app_mixins.lifespan import LifespanMixin + + +def test_register_lifespan_task_can_be_used_as_decorator(): + """Decorating a task registers it and preserves the task callable.""" + mixin = LifespanMixin() + + @mixin.register_lifespan_task + def polling_task() -> str: + """Return a sentinel value for direct-call verification. + + Returns: + A sentinel string. + """ + return "ok" + + assert polling_task() == "ok" + assert polling_task in mixin.get_lifespan_tasks() + + +def test_register_lifespan_task_with_kwargs_can_be_used_as_decorator(): + """Decorator-with-kwargs preserves function binding and registers partial.""" + mixin = LifespanMixin() + + @mixin.register_lifespan_task(timeout=10) + def check_for_updates(timeout: int) -> int: + """Echo timeout to verify direct function access is preserved. + + Args: + timeout: Timeout value in seconds. + + Returns: + The timeout value passed to the function. + """ + return timeout + + assert check_for_updates(timeout=4) == 4 + + registered_tasks = mixin.get_lifespan_tasks() + assert len(registered_tasks) == 1 + registered_task = registered_tasks[0] + assert isinstance(registered_task, functools.partial) + assert registered_task.func is check_for_updates + assert registered_task.keywords == {"timeout": 10} From 8b97bdd9d286a180f3058bdbb3e00ba14807058f Mon Sep 17 00:00:00 2001 From: Farhan Ali Raza Date: Tue, 21 Apr 2026 01:37:44 +0500 Subject: [PATCH 2/2] refactor(lifespan): tighten register_lifespan_task typing and guard Replace the compound return-type union with @overload stubs + a TypeVar so decorated functions keep their exact signature. Reject asyncio.Task + kwargs early with a clear error (previously would silently build a broken partial), which also eliminates the two pyright ignores. Swap the test's implementation-detail asserts (functools.partial internals) for a behavioral call-the-task check. --- reflex/app_mixins/lifespan.py | 40 +++++++++++++------- tests/units/app_mixins/test_lifespan.py | 49 ++++++++++++++----------- 2 files changed, 53 insertions(+), 36 deletions(-) diff --git a/reflex/app_mixins/lifespan.py b/reflex/app_mixins/lifespan.py index 0bd9cb17642..bac00517bb0 100644 --- a/reflex/app_mixins/lifespan.py +++ b/reflex/app_mixins/lifespan.py @@ -9,7 +9,7 @@ import inspect import time from collections.abc import Callable, Coroutine -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeVar, overload from reflex_base.utils import console from reflex_base.utils.exceptions import InvalidLifespanTaskTypeError @@ -20,6 +20,8 @@ if TYPE_CHECKING: from typing_extensions import deprecated +_LifespanTaskT = TypeVar("_LifespanTaskT", bound="Callable | asyncio.Task") + def _get_task_name(task: asyncio.Task | Callable) -> str: """Get a display name for a lifespan task. @@ -137,24 +139,35 @@ async def _run_lifespan_tasks(self, app: Starlette): else: await state_manager.close() + @overload + def register_lifespan_task( + self, task: _LifespanTaskT, **task_kwargs + ) -> _LifespanTaskT: ... + + @overload + def register_lifespan_task( + self, task: None = None, **task_kwargs + ) -> Callable[[_LifespanTaskT], _LifespanTaskT]: ... + def register_lifespan_task( self, task: Callable | asyncio.Task | None = None, **task_kwargs, - ) -> ( - Callable - | asyncio.Task - | Callable[[Callable | asyncio.Task], Callable | asyncio.Task] ): """Register a task to run during the lifespan of the app. + Supports three call shapes: + - `app.register_lifespan_task(fn, **kwargs)` — direct call. + - `@app.register_lifespan_task` — bare decorator. + - `@app.register_lifespan_task(**kwargs)` — parameterized decorator. + Args: - task: The task to register. + task: The task to register, or None to return a decorator. **task_kwargs: The kwargs of the task. Returns: - The original task when called directly, or a decorator when called - with kwargs only. + The original task when called with a task, or a decorator when + called without one. Raises: InvalidLifespanTaskTypeError: If the task is a generator function. @@ -176,12 +189,11 @@ def register_lifespan_task( task_name = _get_task_name(task) registered_task = task if task_kwargs: - registered_task = functools.partial( - task, **task_kwargs - ) # pyright: ignore [reportArgumentType] - functools.update_wrapper( - registered_task, task - ) # pyright: ignore [reportArgumentType] + if isinstance(task, asyncio.Task): + msg = f"Task {task_name!r} of type asyncio.Task cannot be registered with kwargs." + raise InvalidLifespanTaskTypeError(msg) + registered_task = functools.partial(task, **task_kwargs) + functools.update_wrapper(registered_task, task) self._lifespan_tasks[registered_task] = None console.debug(f"Registered lifespan task: {task_name}") return task diff --git a/tests/units/app_mixins/test_lifespan.py b/tests/units/app_mixins/test_lifespan.py index 1be2c8bb0ad..d1d2f38bdd4 100644 --- a/tests/units/app_mixins/test_lifespan.py +++ b/tests/units/app_mixins/test_lifespan.py @@ -2,22 +2,21 @@ from __future__ import annotations -import functools +import asyncio +import contextlib + +import pytest +from reflex_base.utils.exceptions import InvalidLifespanTaskTypeError from reflex.app_mixins.lifespan import LifespanMixin def test_register_lifespan_task_can_be_used_as_decorator(): - """Decorating a task registers it and preserves the task callable.""" + """Bare decorator registers the task and preserves the name binding.""" mixin = LifespanMixin() @mixin.register_lifespan_task def polling_task() -> str: - """Return a sentinel value for direct-call verification. - - Returns: - A sentinel string. - """ return "ok" assert polling_task() == "ok" @@ -25,26 +24,32 @@ def polling_task() -> str: def test_register_lifespan_task_with_kwargs_can_be_used_as_decorator(): - """Decorator-with-kwargs preserves function binding and registers partial.""" + """Decorator-with-kwargs registers a partial that applies the kwargs.""" mixin = LifespanMixin() @mixin.register_lifespan_task(timeout=10) def check_for_updates(timeout: int) -> int: - """Echo timeout to verify direct function access is preserved. - - Args: - timeout: Timeout value in seconds. - - Returns: - The timeout value passed to the function. - """ return timeout assert check_for_updates(timeout=4) == 4 - registered_tasks = mixin.get_lifespan_tasks() - assert len(registered_tasks) == 1 - registered_task = registered_tasks[0] - assert isinstance(registered_task, functools.partial) - assert registered_task.func is check_for_updates - assert registered_task.keywords == {"timeout": 10} + (registered_task,) = mixin.get_lifespan_tasks() + assert not isinstance(registered_task, asyncio.Task) + assert registered_task() == 10 + + +async def test_register_lifespan_task_rejects_kwargs_for_asyncio_task(): + """Registering kwargs against an asyncio.Task raises a clear error.""" + mixin = LifespanMixin() + task = asyncio.create_task(asyncio.sleep(0), name="scheduled-lifespan-task") + + try: + with pytest.raises( + InvalidLifespanTaskTypeError, + match=r"of type asyncio\.Task cannot be registered with kwargs", + ): + mixin.register_lifespan_task(task, timeout=10) + finally: + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task