Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::time::Duration;
use uuid::Uuid;

use crate::Durable;
use crate::error::suspend_handle::SuspendMarker;
use crate::error::{ControlFlow, TaskError, TaskResult};
use std::sync::Arc;

Expand Down Expand Up @@ -69,6 +70,24 @@ where

/// Cloneable heartbeat handle for use in step closures.
heartbeat_handle: HeartbeatHandle,

/// Whether or not we've suspended the task
/// This is set to `true` when we construct a `ControlFlow::Suspend` error type,
/// which enforces that we cannot suspend again for this particular execution
/// (e.g. until the task is woken up and re-run by a durable worker).
/// This blocks incorrect patterns like:
/// ```rust
/// // Note the lack of '.await' and propagation of the error with `?`
/// let fut1 = ctx.sleep_for("first_sleep", Duration::from_secs(1));
/// let fut2 = ctx.sleep_for("second_sleep", Duration::from_secs(1));
///
/// tokio::join!(fut1, fut2).await;
/// ```
///
/// Producing `ControlFlow::Suspend` means that we've updated our task suspend
/// state in durable, so trying to call `ControlFlow::Suspend` again during the same
/// execution will overwrite state in Durable.
has_suspended: bool,
}

/// Validate that a user-provided step name doesn't use reserved prefix.
Expand All @@ -85,6 +104,16 @@ impl<State> TaskContext<State>
where
State: Clone + Send + Sync + 'static,
{
pub(crate) fn mark_suspended(&mut self) -> TaskResult<()> {
if self.has_suspended {
return Err(TaskError::Validation {
message: "Task has already been suspended during this execution".to_string(),
});
}
self.has_suspended = true;
Ok(())
}

/// Create a new TaskContext. Called by the worker before executing a task.
/// Loads all existing checkpoints into the cache.
#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -128,6 +157,7 @@ where
step_counters: HashMap::new(),
lease_extender,
heartbeat_handle,
has_suspended: false,
})
}

Expand Down Expand Up @@ -335,7 +365,9 @@ where
.map_err(TaskError::from_sqlx_error)?;

if needs_suspend {
return Err(TaskError::Control(ControlFlow::Suspend));
return Err(TaskError::Control(ControlFlow::Suspend(
SuspendMarker::new(self)?,
)));
}
Ok(())
}
Expand Down Expand Up @@ -414,7 +446,9 @@ where
.map_err(TaskError::from_sqlx_error)?;

if result.should_suspend {
return Err(TaskError::Control(ControlFlow::Suspend));
return Err(TaskError::Control(ControlFlow::Suspend(
SuspendMarker::new(self)?,
)));
}

// Event arrived - cache and return
Expand Down Expand Up @@ -768,7 +802,9 @@ where
.map_err(TaskError::from_sqlx_error)?;

if result.should_suspend {
return Err(TaskError::Control(ControlFlow::Suspend));
return Err(TaskError::Control(ControlFlow::Suspend(
SuspendMarker::new(self)?,
)));
}

// Event arrived - parse and return
Expand Down
21 changes: 20 additions & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use serde_json::Value as JsonValue;
use thiserror::Error;

use crate::error::suspend_handle::SuspendMarker;

/// Signals that interrupt task execution without indicating failure.
///
/// These are not errors - they represent intentional control flow that the worker
Expand All @@ -13,7 +15,7 @@ pub enum ControlFlow {
/// Returned by [`TaskContext::sleep_for`](crate::TaskContext::sleep_for)
/// and [`TaskContext::await_event`](crate::TaskContext::await_event)
/// when the task needs to wait.
Suspend,
Suspend(SuspendMarker),
/// Task was cancelled.
///
/// Detected when database operations return error code AB001, indicating
Expand All @@ -27,6 +29,23 @@ pub enum ControlFlow {
LeaseExpired,
}

pub mod suspend_handle {
use crate::{TaskContext, TaskResult};

// An internal marker type that helps prevent us from constructing `ControlFlow::Suspend` errors
// without calling `task_context.mark_suspended()` first.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SuspendMarker {
_private: (),
}
impl SuspendMarker {
pub fn new<S: Clone + Send + Sync>(task_context: &mut TaskContext<S>) -> TaskResult<Self> {
task_context.mark_suspended()?;
Ok(Self { _private: () })
}
}
}

/// Error type for task execution.
///
/// This enum distinguishes between control flow signals (suspension, cancellation)
Expand Down
2 changes: 1 addition & 1 deletion src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ impl Worker {
#[cfg(feature = "telemetry")]
crate::telemetry::record_task_completed(&queue_name_for_metrics, &task_name);
}
Err(TaskError::Control(ControlFlow::Suspend)) => {
Err(TaskError::Control(ControlFlow::Suspend(_))) => {
// Task suspended - do nothing, scheduler will resume it
#[cfg(feature = "telemetry")]
{
Expand Down
Loading