diff --git a/ext/AdaptiveArrayPoolsCUDAExt/task_local_pool.jl b/ext/AdaptiveArrayPoolsCUDAExt/task_local_pool.jl index 7b11c23..2306652 100644 --- a/ext/AdaptiveArrayPoolsCUDAExt/task_local_pool.jl +++ b/ext/AdaptiveArrayPoolsCUDAExt/task_local_pool.jl @@ -18,14 +18,13 @@ a dictionary of pools (one per device) in task-local storage, ensuring that: - Switching devices (`CUDA.device!(n)`) gets the correct pool ## Implementation -Uses `Dict{Int, CuAdaptiveArrayPool}` in task-local storage, keyed by device ID. -Values are `CuAdaptiveArrayPool{S}` where S is determined by `RUNTIME_CHECK`. +Uses `Dict{Int, CuAdaptiveArrayPool{RUNTIME_CHECK}}` in task-local storage, keyed by device ID. """ @inline function AdaptiveArrayPools.get_task_local_cuda_pool() # 1. Get or create the pools dictionary pools = get(task_local_storage(), _CU_POOL_KEY, nothing) if pools === nothing - pools = Dict{Int, CuAdaptiveArrayPool}() + pools = Dict{Int, CuAdaptiveArrayPool{RUNTIME_CHECK}}() task_local_storage(_CU_POOL_KEY, pools) end @@ -39,11 +38,11 @@ Values are `CuAdaptiveArrayPool{S}` where S is determined by `RUNTIME_CHECK`. pools[dev_id] = pool end - return pool::CuAdaptiveArrayPool + return pool::CuAdaptiveArrayPool{RUNTIME_CHECK} end """ - get_task_local_cuda_pools() -> Dict{Int, CuAdaptiveArrayPool} + get_task_local_cuda_pools() -> Dict{Int, CuAdaptiveArrayPool{RUNTIME_CHECK}} Returns the dictionary of all CUDA pools for the current task (one per device). Useful for diagnostics or bulk operations across all devices. @@ -51,7 +50,7 @@ Useful for diagnostics or bulk operations across all devices. @inline function AdaptiveArrayPools.get_task_local_cuda_pools() pools = get(task_local_storage(), _CU_POOL_KEY, nothing) if pools === nothing - pools = Dict{Int, CuAdaptiveArrayPool}() + pools = Dict{Int, CuAdaptiveArrayPool{RUNTIME_CHECK}}() task_local_storage(_CU_POOL_KEY, pools) end return pools diff --git a/ext/AdaptiveArrayPoolsMetalExt/task_local_pool.jl b/ext/AdaptiveArrayPoolsMetalExt/task_local_pool.jl index c60c943..2e3e579 100644 --- a/ext/AdaptiveArrayPoolsMetalExt/task_local_pool.jl +++ b/ext/AdaptiveArrayPoolsMetalExt/task_local_pool.jl @@ -18,14 +18,13 @@ a dictionary of pools (one per device) in task-local storage, ensuring that: - Switching devices gets the correct pool ## Implementation -Uses `Dict{UInt64, MetalAdaptiveArrayPool}` in task-local storage, keyed by device hash. -Values are `MetalAdaptiveArrayPool{R,S}` where R is determined by `RUNTIME_CHECK`. +Uses `Dict{UInt64, MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}}` in task-local storage, keyed by device hash. """ @inline function AdaptiveArrayPools.get_task_local_metal_pool() # 1. Get or create the pools dictionary pools = get(task_local_storage(), _METAL_POOL_KEY, nothing) if pools === nothing - pools = Dict{UInt64, MetalAdaptiveArrayPool}() + pools = Dict{UInt64, MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}}() task_local_storage(_METAL_POOL_KEY, pools) end @@ -40,11 +39,11 @@ Values are `MetalAdaptiveArrayPool{R,S}` where R is determined by `RUNTIME_CHECK pools[dev_key] = pool end - return pool::MetalAdaptiveArrayPool + return pool::MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE} end """ - get_task_local_metal_pools() -> Dict{UInt64, MetalAdaptiveArrayPool} + get_task_local_metal_pools() -> Dict{UInt64, MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}} Returns the dictionary of all Metal pools for the current task (one per device). Useful for diagnostics or bulk operations across all devices. @@ -52,7 +51,7 @@ Useful for diagnostics or bulk operations across all devices. @inline function AdaptiveArrayPools.get_task_local_metal_pools() pools = get(task_local_storage(), _METAL_POOL_KEY, nothing) if pools === nothing - pools = Dict{UInt64, MetalAdaptiveArrayPool}() + pools = Dict{UInt64, MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}}() task_local_storage(_METAL_POOL_KEY, pools) end return pools diff --git a/src/task_local_pool.jl b/src/task_local_pool.jl index 3f4f14c..f1ec8f8 100644 --- a/src/task_local_pool.jl +++ b/src/task_local_pool.jl @@ -75,7 +75,7 @@ const MAYBE_POOLING_ENABLED = MAYBE_POOLING const _POOL_KEY = :ADAPTIVE_ARRAY_POOL """ - get_task_local_pool() -> AdaptiveArrayPool + get_task_local_pool() -> AdaptiveArrayPool{RUNTIME_CHECK} Retrieves (or creates) the `AdaptiveArrayPool` for the current Task. @@ -98,7 +98,7 @@ type-asserts directly to `AdaptiveArrayPool{RUNTIME_CHECK}`. task_local_storage(_POOL_KEY, pool) end - return pool::AdaptiveArrayPool + return pool::AdaptiveArrayPool{RUNTIME_CHECK} end # ============================================================================== diff --git a/test/cuda/test_extension.jl b/test/cuda/test_extension.jl index 0638956..a8cfd2b 100644 --- a/test/cuda/test_extension.jl +++ b/test/cuda/test_extension.jl @@ -92,6 +92,21 @@ end end @testset "Task-Local Pool" begin + @testset "get_task_local_cuda_pool type stability" begin + # Fast path: pool already exists in task-local storage + pool = @inferred get_task_local_cuda_pool() + @test pool isa CuAdaptiveArrayPool{RUNTIME_CHECK} + + # Slow path: fresh task creates a new pool + result = fetch( + Threads.@spawn begin + p = @inferred get_task_local_cuda_pool() + p isa CuAdaptiveArrayPool{RUNTIME_CHECK} + end + ) + @test result == true + end + @testset "get_task_local_cuda_pool" begin pool1 = get_task_local_cuda_pool() @test pool1 isa CuAdaptiveArrayPool @@ -103,7 +118,7 @@ end @testset "get_task_local_cuda_pools" begin pools_dict = get_task_local_cuda_pools() - @test pools_dict isa Dict{Int, CuAdaptiveArrayPool} + @test pools_dict isa Dict{Int, CuAdaptiveArrayPool{RUNTIME_CHECK}} pool = get_task_local_cuda_pool() @test haskey(pools_dict, pool.device_id) end @@ -114,7 +129,7 @@ end Threads.@spawn begin # Call get_task_local_cuda_pools() FIRST (before get_task_local_cuda_pool) pools = get_task_local_cuda_pools() - @test pools isa Dict{Int, CuAdaptiveArrayPool} + @test pools isa Dict{Int, CuAdaptiveArrayPool{RUNTIME_CHECK}} @test isempty(pools) # No pools created yet true end diff --git a/test/metal/runtests.jl b/test/metal/runtests.jl index 8984dfa..a327908 100644 --- a/test/metal/runtests.jl +++ b/test/metal/runtests.jl @@ -37,6 +37,7 @@ else const MetalTypedPool = ext.MetalTypedPool const MetalAdaptiveArrayPool = ext.MetalAdaptiveArrayPool const METAL_FIXED_SLOT_FIELDS = ext.METAL_FIXED_SLOT_FIELDS + const METAL_STORAGE = ext.METAL_STORAGE # Include all Metal test files include("test_extension.jl") diff --git a/test/metal/test_task_local_pool.jl b/test/metal/test_task_local_pool.jl index 9a8c104..c0c0562 100644 --- a/test/metal/test_task_local_pool.jl +++ b/test/metal/test_task_local_pool.jl @@ -2,6 +2,21 @@ @testset "Metal Task-Local Pool" begin + @testset "get_task_local_metal_pool type stability" begin + # Fast path: pool already exists + pool = @inferred get_task_local_metal_pool() + @test pool isa MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE} + + # Slow path: fresh task + result = fetch( + Threads.@spawn begin + p = @inferred get_task_local_metal_pool() + p isa MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE} + end + ) + @test result == true + end + @testset "get_task_local_metal_pool" begin pool1 = get_task_local_metal_pool() @test pool1 isa MetalAdaptiveArrayPool @@ -13,7 +28,7 @@ @testset "get_task_local_metal_pools" begin pools_dict = get_task_local_metal_pools() - @test pools_dict isa Dict{UInt64, MetalAdaptiveArrayPool} + @test pools_dict isa Dict{UInt64, MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}} pool = get_task_local_metal_pool() dev_key = objectid(Metal.device()) @test haskey(pools_dict, dev_key) @@ -23,7 +38,7 @@ result = fetch( Threads.@spawn begin pools = get_task_local_metal_pools() - @test pools isa Dict{UInt64, MetalAdaptiveArrayPool} + @test pools isa Dict{UInt64, MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}} @test isempty(pools) true end diff --git a/test/test_task_local_pool.jl b/test/test_task_local_pool.jl index d4ef765..31d5b19 100644 --- a/test/test_task_local_pool.jl +++ b/test/test_task_local_pool.jl @@ -2,6 +2,21 @@ _test_accumulator = Ref(0.0) + @testset "get_task_local_pool type stability" begin + # Fast path: pool already exists in task-local storage + pool = @inferred get_task_local_pool() + @test pool isa AdaptiveArrayPool{RUNTIME_CHECK} + + # Slow path: fresh task creates a new pool + result = fetch( + Threads.@spawn begin + p = @inferred get_task_local_pool() + p isa AdaptiveArrayPool{RUNTIME_CHECK} + end + ) + @test result == true + end + @testset "@with_pool" begin # Define a function that takes pool as argument function global_test(n, pool)