Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions reflex/app_mixins/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import inspect
import time
from collections.abc import Callable, Coroutine
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypeVar, overload

from reflex_base.utils import console
from reflex_base.utils.exceptions import InvalidLifespanTaskTypeError
Expand All @@ -20,6 +20,8 @@
if TYPE_CHECKING:
from typing_extensions import deprecated

_LifespanTaskT = TypeVar("_LifespanTaskT", bound="Callable | asyncio.Task")


def _get_task_name(task: asyncio.Task | Callable) -> str:
"""Get a display name for a lifespan task.
Expand Down Expand Up @@ -137,17 +139,43 @@ async def _run_lifespan_tasks(self, app: Starlette):
else:
await state_manager.close()

def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
@overload
def register_lifespan_task(
self, task: _LifespanTaskT, **task_kwargs
) -> _LifespanTaskT: ...

@overload
def register_lifespan_task(
self, task: None = None, **task_kwargs
) -> Callable[[_LifespanTaskT], _LifespanTaskT]: ...

def register_lifespan_task(
self,
task: Callable | asyncio.Task | None = None,
**task_kwargs,
):
"""Register a task to run during the lifespan of the app.

Supports three call shapes:
- `app.register_lifespan_task(fn, **kwargs)` — direct call.
- `@app.register_lifespan_task` — bare decorator.
- `@app.register_lifespan_task(**kwargs)` — parameterized decorator.

Args:
task: The task to register.
task: The task to register, or None to return a decorator.
**task_kwargs: The kwargs of the task.

Returns:
The original task when called with a task, or a decorator when
called without one.

Raises:
InvalidLifespanTaskTypeError: If the task is a generator function.
RuntimeError: If lifespan tasks are already running.
"""
if task is None:
return functools.partial(self.register_lifespan_task, **task_kwargs)

if self._lifespan_tasks_started:
msg = (
f"Cannot register lifespan task {_get_task_name(task)!r} after "
Expand All @@ -159,9 +187,13 @@ def register_lifespan_task(self, task: Callable | asyncio.Task, **task_kwargs):
raise InvalidLifespanTaskTypeError(msg)

task_name = _get_task_name(task)
registered_task = task
if task_kwargs:
original_task = task
task = functools.partial(task, **task_kwargs) # pyright: ignore [reportArgumentType]
functools.update_wrapper(task, original_task) # pyright: ignore [reportArgumentType]
self._lifespan_tasks[task] = None
if isinstance(task, asyncio.Task):
msg = f"Task {task_name!r} of type asyncio.Task cannot be registered with kwargs."
raise InvalidLifespanTaskTypeError(msg)
registered_task = functools.partial(task, **task_kwargs)
functools.update_wrapper(registered_task, task)
self._lifespan_tasks[registered_task] = None
console.debug(f"Registered lifespan task: {task_name}")
return task
55 changes: 55 additions & 0 deletions tests/units/app_mixins/test_lifespan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Unit tests for lifespan app mixin behavior."""

from __future__ import annotations

import asyncio
import contextlib

import pytest
from reflex_base.utils.exceptions import InvalidLifespanTaskTypeError

from reflex.app_mixins.lifespan import LifespanMixin


def test_register_lifespan_task_can_be_used_as_decorator():
"""Bare decorator registers the task and preserves the name binding."""
mixin = LifespanMixin()

@mixin.register_lifespan_task
def polling_task() -> str:
return "ok"

assert polling_task() == "ok"
assert polling_task in mixin.get_lifespan_tasks()


def test_register_lifespan_task_with_kwargs_can_be_used_as_decorator():
"""Decorator-with-kwargs registers a partial that applies the kwargs."""
mixin = LifespanMixin()

@mixin.register_lifespan_task(timeout=10)
def check_for_updates(timeout: int) -> int:
return timeout

assert check_for_updates(timeout=4) == 4

(registered_task,) = mixin.get_lifespan_tasks()
assert not isinstance(registered_task, asyncio.Task)
assert registered_task() == 10


async def test_register_lifespan_task_rejects_kwargs_for_asyncio_task():
"""Registering kwargs against an asyncio.Task raises a clear error."""
mixin = LifespanMixin()
task = asyncio.create_task(asyncio.sleep(0), name="scheduled-lifespan-task")

try:
with pytest.raises(
InvalidLifespanTaskTypeError,
match=r"of type asyncio\.Task cannot be registered with kwargs",
):
mixin.register_lifespan_task(task, timeout=10)
finally:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
Loading