diff --git a/src/context.rs b/src/context.rs index cd4bb40..50c14bd 100644 --- a/src/context.rs +++ b/src/context.rs @@ -6,6 +6,7 @@ use std::time::Duration; use uuid::Uuid; use crate::Durable; +use crate::error::suspend_handle::SuspendMarker; use crate::error::{ControlFlow, TaskError, TaskResult}; use std::sync::Arc; @@ -69,6 +70,24 @@ where /// Cloneable heartbeat handle for use in step closures. heartbeat_handle: HeartbeatHandle, + + /// Whether or not we've suspended the task + /// This is set to `true` when we construct a `ControlFlow::Suspend` error type, + /// which enforces that we cannot suspend again for this particular execution + /// (e.g. until the task is woken up and re-run by a durable worker). + /// This blocks incorrect patterns like: + /// ```rust + /// // Note the lack of '.await' and propagation of the error with `?` + /// let fut1 = ctx.sleep_for("first_sleep", Duration::from_secs(1)); + /// let fut2 = ctx.sleep_for("second_sleep", Duration::from_secs(1)); + /// + /// tokio::join!(fut1, fut2).await; + /// ``` + /// + /// Producing `ControlFlow::Suspend` means that we've updated our task suspend + /// state in durable, so trying to call `ControlFlow::Suspend` again during the same + /// execution will overwrite state in Durable. + has_suspended: bool, } /// Validate that a user-provided step name doesn't use reserved prefix. @@ -85,6 +104,16 @@ impl TaskContext where State: Clone + Send + Sync + 'static, { + pub(crate) fn mark_suspended(&mut self) -> TaskResult<()> { + if self.has_suspended { + return Err(TaskError::Validation { + message: "Task has already been suspended during this execution".to_string(), + }); + } + self.has_suspended = true; + Ok(()) + } + /// Create a new TaskContext. Called by the worker before executing a task. /// Loads all existing checkpoints into the cache. #[allow(clippy::too_many_arguments)] @@ -128,6 +157,7 @@ where step_counters: HashMap::new(), lease_extender, heartbeat_handle, + has_suspended: false, }) } @@ -335,7 +365,9 @@ where .map_err(TaskError::from_sqlx_error)?; if needs_suspend { - return Err(TaskError::Control(ControlFlow::Suspend)); + return Err(TaskError::Control(ControlFlow::Suspend( + SuspendMarker::new(self)?, + ))); } Ok(()) } @@ -414,7 +446,9 @@ where .map_err(TaskError::from_sqlx_error)?; if result.should_suspend { - return Err(TaskError::Control(ControlFlow::Suspend)); + return Err(TaskError::Control(ControlFlow::Suspend( + SuspendMarker::new(self)?, + ))); } // Event arrived - cache and return @@ -768,7 +802,9 @@ where .map_err(TaskError::from_sqlx_error)?; if result.should_suspend { - return Err(TaskError::Control(ControlFlow::Suspend)); + return Err(TaskError::Control(ControlFlow::Suspend( + SuspendMarker::new(self)?, + ))); } // Event arrived - parse and return diff --git a/src/error.rs b/src/error.rs index fcb5192..b93a069 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,8 @@ use serde_json::Value as JsonValue; use thiserror::Error; +use crate::error::suspend_handle::SuspendMarker; + /// Signals that interrupt task execution without indicating failure. /// /// These are not errors - they represent intentional control flow that the worker @@ -13,7 +15,7 @@ pub enum ControlFlow { /// Returned by [`TaskContext::sleep_for`](crate::TaskContext::sleep_for) /// and [`TaskContext::await_event`](crate::TaskContext::await_event) /// when the task needs to wait. - Suspend, + Suspend(SuspendMarker), /// Task was cancelled. /// /// Detected when database operations return error code AB001, indicating @@ -27,6 +29,23 @@ pub enum ControlFlow { LeaseExpired, } +pub mod suspend_handle { + use crate::{TaskContext, TaskResult}; + + // An internal marker type that helps prevent us from constructing `ControlFlow::Suspend` errors + // without calling `task_context.mark_suspended()` first. + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + pub struct SuspendMarker { + _private: (), + } + impl SuspendMarker { + pub fn new(task_context: &mut TaskContext) -> TaskResult { + task_context.mark_suspended()?; + Ok(Self { _private: () }) + } + } +} + /// Error type for task execution. /// /// This enum distinguishes between control flow signals (suspension, cancellation) diff --git a/src/worker.rs b/src/worker.rs index bef8bc1..42df8ea 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -481,7 +481,7 @@ impl Worker { #[cfg(feature = "telemetry")] crate::telemetry::record_task_completed(&queue_name_for_metrics, &task_name); } - Err(TaskError::Control(ControlFlow::Suspend)) => { + Err(TaskError::Control(ControlFlow::Suspend(_))) => { // Task suspended - do nothing, scheduler will resume it #[cfg(feature = "telemetry")] {