diff --git a/Test/DurableTask.AzureStorage.Tests/OrchestrationSessionTests.cs b/Test/DurableTask.AzureStorage.Tests/OrchestrationSessionTests.cs new file mode 100644 index 000000000..126c4b9bc --- /dev/null +++ b/Test/DurableTask.AzureStorage.Tests/OrchestrationSessionTests.cs @@ -0,0 +1,227 @@ +// ---------------------------------------------------------------------------------- +// Copyright Microsoft Corporation +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ---------------------------------------------------------------------------------- + +namespace DurableTask.AzureStorage.Tests +{ + using System; + using System.Collections.Generic; + using System.Diagnostics; + using System.Linq; + using System.Reflection; + using System.Threading; + using System.Threading.Tasks; + using DurableTask.AzureStorage.Messaging; + using DurableTask.AzureStorage.Monitoring; + using DurableTask.AzureStorage.Tracking; + using Microsoft.VisualStudio.TestTools.UnitTesting; + using Moq; + + /// + /// Tests for shutdown cancellation behavior with extended sessions. + /// + [TestClass] + public class OrchestrationSessionTests + { + /// + /// Verifies that + /// exits immediately when the cancellation token is cancelled. + /// + [TestMethod] + public async Task WaitAsync_CancellationToken_ExitsImmediately() + { + var resetEvent = new AsyncAutoResetEvent(signaled: false); + using var cts = new CancellationTokenSource(); + + TimeSpan longTimeout = TimeSpan.FromSeconds(30); + Task waitTask = resetEvent.WaitAsync(longTimeout, cts.Token); + + Assert.IsFalse(waitTask.IsCompleted, "Wait should not complete immediately"); + + var stopwatch = Stopwatch.StartNew(); + cts.Cancel(); + + bool result = await waitTask; + stopwatch.Stop(); + + Assert.IsFalse(result, "Cancellation should return false (no signal received)"); + Assert.IsTrue( + stopwatch.ElapsedMilliseconds < 5000, + $"Cancellation should complete in under 5s, but took {stopwatch.ElapsedMilliseconds}ms"); + } + + /// + /// Verifies that signaling still returns true when a cancellation token is provided. + /// + [TestMethod] + public async Task WaitAsync_WithCancellationToken_SignalStillWorks() + { + var resetEvent = new AsyncAutoResetEvent(signaled: false); + using var cts = new CancellationTokenSource(); + + Task waitTask = resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token); + Assert.IsFalse(waitTask.IsCompleted); + + resetEvent.Set(); + + Task winner = await Task.WhenAny(waitTask, Task.Delay(TimeSpan.FromSeconds(5))); + Assert.IsTrue(winner == waitTask, "Signal should wake the waiter"); + Assert.IsTrue(waitTask.Result, "Wait result should be true when signaled"); + } + + /// + /// Verifies that the wait returns false on timeout when a cancellation token is provided but not cancelled. + /// + [TestMethod] + public async Task WaitAsync_WithCancellationToken_TimeoutStillWorks() + { + var resetEvent = new AsyncAutoResetEvent(signaled: false); + using var cts = new CancellationTokenSource(); + + bool result = await resetEvent.WaitAsync(TimeSpan.FromMilliseconds(100), cts.Token); + + Assert.IsFalse(result, "Wait should return false on timeout"); + } + + /// + /// Verifies that all queued waiters return false when the token is cancelled. + /// + [TestMethod] + public async Task WaitAsync_CancellationToken_MultipleWaiters() + { + var resetEvent = new AsyncAutoResetEvent(signaled: false); + using var cts = new CancellationTokenSource(); + + var waiters = new List>(); + for (int i = 0; i < 5; i++) + { + waiters.Add(resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token)); + } + + foreach (var waiter in waiters) + { + Assert.IsFalse(waiter.IsCompleted); + } + + var stopwatch = Stopwatch.StartNew(); + cts.Cancel(); + + // All waiters should return false (cancelled = not signaled) + await Task.WhenAll( + waiters.Select( + async waiter => + { + bool result = await waiter; + Assert.IsFalse(result, "Cancelled waiter should return false"); + })); + + stopwatch.Stop(); + + Assert.IsTrue( + stopwatch.ElapsedMilliseconds < 5000, + $"All waiters should complete in under 5s, but took {stopwatch.ElapsedMilliseconds}ms"); + } + + /// + /// Verifies that a pre-cancelled token causes WaitAsync to return false immediately. + /// + [TestMethod] + public async Task WaitAsync_AlreadyCancelledToken_ReturnsFalseImmediately() + { + var resetEvent = new AsyncAutoResetEvent(signaled: false); + using var cts = new CancellationTokenSource(); + cts.Cancel(); // Pre-cancel + + var stopwatch = Stopwatch.StartNew(); + bool result = await resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token); + stopwatch.Stop(); + + Assert.IsFalse(result, "Pre-cancelled token should cause immediate false return"); + Assert.IsTrue( + stopwatch.ElapsedMilliseconds < 5000, + $"Should complete immediately, but took {stopwatch.ElapsedMilliseconds}ms"); + } + + /// + /// Verifies that a pre-cancelled token still returns true if the event is already signaled. + /// + [TestMethod] + public async Task WaitAsync_AlreadySignaledAndCancelled_ReturnsTrue() + { + var resetEvent = new AsyncAutoResetEvent(signaled: true); + using var cts = new CancellationTokenSource(); + cts.Cancel(); + + bool result = await resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token); + Assert.IsTrue(result, "Already signaled event should return true even with cancelled token"); + } + + /// + /// Verifies that clears all active sessions. + /// + [TestMethod] + public void AbortAllSessions_ClearsActiveSessions() + { + var settings = new AzureStorageOrchestrationServiceSettings(); + var stats = new AzureStorageOrchestrationServiceStats(); + var trackingStore = new Mock(); + + using var manager = new OrchestrationSessionManager( + "testaccount", + settings, + stats, + trackingStore.Object); + + // Use reflection to access the internal sessions dictionary. + var sessionsField = typeof(OrchestrationSessionManager) + .GetField("activeOrchestrationSessions", BindingFlags.NonPublic | BindingFlags.Instance); + var sessions = (Dictionary)sessionsField.GetValue(manager); + + manager.GetStats(out _, out _, out int initialCount); + Assert.AreEqual(0, initialCount, "Should start with no active sessions"); + + sessions["instance1"] = null; + sessions["instance2"] = null; + sessions["instance3"] = null; + + manager.GetStats(out _, out _, out int activeCount); + Assert.AreEqual(3, activeCount, "Should have 3 active sessions"); + + manager.AbortAllSessions(); + + manager.GetStats(out _, out _, out int afterAbortCount); + Assert.AreEqual(0, afterAbortCount, "AbortAllSessions should clear all active sessions"); + } + + /// + /// Verifies that is safe to call with no active sessions. + /// + [TestMethod] + public void AbortAllSessions_NoSessions_DoesNotThrow() + { + var settings = new AzureStorageOrchestrationServiceSettings(); + var stats = new AzureStorageOrchestrationServiceStats(); + var trackingStore = new Mock(); + + using var manager = new OrchestrationSessionManager( + "testaccount", + settings, + stats, + trackingStore.Object); + + manager.AbortAllSessions(); + + manager.GetStats(out _, out _, out int count); + Assert.AreEqual(0, count, "Should still have no active sessions"); + } + } +} diff --git a/src/DurableTask.AzureStorage/AzureStorageOrchestrationService.cs b/src/DurableTask.AzureStorage/AzureStorageOrchestrationService.cs index 38cb7fac3..9c3471978 100644 --- a/src/DurableTask.AzureStorage/AzureStorageOrchestrationService.cs +++ b/src/DurableTask.AzureStorage/AzureStorageOrchestrationService.cs @@ -478,6 +478,15 @@ public async Task StopAsync(bool isForced) { this.shutdownSource.Cancel(); await this.statsLoop; + + if (isForced) + { + // When forced, immediately remove all active sessions so that + // partition draining completes without waiting for sessions to + // finish their idle timeout or in-flight work. + this.orchestrationSessionManager.AbortAllSessions(); + } + await this.appLeaseManager.StopAsync(); this.isStarted = false; } diff --git a/src/DurableTask.AzureStorage/Messaging/OrchestrationSession.cs b/src/DurableTask.AzureStorage/Messaging/OrchestrationSession.cs index 1b2e4a20e..d30a1b40f 100644 --- a/src/DurableTask.AzureStorage/Messaging/OrchestrationSession.cs +++ b/src/DurableTask.AzureStorage/Messaging/OrchestrationSession.cs @@ -17,6 +17,7 @@ namespace DurableTask.AzureStorage.Messaging using System.Collections.Generic; using System.IO; using System.Linq; + using System.Threading; using System.Threading.Tasks; using Azure; using DurableTask.Core; @@ -26,6 +27,7 @@ namespace DurableTask.AzureStorage.Messaging sealed class OrchestrationSession : SessionBase, IOrchestrationSession { readonly TimeSpan idleTimeout; + readonly CancellationToken shutdownToken; readonly AsyncAutoResetEvent messagesAvailableEvent; readonly MessageCollection nextMessageBatch; @@ -41,10 +43,12 @@ public OrchestrationSession( DateTime lastCheckpointTime, object trackingStoreContext, TimeSpan idleTimeout, + CancellationToken shutdownToken, Guid traceActivityId) : base(settings, storageAccountName, orchestrationInstance, traceActivityId) { this.idleTimeout = idleTimeout; + this.shutdownToken = shutdownToken; this.ControlQueue = controlQueue ?? throw new ArgumentNullException(nameof(controlQueue)); this.CurrentMessageBatch = initialMessageBatch ?? throw new ArgumentNullException(nameof(initialMessageBatch)); this.RuntimeState = runtimeState ?? throw new ArgumentNullException(nameof(runtimeState)); @@ -98,9 +102,9 @@ public void AddOrReplaceMessages(IEnumerable messages) public async Task> FetchNewOrchestrationMessagesAsync( TaskOrchestrationWorkItem workItem) { - if (!await this.messagesAvailableEvent.WaitAsync(this.idleTimeout)) + if (!await this.messagesAvailableEvent.WaitAsync(this.idleTimeout, this.shutdownToken)) { - return null; // timed-out + return null; // timed-out or shutting down } this.StartNewLogicalTraceScope(); diff --git a/src/DurableTask.AzureStorage/OrchestrationSessionManager.cs b/src/DurableTask.AzureStorage/OrchestrationSessionManager.cs index 7ecae232d..abf7a58b2 100644 --- a/src/DurableTask.AzureStorage/OrchestrationSessionManager.cs +++ b/src/DurableTask.AzureStorage/OrchestrationSessionManager.cs @@ -43,6 +43,8 @@ class OrchestrationSessionManager : IDisposable readonly ITrackingStore trackingStore; readonly DispatchQueue fetchRuntimeStateQueue; + CancellationToken shutdownToken; + public OrchestrationSessionManager( string queueAccountName, AzureStorageOrchestrationServiceSettings settings, @@ -61,6 +63,8 @@ public OrchestrationSessionManager( public void AddQueue(string partitionId, ControlQueue controlQueue, CancellationToken cancellationToken) { + this.shutdownToken = cancellationToken; + if (this.ownedControlQueues.TryAdd(partitionId, controlQueue)) { _ = Task.Run(() => this.DequeueLoop(partitionId, controlQueue, cancellationToken)); @@ -613,6 +617,7 @@ async Task ScheduleOrchestrationStatePrefetch( nextBatch.LastCheckpointTime, nextBatch.TrackingStoreContext, this.settings.ExtendedSessionIdleTimeout, + this.shutdownToken, traceActivityId); this.activeOrchestrationSessions.Add(instance.InstanceId, session); @@ -656,6 +661,19 @@ async Task ScheduleOrchestrationStatePrefetch( return null; } + /// + /// Immediately removes all active sessions, causing + /// to return false for all partitions. This unblocks so that + /// a forced shutdown can complete without waiting for sessions to drain naturally. + /// + public void AbortAllSessions() + { + lock (this.messageAndSessionLock) + { + this.activeOrchestrationSessions.Clear(); + } + } + public bool TryGetExistingSession(string instanceId, out OrchestrationSession session) { lock (this.messageAndSessionLock)