Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions ext/AdaptiveArrayPoolsCUDAExt/task_local_pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ a dictionary of pools (one per device) in task-local storage, ensuring that:
- Switching devices (`CUDA.device!(n)`) gets the correct pool

## Implementation
Uses `Dict{Int, CuAdaptiveArrayPool}` in task-local storage, keyed by device ID.
Values are `CuAdaptiveArrayPool{S}` where S is determined by `RUNTIME_CHECK`.
Uses `Dict{Int, CuAdaptiveArrayPool{RUNTIME_CHECK}}` in task-local storage, keyed by device ID.
"""
@inline function AdaptiveArrayPools.get_task_local_cuda_pool()
# 1. Get or create the pools dictionary
pools = get(task_local_storage(), _CU_POOL_KEY, nothing)
if pools === nothing
pools = Dict{Int, CuAdaptiveArrayPool}()
pools = Dict{Int, CuAdaptiveArrayPool{RUNTIME_CHECK}}()
task_local_storage(_CU_POOL_KEY, pools)
end

Expand All @@ -39,19 +38,19 @@ Values are `CuAdaptiveArrayPool{S}` where S is determined by `RUNTIME_CHECK`.
pools[dev_id] = pool
end

return pool::CuAdaptiveArrayPool
return pool::CuAdaptiveArrayPool{RUNTIME_CHECK}
end

"""
get_task_local_cuda_pools() -> Dict{Int, CuAdaptiveArrayPool}
get_task_local_cuda_pools() -> Dict{Int, CuAdaptiveArrayPool{RUNTIME_CHECK}}

Returns the dictionary of all CUDA pools for the current task (one per device).
Useful for diagnostics or bulk operations across all devices.
"""
@inline function AdaptiveArrayPools.get_task_local_cuda_pools()
pools = get(task_local_storage(), _CU_POOL_KEY, nothing)
if pools === nothing
pools = Dict{Int, CuAdaptiveArrayPool}()
pools = Dict{Int, CuAdaptiveArrayPool{RUNTIME_CHECK}}()
task_local_storage(_CU_POOL_KEY, pools)
end
return pools
Expand Down
11 changes: 5 additions & 6 deletions ext/AdaptiveArrayPoolsMetalExt/task_local_pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ a dictionary of pools (one per device) in task-local storage, ensuring that:
- Switching devices gets the correct pool

## Implementation
Uses `Dict{UInt64, MetalAdaptiveArrayPool}` in task-local storage, keyed by device hash.
Values are `MetalAdaptiveArrayPool{R,S}` where R is determined by `RUNTIME_CHECK`.
Uses `Dict{UInt64, MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}}` in task-local storage, keyed by device hash.
"""
@inline function AdaptiveArrayPools.get_task_local_metal_pool()
# 1. Get or create the pools dictionary
pools = get(task_local_storage(), _METAL_POOL_KEY, nothing)
if pools === nothing
pools = Dict{UInt64, MetalAdaptiveArrayPool}()
pools = Dict{UInt64, MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}}()
task_local_storage(_METAL_POOL_KEY, pools)
end

Expand All @@ -40,19 +39,19 @@ Values are `MetalAdaptiveArrayPool{R,S}` where R is determined by `RUNTIME_CHECK
pools[dev_key] = pool
end

return pool::MetalAdaptiveArrayPool
return pool::MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}
end

"""
get_task_local_metal_pools() -> Dict{UInt64, MetalAdaptiveArrayPool}
get_task_local_metal_pools() -> Dict{UInt64, MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}}

Returns the dictionary of all Metal pools for the current task (one per device).
Useful for diagnostics or bulk operations across all devices.
"""
@inline function AdaptiveArrayPools.get_task_local_metal_pools()
pools = get(task_local_storage(), _METAL_POOL_KEY, nothing)
if pools === nothing
pools = Dict{UInt64, MetalAdaptiveArrayPool}()
pools = Dict{UInt64, MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}}()
task_local_storage(_METAL_POOL_KEY, pools)
end
return pools
Expand Down
4 changes: 2 additions & 2 deletions src/task_local_pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ const MAYBE_POOLING_ENABLED = MAYBE_POOLING
const _POOL_KEY = :ADAPTIVE_ARRAY_POOL

"""
get_task_local_pool() -> AdaptiveArrayPool
get_task_local_pool() -> AdaptiveArrayPool{RUNTIME_CHECK}

Retrieves (or creates) the `AdaptiveArrayPool` for the current Task.

Expand All @@ -98,7 +98,7 @@ type-asserts directly to `AdaptiveArrayPool{RUNTIME_CHECK}`.
task_local_storage(_POOL_KEY, pool)
end

return pool::AdaptiveArrayPool
return pool::AdaptiveArrayPool{RUNTIME_CHECK}
end

# ==============================================================================
Expand Down
19 changes: 17 additions & 2 deletions test/cuda/test_extension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,21 @@ end
end

@testset "Task-Local Pool" begin
@testset "get_task_local_cuda_pool type stability" begin
# Fast path: pool already exists in task-local storage
pool = @inferred get_task_local_cuda_pool()
@test pool isa CuAdaptiveArrayPool{RUNTIME_CHECK}

# Slow path: fresh task creates a new pool
result = fetch(
Threads.@spawn begin
p = @inferred get_task_local_cuda_pool()
p isa CuAdaptiveArrayPool{RUNTIME_CHECK}
end
)
@test result == true
end

@testset "get_task_local_cuda_pool" begin
pool1 = get_task_local_cuda_pool()
@test pool1 isa CuAdaptiveArrayPool
Expand All @@ -103,7 +118,7 @@ end

@testset "get_task_local_cuda_pools" begin
pools_dict = get_task_local_cuda_pools()
@test pools_dict isa Dict{Int, CuAdaptiveArrayPool}
@test pools_dict isa Dict{Int, CuAdaptiveArrayPool{RUNTIME_CHECK}}
pool = get_task_local_cuda_pool()
@test haskey(pools_dict, pool.device_id)
end
Expand All @@ -114,7 +129,7 @@ end
Threads.@spawn begin
# Call get_task_local_cuda_pools() FIRST (before get_task_local_cuda_pool)
pools = get_task_local_cuda_pools()
@test pools isa Dict{Int, CuAdaptiveArrayPool}
@test pools isa Dict{Int, CuAdaptiveArrayPool{RUNTIME_CHECK}}
@test isempty(pools) # No pools created yet
true
end
Expand Down
1 change: 1 addition & 0 deletions test/metal/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ else
const MetalTypedPool = ext.MetalTypedPool
const MetalAdaptiveArrayPool = ext.MetalAdaptiveArrayPool
const METAL_FIXED_SLOT_FIELDS = ext.METAL_FIXED_SLOT_FIELDS
const METAL_STORAGE = ext.METAL_STORAGE

# Include all Metal test files
include("test_extension.jl")
Expand Down
19 changes: 17 additions & 2 deletions test/metal/test_task_local_pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,21 @@

@testset "Metal Task-Local Pool" begin

@testset "get_task_local_metal_pool type stability" begin
# Fast path: pool already exists
pool = @inferred get_task_local_metal_pool()
@test pool isa MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}

# Slow path: fresh task
result = fetch(
Threads.@spawn begin
p = @inferred get_task_local_metal_pool()
p isa MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}
end
)
@test result == true
end

@testset "get_task_local_metal_pool" begin
pool1 = get_task_local_metal_pool()
@test pool1 isa MetalAdaptiveArrayPool
Expand All @@ -13,7 +28,7 @@

@testset "get_task_local_metal_pools" begin
pools_dict = get_task_local_metal_pools()
@test pools_dict isa Dict{UInt64, MetalAdaptiveArrayPool}
@test pools_dict isa Dict{UInt64, MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}}
pool = get_task_local_metal_pool()
dev_key = objectid(Metal.device())
@test haskey(pools_dict, dev_key)
Expand All @@ -23,7 +38,7 @@
result = fetch(
Threads.@spawn begin
pools = get_task_local_metal_pools()
@test pools isa Dict{UInt64, MetalAdaptiveArrayPool}
@test pools isa Dict{UInt64, MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}}
@test isempty(pools)
true
end
Expand Down
15 changes: 15 additions & 0 deletions test/test_task_local_pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,21 @@

_test_accumulator = Ref(0.0)

@testset "get_task_local_pool type stability" begin
# Fast path: pool already exists in task-local storage
pool = @inferred get_task_local_pool()
@test pool isa AdaptiveArrayPool{RUNTIME_CHECK}

# Slow path: fresh task creates a new pool
result = fetch(
Threads.@spawn begin
p = @inferred get_task_local_pool()
p isa AdaptiveArrayPool{RUNTIME_CHECK}
end
)
@test result == true
end

@testset "@with_pool" begin
# Define a function that takes pool as argument
function global_test(n, pool)
Expand Down
Loading