From 36910b24ab5004c0188df883585eac04b613497c Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Tue, 10 Mar 2026 12:09:46 -0600 Subject: [PATCH 1/5] Add cancellable tasks --- durabletask/task.py | 52 +++++++++++- durabletask/worker.py | 42 ++++++++-- .../test_orchestration_executor.py | 83 ++++++++++++++++++- 3 files changed, 167 insertions(+), 10 deletions(-) diff --git a/durabletask/task.py b/durabletask/task.py index 0ef03da..6dfe789 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -98,7 +98,7 @@ def set_custom_status(self, custom_status: Any) -> None: pass @abstractmethod - def create_timer(self, fire_at: Union[datetime, timedelta]) -> Task: + def create_timer(self, fire_at: Union[datetime, timedelta]) -> CancellableTask: """Create a Timer Task to fire after at the specified deadline. Parameters @@ -231,7 +231,7 @@ def call_sub_orchestrator(self, orchestrator: Union[Orchestrator[TInput, TOutput # TOOD: Add a timeout parameter, which allows the task to be canceled if the event is # not received within the specified timeout. This requires support for task cancellation. @abstractmethod - def wait_for_external_event(self, name: str) -> CompletableTask: + def wait_for_external_event(self, name: str) -> CancellableTask: """Wait asynchronously for an event to be raised with the name `name`. Parameters @@ -324,6 +324,10 @@ class OrchestrationStateError(Exception): pass +class TaskCanceledError(Exception): + """Exception type for canceled orchestration tasks.""" + + class Task(ABC, Generic[T]): """Abstract base class for asynchronous tasks in a durable orchestration.""" _result: T @@ -435,6 +439,48 @@ def fail(self, message: str, details: Union[Exception, pb.TaskFailureDetails]): self._parent.on_child_completed(self) +class CancellableTask(CompletableTask[T]): + """A completable task that can be canceled before it finishes.""" + + def __init__(self) -> None: + super().__init__() + self._is_cancelled = False + self._cancel_handler: Optional[Callable[[], None]] = None + + @property + def is_cancelled(self) -> bool: + """Returns True if the task was canceled, False otherwise.""" + return self._is_cancelled + + def get_result(self) -> T: + if self._is_cancelled: + raise TaskCanceledError('The task was canceled.') + return super().get_result() + + def set_cancel_handler(self, cancel_handler: Callable[[], None]) -> None: + self._cancel_handler = cancel_handler + + def cancel(self) -> bool: + """Attempts to cancel this task. + + Returns + ------- + bool + True if cancellation was applied, False if the task had already completed. + """ + if self._is_complete: + return False + + if self._cancel_handler is not None: + self._cancel_handler() + + self._is_cancelled = True + self._is_complete = True + if self._parent is not None: + self._parent.on_child_completed(self) + return True + + class RetryableTask(CompletableTask[T]): """A task that can be retried according to a retry policy.""" @@ -474,7 +520,7 @@ def compute_next_delay(self) -> Optional[timedelta]: return None -class TimerTask(CompletableTask[T]): +class TimerTask(CancellableTask[T]): def __init__(self) -> None: super().__init__() diff --git a/durabletask/worker.py b/durabletask/worker.py index 442165d..9edf2ac 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -307,7 +307,7 @@ class TaskHubGrpcWorker: activity function. """ - _response_stream: Optional[grpc.Future] = None + _response_stream: Optional[Any] = None _interceptors: Optional[list[shared.ClientInterceptor]] = None def __init__( @@ -512,7 +512,11 @@ def should_invalidate_connection(rpc_error): def stream_reader(): try: - for work_item in self._response_stream: + response_stream = self._response_stream + if response_stream is None: + return + + for work_item in response_stream: work_item_queue.put(work_item) except Exception as e: work_item_queue.put(e) @@ -843,7 +847,7 @@ def __init__(self, instance_id: str, registry: _Registry): self._version: Optional[str] = None self._completion_status: Optional[pb.OrchestrationStatus] = None self._received_events: dict[str, list[Any]] = {} - self._pending_events: dict[str, list[task.CompletableTask]] = {} + self._pending_events: dict[str, list[task.CancellableTask]] = {} self._new_input: Optional[Any] = None self._save_events = False self._encoded_custom_status: Optional[str] = None @@ -1026,7 +1030,13 @@ def create_timer_internal( action = ph.new_create_timer_action(id, fire_at) self._pending_actions[id] = action - timer_task: task.TimerTask = task.TimerTask() + timer_task = task.TimerTask() + + def _cancel_timer() -> None: + self._pending_actions.pop(id, None) + self._pending_tasks.pop(id, None) + + timer_task.set_cancel_handler(_cancel_timer) if retryable_task is not None: timer_task.set_retryable_parent(retryable_task) self._pending_tasks[id] = timer_task @@ -1234,13 +1244,13 @@ def _exit_critical_section(self) -> None: action = pb.OrchestratorAction(id=task_id, sendEntityMessage=entity_unlock_message) self._pending_actions[task_id] = action - def wait_for_external_event(self, name: str) -> task.CompletableTask: + def wait_for_external_event(self, name: str) -> task.CancellableTask: # Check to see if this event has already been received, in which case we # can return it immediately. Otherwise, record out intent to receive an # event with the given name so that we can resume the generator when it # arrives. If there are multiple events with the same name, we return # them in the order they were received. - external_event_task: task.CompletableTask = task.CompletableTask() + external_event_task: task.CancellableTask = task.CancellableTask() event_name = name.casefold() event_list = self._received_events.get(event_name, None) if event_list: @@ -1254,6 +1264,19 @@ def wait_for_external_event(self, name: str) -> task.CompletableTask: task_list = [] self._pending_events[event_name] = task_list task_list.append(external_event_task) + + def _cancel_wait() -> None: + waiting_tasks = self._pending_events.get(event_name) + if waiting_tasks is None: + return + try: + waiting_tasks.remove(external_event_task) + except ValueError: + return + if not waiting_tasks: + del self._pending_events[event_name] + + external_event_task.set_cancel_handler(_cancel_wait) return external_event_task def continue_as_new(self, new_input, *, save_events: bool = False) -> None: @@ -1450,6 +1473,13 @@ def process_event( f"{ctx.instance_id}: Ignoring unexpected timerFired event with ID = {timer_id}." ) return + if not isinstance(timer_task, task.TimerTask): + if not ctx._is_replaying: + self._logger.warning( + f"{ctx.instance_id}: Ignoring timerFired event with non-timer task ID = {timer_id}." + ) + return + timer_task.complete(None) if timer_task._retryable_parent is not None: activity_action = timer_task._retryable_parent._action diff --git a/tests/durabletask/test_orchestration_executor.py b/tests/durabletask/test_orchestration_executor.py index 14d5e14..6300d21 100644 --- a/tests/durabletask/test_orchestration_executor.py +++ b/tests/durabletask/test_orchestration_executor.py @@ -143,6 +143,87 @@ def delay_orchestrator(ctx: task.OrchestrationContext, _): assert complete_action.result.value == '"done"' # results are JSON-encoded +def test_timer_can_be_cancelled_after_when_any_winner(): + """Tests cancellation of an outstanding timer task after another task wins when_any.""" + + def orchestrator(ctx: task.OrchestrationContext, _): + approval = ctx.wait_for_external_event("approval") + timeout = ctx.create_timer(timedelta(hours=1)) + winner = yield task.when_any([approval, timeout]) + if winner == approval: + timeout.cancel() + return "approved" + return "timed out" + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + + start_time = datetime(2020, 1, 1, 12, 0, 0) + timeout_fire_at = start_time + timedelta(hours=1) + + result = executor.execute( + TEST_INSTANCE_ID, + [], + [ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + ], + ) + assert len(result.actions) == 1 + assert result.actions[0].HasField("createTimer") + assert result.actions[0].createTimer.fireAt.ToDatetime() == timeout_fire_at + + old_events = [ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_timer_created_event(1, timeout_fire_at), + ] + result = executor.execute( + TEST_INSTANCE_ID, + old_events, + [helpers.new_event_raised_event("approval", json.dumps(True))], + ) + complete_action = get_and_validate_complete_orchestration_action_list(1, result.actions) + assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED + assert complete_action.result.value == '"approved"' + + +def test_only_cancellable_tasks_expose_cancel(): + """Tests that only timer and external-event tasks expose cancellation state and operations.""" + + def dummy_activity(ctx, _): + pass + + ctx = worker._RuntimeOrchestrationContext(TEST_INSTANCE_ID, worker._Registry()) + + timer_task = ctx.create_timer(timedelta(minutes=5)) + external_event_task = ctx.wait_for_external_event("approval") + activity_task = ctx.call_activity(dummy_activity) + + assert isinstance(timer_task, task.CancellableTask) + assert isinstance(external_event_task, task.CancellableTask) + assert not isinstance(activity_task, task.CancellableTask) + assert hasattr(timer_task, "cancel") + assert hasattr(external_event_task, "cancel") + assert not hasattr(activity_task, "cancel") + assert hasattr(timer_task, "is_cancelled") + assert hasattr(external_event_task, "is_cancelled") + assert not hasattr(activity_task, "is_cancelled") + + +def test_cancelled_task_get_result_raises_task_canceled_error(): + """Tests that canceled cancellable tasks raise TaskCanceledError from get_result.""" + + cancellable_task = task.CancellableTask() + + assert cancellable_task.cancel() is True + assert cancellable_task.is_cancelled is True + + with pytest.raises(task.TaskCanceledError): + cancellable_task.get_result() + + def test_schedule_activity_actions(): """Test the actions output for the call_activity orchestrator method""" def dummy_activity(ctx, _): @@ -1313,7 +1394,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): encoded_output = json.dumps(dummy_activity(None, "Seattle")) old_events = old_events + new_events new_events = [helpers.new_task_completed_event(2, encoded_output), - helpers.new_timer_fired_event(4, current_timestamp)] + helpers.new_timer_fired_event(4, expected_fire_at)] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions From 57d8b7df1b63dbe0ce115eded4aa4c9de55730e5 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Tue, 10 Mar 2026 12:39:04 -0600 Subject: [PATCH 2/5] Implement long timer support --- durabletask/task.py | 27 +- durabletask/worker.py | 96 +++++-- .../test_orchestration_executor.py | 268 ++++++++++++++++++ 3 files changed, 360 insertions(+), 31 deletions(-) diff --git a/durabletask/task.py b/durabletask/task.py index 6dfe789..e44de2e 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -520,13 +520,32 @@ def compute_next_delay(self) -> Optional[timedelta]: return None -class TimerTask(CancellableTask[T]): +class TimerTask(CancellableTask[None]): + def set_retryable_parent(self, retryable_task: RetryableTask): + self._retryable_parent = retryable_task - def __init__(self) -> None: + def complete(self, *args, **kwargs): + super().complete(None) + + +class LongTimerTask(TimerTask): + def __init__(self, final_fire_at: datetime, maximum_timer_duration: timedelta): super().__init__() + self._final_fire_at = final_fire_at + self._maximum_timer_duration = maximum_timer_duration - def set_retryable_parent(self, retryable_task: RetryableTask): - self._retryable_parent = retryable_task + def start(self, current_utc_datetime: datetime) -> datetime: + return self._get_next_fire_at(current_utc_datetime) + + def complete(self, current_utc_datetime: datetime): + if current_utc_datetime < self._final_fire_at: + return self._get_next_fire_at(current_utc_datetime) + super().complete(None) + + def _get_next_fire_at(self, current_utc_datetime: datetime) -> datetime: + if current_utc_datetime + self._maximum_timer_duration < self._final_fire_at: + return current_utc_datetime + self._maximum_timer_duration + return self._final_fire_at class WhenAnyTask(CompositeTask[Task]): diff --git a/durabletask/worker.py b/durabletask/worker.py index 9edf2ac..f473b38 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -38,6 +38,7 @@ TInput = TypeVar("TInput") TOutput = TypeVar("TOutput") DATETIME_STRING_FORMAT = '%Y-%m-%dT%H:%M:%S.%fZ' +DEFAULT_MAXIMUM_TIMER_INTERVAL = timedelta(days=3) class ConcurrencyOptions: @@ -320,6 +321,7 @@ def __init__( secure_channel: bool = False, interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, concurrency_options: Optional[ConcurrencyOptions] = None, + maximum_timer_interval: Optional[timedelta] = DEFAULT_MAXIMUM_TIMER_INTERVAL ): self._registry = _Registry() self._host_address = ( @@ -348,12 +350,18 @@ def __init__( self._interceptors = None self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options, self._logger) + self._maximum_timer_interval = maximum_timer_interval @property def concurrency_options(self) -> ConcurrencyOptions: """Get the current concurrency options for this worker.""" return self._concurrency_options + @property + def maximum_timer_interval(self) -> Optional[timedelta]: + """Get the configured maximum timer interval for long timer chunking.""" + return self._maximum_timer_interval + def __enter__(self): return self @@ -826,7 +834,11 @@ class _RuntimeOrchestrationContext(task.OrchestrationContext): _generator: Optional[Generator[task.Task, Any, Any]] _previous_task: Optional[task.Task] - def __init__(self, instance_id: str, registry: _Registry): + def __init__(self, + instance_id: str, + registry: _Registry, + maximum_timer_interval: Optional[timedelta] = DEFAULT_MAXIMUM_TIMER_INTERVAL, + ): self._generator = None self._is_replaying = True self._is_complete = False @@ -851,6 +863,7 @@ def __init__(self, instance_id: str, registry: _Registry): self._new_input: Optional[Any] = None self._save_events = False self._encoded_custom_status: Optional[str] = None + self._maximum_timer_interval = maximum_timer_interval def run(self, generator: Generator[task.Task, Any, Any]): self._generator = generator @@ -1026,11 +1039,20 @@ def create_timer_internal( ) -> task.TimerTask: id = self.next_sequence_number() if isinstance(fire_at, timedelta): - fire_at = self.current_utc_datetime + fire_at - action = ph.new_create_timer_action(id, fire_at) - self._pending_actions[id] = action + final_fire_at = self.current_utc_datetime + fire_at + else: + final_fire_at = fire_at - timer_task = task.TimerTask() + next_fire_at: datetime = final_fire_at + + if self._maximum_timer_interval is not None and self.current_utc_datetime + self._maximum_timer_interval < final_fire_at: + timer_task = task.LongTimerTask(final_fire_at, self._maximum_timer_interval) + next_fire_at = timer_task.start(self.current_utc_datetime) + else: + timer_task = task.TimerTask() + + action = ph.new_create_timer_action(id, next_fire_at) + self._pending_actions[id] = action def _cancel_timer() -> None: self._pending_actions.pop(id, None) @@ -1311,9 +1333,13 @@ def __init__( class _OrchestrationExecutor: _generator: Optional[task.Orchestrator] = None - def __init__(self, registry: _Registry, logger: logging.Logger): + def __init__(self, + registry: _Registry, + logger: logging.Logger, + maximum_timer_interval: Optional[timedelta] = DEFAULT_MAXIMUM_TIMER_INTERVAL): self._registry = registry self._logger = logger + self._maximum_timer_interval = maximum_timer_interval self._is_suspended = False self._suspended_events: list[pb.HistoryEvent] = [] @@ -1337,7 +1363,11 @@ def execute( "The new history event list must have at least one event in it." ) - ctx = _RuntimeOrchestrationContext(instance_id, self._registry) + ctx = _RuntimeOrchestrationContext( + instance_id, + self._registry, + maximum_timer_interval=self._maximum_timer_interval, + ) try: # Rebuild local state by replaying old history into the orchestrator function self._logger.debug( @@ -1473,34 +1503,46 @@ def process_event( f"{ctx.instance_id}: Ignoring unexpected timerFired event with ID = {timer_id}." ) return - if not isinstance(timer_task, task.TimerTask): + if not (isinstance(timer_task, task.TimerTask) or isinstance(timer_task, task.LongTimerTask)): if not ctx._is_replaying: self._logger.warning( f"{ctx.instance_id}: Ignoring timerFired event with non-timer task ID = {timer_id}." ) return - timer_task.complete(None) - if timer_task._retryable_parent is not None: - activity_action = timer_task._retryable_parent._action + next_fire_at = timer_task.complete(event.timerFired.fireAt.ToDatetime()) + if next_fire_at is not None: + id = ctx.next_sequence_number() + new_action = ph.new_create_timer_action(id, next_fire_at) + ctx._pending_tasks[id] = timer_task + ctx._pending_actions[id] = new_action - if not timer_task._retryable_parent._is_sub_orch: - cur_task = activity_action.scheduleTask - instance_id = None - else: - cur_task = activity_action.createSubOrchestration - instance_id = cur_task.instanceId - ctx.call_activity_function_helper( - id=activity_action.id, - activity_function=cur_task.name, - input=cur_task.input.value, - retry_policy=timer_task._retryable_parent._retry_policy, - is_sub_orch=timer_task._retryable_parent._is_sub_orch, - instance_id=instance_id, - fn_task=timer_task._retryable_parent, - ) + def _cancel_timer() -> None: + ctx._pending_actions.pop(id, None) + ctx._pending_tasks.pop(id, None) + + timer_task.set_cancel_handler(_cancel_timer) else: - ctx.resume() + if timer_task._retryable_parent is not None: + activity_action = timer_task._retryable_parent._action + + if not timer_task._retryable_parent._is_sub_orch: + cur_task = activity_action.scheduleTask + instance_id = None + else: + cur_task = activity_action.createSubOrchestration + instance_id = cur_task.instanceId + ctx.call_activity_function_helper( + id=activity_action.id, + activity_function=cur_task.name, + input=cur_task.input.value, + retry_policy=timer_task._retryable_parent._retry_policy, + is_sub_orch=timer_task._retryable_parent._is_sub_orch, + instance_id=instance_id, + fn_task=timer_task._retryable_parent, + ) + else: + ctx.resume() elif event.HasField("taskScheduled"): # This history event confirms that the activity execution was successfully scheduled. # Remove the taskScheduled event from the pending action list so we don't schedule it again. diff --git a/tests/durabletask/test_orchestration_executor.py b/tests/durabletask/test_orchestration_executor.py index 6300d21..3e2cb02 100644 --- a/tests/durabletask/test_orchestration_executor.py +++ b/tests/durabletask/test_orchestration_executor.py @@ -143,6 +143,174 @@ def delay_orchestrator(ctx: task.OrchestrationContext, _): assert complete_action.result.value == '"done"' # results are JSON-encoded +def test_long_timer_is_chunked_by_maximum_timer_interval(): + """Tests that long timers are scheduled in chunks when exceeding max timer interval.""" + + def orchestrator(ctx: task.OrchestrationContext, _): + due_time = ctx.current_utc_datetime + timedelta(days=10) + yield ctx.create_timer(due_time) + return "done" + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + + start_time = datetime(2020, 1, 1, 12, 0, 0) + first_chunk_fire_at = start_time + timedelta(days=3) + + new_events = [ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None)] + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + result = executor.execute(TEST_INSTANCE_ID, [], new_events) + actions = result.actions + + assert len(actions) == 1 + assert actions[0].HasField("createTimer") + assert actions[0].id == 1 + assert actions[0].createTimer.fireAt.ToDatetime() == first_chunk_fire_at + + +def test_long_timer_progresses_and_completes_on_final_chunk(): + """Tests that long timers schedule intermediate chunks and complete on the final timerFired.""" + + def orchestrator(ctx: task.OrchestrationContext, _): + due_time = ctx.current_utc_datetime + timedelta(days=10) + yield ctx.create_timer(due_time) + return "done" + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + + start_time = datetime(2020, 1, 1, 12, 0, 0) + t1 = start_time + timedelta(days=3) + t2 = start_time + timedelta(days=6) + t3 = start_time + timedelta(days=9) + t4 = start_time + timedelta(days=10) + + # 1) Initial execution schedules first chunk. + first = executor.execute( + TEST_INSTANCE_ID, + [], + [ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + ], + ) + assert len(first.actions) == 1 + assert first.actions[0].HasField("createTimer") + assert first.actions[0].id == 1 + assert first.actions[0].createTimer.fireAt.ToDatetime() == t1 + + # 2) First chunk fires -> schedule second chunk. + second_old_events = [ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_timer_created_event(1, t1), + ] + second = executor.execute( + TEST_INSTANCE_ID, + second_old_events, + [helpers.new_timer_fired_event(1, t1)], + ) + assert len(second.actions) == 1 + assert second.actions[0].HasField("createTimer") + assert second.actions[0].id == 2 + assert second.actions[0].createTimer.fireAt.ToDatetime() == t2 + + # 3) Second chunk fires -> schedule third chunk. + third_old_events = second_old_events + [ + helpers.new_timer_fired_event(1, t1), + helpers.new_timer_created_event(2, t2), + ] + third = executor.execute( + TEST_INSTANCE_ID, + third_old_events, + [helpers.new_timer_fired_event(2, t2)], + ) + assert len(third.actions) == 1 + assert third.actions[0].HasField("createTimer") + assert third.actions[0].id == 3 + assert third.actions[0].createTimer.fireAt.ToDatetime() == t3 + + # 4) Third chunk fires -> schedule final short chunk. + fourth_old_events = third_old_events + [ + helpers.new_timer_fired_event(2, t2), + helpers.new_timer_created_event(3, t3), + ] + fourth = executor.execute( + TEST_INSTANCE_ID, + fourth_old_events, + [helpers.new_timer_fired_event(3, t3)], + ) + assert len(fourth.actions) == 1 + assert fourth.actions[0].HasField("createTimer") + assert fourth.actions[0].id == 4 + assert fourth.actions[0].createTimer.fireAt.ToDatetime() == t4 + + # 5) Final chunk fires -> orchestration completes. + fifth_old_events = fourth_old_events + [ + helpers.new_timer_fired_event(3, t3), + helpers.new_timer_created_event(4, t4), + ] + fifth = executor.execute( + TEST_INSTANCE_ID, + fifth_old_events, + [helpers.new_timer_fired_event(4, t4)], + ) + complete_action = get_and_validate_complete_orchestration_action_list(1, fifth.actions) + assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED + assert complete_action.result.value == '"done"' + + +def test_long_timer_can_be_cancelled_after_when_any_winner(): + """Tests cancellation of a long timer after an external event wins when_any.""" + + def orchestrator(ctx: task.OrchestrationContext, _): + approval = ctx.wait_for_external_event("approval") + timeout = ctx.create_timer(timedelta(days=10)) + winner = yield task.when_any([approval, timeout]) + if winner == approval: + timeout.cancel() + return "approved" + return "timed out" + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + + start_time = datetime(2020, 1, 1, 12, 0, 0) + first_chunk_fire_at = start_time + timedelta(days=3) + + # Initial execution schedules first long-timer chunk. + first = executor.execute( + TEST_INSTANCE_ID, + [], + [ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + ], + ) + assert len(first.actions) == 1 + assert first.actions[0].HasField("createTimer") + assert first.actions[0].createTimer.fireAt.ToDatetime() == first_chunk_fire_at + + # External event arrives before timeout -> long timer is cancelled and orchestration completes. + old_events = [ + helpers.new_orchestrator_started_event(start_time), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_timer_created_event(1, first_chunk_fire_at), + ] + second = executor.execute( + TEST_INSTANCE_ID, + old_events, + [helpers.new_event_raised_event("approval", json.dumps(True))], + ) + complete_action = get_and_validate_complete_orchestration_action_list(1, second.actions) + assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED + assert complete_action.result.value == '"approved"' + + def test_timer_can_be_cancelled_after_when_any_winner(): """Tests cancellation of an outstanding timer task after another task wins when_any.""" @@ -645,6 +813,106 @@ def orchestrator(ctx: task.OrchestrationContext, _): assert actions[2].createTimer.fireAt.ToDatetime() == expected_fire_at +def test_activity_retry_with_long_timer_preserves_retryable_parent(): + """Tests that long retry timers keep retryable parent state until the final chunk fires.""" + + def dummy_activity(ctx, _): + raise ValueError("Kah-BOOOOM!!!") + + def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): + result = yield ctx.call_activity( + dummy_activity, + retry_policy=task.RetryPolicy( + first_retry_interval=timedelta(days=10), + max_number_of_attempts=2, + backoff_coefficient=1, + ), + input=orchestrator_input, + ) + return result + + registry = worker._Registry() + name = registry.add_orchestrator(orchestrator) + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + + start = datetime.utcnow() + t1 = start + timedelta(days=3) + t2 = start + timedelta(days=6) + t3 = start + timedelta(days=9) + t4 = start + timedelta(days=10) + + old_events = [ + helpers.new_orchestrator_started_event(timestamp=start), + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_task_scheduled_event(1, task.get_name(dummy_activity)), + ] + + # First activity failure should create the first long-timer chunk. + new_events = [ + helpers.new_orchestrator_started_event(timestamp=start), + helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!")), + ] + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + assert actions[-1].HasField("createTimer") + assert actions[-1].id == 2 + assert actions[-1].createTimer.fireAt.ToDatetime() == t1 + + old_events = old_events + new_events + + # Intermediate chunk 1 fires -> schedule next chunk, not activity retry yet. + new_events = [ + helpers.new_orchestrator_started_event(t1), + helpers.new_timer_fired_event(2, t1), + ] + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + assert actions[-1].HasField("createTimer") + assert actions[-1].id == 3 + assert actions[-1].createTimer.fireAt.ToDatetime() == t2 + assert not actions[-1].HasField("scheduleTask") + + old_events = old_events + new_events + + # Intermediate chunk 2 fires -> schedule next chunk, still no activity retry. + new_events = [ + helpers.new_orchestrator_started_event(t2), + helpers.new_timer_fired_event(3, t2), + ] + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + assert actions[-1].HasField("createTimer") + assert actions[-1].id == 4 + assert actions[-1].createTimer.fireAt.ToDatetime() == t3 + assert not actions[-1].HasField("scheduleTask") + + old_events = old_events + new_events + + # Intermediate chunk 3 fires -> schedule final chunk, still no activity retry. + new_events = [ + helpers.new_orchestrator_started_event(t3), + helpers.new_timer_fired_event(4, t3), + ] + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + assert actions[-1].HasField("createTimer") + assert actions[-1].id == 5 + assert actions[-1].createTimer.fireAt.ToDatetime() == t4 + assert not actions[-1].HasField("scheduleTask") + + old_events = old_events + new_events + + # Final chunk fires -> retry activity should be rescheduled with original task ID. + new_events = [ + helpers.new_orchestrator_started_event(t4), + helpers.new_timer_fired_event(5, t4), + ] + result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + actions = result.actions + assert actions[-1].HasField("scheduleTask") + assert actions[-1].id == 1 + + def test_nondeterminism_expected_timer(): """Tests the non-determinism detection logic when call_timer is expected but some other method (call_activity) is called instead""" def dummy_activity(ctx, _): From 9777a365f4c25162d06c63bc97386211dffbbc75 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Tue, 10 Mar 2026 12:50:56 -0600 Subject: [PATCH 3/5] Add DTS default, e2e test --- .../durabletask/azuremanaged/worker.py | 4 +- tests/durabletask/test_orchestration_e2e.py | 68 +++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py index 4816223..b270c93 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py @@ -80,4 +80,6 @@ def __init__(self, *, log_handler=log_handler, log_formatter=log_formatter, interceptors=interceptors, - concurrency_options=concurrency_options) + concurrency_options=concurrency_options, + maximum_timer_interval=None # DTS allows timers of indefinite length + ) diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index a6f670c..0a74651 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -586,3 +586,71 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): assert uuid.UUID(results[0]) != uuid.UUID(results[1]) assert uuid.UUID(results[0]) != uuid.UUID(results[2]) assert uuid.UUID(results[1]) != uuid.UUID(results[2]) + + +@pytest.mark.parametrize("raise_event", [True, False]) +def test_when_any_cancels_timer_when_event_wins(raise_event: bool): + """Verify that the losing timer in a when_any race can be explicitly + cancelled without causing errors or affecting the orchestration result.""" + + def orchestrator(ctx: task.OrchestrationContext, _): + approval: task.Task[bool] = ctx.wait_for_external_event('Approval') + timeout = ctx.create_timer(timedelta(seconds=3)) + winner = yield task.when_any([approval, timeout]) + if winner == approval: + # Explicitly cancel the timer so it does not linger + timeout.cancel() + return "approved" + else: + return "timed out" + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(orchestrator) + w.start() + + task_hub_client = client.TaskHubGrpcClient(host_address=HOST) + id = task_hub_client.schedule_new_orchestration(orchestrator) + if raise_event: + task_hub_client.raise_orchestration_event(id, 'Approval') + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.failure_details is None + if raise_event: + assert state.serialized_output == json.dumps("approved") + else: + assert state.serialized_output == json.dumps("timed out") + + +@pytest.mark.parametrize("winning_event", ["Approve", "Reject"]) +def test_when_any_cancels_competing_external_event(winning_event: str): + """Verify that the losing external-event task in a when_any race is + explicitly cancelled, preventing it from consuming a late-arriving event + and leaving the orchestration in a clean state.""" + + def orchestrator(ctx: task.OrchestrationContext, _): + approve: task.Task = ctx.wait_for_external_event('Approve') + reject: task.Task = ctx.wait_for_external_event('Reject') + winner = yield task.when_any([approve, reject]) + if winner == approve: + reject.cancel() + return "approved" + else: + approve.cancel() + return "rejected" + + with worker.TaskHubGrpcWorker(host_address=HOST) as w: + w.add_orchestrator(orchestrator) + w.start() + + task_hub_client = client.TaskHubGrpcClient(host_address=HOST) + id = task_hub_client.schedule_new_orchestration(orchestrator) + task_hub_client.raise_orchestration_event(id, winning_event) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.failure_details is None + expected = "approved" if winning_event == "Approve" else "rejected" + assert state.serialized_output == json.dumps(expected) From 87dc5f733d5b5745be426f2ff0ded231894c0e8f Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Tue, 10 Mar 2026 13:17:58 -0600 Subject: [PATCH 4/5] PR Feedback --- .flake8 | 2 +- docs/supported-patterns.md | 4 +- durabletask/task.py | 27 ++++++------ durabletask/worker.py | 12 ++++-- examples/human_interaction.py | 2 +- tests/durabletask/test_orchestration_e2e.py | 42 +++++++++++++++++++ .../test_orchestration_executor.py | 6 +-- 7 files changed, 71 insertions(+), 24 deletions(-) diff --git a/.flake8 b/.flake8 index ecc399c..239bf20 100644 --- a/.flake8 +++ b/.flake8 @@ -1,5 +1,5 @@ [flake8] -ignore = E501,C901 +ignore = E501,C901,W503 exclude = .git *_pb2* diff --git a/docs/supported-patterns.md b/docs/supported-patterns.md index 612678a..98f6f3c 100644 --- a/docs/supported-patterns.md +++ b/docs/supported-patterns.md @@ -64,12 +64,12 @@ def purchase_order_workflow(ctx: task.OrchestrationContext, order: Order): # Orders of $1000 or more require manager approval yield ctx.call_activity(send_approval_request, input=order) - # Approvals must be received within 24 hours or they will be canceled. + # Approvals must be received within 24 hours or they will be cancelled. approval_event = ctx.wait_for_external_event("approval_received") timeout_event = ctx.create_timer(timedelta(hours=24)) winner = yield task.when_any([approval_event, timeout_event]) if winner == timeout_event: - return "Canceled" + return "Cancelled" # The order was approved yield ctx.call_activity(place_order, input=order) diff --git a/durabletask/task.py b/durabletask/task.py index e44de2e..d0529b8 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -228,7 +228,7 @@ def call_sub_orchestrator(self, orchestrator: Union[Orchestrator[TInput, TOutput """ pass - # TOOD: Add a timeout parameter, which allows the task to be canceled if the event is + # TOOD: Add a timeout parameter, which allows the task to be cancelled if the event is # not received within the specified timeout. This requires support for task cancellation. @abstractmethod def wait_for_external_event(self, name: str) -> CancellableTask: @@ -324,8 +324,8 @@ class OrchestrationStateError(Exception): pass -class TaskCanceledError(Exception): - """Exception type for canceled orchestration tasks.""" +class TaskCancelledError(Exception): + """Exception type for cancelled orchestration tasks.""" class Task(ABC, Generic[T]): @@ -440,7 +440,7 @@ def fail(self, message: str, details: Union[Exception, pb.TaskFailureDetails]): class CancellableTask(CompletableTask[T]): - """A completable task that can be canceled before it finishes.""" + """A completable task that can be cancelled before it finishes.""" def __init__(self) -> None: super().__init__() @@ -449,12 +449,12 @@ def __init__(self) -> None: @property def is_cancelled(self) -> bool: - """Returns True if the task was canceled, False otherwise.""" + """Returns True if the task was cancelled, False otherwise.""" return self._is_cancelled def get_result(self) -> T: if self._is_cancelled: - raise TaskCanceledError('The task was canceled.') + raise TaskCancelledError('The task was cancelled.') return super().get_result() def set_cancel_handler(self, cancel_handler: Callable[[], None]) -> None: @@ -524,27 +524,26 @@ class TimerTask(CancellableTask[None]): def set_retryable_parent(self, retryable_task: RetryableTask): self._retryable_parent = retryable_task - def complete(self, *args, **kwargs): + def complete(self, _: datetime) -> None: super().complete(None) - class LongTimerTask(TimerTask): - def __init__(self, final_fire_at: datetime, maximum_timer_duration: timedelta): + def __init__(self, final_fire_at: datetime, maximum_timer_interval: timedelta): super().__init__() self._final_fire_at = final_fire_at - self._maximum_timer_duration = maximum_timer_duration + self._maximum_timer_interval = maximum_timer_interval def start(self, current_utc_datetime: datetime) -> datetime: return self._get_next_fire_at(current_utc_datetime) - def complete(self, current_utc_datetime: datetime): + def complete(self, current_utc_datetime: datetime) -> Optional[datetime]: if current_utc_datetime < self._final_fire_at: return self._get_next_fire_at(current_utc_datetime) - super().complete(None) + return super().complete(current_utc_datetime) def _get_next_fire_at(self, current_utc_datetime: datetime) -> datetime: - if current_utc_datetime + self._maximum_timer_duration < self._final_fire_at: - return current_utc_datetime + self._maximum_timer_duration + if current_utc_datetime + self._maximum_timer_interval < self._final_fire_at: + return current_utc_datetime + self._maximum_timer_interval return self._final_fire_at diff --git a/durabletask/worker.py b/durabletask/worker.py index f473b38..9d45dc9 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -646,7 +646,9 @@ def _execute_orchestrator( completionToken, ): try: - executor = _OrchestrationExecutor(self._registry, self._logger) + executor = _OrchestrationExecutor(self._registry, + self._logger, + self.maximum_timer_interval) result = executor.execute(req.instanceId, req.pastEvents, req.newEvents) res = pb.OrchestratorResponse( instanceId=req.instanceId, @@ -1029,7 +1031,7 @@ def set_custom_status(self, custom_status: Any) -> None: shared.to_json(custom_status) if custom_status is not None else None ) - def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task: + def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.CancellableTask: return self.create_timer_internal(fire_at) def create_timer_internal( @@ -1045,7 +1047,11 @@ def create_timer_internal( next_fire_at: datetime = final_fire_at - if self._maximum_timer_interval is not None and self.current_utc_datetime + self._maximum_timer_interval < final_fire_at: + if ( + self._maximum_timer_interval is not None + and self._maximum_timer_interval > timedelta(0) + and self.current_utc_datetime + self._maximum_timer_interval < final_fire_at + ): timer_task = task.LongTimerTask(final_fire_at, self._maximum_timer_interval) next_fire_at = timer_task.start(self.current_utc_datetime) else: diff --git a/examples/human_interaction.py b/examples/human_interaction.py index 9d60758..b43336d 100644 --- a/examples/human_interaction.py +++ b/examples/human_interaction.py @@ -48,7 +48,7 @@ def purchase_order_workflow(ctx: task.OrchestrationContext, order: Order): # Orders of $1000 or more require manager approval yield ctx.call_activity(send_approval_request, input=order) - # Approvals must be received within 24 hours or they will be canceled. + # Approvals must be received within 24 hours or they will be cancelled. approval_event = ctx.wait_for_external_event("approval_received") timeout_event = ctx.create_timer(timedelta(hours=24)) winner = yield task.when_any([approval_event, timeout_event]) diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index 0a74651..103f14a 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -654,3 +654,45 @@ def orchestrator(ctx: task.OrchestrationContext, _): assert state.failure_details is None expected = "approved" if winning_event == "Approve" else "rejected" assert state.serialized_output == json.dumps(expected) + + +def test_long_timer_chunking(): + """Verify that a timer longer than maximum_timer_interval is broken into + intermediate chunks and that the orchestration completes correctly. + + The worker is configured with a 2-second maximum_timer_interval. The + orchestrator requests a 5-second timer, which requires 3 chunks + (0→2s, 2→4s, 4→5s). Each chunk causes a full orchestrator replay, so + the orchestrator function is invoked once for the initial scheduling and + once more for each timerFired event — 4 invocations in total. Asserting + invocation_count >= 4 confirms that intermediate chunks actually fired + rather than the timer being scheduled as a single unit. + """ + + invocation_count = 0 + + def orchestrator(ctx: task.OrchestrationContext, _): + nonlocal invocation_count + invocation_count += 1 + yield ctx.create_timer(timedelta(seconds=5)) + return "done" + + with worker.TaskHubGrpcWorker( + host_address=HOST, + maximum_timer_interval=timedelta(seconds=2), + ) as w: + w.add_orchestrator(orchestrator) + w.start() + + task_hub_client = client.TaskHubGrpcClient(host_address=HOST) + id = task_hub_client.schedule_new_orchestration(orchestrator) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.failure_details is None + assert state.serialized_output == json.dumps("done") + # 3 chunks (0→2s, 2→4s, 4→5s) produce 4 total orchestrator invocations + # (initial scheduling + one replay per timerFired). >= 4 proves that at + # least two intermediate chunk timers fired rather than one direct timer. + assert invocation_count >= 4 diff --git a/tests/durabletask/test_orchestration_executor.py b/tests/durabletask/test_orchestration_executor.py index 3e2cb02..ee4c0f2 100644 --- a/tests/durabletask/test_orchestration_executor.py +++ b/tests/durabletask/test_orchestration_executor.py @@ -380,15 +380,15 @@ def dummy_activity(ctx, _): assert not hasattr(activity_task, "is_cancelled") -def test_cancelled_task_get_result_raises_task_canceled_error(): - """Tests that canceled cancellable tasks raise TaskCanceledError from get_result.""" +def test_cancelled_task_get_result_raises_task_cancelled_error(): + """Tests that cancelled cancellable tasks raise TaskCancelledError from get_result.""" cancellable_task = task.CancellableTask() assert cancellable_task.cancel() is True assert cancellable_task.is_cancelled is True - with pytest.raises(task.TaskCanceledError): + with pytest.raises(task.TaskCancelledError): cancellable_task.get_result() From ceb9c0dc1a93d9ed9d6ccfabdb08cc3ded21ef6f Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Tue, 10 Mar 2026 13:19:09 -0600 Subject: [PATCH 5/5] Lint --- durabletask/task.py | 1 + 1 file changed, 1 insertion(+) diff --git a/durabletask/task.py b/durabletask/task.py index d0529b8..d5bcc41 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -527,6 +527,7 @@ def set_retryable_parent(self, retryable_task: RetryableTask): def complete(self, _: datetime) -> None: super().complete(None) + class LongTimerTask(TimerTask): def __init__(self, final_fire_at: datetime, maximum_timer_interval: timedelta): super().__init__()