diff --git a/Test/DurableTask.AzureStorage.Tests/OrchestrationSessionTests.cs b/Test/DurableTask.AzureStorage.Tests/OrchestrationSessionTests.cs
new file mode 100644
index 000000000..126c4b9bc
--- /dev/null
+++ b/Test/DurableTask.AzureStorage.Tests/OrchestrationSessionTests.cs
@@ -0,0 +1,227 @@
+// ----------------------------------------------------------------------------------
+// Copyright Microsoft Corporation
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+// http://www.apache.org/licenses/LICENSE-2.0
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// ----------------------------------------------------------------------------------
+
+namespace DurableTask.AzureStorage.Tests
+{
+ using System;
+ using System.Collections.Generic;
+ using System.Diagnostics;
+ using System.Linq;
+ using System.Reflection;
+ using System.Threading;
+ using System.Threading.Tasks;
+ using DurableTask.AzureStorage.Messaging;
+ using DurableTask.AzureStorage.Monitoring;
+ using DurableTask.AzureStorage.Tracking;
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+ using Moq;
+
+ ///
+ /// Tests for shutdown cancellation behavior with extended sessions.
+ ///
+ [TestClass]
+ public class OrchestrationSessionTests
+ {
+ ///
+ /// Verifies that
+ /// exits immediately when the cancellation token is cancelled.
+ ///
+ [TestMethod]
+ public async Task WaitAsync_CancellationToken_ExitsImmediately()
+ {
+ var resetEvent = new AsyncAutoResetEvent(signaled: false);
+ using var cts = new CancellationTokenSource();
+
+ TimeSpan longTimeout = TimeSpan.FromSeconds(30);
+ Task waitTask = resetEvent.WaitAsync(longTimeout, cts.Token);
+
+ Assert.IsFalse(waitTask.IsCompleted, "Wait should not complete immediately");
+
+ var stopwatch = Stopwatch.StartNew();
+ cts.Cancel();
+
+ bool result = await waitTask;
+ stopwatch.Stop();
+
+ Assert.IsFalse(result, "Cancellation should return false (no signal received)");
+ Assert.IsTrue(
+ stopwatch.ElapsedMilliseconds < 5000,
+ $"Cancellation should complete in under 5s, but took {stopwatch.ElapsedMilliseconds}ms");
+ }
+
+ ///
+ /// Verifies that signaling still returns true when a cancellation token is provided.
+ ///
+ [TestMethod]
+ public async Task WaitAsync_WithCancellationToken_SignalStillWorks()
+ {
+ var resetEvent = new AsyncAutoResetEvent(signaled: false);
+ using var cts = new CancellationTokenSource();
+
+ Task waitTask = resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token);
+ Assert.IsFalse(waitTask.IsCompleted);
+
+ resetEvent.Set();
+
+ Task winner = await Task.WhenAny(waitTask, Task.Delay(TimeSpan.FromSeconds(5)));
+ Assert.IsTrue(winner == waitTask, "Signal should wake the waiter");
+ Assert.IsTrue(waitTask.Result, "Wait result should be true when signaled");
+ }
+
+ ///
+ /// Verifies that the wait returns false on timeout when a cancellation token is provided but not cancelled.
+ ///
+ [TestMethod]
+ public async Task WaitAsync_WithCancellationToken_TimeoutStillWorks()
+ {
+ var resetEvent = new AsyncAutoResetEvent(signaled: false);
+ using var cts = new CancellationTokenSource();
+
+ bool result = await resetEvent.WaitAsync(TimeSpan.FromMilliseconds(100), cts.Token);
+
+ Assert.IsFalse(result, "Wait should return false on timeout");
+ }
+
+ ///
+ /// Verifies that all queued waiters return false when the token is cancelled.
+ ///
+ [TestMethod]
+ public async Task WaitAsync_CancellationToken_MultipleWaiters()
+ {
+ var resetEvent = new AsyncAutoResetEvent(signaled: false);
+ using var cts = new CancellationTokenSource();
+
+ var waiters = new List>();
+ for (int i = 0; i < 5; i++)
+ {
+ waiters.Add(resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token));
+ }
+
+ foreach (var waiter in waiters)
+ {
+ Assert.IsFalse(waiter.IsCompleted);
+ }
+
+ var stopwatch = Stopwatch.StartNew();
+ cts.Cancel();
+
+ // All waiters should return false (cancelled = not signaled)
+ await Task.WhenAll(
+ waiters.Select(
+ async waiter =>
+ {
+ bool result = await waiter;
+ Assert.IsFalse(result, "Cancelled waiter should return false");
+ }));
+
+ stopwatch.Stop();
+
+ Assert.IsTrue(
+ stopwatch.ElapsedMilliseconds < 5000,
+ $"All waiters should complete in under 5s, but took {stopwatch.ElapsedMilliseconds}ms");
+ }
+
+ ///
+ /// Verifies that a pre-cancelled token causes WaitAsync to return false immediately.
+ ///
+ [TestMethod]
+ public async Task WaitAsync_AlreadyCancelledToken_ReturnsFalseImmediately()
+ {
+ var resetEvent = new AsyncAutoResetEvent(signaled: false);
+ using var cts = new CancellationTokenSource();
+ cts.Cancel(); // Pre-cancel
+
+ var stopwatch = Stopwatch.StartNew();
+ bool result = await resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token);
+ stopwatch.Stop();
+
+ Assert.IsFalse(result, "Pre-cancelled token should cause immediate false return");
+ Assert.IsTrue(
+ stopwatch.ElapsedMilliseconds < 5000,
+ $"Should complete immediately, but took {stopwatch.ElapsedMilliseconds}ms");
+ }
+
+ ///
+ /// Verifies that a pre-cancelled token still returns true if the event is already signaled.
+ ///
+ [TestMethod]
+ public async Task WaitAsync_AlreadySignaledAndCancelled_ReturnsTrue()
+ {
+ var resetEvent = new AsyncAutoResetEvent(signaled: true);
+ using var cts = new CancellationTokenSource();
+ cts.Cancel();
+
+ bool result = await resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token);
+ Assert.IsTrue(result, "Already signaled event should return true even with cancelled token");
+ }
+
+ ///
+ /// Verifies that clears all active sessions.
+ ///
+ [TestMethod]
+ public void AbortAllSessions_ClearsActiveSessions()
+ {
+ var settings = new AzureStorageOrchestrationServiceSettings();
+ var stats = new AzureStorageOrchestrationServiceStats();
+ var trackingStore = new Mock();
+
+ using var manager = new OrchestrationSessionManager(
+ "testaccount",
+ settings,
+ stats,
+ trackingStore.Object);
+
+ // Use reflection to access the internal sessions dictionary.
+ var sessionsField = typeof(OrchestrationSessionManager)
+ .GetField("activeOrchestrationSessions", BindingFlags.NonPublic | BindingFlags.Instance);
+ var sessions = (Dictionary)sessionsField.GetValue(manager);
+
+ manager.GetStats(out _, out _, out int initialCount);
+ Assert.AreEqual(0, initialCount, "Should start with no active sessions");
+
+ sessions["instance1"] = null;
+ sessions["instance2"] = null;
+ sessions["instance3"] = null;
+
+ manager.GetStats(out _, out _, out int activeCount);
+ Assert.AreEqual(3, activeCount, "Should have 3 active sessions");
+
+ manager.AbortAllSessions();
+
+ manager.GetStats(out _, out _, out int afterAbortCount);
+ Assert.AreEqual(0, afterAbortCount, "AbortAllSessions should clear all active sessions");
+ }
+
+ ///
+ /// Verifies that is safe to call with no active sessions.
+ ///
+ [TestMethod]
+ public void AbortAllSessions_NoSessions_DoesNotThrow()
+ {
+ var settings = new AzureStorageOrchestrationServiceSettings();
+ var stats = new AzureStorageOrchestrationServiceStats();
+ var trackingStore = new Mock();
+
+ using var manager = new OrchestrationSessionManager(
+ "testaccount",
+ settings,
+ stats,
+ trackingStore.Object);
+
+ manager.AbortAllSessions();
+
+ manager.GetStats(out _, out _, out int count);
+ Assert.AreEqual(0, count, "Should still have no active sessions");
+ }
+ }
+}
diff --git a/src/DurableTask.AzureStorage/AzureStorageOrchestrationService.cs b/src/DurableTask.AzureStorage/AzureStorageOrchestrationService.cs
index 38cb7fac3..9c3471978 100644
--- a/src/DurableTask.AzureStorage/AzureStorageOrchestrationService.cs
+++ b/src/DurableTask.AzureStorage/AzureStorageOrchestrationService.cs
@@ -478,6 +478,15 @@ public async Task StopAsync(bool isForced)
{
this.shutdownSource.Cancel();
await this.statsLoop;
+
+ if (isForced)
+ {
+ // When forced, immediately remove all active sessions so that
+ // partition draining completes without waiting for sessions to
+ // finish their idle timeout or in-flight work.
+ this.orchestrationSessionManager.AbortAllSessions();
+ }
+
await this.appLeaseManager.StopAsync();
this.isStarted = false;
}
diff --git a/src/DurableTask.AzureStorage/Messaging/OrchestrationSession.cs b/src/DurableTask.AzureStorage/Messaging/OrchestrationSession.cs
index 1b2e4a20e..d30a1b40f 100644
--- a/src/DurableTask.AzureStorage/Messaging/OrchestrationSession.cs
+++ b/src/DurableTask.AzureStorage/Messaging/OrchestrationSession.cs
@@ -17,6 +17,7 @@ namespace DurableTask.AzureStorage.Messaging
using System.Collections.Generic;
using System.IO;
using System.Linq;
+ using System.Threading;
using System.Threading.Tasks;
using Azure;
using DurableTask.Core;
@@ -26,6 +27,7 @@ namespace DurableTask.AzureStorage.Messaging
sealed class OrchestrationSession : SessionBase, IOrchestrationSession
{
readonly TimeSpan idleTimeout;
+ readonly CancellationToken shutdownToken;
readonly AsyncAutoResetEvent messagesAvailableEvent;
readonly MessageCollection nextMessageBatch;
@@ -41,10 +43,12 @@ public OrchestrationSession(
DateTime lastCheckpointTime,
object trackingStoreContext,
TimeSpan idleTimeout,
+ CancellationToken shutdownToken,
Guid traceActivityId)
: base(settings, storageAccountName, orchestrationInstance, traceActivityId)
{
this.idleTimeout = idleTimeout;
+ this.shutdownToken = shutdownToken;
this.ControlQueue = controlQueue ?? throw new ArgumentNullException(nameof(controlQueue));
this.CurrentMessageBatch = initialMessageBatch ?? throw new ArgumentNullException(nameof(initialMessageBatch));
this.RuntimeState = runtimeState ?? throw new ArgumentNullException(nameof(runtimeState));
@@ -98,9 +102,9 @@ public void AddOrReplaceMessages(IEnumerable messages)
public async Task> FetchNewOrchestrationMessagesAsync(
TaskOrchestrationWorkItem workItem)
{
- if (!await this.messagesAvailableEvent.WaitAsync(this.idleTimeout))
+ if (!await this.messagesAvailableEvent.WaitAsync(this.idleTimeout, this.shutdownToken))
{
- return null; // timed-out
+ return null; // timed-out or shutting down
}
this.StartNewLogicalTraceScope();
diff --git a/src/DurableTask.AzureStorage/OrchestrationSessionManager.cs b/src/DurableTask.AzureStorage/OrchestrationSessionManager.cs
index 7ecae232d..abf7a58b2 100644
--- a/src/DurableTask.AzureStorage/OrchestrationSessionManager.cs
+++ b/src/DurableTask.AzureStorage/OrchestrationSessionManager.cs
@@ -43,6 +43,8 @@ class OrchestrationSessionManager : IDisposable
readonly ITrackingStore trackingStore;
readonly DispatchQueue fetchRuntimeStateQueue;
+ CancellationToken shutdownToken;
+
public OrchestrationSessionManager(
string queueAccountName,
AzureStorageOrchestrationServiceSettings settings,
@@ -61,6 +63,8 @@ public OrchestrationSessionManager(
public void AddQueue(string partitionId, ControlQueue controlQueue, CancellationToken cancellationToken)
{
+ this.shutdownToken = cancellationToken;
+
if (this.ownedControlQueues.TryAdd(partitionId, controlQueue))
{
_ = Task.Run(() => this.DequeueLoop(partitionId, controlQueue, cancellationToken));
@@ -613,6 +617,7 @@ async Task ScheduleOrchestrationStatePrefetch(
nextBatch.LastCheckpointTime,
nextBatch.TrackingStoreContext,
this.settings.ExtendedSessionIdleTimeout,
+ this.shutdownToken,
traceActivityId);
this.activeOrchestrationSessions.Add(instance.InstanceId, session);
@@ -656,6 +661,19 @@ async Task ScheduleOrchestrationStatePrefetch(
return null;
}
+ ///
+ /// Immediately removes all active sessions, causing
+ /// to return false for all partitions. This unblocks so that
+ /// a forced shutdown can complete without waiting for sessions to drain naturally.
+ ///
+ public void AbortAllSessions()
+ {
+ lock (this.messageAndSessionLock)
+ {
+ this.activeOrchestrationSessions.Clear();
+ }
+ }
+
public bool TryGetExistingSession(string instanceId, out OrchestrationSession session)
{
lock (this.messageAndSessionLock)