Skip to content

Commit 2bfe262

Browse files
authored
Prevent attempting to suspend multiple times in the same run (#82)
The durable sql itself already checks that a task is currnetly running when we try to suspend it, which makes it very difficult to write tests for. However, checking in the durable sql is insufficient, since we might have an execution that looks like: * Tokio task A - calls `await_event("foo")`, and doesn't propagate the `ControlFlow::Suspend` error. * Tokio task B (possibly in another process) - calls `emit_event("foo")` * Tokio task C - picks up the now-ready task that was previously suspended * Tokio task A - calls `await_event("bar")`, which succeeds, since the task is now running. By adding a check rust-side, we can be sure that we catch incorrect usage of durable
1 parent 1eb6aea commit 2bfe262

3 files changed

Lines changed: 60 additions & 5 deletions

File tree

src/context.rs

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::time::Duration;
66
use uuid::Uuid;
77

88
use crate::Durable;
9+
use crate::error::suspend_handle::SuspendMarker;
910
use crate::error::{ControlFlow, TaskError, TaskResult};
1011
use std::sync::Arc;
1112

@@ -69,6 +70,24 @@ where
6970

7071
/// Cloneable heartbeat handle for use in step closures.
7172
heartbeat_handle: HeartbeatHandle,
73+
74+
/// Whether or not we've suspended the task
75+
/// This is set to `true` when we construct a `ControlFlow::Suspend` error type,
76+
/// which enforces that we cannot suspend again for this particular execution
77+
/// (e.g. until the task is woken up and re-run by a durable worker).
78+
/// This blocks incorrect patterns like:
79+
/// ```rust
80+
/// // Note the lack of '.await' and propagation of the error with `?`
81+
/// let fut1 = ctx.sleep_for("first_sleep", Duration::from_secs(1));
82+
/// let fut2 = ctx.sleep_for("second_sleep", Duration::from_secs(1));
83+
///
84+
/// tokio::join!(fut1, fut2).await;
85+
/// ```
86+
///
87+
/// Producing `ControlFlow::Suspend` means that we've updated our task suspend
88+
/// state in durable, so trying to call `ControlFlow::Suspend` again during the same
89+
/// execution will overwrite state in Durable.
90+
has_suspended: bool,
7291
}
7392

7493
/// Validate that a user-provided step name doesn't use reserved prefix.
@@ -85,6 +104,16 @@ impl<State> TaskContext<State>
85104
where
86105
State: Clone + Send + Sync + 'static,
87106
{
107+
pub(crate) fn mark_suspended(&mut self) -> TaskResult<()> {
108+
if self.has_suspended {
109+
return Err(TaskError::Validation {
110+
message: "Task has already been suspended during this execution".to_string(),
111+
});
112+
}
113+
self.has_suspended = true;
114+
Ok(())
115+
}
116+
88117
/// Create a new TaskContext. Called by the worker before executing a task.
89118
/// Loads all existing checkpoints into the cache.
90119
#[allow(clippy::too_many_arguments)]
@@ -128,6 +157,7 @@ where
128157
step_counters: HashMap::new(),
129158
lease_extender,
130159
heartbeat_handle,
160+
has_suspended: false,
131161
})
132162
}
133163

@@ -335,7 +365,9 @@ where
335365
.map_err(TaskError::from_sqlx_error)?;
336366

337367
if needs_suspend {
338-
return Err(TaskError::Control(ControlFlow::Suspend));
368+
return Err(TaskError::Control(ControlFlow::Suspend(
369+
SuspendMarker::new(self)?,
370+
)));
339371
}
340372
Ok(())
341373
}
@@ -414,7 +446,9 @@ where
414446
.map_err(TaskError::from_sqlx_error)?;
415447

416448
if result.should_suspend {
417-
return Err(TaskError::Control(ControlFlow::Suspend));
449+
return Err(TaskError::Control(ControlFlow::Suspend(
450+
SuspendMarker::new(self)?,
451+
)));
418452
}
419453

420454
// Event arrived - cache and return
@@ -768,7 +802,9 @@ where
768802
.map_err(TaskError::from_sqlx_error)?;
769803

770804
if result.should_suspend {
771-
return Err(TaskError::Control(ControlFlow::Suspend));
805+
return Err(TaskError::Control(ControlFlow::Suspend(
806+
SuspendMarker::new(self)?,
807+
)));
772808
}
773809

774810
// Event arrived - parse and return

src/error.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use serde_json::Value as JsonValue;
22
use thiserror::Error;
33

4+
use crate::error::suspend_handle::SuspendMarker;
5+
46
/// Signals that interrupt task execution without indicating failure.
57
///
68
/// These are not errors - they represent intentional control flow that the worker
@@ -13,7 +15,7 @@ pub enum ControlFlow {
1315
/// Returned by [`TaskContext::sleep_for`](crate::TaskContext::sleep_for)
1416
/// and [`TaskContext::await_event`](crate::TaskContext::await_event)
1517
/// when the task needs to wait.
16-
Suspend,
18+
Suspend(SuspendMarker),
1719
/// Task was cancelled.
1820
///
1921
/// Detected when database operations return error code AB001, indicating
@@ -27,6 +29,23 @@ pub enum ControlFlow {
2729
LeaseExpired,
2830
}
2931

32+
pub mod suspend_handle {
33+
use crate::{TaskContext, TaskResult};
34+
35+
// An internal marker type that helps prevent us from constructing `ControlFlow::Suspend` errors
36+
// without calling `task_context.mark_suspended()` first.
37+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38+
pub struct SuspendMarker {
39+
_private: (),
40+
}
41+
impl SuspendMarker {
42+
pub fn new<S: Clone + Send + Sync>(task_context: &mut TaskContext<S>) -> TaskResult<Self> {
43+
task_context.mark_suspended()?;
44+
Ok(Self { _private: () })
45+
}
46+
}
47+
}
48+
3049
/// Error type for task execution.
3150
///
3251
/// This enum distinguishes between control flow signals (suspension, cancellation)

src/worker.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ impl Worker {
481481
#[cfg(feature = "telemetry")]
482482
crate::telemetry::record_task_completed(&queue_name_for_metrics, &task_name);
483483
}
484-
Err(TaskError::Control(ControlFlow::Suspend)) => {
484+
Err(TaskError::Control(ControlFlow::Suspend(_))) => {
485485
// Task suspended - do nothing, scheduler will resume it
486486
#[cfg(feature = "telemetry")]
487487
{

0 commit comments

Comments
 (0)