Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 227 additions & 0 deletions Test/DurableTask.AzureStorage.Tests/OrchestrationSessionTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
// ----------------------------------------------------------------------------------
// Copyright Microsoft Corporation
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// ----------------------------------------------------------------------------------

namespace DurableTask.AzureStorage.Tests
{
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using DurableTask.AzureStorage.Messaging;
using DurableTask.AzureStorage.Monitoring;
using DurableTask.AzureStorage.Tracking;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Moq;

/// <summary>
/// Tests for shutdown cancellation behavior with extended sessions.
/// </summary>
[TestClass]
public class OrchestrationSessionTests
{
/// <summary>
/// Verifies that <see cref="AsyncAutoResetEvent.WaitAsync(TimeSpan, CancellationToken)"/>
/// exits immediately when the cancellation token is cancelled.
/// </summary>
[TestMethod]
public async Task WaitAsync_CancellationToken_ExitsImmediately()
{
var resetEvent = new AsyncAutoResetEvent(signaled: false);
using var cts = new CancellationTokenSource();

TimeSpan longTimeout = TimeSpan.FromSeconds(30);
Task<bool> waitTask = resetEvent.WaitAsync(longTimeout, cts.Token);

Assert.IsFalse(waitTask.IsCompleted, "Wait should not complete immediately");

var stopwatch = Stopwatch.StartNew();
cts.Cancel();

bool result = await waitTask;
stopwatch.Stop();

Assert.IsFalse(result, "Cancellation should return false (no signal received)");
Assert.IsTrue(
stopwatch.ElapsedMilliseconds < 5000,
$"Cancellation should complete in under 5s, but took {stopwatch.ElapsedMilliseconds}ms");
}

/// <summary>
/// Verifies that signaling still returns true when a cancellation token is provided.
/// </summary>
[TestMethod]
public async Task WaitAsync_WithCancellationToken_SignalStillWorks()
{
var resetEvent = new AsyncAutoResetEvent(signaled: false);
using var cts = new CancellationTokenSource();

Task<bool> waitTask = resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token);
Assert.IsFalse(waitTask.IsCompleted);

resetEvent.Set();

Task winner = await Task.WhenAny(waitTask, Task.Delay(TimeSpan.FromSeconds(5)));
Assert.IsTrue(winner == waitTask, "Signal should wake the waiter");
Assert.IsTrue(waitTask.Result, "Wait result should be true when signaled");
}

/// <summary>
/// Verifies that the wait returns false on timeout when a cancellation token is provided but not cancelled.
/// </summary>
[TestMethod]
public async Task WaitAsync_WithCancellationToken_TimeoutStillWorks()
{
var resetEvent = new AsyncAutoResetEvent(signaled: false);
using var cts = new CancellationTokenSource();

bool result = await resetEvent.WaitAsync(TimeSpan.FromMilliseconds(100), cts.Token);

Assert.IsFalse(result, "Wait should return false on timeout");
}

/// <summary>
/// Verifies that all queued waiters return false when the token is cancelled.
/// </summary>
[TestMethod]
public async Task WaitAsync_CancellationToken_MultipleWaiters()
{
var resetEvent = new AsyncAutoResetEvent(signaled: false);
using var cts = new CancellationTokenSource();

var waiters = new List<Task<bool>>();
for (int i = 0; i < 5; i++)
{
waiters.Add(resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token));
}

foreach (var waiter in waiters)
{
Assert.IsFalse(waiter.IsCompleted);
}

var stopwatch = Stopwatch.StartNew();
cts.Cancel();

// All waiters should return false (cancelled = not signaled)
await Task.WhenAll(
waiters.Select(
async waiter =>
{
bool result = await waiter;
Assert.IsFalse(result, "Cancelled waiter should return false");
}));

stopwatch.Stop();

Assert.IsTrue(
stopwatch.ElapsedMilliseconds < 5000,
$"All waiters should complete in under 5s, but took {stopwatch.ElapsedMilliseconds}ms");
}

/// <summary>
/// Verifies that a pre-cancelled token causes WaitAsync to return false immediately.
/// </summary>
[TestMethod]
public async Task WaitAsync_AlreadyCancelledToken_ReturnsFalseImmediately()
{
var resetEvent = new AsyncAutoResetEvent(signaled: false);
using var cts = new CancellationTokenSource();
cts.Cancel(); // Pre-cancel

var stopwatch = Stopwatch.StartNew();
bool result = await resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token);
stopwatch.Stop();

Assert.IsFalse(result, "Pre-cancelled token should cause immediate false return");
Assert.IsTrue(
stopwatch.ElapsedMilliseconds < 5000,
$"Should complete immediately, but took {stopwatch.ElapsedMilliseconds}ms");
}

/// <summary>
/// Verifies that a pre-cancelled token still returns true if the event is already signaled.
/// </summary>
[TestMethod]
public async Task WaitAsync_AlreadySignaledAndCancelled_ReturnsTrue()
{
var resetEvent = new AsyncAutoResetEvent(signaled: true);
using var cts = new CancellationTokenSource();
cts.Cancel();

bool result = await resetEvent.WaitAsync(TimeSpan.FromSeconds(30), cts.Token);
Assert.IsTrue(result, "Already signaled event should return true even with cancelled token");
}

/// <summary>
/// Verifies that <see cref="OrchestrationSessionManager.AbortAllSessions"/> clears all active sessions.
/// </summary>
[TestMethod]
public void AbortAllSessions_ClearsActiveSessions()
{
var settings = new AzureStorageOrchestrationServiceSettings();
var stats = new AzureStorageOrchestrationServiceStats();
var trackingStore = new Mock<ITrackingStore>();

using var manager = new OrchestrationSessionManager(
"testaccount",
settings,
stats,
trackingStore.Object);

// Use reflection to access the internal sessions dictionary.
var sessionsField = typeof(OrchestrationSessionManager)
.GetField("activeOrchestrationSessions", BindingFlags.NonPublic | BindingFlags.Instance);
var sessions = (Dictionary<string, OrchestrationSession>)sessionsField.GetValue(manager);
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test uses reflection to reach into the private activeOrchestrationSessions field and assumes both the exact field name and concrete type. That makes the test brittle (a simple rename/type change becomes a runtime NullReferenceException/InvalidCastException). Consider asserting sessionsField != null before GetValue, and prefer exercising the manager via exposed/internal APIs (or an internal test hook) to avoid hard-coding private implementation details.

Suggested change
var sessions = (Dictionary<string, OrchestrationSession>)sessionsField.GetValue(manager);
Assert.IsNotNull(
sessionsField,
"Expected OrchestrationSessionManager to have a non-public instance field named 'activeOrchestrationSessions'.");
var sessions = sessionsField.GetValue(manager) as IDictionary<string, OrchestrationSession>;
Assert.IsNotNull(
sessions,
"Expected 'activeOrchestrationSessions' field to be assignable to IDictionary<string, OrchestrationSession>.");

Copilot uses AI. Check for mistakes.

manager.GetStats(out _, out _, out int initialCount);
Assert.AreEqual(0, initialCount, "Should start with no active sessions");

sessions["instance1"] = null;
sessions["instance2"] = null;
sessions["instance3"] = null;

manager.GetStats(out _, out _, out int activeCount);
Assert.AreEqual(3, activeCount, "Should have 3 active sessions");

manager.AbortAllSessions();

manager.GetStats(out _, out _, out int afterAbortCount);
Assert.AreEqual(0, afterAbortCount, "AbortAllSessions should clear all active sessions");
}

/// <summary>
/// Verifies that <see cref="OrchestrationSessionManager.AbortAllSessions"/> is safe to call with no active sessions.
/// </summary>
[TestMethod]
public void AbortAllSessions_NoSessions_DoesNotThrow()
{
var settings = new AzureStorageOrchestrationServiceSettings();
var stats = new AzureStorageOrchestrationServiceStats();
var trackingStore = new Mock<ITrackingStore>();

using var manager = new OrchestrationSessionManager(
"testaccount",
settings,
stats,
trackingStore.Object);

manager.AbortAllSessions();

manager.GetStats(out _, out _, out int count);
Assert.AreEqual(0, count, "Should still have no active sessions");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,15 @@ public async Task StopAsync(bool isForced)
{
this.shutdownSource.Cancel();
await this.statsLoop;

if (isForced)
{
// When forced, immediately remove all active sessions so that
// partition draining completes without waiting for sessions to
// finish their idle timeout or in-flight work.
this.orchestrationSessionManager.AbortAllSessions();
}

await this.appLeaseManager.StopAsync();
this.isStarted = false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace DurableTask.AzureStorage.Messaging
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Azure;
using DurableTask.Core;
Expand All @@ -26,6 +27,7 @@ namespace DurableTask.AzureStorage.Messaging
sealed class OrchestrationSession : SessionBase, IOrchestrationSession
{
readonly TimeSpan idleTimeout;
readonly CancellationToken shutdownToken;

readonly AsyncAutoResetEvent messagesAvailableEvent;
readonly MessageCollection nextMessageBatch;
Expand All @@ -41,10 +43,12 @@ public OrchestrationSession(
DateTime lastCheckpointTime,
object trackingStoreContext,
TimeSpan idleTimeout,
CancellationToken shutdownToken,
Guid traceActivityId)
: base(settings, storageAccountName, orchestrationInstance, traceActivityId)
{
this.idleTimeout = idleTimeout;
this.shutdownToken = shutdownToken;
this.ControlQueue = controlQueue ?? throw new ArgumentNullException(nameof(controlQueue));
this.CurrentMessageBatch = initialMessageBatch ?? throw new ArgumentNullException(nameof(initialMessageBatch));
this.RuntimeState = runtimeState ?? throw new ArgumentNullException(nameof(runtimeState));
Expand Down Expand Up @@ -98,9 +102,9 @@ public void AddOrReplaceMessages(IEnumerable<MessageData> messages)
public async Task<IList<TaskMessage>> FetchNewOrchestrationMessagesAsync(
TaskOrchestrationWorkItem workItem)
{
if (!await this.messagesAvailableEvent.WaitAsync(this.idleTimeout))
if (!await this.messagesAvailableEvent.WaitAsync(this.idleTimeout, this.shutdownToken))
{
return null; // timed-out
return null; // timed-out or shutting down
}

this.StartNewLogicalTraceScope();
Expand Down
18 changes: 18 additions & 0 deletions src/DurableTask.AzureStorage/OrchestrationSessionManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class OrchestrationSessionManager : IDisposable
readonly ITrackingStore trackingStore;
readonly DispatchQueue fetchRuntimeStateQueue;

CancellationToken shutdownToken;

public OrchestrationSessionManager(
string queueAccountName,
AzureStorageOrchestrationServiceSettings settings,
Expand All @@ -61,6 +63,8 @@ public OrchestrationSessionManager(

public void AddQueue(string partitionId, ControlQueue controlQueue, CancellationToken cancellationToken)
{
this.shutdownToken = cancellationToken;

if (this.ownedControlQueues.TryAdd(partitionId, controlQueue))
{
_ = Task.Run(() => this.DequeueLoop(partitionId, controlQueue, cancellationToken));
Expand Down Expand Up @@ -613,6 +617,7 @@ async Task ScheduleOrchestrationStatePrefetch(
nextBatch.LastCheckpointTime,
nextBatch.TrackingStoreContext,
this.settings.ExtendedSessionIdleTimeout,
this.shutdownToken,
traceActivityId);

this.activeOrchestrationSessions.Add(instance.InstanceId, session);
Expand Down Expand Up @@ -656,6 +661,19 @@ async Task ScheduleOrchestrationStatePrefetch(
return null;
}

/// <summary>
/// Immediately removes all active sessions, causing <see cref="IsControlQueueProcessingMessages"/>
/// to return <c>false</c> for all partitions. This unblocks <see cref="DrainAsync"/> so that
/// a forced shutdown can complete without waiting for sessions to drain naturally.
/// </summary>
public void AbortAllSessions()
{
lock (this.messageAndSessionLock)
{
this.activeOrchestrationSessions.Clear();
Comment on lines +672 to +673
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AbortAllSessions() clears only activeOrchestrationSessions, but leaves pendingOrchestrationMessageBatches and the ready-for-processing AsyncQueue buffers intact. Because OrchestrationSessionManager is reused across StartAsync/StopAsync, a forced stop could leave stale in-memory batches that may get processed after a subsequent StartAsync (or just retain memory). Consider also clearing the pending/ready buffers (or recreating the OrchestrationSessionManager on restart) when aborting sessions.

Suggested change
{
this.activeOrchestrationSessions.Clear();
{
// Clear any active sessions as well as all pending in-memory message batches
// so that no stale work is processed after a subsequent StartAsync.
this.activeOrchestrationSessions.Clear();
this.pendingOrchestrationMessageBatches.Clear();

Copilot uses AI. Check for mistakes.
}
}

public bool TryGetExistingSession(string instanceId, out OrchestrationSession session)
{
lock (this.messageAndSessionLock)
Expand Down
Loading