From a0b931307ee6932ea27f96cfb76ea0297ce8f2e7 Mon Sep 17 00:00:00 2001 From: Aaron Hill Date: Fri, 13 Mar 2026 12:30:18 -0400 Subject: [PATCH] Prevent attempting to suspend multiple times in the same run The durable sql itself already checks that a task is currnetly running when we try to suspend it, which makes it very difficult to write tests for. However, checking in the durable sql is insufficient, since we might have an execution that looks like: * Tokio task A - calls `await_event("foo")`, and doesn't propagate the `ControlFlow::Suspend` error. * Tokio task B (possibly in another process) - calls `emit_event("foo")` * Tokio task C - picks up the now-ready task that was previously suspended * Tokio task A - calls `await_event("bar")`, which succeeds, since the task is now running. By adding a check rust-side, we can be sure that we catch incorrect usage of durable --- src/context.rs | 42 +++++++++++++++++++++++++++++++++++++++--- src/error.rs | 21 ++++++++++++++++++++- src/worker.rs | 2 +- 3 files changed, 60 insertions(+), 5 deletions(-) diff --git a/src/context.rs b/src/context.rs index cd4bb40..50c14bd 100644 --- a/src/context.rs +++ b/src/context.rs @@ -6,6 +6,7 @@ use std::time::Duration; use uuid::Uuid; use crate::Durable; +use crate::error::suspend_handle::SuspendMarker; use crate::error::{ControlFlow, TaskError, TaskResult}; use std::sync::Arc; @@ -69,6 +70,24 @@ where /// Cloneable heartbeat handle for use in step closures. heartbeat_handle: HeartbeatHandle, + + /// Whether or not we've suspended the task + /// This is set to `true` when we construct a `ControlFlow::Suspend` error type, + /// which enforces that we cannot suspend again for this particular execution + /// (e.g. until the task is woken up and re-run by a durable worker). + /// This blocks incorrect patterns like: + /// ```rust + /// // Note the lack of '.await' and propagation of the error with `?` + /// let fut1 = ctx.sleep_for("first_sleep", Duration::from_secs(1)); + /// let fut2 = ctx.sleep_for("second_sleep", Duration::from_secs(1)); + /// + /// tokio::join!(fut1, fut2).await; + /// ``` + /// + /// Producing `ControlFlow::Suspend` means that we've updated our task suspend + /// state in durable, so trying to call `ControlFlow::Suspend` again during the same + /// execution will overwrite state in Durable. + has_suspended: bool, } /// Validate that a user-provided step name doesn't use reserved prefix. @@ -85,6 +104,16 @@ impl TaskContext where State: Clone + Send + Sync + 'static, { + pub(crate) fn mark_suspended(&mut self) -> TaskResult<()> { + if self.has_suspended { + return Err(TaskError::Validation { + message: "Task has already been suspended during this execution".to_string(), + }); + } + self.has_suspended = true; + Ok(()) + } + /// Create a new TaskContext. Called by the worker before executing a task. /// Loads all existing checkpoints into the cache. #[allow(clippy::too_many_arguments)] @@ -128,6 +157,7 @@ where step_counters: HashMap::new(), lease_extender, heartbeat_handle, + has_suspended: false, }) } @@ -335,7 +365,9 @@ where .map_err(TaskError::from_sqlx_error)?; if needs_suspend { - return Err(TaskError::Control(ControlFlow::Suspend)); + return Err(TaskError::Control(ControlFlow::Suspend( + SuspendMarker::new(self)?, + ))); } Ok(()) } @@ -414,7 +446,9 @@ where .map_err(TaskError::from_sqlx_error)?; if result.should_suspend { - return Err(TaskError::Control(ControlFlow::Suspend)); + return Err(TaskError::Control(ControlFlow::Suspend( + SuspendMarker::new(self)?, + ))); } // Event arrived - cache and return @@ -768,7 +802,9 @@ where .map_err(TaskError::from_sqlx_error)?; if result.should_suspend { - return Err(TaskError::Control(ControlFlow::Suspend)); + return Err(TaskError::Control(ControlFlow::Suspend( + SuspendMarker::new(self)?, + ))); } // Event arrived - parse and return diff --git a/src/error.rs b/src/error.rs index fcb5192..b93a069 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,8 @@ use serde_json::Value as JsonValue; use thiserror::Error; +use crate::error::suspend_handle::SuspendMarker; + /// Signals that interrupt task execution without indicating failure. /// /// These are not errors - they represent intentional control flow that the worker @@ -13,7 +15,7 @@ pub enum ControlFlow { /// Returned by [`TaskContext::sleep_for`](crate::TaskContext::sleep_for) /// and [`TaskContext::await_event`](crate::TaskContext::await_event) /// when the task needs to wait. - Suspend, + Suspend(SuspendMarker), /// Task was cancelled. /// /// Detected when database operations return error code AB001, indicating @@ -27,6 +29,23 @@ pub enum ControlFlow { LeaseExpired, } +pub mod suspend_handle { + use crate::{TaskContext, TaskResult}; + + // An internal marker type that helps prevent us from constructing `ControlFlow::Suspend` errors + // without calling `task_context.mark_suspended()` first. + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub struct SuspendMarker { + _private: (), + } + impl SuspendMarker { + pub fn new(task_context: &mut TaskContext) -> TaskResult { + task_context.mark_suspended()?; + Ok(Self { _private: () }) + } + } +} + /// Error type for task execution. /// /// This enum distinguishes between control flow signals (suspension, cancellation) diff --git a/src/worker.rs b/src/worker.rs index bef8bc1..42df8ea 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -481,7 +481,7 @@ impl Worker { #[cfg(feature = "telemetry")] crate::telemetry::record_task_completed(&queue_name_for_metrics, &task_name); } - Err(TaskError::Control(ControlFlow::Suspend)) => { + Err(TaskError::Control(ControlFlow::Suspend(_))) => { // Task suspended - do nothing, scheduler will resume it #[cfg(feature = "telemetry")] {