Skip to content
Merged
6 changes: 5 additions & 1 deletion ext/AdaptiveArrayPoolsCUDAExt/convenience.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using AdaptiveArrayPools: DisabledPool
DISABLED_CUDA

Singleton instance for disabled CUDA pooling.
Used by macros when `USE_POOLING=false` with `:cuda` backend.
Used by macros when `STATIC_POOLING=false` with `:cuda` backend.
"""
const DISABLED_CUDA = DisabledPool{:cuda}()

Expand Down Expand Up @@ -82,6 +82,10 @@ AdaptiveArrayPools.default_eltype(::DisabledPool{:cuda}) = Float32
@inline AdaptiveArrayPools.unsafe_similar!(::DisabledPool{:cuda}, x::AbstractArray, dims::Vararg{Int, N}) where {N} = CuArray{eltype(x)}(undef, dims)
@inline AdaptiveArrayPools.unsafe_similar!(::DisabledPool{:cuda}, x::AbstractArray, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = CuArray{T}(undef, dims)

# --- reshape! for DisabledPool{:cuda} ---
@inline AdaptiveArrayPools.reshape!(::DisabledPool{:cuda}, A::AbstractArray, dims::Vararg{Int, N}) where {N} = reshape(A, dims...)
@inline AdaptiveArrayPools.reshape!(::DisabledPool{:cuda}, A::AbstractArray, dims::NTuple{N, Int}) where {N} = reshape(A, dims)

# --- acquire! for DisabledPool{:cuda} ---
@inline AdaptiveArrayPools.acquire!(::DisabledPool{:cuda}, ::Type{T}, n::Int) where {T} = CuVector{T}(undef, n)
@inline AdaptiveArrayPools.acquire!(::DisabledPool{:cuda}, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = CuArray{T, N}(undef, dims)
Expand Down
3 changes: 2 additions & 1 deletion src/AdaptiveArrayPools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ export zeros!, ones!, trues!, falses!, similar!, reshape!, default_eltype # Con
export unsafe_zeros!, unsafe_ones!, unsafe_similar! # Unsafe convenience functions
export Bit # Sentinel type for BitArray (use with acquire!, trues!, falses!)
export @with_pool, @maybe_with_pool
export USE_POOLING, MAYBE_POOLING_ENABLED, POOL_DEBUG
export STATIC_POOLING, MAYBE_POOLING, POOL_DEBUG
export USE_POOLING, MAYBE_POOLING_ENABLED # Deprecated aliases (backward compat)
export checkpoint!, rewind!, reset!
export get_task_local_cuda_pool, get_task_local_cuda_pools # CUDA (stubs, overridden by extension)

Expand Down
22 changes: 4 additions & 18 deletions src/acquire.jl
Original file line number Diff line number Diff line change
Expand Up @@ -553,9 +553,12 @@ const _acquire_view_impl! = _acquire_impl!
const _acquire_array_impl! = _unsafe_acquire_impl!

# ==============================================================================
# DisabledPool Acquire Fallbacks (pooling disabled with backend context)
# DisabledPool Fallbacks (pooling disabled with backend context)
# ==============================================================================

# DisabledPool has no internal state to track, so type touch is a no-op.
@inline _record_type_touch!(::DisabledPool, ::Type) = nothing

# --- acquire! for DisabledPool{:cpu} ---
@inline acquire!(::DisabledPool{:cpu}, ::Type{T}, n::Int) where {T} = Vector{T}(undef, n)
@inline acquire!(::DisabledPool{:cpu}, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = Array{T, N}(undef, dims)
Expand All @@ -567,20 +570,3 @@ const _acquire_array_impl! = _unsafe_acquire_impl!
@inline unsafe_acquire!(::DisabledPool{:cpu}, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = Array{T, N}(undef, dims)
@inline unsafe_acquire!(::DisabledPool{:cpu}, ::Type{T}, dims::NTuple{N, Int}) where {T, N} = Array{T, N}(undef, dims)
@inline unsafe_acquire!(::DisabledPool{:cpu}, x::AbstractArray) = similar(x)

# --- Generic DisabledPool fallbacks (unknown backend → error) ---
@inline acquire!(::DisabledPool{B}, _args...) where {B} = _throw_backend_not_loaded(B)
@inline unsafe_acquire!(::DisabledPool{B}, _args...) where {B} = _throw_backend_not_loaded(B)

# --- _impl! delegators for DisabledPool (macro transformation support) ---
# Called when: USE_POOLING=true + @maybe_with_pool + MAYBE_POOLING_ENABLED[]=false
# Explicit overloads for proper inlining (especially important for CUDA backend).
@inline _acquire_impl!(p::DisabledPool, ::Type{T}, n::Int) where {T} = acquire!(p, T, n)
@inline _acquire_impl!(p::DisabledPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = acquire!(p, T, dims...)
@inline _acquire_impl!(p::DisabledPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N} = acquire!(p, T, dims)
@inline _acquire_impl!(p::DisabledPool, x::AbstractArray) = acquire!(p, x)

@inline _unsafe_acquire_impl!(p::DisabledPool, ::Type{T}, n::Int) where {T} = unsafe_acquire!(p, T, n)
@inline _unsafe_acquire_impl!(p::DisabledPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = unsafe_acquire!(p, T, dims...)
@inline _unsafe_acquire_impl!(p::DisabledPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N} = unsafe_acquire!(p, T, dims)
@inline _unsafe_acquire_impl!(p::DisabledPool, x::AbstractArray) = unsafe_acquire!(p, x)
67 changes: 0 additions & 67 deletions src/convenience.jl
Original file line number Diff line number Diff line change
Expand Up @@ -669,70 +669,3 @@ end
@inline unsafe_similar!(::DisabledPool{:cpu}, x::AbstractArray, ::Type{T}) where {T} = similar(x, T)
@inline unsafe_similar!(::DisabledPool{:cpu}, x::AbstractArray, dims::Vararg{Int, N}) where {N} = similar(x, dims...)
@inline unsafe_similar!(::DisabledPool{:cpu}, x::AbstractArray, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = similar(x, T, dims...)

# --- Generic DisabledPool fallbacks (unknown backend → error) ---
@inline zeros!(p::DisabledPool{B}, args...) where {B} = _throw_backend_not_loaded(B)
@inline ones!(p::DisabledPool{B}, args...) where {B} = _throw_backend_not_loaded(B)
@inline trues!(p::DisabledPool{B}, args...) where {B} = _throw_backend_not_loaded(B)
@inline falses!(p::DisabledPool{B}, args...) where {B} = _throw_backend_not_loaded(B)
@inline similar!(p::DisabledPool{B}, args...) where {B} = _throw_backend_not_loaded(B)
@inline unsafe_zeros!(p::DisabledPool{B}, args...) where {B} = _throw_backend_not_loaded(B)
@inline unsafe_ones!(p::DisabledPool{B}, args...) where {B} = _throw_backend_not_loaded(B)
@inline unsafe_similar!(p::DisabledPool{B}, args...) where {B} = _throw_backend_not_loaded(B)
@inline reshape!(p::DisabledPool{B}, args...) where {B} = _throw_backend_not_loaded(B)

# ==============================================================================
# _impl! Delegators for DisabledPool
# ==============================================================================
# When macros transform zeros!(pool, ...) → _zeros_impl!(pool, ...),
# DisabledPool needs to delegate back to the public API.
#
# Called when: USE_POOLING=true + @maybe_with_pool + MAYBE_POOLING_ENABLED[]=false
# Explicit overloads for proper inlining (especially important for CUDA backend).

# --- _zeros_impl! ---
@inline _zeros_impl!(p::DisabledPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = zeros!(p, T, dims...)
@inline _zeros_impl!(p::DisabledPool, dims::Vararg{Int, N}) where {N} = zeros!(p, dims...)
@inline _zeros_impl!(p::DisabledPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N} = zeros!(p, T, dims)
@inline _zeros_impl!(p::DisabledPool, dims::NTuple{N, Int}) where {N} = zeros!(p, dims)

# --- _ones_impl! ---
@inline _ones_impl!(p::DisabledPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = ones!(p, T, dims...)
@inline _ones_impl!(p::DisabledPool, dims::Vararg{Int, N}) where {N} = ones!(p, dims...)
@inline _ones_impl!(p::DisabledPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N} = ones!(p, T, dims)
@inline _ones_impl!(p::DisabledPool, dims::NTuple{N, Int}) where {N} = ones!(p, dims)

# --- _trues_impl! ---
@inline _trues_impl!(p::DisabledPool, dims::Vararg{Int, N}) where {N} = trues!(p, dims...)
@inline _trues_impl!(p::DisabledPool, dims::NTuple{N, Int}) where {N} = trues!(p, dims)

# --- _falses_impl! ---
@inline _falses_impl!(p::DisabledPool, dims::Vararg{Int, N}) where {N} = falses!(p, dims...)
@inline _falses_impl!(p::DisabledPool, dims::NTuple{N, Int}) where {N} = falses!(p, dims)

# --- _similar_impl! ---
@inline _similar_impl!(p::DisabledPool, x::AbstractArray) = similar!(p, x)
@inline _similar_impl!(p::DisabledPool, x::AbstractArray, ::Type{T}) where {T} = similar!(p, x, T)
@inline _similar_impl!(p::DisabledPool, x::AbstractArray, dims::Vararg{Int, N}) where {N} = similar!(p, x, dims...)
@inline _similar_impl!(p::DisabledPool, x::AbstractArray, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = similar!(p, x, T, dims...)

# --- _reshape_impl! ---
@inline _reshape_impl!(p::DisabledPool, A::AbstractArray, dims::NTuple{N, Int}) where {N} = reshape!(p, A, dims)

# --- _unsafe_zeros_impl! ---
@inline _unsafe_zeros_impl!(p::DisabledPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = unsafe_zeros!(p, T, dims...)
@inline _unsafe_zeros_impl!(p::DisabledPool, dims::Vararg{Int, N}) where {N} = unsafe_zeros!(p, dims...)
@inline _unsafe_zeros_impl!(p::DisabledPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N} = unsafe_zeros!(p, T, dims)
@inline _unsafe_zeros_impl!(p::DisabledPool, dims::NTuple{N, Int}) where {N} = unsafe_zeros!(p, dims)

# --- _unsafe_ones_impl! ---
@inline _unsafe_ones_impl!(p::DisabledPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = unsafe_ones!(p, T, dims...)
@inline _unsafe_ones_impl!(p::DisabledPool, dims::Vararg{Int, N}) where {N} = unsafe_ones!(p, dims...)
@inline _unsafe_ones_impl!(p::DisabledPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N} = unsafe_ones!(p, T, dims)
@inline _unsafe_ones_impl!(p::DisabledPool, dims::NTuple{N, Int}) where {N} = unsafe_ones!(p, dims)

# --- _unsafe_similar_impl! ---
@inline _unsafe_similar_impl!(p::DisabledPool, x::AbstractArray) = unsafe_similar!(p, x)
@inline _unsafe_similar_impl!(p::DisabledPool, x::AbstractArray, ::Type{T}) where {T} = unsafe_similar!(p, x, T)
@inline _unsafe_similar_impl!(p::DisabledPool, x::AbstractArray, dims::Vararg{Int, N}) where {N} = unsafe_similar!(p, x, dims...)
@inline _unsafe_similar_impl!(p::DisabledPool, x::AbstractArray, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = unsafe_similar!(p, x, T, dims...)
22 changes: 4 additions & 18 deletions src/legacy/acquire.jl
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,12 @@ const _acquire_view_impl! = _acquire_impl!
const _acquire_array_impl! = _unsafe_acquire_impl!

# ==============================================================================
# DisabledPool Acquire Fallbacks (pooling disabled with backend context)
# DisabledPool Fallbacks (pooling disabled with backend context)
# ==============================================================================

# DisabledPool has no internal state to track, so type touch is a no-op.
@inline _record_type_touch!(::DisabledPool, ::Type) = nothing

# --- acquire! for DisabledPool{:cpu} ---
@inline acquire!(::DisabledPool{:cpu}, ::Type{T}, n::Int) where {T} = Vector{T}(undef, n)
@inline acquire!(::DisabledPool{:cpu}, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = Array{T, N}(undef, dims)
Expand All @@ -462,20 +465,3 @@ const _acquire_array_impl! = _unsafe_acquire_impl!
@inline unsafe_acquire!(::DisabledPool{:cpu}, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = Array{T, N}(undef, dims)
@inline unsafe_acquire!(::DisabledPool{:cpu}, ::Type{T}, dims::NTuple{N, Int}) where {T, N} = Array{T, N}(undef, dims)
@inline unsafe_acquire!(::DisabledPool{:cpu}, x::AbstractArray) = similar(x)

# --- Generic DisabledPool fallbacks (unknown backend → error) ---
@inline acquire!(::DisabledPool{B}, _args...) where {B} = _throw_backend_not_loaded(B)
@inline unsafe_acquire!(::DisabledPool{B}, _args...) where {B} = _throw_backend_not_loaded(B)

# --- _impl! delegators for DisabledPool (macro transformation support) ---
# Called when: USE_POOLING=true + @maybe_with_pool + MAYBE_POOLING_ENABLED[]=false
# Explicit overloads for proper inlining (especially important for CUDA backend).
@inline _acquire_impl!(p::DisabledPool, ::Type{T}, n::Int) where {T} = acquire!(p, T, n)
@inline _acquire_impl!(p::DisabledPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = acquire!(p, T, dims...)
@inline _acquire_impl!(p::DisabledPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N} = acquire!(p, T, dims)
@inline _acquire_impl!(p::DisabledPool, x::AbstractArray) = acquire!(p, x)

@inline _unsafe_acquire_impl!(p::DisabledPool, ::Type{T}, n::Int) where {T} = unsafe_acquire!(p, T, n)
@inline _unsafe_acquire_impl!(p::DisabledPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = unsafe_acquire!(p, T, dims...)
@inline _unsafe_acquire_impl!(p::DisabledPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N} = unsafe_acquire!(p, T, dims)
@inline _unsafe_acquire_impl!(p::DisabledPool, x::AbstractArray) = unsafe_acquire!(p, x)
2 changes: 1 addition & 1 deletion src/legacy/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ end

See also: [`pooling_enabled`](@ref), [`DISABLED_CPU`](@ref)
"""
struct DisabledPool{Backend} end
struct DisabledPool{Backend} <: AbstractArrayPool end

"""
DISABLED_CPU
Expand Down
69 changes: 44 additions & 25 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,9 @@ end
@maybe_with_pool pool_name expr
@maybe_with_pool expr

Conditionally enables pooling based on `MAYBE_POOLING_ENABLED[]`.
If disabled, `pool_name` becomes `nothing`, and `acquire!` falls back to standard allocation.
Conditionally enables pooling based on `MAYBE_POOLING[]`.
If disabled, `pool_name` is bound to a `DisabledPool` sentinel (e.g. `DISABLED_CPU` on CPU),
and `acquire!` falls back to standard allocation.

Useful for libraries that want to let users control pooling behavior at runtime.

Expand All @@ -146,7 +147,7 @@ end

## Block Usage
```julia
MAYBE_POOLING_ENABLED[] = false
MAYBE_POOLING[] = false
@maybe_with_pool pool begin
v = acquire!(pool, Float64, 100) # Falls back to Vector{Float64}(undef, 100)
end
Expand Down Expand Up @@ -309,7 +310,7 @@ end

function _generate_pool_code(pool_name, expr, force_enable; source::Union{LineNumberNode, Nothing} = nothing)
# Compile-time check: if pooling disabled, use DisabledPool to preserve backend context
if !USE_POOLING
if !STATIC_POOLING
disabled_pool = _disabled_pool_expr(:cpu)
if Meta.isexpr(expr, [:function, :(=)]) && _is_function_def(expr)
# Function definition: inject local pool = DisabledPool at start of body
Expand Down Expand Up @@ -370,7 +371,7 @@ function _generate_pool_code(pool_name, expr, force_enable; source::Union{LineNu
else
# Split branches completely to avoid Union boxing
return quote
if $MAYBE_POOLING_ENABLED[]
if $MAYBE_POOLING[]
local $(esc(pool_name)) = get_task_local_pool()
$checkpoint_call
try
Expand Down Expand Up @@ -404,10 +405,10 @@ Includes type-specific checkpoint/rewind optimization (same as regular @with_poo
"""
function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, force_enable::Bool; source::Union{LineNumberNode, Nothing} = nothing)
# Compile-time check: if pooling disabled, use DisabledPool to preserve backend context
if !USE_POOLING
if !STATIC_POOLING
disabled_pool = _disabled_pool_expr(backend)
if Meta.isexpr(expr, [:function, :(=)]) && _is_function_def(expr)
return _generate_function_pool_code_with_backend(backend, pool_name, expr, true; source)
return _generate_function_pool_code_with_backend(backend, pool_name, expr, force_enable, true; source)
else
return quote
local $(esc(pool_name)) = $disabled_pool
Expand All @@ -421,7 +422,7 @@ function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, forc
disabled_pool = _disabled_pool_expr(backend)
# Check if function definition
if Meta.isexpr(expr, [:function, :(=)]) && _is_function_def(expr)
return _generate_function_pool_code_with_backend(backend, pool_name, expr, false; source)
return _generate_function_pool_code_with_backend(backend, pool_name, expr, false, false; source)
end

# Block logic with runtime check
Expand All @@ -443,7 +444,7 @@ function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, forc
end

return quote
if $MAYBE_POOLING_ENABLED[]
if $MAYBE_POOLING[]
local $(esc(pool_name)) = $pool_getter
$checkpoint_call
try
Expand All @@ -464,7 +465,7 @@ function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, forc

# Check if function definition
if Meta.isexpr(expr, [:function, :(=)]) && _is_function_def(expr)
return _generate_function_pool_code_with_backend(backend, pool_name, expr, false; source)
return _generate_function_pool_code_with_backend(backend, pool_name, expr, true, false; source)
end

# Block logic: Extract types from acquire! calls for optimized checkpoint/rewind
Expand Down Expand Up @@ -510,12 +511,16 @@ function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, forc
end

"""
_generate_function_pool_code_with_backend(backend, pool_name, func_def, disable_pooling)
_generate_function_pool_code_with_backend(backend, pool_name, func_def, force_enable, disable_pooling)

Generate function code for a specific backend (e.g., :cuda).
Wraps the function body with pool getter, checkpoint, try-finally, rewind.

When `disable_pooling=true` (STATIC_POOLING=false), generates DisabledPool binding.
When `force_enable=true` (@with_pool), always uses the real pool.
When `force_enable=false` (@maybe_with_pool), generates MAYBE_POOLING[] runtime check.
"""
function _generate_function_pool_code_with_backend(backend::Symbol, pool_name, func_def, disable_pooling::Bool; source::Union{LineNumberNode, Nothing} = nothing)
function _generate_function_pool_code_with_backend(backend::Symbol, pool_name, func_def, force_enable::Bool, disable_pooling::Bool; source::Union{LineNumberNode, Nothing} = nothing)
def_head = func_def.head
call_expr = func_def.args[1]
body = func_def.args[2]
Expand Down Expand Up @@ -546,23 +551,37 @@ function _generate_function_pool_code_with_backend(backend::Symbol, pool_name, f

if use_typed
checkpoint_call = _generate_typed_checkpoint_call(esc(pool_name), static_types)
else
checkpoint_call = _generate_lazy_checkpoint_call(esc(pool_name))
end

if use_typed
rewind_call = _generate_typed_rewind_call(esc(pool_name), static_types)
else
checkpoint_call = _generate_lazy_checkpoint_call(esc(pool_name))
rewind_call = _generate_lazy_rewind_call(esc(pool_name))
end

new_body = quote
local $(esc(pool_name)) = $pool_getter
$checkpoint_call
try
$(esc(transformed_body))
finally
$rewind_call
if force_enable
new_body = quote
local $(esc(pool_name)) = $pool_getter
$checkpoint_call
try
$(esc(transformed_body))
finally
$rewind_call
end
end
else
disabled_pool = _disabled_pool_expr(backend)
new_body = quote
if $MAYBE_POOLING[]
local $(esc(pool_name)) = $pool_getter
$checkpoint_call
try
$(esc(transformed_body))
finally
$rewind_call
end
else
local $(esc(pool_name)) = $disabled_pool
$(esc(body))
end
end
end

Expand Down Expand Up @@ -623,7 +642,7 @@ function _generate_function_pool_code(pool_name, func_def, force_enable, disable
else
disabled_pool = _disabled_pool_expr(backend)
new_body = quote
if $MAYBE_POOLING_ENABLED[]
if $MAYBE_POOLING[]
local $(esc(pool_name)) = get_task_local_pool()
$checkpoint_call
try
Expand Down
Loading