From 258a96d3b0cfe243d94177ed1181fa1a0f4925e8 Mon Sep 17 00:00:00 2001 From: Min-Gu Yoo Date: Wed, 1 Apr 2026 20:20:08 -0700 Subject: [PATCH 1/2] (fix): Type-stable task-local pool accessors for CPU/CUDA/Metal Return type assertions on get_task_local_*_pool() were missing type parameters, causing type instability when retrieved from task-local storage (IdDict{Any,Any}). Added concrete parametric types to all return assertions and Dict value types. Added @inferred regression tests for all three backends covering both fast path (existing pool) and slow path (fresh task creation). --- .../task_local_pool.jl | 6 +++--- .../task_local_pool.jl | 6 +++--- src/task_local_pool.jl | 2 +- test/cuda/test_extension.jl | 19 +++++++++++++++++-- test/metal/runtests.jl | 1 + test/metal/test_task_local_pool.jl | 19 +++++++++++++++++-- test/test_task_local_pool.jl | 15 +++++++++++++++ 7 files changed, 57 insertions(+), 11 deletions(-) diff --git a/ext/AdaptiveArrayPoolsCUDAExt/task_local_pool.jl b/ext/AdaptiveArrayPoolsCUDAExt/task_local_pool.jl index 7b11c230..098fa2eb 100644 --- a/ext/AdaptiveArrayPoolsCUDAExt/task_local_pool.jl +++ b/ext/AdaptiveArrayPoolsCUDAExt/task_local_pool.jl @@ -25,7 +25,7 @@ Values are `CuAdaptiveArrayPool{S}` where S is determined by `RUNTIME_CHECK`. # 1. Get or create the pools dictionary pools = get(task_local_storage(), _CU_POOL_KEY, nothing) if pools === nothing - pools = Dict{Int, CuAdaptiveArrayPool}() + pools = Dict{Int, CuAdaptiveArrayPool{RUNTIME_CHECK}}() task_local_storage(_CU_POOL_KEY, pools) end @@ -39,7 +39,7 @@ Values are `CuAdaptiveArrayPool{S}` where S is determined by `RUNTIME_CHECK`. pools[dev_id] = pool end - return pool::CuAdaptiveArrayPool + return pool::CuAdaptiveArrayPool{RUNTIME_CHECK} end """ @@ -51,7 +51,7 @@ Useful for diagnostics or bulk operations across all devices. @inline function AdaptiveArrayPools.get_task_local_cuda_pools() pools = get(task_local_storage(), _CU_POOL_KEY, nothing) if pools === nothing - pools = Dict{Int, CuAdaptiveArrayPool}() + pools = Dict{Int, CuAdaptiveArrayPool{RUNTIME_CHECK}}() task_local_storage(_CU_POOL_KEY, pools) end return pools diff --git a/ext/AdaptiveArrayPoolsMetalExt/task_local_pool.jl b/ext/AdaptiveArrayPoolsMetalExt/task_local_pool.jl index c60c9430..92a1e90e 100644 --- a/ext/AdaptiveArrayPoolsMetalExt/task_local_pool.jl +++ b/ext/AdaptiveArrayPoolsMetalExt/task_local_pool.jl @@ -25,7 +25,7 @@ Values are `MetalAdaptiveArrayPool{R,S}` where R is determined by `RUNTIME_CHECK # 1. Get or create the pools dictionary pools = get(task_local_storage(), _METAL_POOL_KEY, nothing) if pools === nothing - pools = Dict{UInt64, MetalAdaptiveArrayPool}() + pools = Dict{UInt64, MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}}() task_local_storage(_METAL_POOL_KEY, pools) end @@ -40,7 +40,7 @@ Values are `MetalAdaptiveArrayPool{R,S}` where R is determined by `RUNTIME_CHECK pools[dev_key] = pool end - return pool::MetalAdaptiveArrayPool + return pool::MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE} end """ @@ -52,7 +52,7 @@ Useful for diagnostics or bulk operations across all devices. @inline function AdaptiveArrayPools.get_task_local_metal_pools() pools = get(task_local_storage(), _METAL_POOL_KEY, nothing) if pools === nothing - pools = Dict{UInt64, MetalAdaptiveArrayPool}() + pools = Dict{UInt64, MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}}() task_local_storage(_METAL_POOL_KEY, pools) end return pools diff --git a/src/task_local_pool.jl b/src/task_local_pool.jl index 3f4f14c9..2c659bb3 100644 --- a/src/task_local_pool.jl +++ b/src/task_local_pool.jl @@ -98,7 +98,7 @@ type-asserts directly to `AdaptiveArrayPool{RUNTIME_CHECK}`. task_local_storage(_POOL_KEY, pool) end - return pool::AdaptiveArrayPool + return pool::AdaptiveArrayPool{RUNTIME_CHECK} end # ============================================================================== diff --git a/test/cuda/test_extension.jl b/test/cuda/test_extension.jl index 0638956b..a8cfd2bd 100644 --- a/test/cuda/test_extension.jl +++ b/test/cuda/test_extension.jl @@ -92,6 +92,21 @@ end end @testset "Task-Local Pool" begin + @testset "get_task_local_cuda_pool type stability" begin + # Fast path: pool already exists in task-local storage + pool = @inferred get_task_local_cuda_pool() + @test pool isa CuAdaptiveArrayPool{RUNTIME_CHECK} + + # Slow path: fresh task creates a new pool + result = fetch( + Threads.@spawn begin + p = @inferred get_task_local_cuda_pool() + p isa CuAdaptiveArrayPool{RUNTIME_CHECK} + end + ) + @test result == true + end + @testset "get_task_local_cuda_pool" begin pool1 = get_task_local_cuda_pool() @test pool1 isa CuAdaptiveArrayPool @@ -103,7 +118,7 @@ end @testset "get_task_local_cuda_pools" begin pools_dict = get_task_local_cuda_pools() - @test pools_dict isa Dict{Int, CuAdaptiveArrayPool} + @test pools_dict isa Dict{Int, CuAdaptiveArrayPool{RUNTIME_CHECK}} pool = get_task_local_cuda_pool() @test haskey(pools_dict, pool.device_id) end @@ -114,7 +129,7 @@ end Threads.@spawn begin # Call get_task_local_cuda_pools() FIRST (before get_task_local_cuda_pool) pools = get_task_local_cuda_pools() - @test pools isa Dict{Int, CuAdaptiveArrayPool} + @test pools isa Dict{Int, CuAdaptiveArrayPool{RUNTIME_CHECK}} @test isempty(pools) # No pools created yet true end diff --git a/test/metal/runtests.jl b/test/metal/runtests.jl index 8984dfa6..a327908f 100644 --- a/test/metal/runtests.jl +++ b/test/metal/runtests.jl @@ -37,6 +37,7 @@ else const MetalTypedPool = ext.MetalTypedPool const MetalAdaptiveArrayPool = ext.MetalAdaptiveArrayPool const METAL_FIXED_SLOT_FIELDS = ext.METAL_FIXED_SLOT_FIELDS + const METAL_STORAGE = ext.METAL_STORAGE # Include all Metal test files include("test_extension.jl") diff --git a/test/metal/test_task_local_pool.jl b/test/metal/test_task_local_pool.jl index 9a8c104d..c0c05627 100644 --- a/test/metal/test_task_local_pool.jl +++ b/test/metal/test_task_local_pool.jl @@ -2,6 +2,21 @@ @testset "Metal Task-Local Pool" begin + @testset "get_task_local_metal_pool type stability" begin + # Fast path: pool already exists + pool = @inferred get_task_local_metal_pool() + @test pool isa MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE} + + # Slow path: fresh task + result = fetch( + Threads.@spawn begin + p = @inferred get_task_local_metal_pool() + p isa MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE} + end + ) + @test result == true + end + @testset "get_task_local_metal_pool" begin pool1 = get_task_local_metal_pool() @test pool1 isa MetalAdaptiveArrayPool @@ -13,7 +28,7 @@ @testset "get_task_local_metal_pools" begin pools_dict = get_task_local_metal_pools() - @test pools_dict isa Dict{UInt64, MetalAdaptiveArrayPool} + @test pools_dict isa Dict{UInt64, MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}} pool = get_task_local_metal_pool() dev_key = objectid(Metal.device()) @test haskey(pools_dict, dev_key) @@ -23,7 +38,7 @@ result = fetch( Threads.@spawn begin pools = get_task_local_metal_pools() - @test pools isa Dict{UInt64, MetalAdaptiveArrayPool} + @test pools isa Dict{UInt64, MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}} @test isempty(pools) true end diff --git a/test/test_task_local_pool.jl b/test/test_task_local_pool.jl index d4ef765b..31d5b193 100644 --- a/test/test_task_local_pool.jl +++ b/test/test_task_local_pool.jl @@ -2,6 +2,21 @@ _test_accumulator = Ref(0.0) + @testset "get_task_local_pool type stability" begin + # Fast path: pool already exists in task-local storage + pool = @inferred get_task_local_pool() + @test pool isa AdaptiveArrayPool{RUNTIME_CHECK} + + # Slow path: fresh task creates a new pool + result = fetch( + Threads.@spawn begin + p = @inferred get_task_local_pool() + p isa AdaptiveArrayPool{RUNTIME_CHECK} + end + ) + @test result == true + end + @testset "@with_pool" begin # Define a function that takes pool as argument function global_test(n, pool) From b6330d7a5148ac1855bdee7da582202cbce65174 Mon Sep 17 00:00:00 2001 From: Min-Gu Yoo Date: Wed, 1 Apr 2026 20:46:50 -0700 Subject: [PATCH 2/2] (docs): Update docstring return types to match concrete parametric types --- ext/AdaptiveArrayPoolsCUDAExt/task_local_pool.jl | 5 ++--- ext/AdaptiveArrayPoolsMetalExt/task_local_pool.jl | 5 ++--- src/task_local_pool.jl | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/ext/AdaptiveArrayPoolsCUDAExt/task_local_pool.jl b/ext/AdaptiveArrayPoolsCUDAExt/task_local_pool.jl index 098fa2eb..2306652a 100644 --- a/ext/AdaptiveArrayPoolsCUDAExt/task_local_pool.jl +++ b/ext/AdaptiveArrayPoolsCUDAExt/task_local_pool.jl @@ -18,8 +18,7 @@ a dictionary of pools (one per device) in task-local storage, ensuring that: - Switching devices (`CUDA.device!(n)`) gets the correct pool ## Implementation -Uses `Dict{Int, CuAdaptiveArrayPool}` in task-local storage, keyed by device ID. -Values are `CuAdaptiveArrayPool{S}` where S is determined by `RUNTIME_CHECK`. +Uses `Dict{Int, CuAdaptiveArrayPool{RUNTIME_CHECK}}` in task-local storage, keyed by device ID. """ @inline function AdaptiveArrayPools.get_task_local_cuda_pool() # 1. Get or create the pools dictionary @@ -43,7 +42,7 @@ Values are `CuAdaptiveArrayPool{S}` where S is determined by `RUNTIME_CHECK`. end """ - get_task_local_cuda_pools() -> Dict{Int, CuAdaptiveArrayPool} + get_task_local_cuda_pools() -> Dict{Int, CuAdaptiveArrayPool{RUNTIME_CHECK}} Returns the dictionary of all CUDA pools for the current task (one per device). Useful for diagnostics or bulk operations across all devices. diff --git a/ext/AdaptiveArrayPoolsMetalExt/task_local_pool.jl b/ext/AdaptiveArrayPoolsMetalExt/task_local_pool.jl index 92a1e90e..2e3e5796 100644 --- a/ext/AdaptiveArrayPoolsMetalExt/task_local_pool.jl +++ b/ext/AdaptiveArrayPoolsMetalExt/task_local_pool.jl @@ -18,8 +18,7 @@ a dictionary of pools (one per device) in task-local storage, ensuring that: - Switching devices gets the correct pool ## Implementation -Uses `Dict{UInt64, MetalAdaptiveArrayPool}` in task-local storage, keyed by device hash. -Values are `MetalAdaptiveArrayPool{R,S}` where R is determined by `RUNTIME_CHECK`. +Uses `Dict{UInt64, MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}}` in task-local storage, keyed by device hash. """ @inline function AdaptiveArrayPools.get_task_local_metal_pool() # 1. Get or create the pools dictionary @@ -44,7 +43,7 @@ Values are `MetalAdaptiveArrayPool{R,S}` where R is determined by `RUNTIME_CHECK end """ - get_task_local_metal_pools() -> Dict{UInt64, MetalAdaptiveArrayPool} + get_task_local_metal_pools() -> Dict{UInt64, MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}} Returns the dictionary of all Metal pools for the current task (one per device). Useful for diagnostics or bulk operations across all devices. diff --git a/src/task_local_pool.jl b/src/task_local_pool.jl index 2c659bb3..f1ec8f84 100644 --- a/src/task_local_pool.jl +++ b/src/task_local_pool.jl @@ -75,7 +75,7 @@ const MAYBE_POOLING_ENABLED = MAYBE_POOLING const _POOL_KEY = :ADAPTIVE_ARRAY_POOL """ - get_task_local_pool() -> AdaptiveArrayPool + get_task_local_pool() -> AdaptiveArrayPool{RUNTIME_CHECK} Retrieves (or creates) the `AdaptiveArrayPool` for the current Task.