From 649d68b13412fda746f271990f0aa2992bfd2e32 Mon Sep 17 00:00:00 2001 From: Min-Gu Yoo Date: Fri, 13 Mar 2026 23:22:00 -0700 Subject: [PATCH 1/4] feat: Metal.jl (Apple Silicon GPU) backend support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Standalone Metal extension with full parity to CUDA: checkpoint/rewind, lazy mode, typed-lazy mode, safety (poisoning/escape detection/borrow tracking), @with_pool :metal, DisabledPool{:metal}, and task-local multi-device pools. No shared GPU common layer — each backend is self-contained to preserve type stability and zero-alloc guarantees. --- Project.toml | 3 + .../AdaptiveArrayPoolsMetalExt.jl | 41 ++ ext/AdaptiveArrayPoolsMetalExt/acquire.jl | 338 ++++++++++++ ext/AdaptiveArrayPoolsMetalExt/convenience.jl | 78 +++ ext/AdaptiveArrayPoolsMetalExt/debug.jl | 263 +++++++++ ext/AdaptiveArrayPoolsMetalExt/dispatch.jl | 49 ++ ext/AdaptiveArrayPoolsMetalExt/macros.jl | 27 + ext/AdaptiveArrayPoolsMetalExt/state.jl | 336 ++++++++++++ .../task_local_pool.jl | 59 ++ ext/AdaptiveArrayPoolsMetalExt/types.jl | 171 ++++++ ext/AdaptiveArrayPoolsMetalExt/utils.jl | 158 ++++++ src/AdaptiveArrayPools.jl | 1 + src/task_local_pool.jl | 24 + src/utils.jl | 14 +- test/Project.toml | 1 + test/metal/runtests.jl | 52 ++ test/metal/test_allocation.jl | 261 +++++++++ test/metal/test_convenience.jl | 124 +++++ test/metal/test_disabled_pool.jl | 192 +++++++ test/metal/test_display.jl | 204 +++++++ test/metal/test_extension.jl | 504 +++++++++++++++++ test/metal/test_metal_safety.jl | 509 ++++++++++++++++++ test/metal/test_reshape.jl | 88 +++ test/metal/test_task_local_pool.jl | 47 ++ test/runtests.jl | 7 + 25 files changed, 3545 insertions(+), 6 deletions(-) create mode 100644 ext/AdaptiveArrayPoolsMetalExt/AdaptiveArrayPoolsMetalExt.jl create mode 100644 ext/AdaptiveArrayPoolsMetalExt/acquire.jl create mode 100644 ext/AdaptiveArrayPoolsMetalExt/convenience.jl create mode 100644 ext/AdaptiveArrayPoolsMetalExt/debug.jl create mode 100644 ext/AdaptiveArrayPoolsMetalExt/dispatch.jl create mode 100644 ext/AdaptiveArrayPoolsMetalExt/macros.jl create mode 100644 ext/AdaptiveArrayPoolsMetalExt/state.jl create mode 100644 ext/AdaptiveArrayPoolsMetalExt/task_local_pool.jl create mode 100644 ext/AdaptiveArrayPoolsMetalExt/types.jl create mode 100644 ext/AdaptiveArrayPoolsMetalExt/utils.jl create mode 100644 test/metal/runtests.jl create mode 100644 test/metal/test_allocation.jl create mode 100644 test/metal/test_convenience.jl create mode 100644 test/metal/test_disabled_pool.jl create mode 100644 test/metal/test_display.jl create mode 100644 test/metal/test_extension.jl create mode 100644 test/metal/test_metal_safety.jl create mode 100644 test/metal/test_reshape.jl create mode 100644 test/metal/test_task_local_pool.jl diff --git a/Project.toml b/Project.toml index 4d43720b..093e1bee 100644 --- a/Project.toml +++ b/Project.toml @@ -9,12 +9,15 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" [extensions] AdaptiveArrayPoolsCUDAExt = "CUDA" +AdaptiveArrayPoolsMetalExt = "Metal" [compat] CUDA = "5" +Metal = "1" Preferences = "1" Printf = "1" julia = "1.10" diff --git a/ext/AdaptiveArrayPoolsMetalExt/AdaptiveArrayPoolsMetalExt.jl b/ext/AdaptiveArrayPoolsMetalExt/AdaptiveArrayPoolsMetalExt.jl new file mode 100644 index 00000000..7486f3be --- /dev/null +++ b/ext/AdaptiveArrayPoolsMetalExt/AdaptiveArrayPoolsMetalExt.jl @@ -0,0 +1,41 @@ +""" + AdaptiveArrayPoolsMetalExt + +Metal extension for AdaptiveArrayPools.jl. Provides GPU memory pooling +with the same checkpoint/rewind semantics as CPU pools. + +Loaded automatically when `using Metal` with AdaptiveArrayPools. + +Supports Metal.PrivateStorage only. Default element type is Float32. +Explicitly unsupported: Float64, ComplexF64. +""" +module AdaptiveArrayPoolsMetalExt + +using AdaptiveArrayPools +using Metal + +# GPU pooling requires Julia 1.11+ (setfield!-based Array, arr_wrappers cache). +# On older Julia, the extension loads but provides no functionality. +@static if VERSION >= v"1.11-" + + using AdaptiveArrayPools: AbstractTypedPool, AbstractArrayPool + using Metal.GPUArrays + + include("types.jl") + include("dispatch.jl") + include("acquire.jl") + include("task_local_pool.jl") + include("state.jl") + include("debug.jl") + include("utils.jl") + include("macros.jl") + include("convenience.jl") + + export MetalTypedPool, MetalAdaptiveArrayPool + export METAL_FIXED_SLOT_FIELDS + +else + @warn "AdaptiveArrayPoolsMetalExt requires Julia 1.11+. GPU pooling is disabled." maxlog = 1 +end # @static if + +end # module diff --git a/ext/AdaptiveArrayPoolsMetalExt/acquire.jl b/ext/AdaptiveArrayPoolsMetalExt/acquire.jl new file mode 100644 index 00000000..a1f1909a --- /dev/null +++ b/ext/AdaptiveArrayPoolsMetalExt/acquire.jl @@ -0,0 +1,338 @@ +# ============================================================================== +# Metal-Specific Acquire Implementation (arr_wrappers + setfield!) +# ============================================================================== +# Mirrors CUDA acquire.jl — fully self-contained, no shared GPU common layer. +# +# Key differences from CUDA: +# - MtlArray uses sizeof(T) for capacity (no aligned_sizeof) +# - MtlArray{T,N,S} carries storage mode S +# - DataRef access via getfield (same .data.rc identity check as CUDA) +# +# ⚠ Depends on MtlArray internal fields (:data, :maxsize, :offset, :dims). +# Tested with Metal.jl v1.x. +# ============================================================================== + +using AdaptiveArrayPools: get_view!, get_array!, allocate_vector, safe_prod, + _record_type_touch!, _fixed_slot_bit, _checkpoint_typed_pool!, + _store_arr_wrapper!, _check_pool_growth, _reshape_impl!, + _acquire_impl!, _acquire_view_impl!, _maybe_record_borrow!, + _MODE_BITS_MASK + +using Metal.GPUArrays: unsafe_free! + +# Guard against Metal.jl internal API changes +@static if !( + ismutabletype(MtlArray) && + hasfield(MtlArray, :data) && + hasfield(MtlArray, :maxsize) && + hasfield(MtlArray, :offset) && + hasfield(MtlArray, :dims) + ) + error("Unsupported Metal.jl version: MtlArray must be mutable with :data, :maxsize, :offset, :dims.") +end + +# Verify DataRef has .rc field +let DataRefT = fieldtype(MtlArray{Float32, 1, Metal.PrivateStorage}, :data) + if !hasfield(DataRefT, :rc) + error("Unsupported Metal.jl version: DataRef must have :rc field for storage identity check.") + end +end + +# ============================================================================== +# _resize_to_fit! — Capacity-Aware Resize for Metal +# ============================================================================== + +""" + _resize_to_fit!(A::MtlArray{T,1,S}, n::Integer) -> MtlArray{T,1,S} + +Resize a MtlVector's logical length, using `setfield!(:dims)` when within capacity. + +- `n > capacity`: delegates to `resize!(A, n)` (may grow GPU allocation) +- `n <= capacity, n != length(A)`: `setfield!(:dims)` only — no GPU operation +- `n == length(A)`: no-op + +Capacity = `A.maxsize / sizeof(T)`. Since `setfield!(:dims)` preserves +`maxsize`, capacity information is naturally retained across shrink/grow cycles. +""" +@inline function _resize_to_fit!(A::MtlArray{T, 1, S}, n::Integer) where {T, S} + cap = getfield(A, :maxsize) ÷ sizeof(T) + if n > cap + resize!(A, n) + elseif n != length(A) + setfield!(A, :dims, (Int(n),)) + end + return A +end + +# ============================================================================== +# _metal_claim_slot! — Capacity-Based Slot Claim +# ============================================================================== + +""" + _metal_claim_slot!(tp::MetalTypedPool{T,S}, total_len::Int) -> Int + +Claim the next slot, ensuring the backing vector's GPU buffer has capacity >= `total_len`. +Uses maxsize-based capacity check instead of length check to avoid triggering +Metal.jl's resize! unnecessarily (especially after safety invalidation sets dims=(0,)). +""" +@inline function _metal_claim_slot!(tp::MetalTypedPool{T, S}, total_len::Int) where {T, S} + tp.n_active += 1 + idx = tp.n_active + if idx > length(tp.vectors) + push!(tp.vectors, allocate_vector(tp, total_len)) + _check_pool_growth(tp, idx) + else + _resize_to_fit!(@inbounds(tp.vectors[idx]), total_len) + end + return idx +end + +""" + _metal_claim_slot!(tp::MetalTypedPool{T,S}) -> Int + +Claim the next slot without provisioning memory (zero-length backing vector). +Used by `_reshape_impl!` which only needs the slot index for wrapper caching — +the wrapper points to a different array's memory via `setfield!(:data)`. +""" +@inline function _metal_claim_slot!(tp::MetalTypedPool{T, S}) where {T, S} + tp.n_active += 1 + idx = tp.n_active + if idx > length(tp.vectors) + push!(tp.vectors, MtlArray{T, 1, S}(undef, 0)) + _check_pool_growth(tp, idx) + end + return idx +end + +# ============================================================================== +# _update_metal_wrapper_data! — DataRef Refcount Management +# ============================================================================== + +""" + _update_metal_wrapper_data!(wrapper::MtlArray, source::MtlArray) + +Update wrapper's GPU data reference when the source's buffer has changed. +Decrements old refcount, increments new. @noinline: rare path (only on grow +beyond capacity), keep off the hot inlined acquire path. +""" +@noinline function _update_metal_wrapper_data!(wrapper::MtlArray, source::MtlArray) + unsafe_free!(getfield(wrapper, :data)) + setfield!(wrapper, :data, copy(getfield(source, :data))) + setfield!(wrapper, :maxsize, getfield(source, :maxsize)) + setfield!(wrapper, :offset, getfield(source, :offset)) + return nothing +end + +# ============================================================================== +# _acquire_impl! / _acquire_view_impl! — Direct get_array! Dispatch +# ============================================================================== +# On Metal, both acquire! and acquire_view! go through get_array! directly. +# No view/array distinction — MtlArray is always returned. + +""" + _acquire_impl!(pool::MetalAdaptiveArrayPool, T, n) -> MtlArray{T,1,S} + _acquire_impl!(pool::MetalAdaptiveArrayPool, T, dims...) -> MtlArray{T,N,S} + +Metal override: routes directly to `get_array!` (no view indirection). +""" +@inline function AdaptiveArrayPools._acquire_impl!(pool::MetalAdaptiveArrayPool, ::Type{T}, n::Int) where {T} + tp = get_typed_pool!(pool, T) + result = get_array!(tp, (n,)) + _maybe_record_borrow!(pool, tp) + return result +end + +@inline function AdaptiveArrayPools._acquire_impl!(pool::MetalAdaptiveArrayPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N} + tp = get_typed_pool!(pool, T) + result = get_array!(tp, dims) + _maybe_record_borrow!(pool, tp) + return result +end + +@inline function AdaptiveArrayPools._acquire_impl!(pool::MetalAdaptiveArrayPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N} + return _acquire_impl!(pool, T, dims...) +end + +""" + _acquire_view_impl!(pool::MetalAdaptiveArrayPool, T, dims...) -> MtlArray{T,N,S} + +Metal override: same as `_acquire_impl!` — Metal has no view/array distinction. +""" +@inline function AdaptiveArrayPools._acquire_view_impl!(pool::MetalAdaptiveArrayPool, ::Type{T}, n::Int) where {T} + return _acquire_impl!(pool, T, n) +end + +@inline function AdaptiveArrayPools._acquire_view_impl!(pool::MetalAdaptiveArrayPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N} + return _acquire_impl!(pool, T, dims...) +end + +@inline function AdaptiveArrayPools._acquire_view_impl!(pool::MetalAdaptiveArrayPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N} + return _acquire_impl!(pool, T, dims...) +end + +# ============================================================================== +# get_view! / get_array! — arr_wrappers + setfield! Based Zero-Alloc +# ============================================================================== +# get_view! delegates to get_array! for backward compat (e.g., direct get_view! calls). +# The main acquire path now bypasses get_view! entirely via _acquire_impl! above. + +@inline function AdaptiveArrayPools.get_view!(tp::MetalTypedPool{T, S}, n::Int) where {T, S} + return get_array!(tp, (n,)) +end + +@inline function AdaptiveArrayPools.get_view!(tp::MetalTypedPool{T, S}, dims::NTuple{N, Int}) where {T, S, N} + return get_array!(tp, dims) +end + +""" + get_array!(tp::MetalTypedPool{T,S}, dims::NTuple{N,Int}) -> MtlArray{T,N,S} + +Get an N-dimensional `MtlArray` from the pool with `setfield!`-based wrapper reuse. + +## Cache Hit (common case, 0-alloc) +1. Look up `arr_wrappers[N][slot]` +2. Check `wrapper.data.rc !== vec.data.rc` — if same GPU buffer, just `setfield!(:dims)` +3. If different (rare: only after grow beyond capacity), update `:data` via refcount management + +## Cache Miss (first call per (slot, N)) +Creates MtlArray wrapper sharing backing vector's GPU memory via `copy(vec.data)`, +stores in `arr_wrappers[N][slot]` via `_store_arr_wrapper!` (reuses base module helper). +""" +@inline function AdaptiveArrayPools.get_array!(tp::MetalTypedPool{T, S}, dims::NTuple{N, Int}) where {T, S, N} + total_len = safe_prod(dims) + slot = _metal_claim_slot!(tp, total_len) + @inbounds vec = tp.vectors[slot] + + # arr_wrappers lookup (direct index, no hash — same as CPU/CUDA path) + wrappers = N <= length(tp.arr_wrappers) ? (@inbounds tp.arr_wrappers[N]) : nothing + if wrappers !== nothing && slot <= length(wrappers) + wrapper = @inbounds wrappers[slot] + if wrapper !== nothing + mtl = wrapper::MtlArray{T, N, S} + # Check if backing vec's GPU buffer changed (rare: only on grow beyond capacity) + if getfield(mtl, :data).rc !== getfield(vec, :data).rc + _update_metal_wrapper_data!(mtl, vec) + end + setfield!(mtl, :dims, dims) + return mtl + end + end + + # Cache miss: create wrapper sharing vec's GPU memory + mtl = MtlArray{T, N, S}( + copy(getfield(vec, :data)), dims; + maxsize = getfield(vec, :maxsize), + offset = getfield(vec, :offset), + ) + _store_arr_wrapper!(tp, N, slot, mtl) + return mtl +end + +# ============================================================================== +# _reshape_impl! for MtlArray — Zero-Alloc Reshape +# ============================================================================== + +""" + _reshape_impl!(pool::MetalAdaptiveArrayPool, A::MtlArray{T,M,S}, dims::NTuple{N,Int}) -> MtlArray{T,N,S} + +Zero-allocation reshape for MtlArray using `setfield!`-based wrapper reuse. + +- **Same dimensionality (M == N)**: `setfield!(A, :dims, dims)` — no pool interaction +- **Different dimensionality (M != N)**: Claims a pool slot, reuses cached `MtlArray{T,N,S}` + wrapper with `setfield!(:dims)` pointing to `A`'s GPU memory. +""" +@inline function AdaptiveArrayPools._reshape_impl!( + pool::MetalAdaptiveArrayPool, A::MtlArray{T, M, S}, dims::NTuple{N, Int} + ) where {T, M, S, N} + for d in dims + d < 0 && throw(ArgumentError("invalid MtlArray dimensions")) + end + total_len = safe_prod(dims) + length(A) == total_len || throw( + DimensionMismatch( + "new dimensions $(dims) must be consistent with array length $(length(A))" + ) + ) + + # 0-D reshape: rare edge case, delegate to Base (arr_wrappers is 1-indexed by N) + N == 0 && return reshape(A, dims) + + # Same dimensionality: just update dims in-place, no pool interaction + if M == N + setfield!(A, :dims, dims) + return A + end + + # Different dimensionality: claim slot + reuse cached N-D wrapper + tp = AdaptiveArrayPools.get_typed_pool!(pool, T) + _record_type_touch!(pool, T) + slot = _metal_claim_slot!(tp) + + # Look up cached wrapper (direct index, no hash) + wrappers = N <= length(tp.arr_wrappers) ? (@inbounds tp.arr_wrappers[N]) : nothing + if wrappers !== nothing && slot <= length(wrappers) + wrapper = @inbounds wrappers[slot] + if wrapper !== nothing + mtl = wrapper::MtlArray{T, N, S} + if getfield(mtl, :data).rc !== getfield(A, :data).rc + _update_metal_wrapper_data!(mtl, A) + end + setfield!(mtl, :dims, dims) + setfield!(mtl, :offset, getfield(A, :offset)) + return mtl + end + end + + # Cache miss (first call per slot+N): create wrapper, cache forever + mtl = MtlArray{T, N, S}( + copy(getfield(A, :data)), dims; + maxsize = getfield(A, :maxsize), + offset = getfield(A, :offset), + ) + _store_arr_wrapper!(tp, N, slot, mtl) + return mtl +end + +# ============================================================================== +# Metal _record_type_touch! override +# ============================================================================== +# Float16 on Metal: direct struct field with _fixed_slot_bit(Float16)=0. +# We track Float16 via bit 7 (Metal reassignment; CPU uses bit 7 for Bit type, +# absent on GPU). This gives Float16 lazy first-touch checkpointing in bit-14 +# (typed lazy) and bit-15 (dynamic) modes, ensuring Case A (not Case B) fires +# at rewind and parent n_active is preserved. + +@inline function AdaptiveArrayPools._record_type_touch!(pool::MetalAdaptiveArrayPool, ::Type{T}) where {T} + depth = pool._current_depth + b = _fixed_slot_bit(T) + if b == UInt16(0) + if T === Float16 + # Float16: Metal direct field tracked via bit 7 (not in pool.others dict). + b16 = UInt16(1) << 7 + current_mask = @inbounds pool._touched_type_masks[depth] + # Lazy first-touch checkpoint: bit 14 (typed lazy) OR bit 15 (dynamic), first touch only. + # Guard: skip if already checkpointed at this depth (prevents double-push). + if (current_mask & _MODE_BITS_MASK) != 0 && (current_mask & b16) == 0 + if @inbounds(pool.float16._checkpoint_depths[end]) != depth + _checkpoint_typed_pool!(pool.float16, depth) + end + end + @inbounds pool._touched_type_masks[depth] = current_mask | b16 + else + # Genuine others type (UInt8, Int8, etc.) — eagerly snapshotted at scope entry. + @inbounds pool._touched_has_others[depth] = true + end + else + current_mask = @inbounds pool._touched_type_masks[depth] + # Lazy first-touch checkpoint for fixed-slot types in bit 14/15 modes. + # Guard: skip if already checkpointed at this depth (prevents double-push). + if (current_mask & _MODE_BITS_MASK) != 0 && (current_mask & b) == 0 + tp = AdaptiveArrayPools.get_typed_pool!(pool, T) + if @inbounds(tp._checkpoint_depths[end]) != depth + _checkpoint_typed_pool!(tp, depth) + end + end + @inbounds pool._touched_type_masks[depth] = current_mask | b + end + return nothing +end diff --git a/ext/AdaptiveArrayPoolsMetalExt/convenience.jl b/ext/AdaptiveArrayPoolsMetalExt/convenience.jl new file mode 100644 index 00000000..852a8b22 --- /dev/null +++ b/ext/AdaptiveArrayPoolsMetalExt/convenience.jl @@ -0,0 +1,78 @@ +# ============================================================================== +# Metal Default Element Type +# ============================================================================== +# Metal pools default to Float32 (matching Metal.zeros() behavior). +# All convenience functions (zeros!, ones!, etc.) dispatch through _*_impl! +# which calls default_eltype(pool) for the default type. + +""" + default_eltype(::MetalAdaptiveArrayPool) -> Type + +Returns `Float32` as the default element type for Metal pools. +This matches Metal GPU convention. +""" +AdaptiveArrayPools.default_eltype(::MetalAdaptiveArrayPool) = Float32 + +# ============================================================================== +# DisabledPool{:metal} Fallbacks +# ============================================================================== +# When pooling is disabled but :metal backend is specified, these methods ensure +# proper MtlArray allocation instead of falling back to CPU arrays. + +using AdaptiveArrayPools: DisabledPool + +""" + DISABLED_METAL + +Singleton instance for disabled Metal pooling. +Used by macros when `STATIC_POOLING=false` with `:metal` backend. +""" +const DISABLED_METAL = DisabledPool{:metal}() + +""" + default_eltype(::DisabledPool{:metal}) -> Float32 + +Default element type for disabled Metal pools (matches Metal convention). +""" +AdaptiveArrayPools.default_eltype(::DisabledPool{:metal}) = Float32 + +# --- zeros! for DisabledPool{:metal} --- +@inline AdaptiveArrayPools.zeros!(::DisabledPool{:metal}, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = MtlArray(zeros(T, dims...)) +@inline AdaptiveArrayPools.zeros!(p::DisabledPool{:metal}, dims::Vararg{Int, N}) where {N} = MtlArray(zeros(AdaptiveArrayPools.default_eltype(p), dims...)) +@inline AdaptiveArrayPools.zeros!(::DisabledPool{:metal}, ::Type{T}, dims::NTuple{N, Int}) where {T, N} = MtlArray(zeros(T, dims...)) +@inline AdaptiveArrayPools.zeros!(p::DisabledPool{:metal}, dims::NTuple{N, Int}) where {N} = MtlArray(zeros(AdaptiveArrayPools.default_eltype(p), dims...)) + +# --- ones! for DisabledPool{:metal} --- +@inline AdaptiveArrayPools.ones!(::DisabledPool{:metal}, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = MtlArray(ones(T, dims...)) +@inline AdaptiveArrayPools.ones!(p::DisabledPool{:metal}, dims::Vararg{Int, N}) where {N} = MtlArray(ones(AdaptiveArrayPools.default_eltype(p), dims...)) +@inline AdaptiveArrayPools.ones!(::DisabledPool{:metal}, ::Type{T}, dims::NTuple{N, Int}) where {T, N} = MtlArray(ones(T, dims...)) +@inline AdaptiveArrayPools.ones!(p::DisabledPool{:metal}, dims::NTuple{N, Int}) where {N} = MtlArray(ones(AdaptiveArrayPools.default_eltype(p), dims...)) + +# --- similar! for DisabledPool{:metal} --- +@inline AdaptiveArrayPools.similar!(::DisabledPool{:metal}, x::MtlArray) = Metal.similar(x) +@inline AdaptiveArrayPools.similar!(::DisabledPool{:metal}, x::MtlArray, ::Type{T}) where {T} = Metal.similar(x, T) +@inline AdaptiveArrayPools.similar!(::DisabledPool{:metal}, x::MtlArray, dims::Vararg{Int, N}) where {N} = Metal.similar(x, dims...) +@inline AdaptiveArrayPools.similar!(::DisabledPool{:metal}, x::MtlArray, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = Metal.similar(x, T, dims...) +# Fallback for non-MtlArray inputs (creates MtlArray from AbstractArray) +@inline AdaptiveArrayPools.similar!(::DisabledPool{:metal}, x::AbstractArray) = MtlArray{eltype(x)}(undef, size(x)) +@inline AdaptiveArrayPools.similar!(::DisabledPool{:metal}, x::AbstractArray, ::Type{T}) where {T} = MtlArray{T}(undef, size(x)) +@inline AdaptiveArrayPools.similar!(::DisabledPool{:metal}, x::AbstractArray, dims::Vararg{Int, N}) where {N} = MtlArray{eltype(x)}(undef, dims) +@inline AdaptiveArrayPools.similar!(::DisabledPool{:metal}, x::AbstractArray, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = MtlArray{T}(undef, dims) + +# --- reshape! for DisabledPool{:metal} --- +@inline AdaptiveArrayPools.reshape!(::DisabledPool{:metal}, A::AbstractArray, dims::Vararg{Int, N}) where {N} = reshape(A, dims...) +@inline AdaptiveArrayPools.reshape!(::DisabledPool{:metal}, A::AbstractArray, dims::NTuple{N, Int}) where {N} = reshape(A, dims) + +# --- acquire! for DisabledPool{:metal} --- +@inline AdaptiveArrayPools.acquire!(::DisabledPool{:metal}, ::Type{T}, n::Int) where {T} = MtlVector{T}(undef, n) +@inline AdaptiveArrayPools.acquire!(::DisabledPool{:metal}, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = MtlArray{T, N}(undef, dims) +@inline AdaptiveArrayPools.acquire!(::DisabledPool{:metal}, ::Type{T}, dims::NTuple{N, Int}) where {T, N} = MtlArray{T, N}(undef, dims) +@inline AdaptiveArrayPools.acquire!(::DisabledPool{:metal}, x::MtlArray) = Metal.similar(x) +@inline AdaptiveArrayPools.acquire!(::DisabledPool{:metal}, x::AbstractArray) = MtlArray{eltype(x)}(undef, size(x)) + +# --- acquire_view! for DisabledPool{:metal} (no view distinction on GPU) --- +@inline AdaptiveArrayPools.acquire_view!(::DisabledPool{:metal}, ::Type{T}, n::Int) where {T} = MtlVector{T}(undef, n) +@inline AdaptiveArrayPools.acquire_view!(::DisabledPool{:metal}, ::Type{T}, dims::Vararg{Int, N}) where {T, N} = MtlArray{T, N}(undef, dims) +@inline AdaptiveArrayPools.acquire_view!(::DisabledPool{:metal}, ::Type{T}, dims::NTuple{N, Int}) where {T, N} = MtlArray{T, N}(undef, dims) +@inline AdaptiveArrayPools.acquire_view!(::DisabledPool{:metal}, x::MtlArray) = Metal.similar(x) +@inline AdaptiveArrayPools.acquire_view!(::DisabledPool{:metal}, x::AbstractArray) = MtlArray{eltype(x)}(undef, size(x)) diff --git a/ext/AdaptiveArrayPoolsMetalExt/debug.jl b/ext/AdaptiveArrayPoolsMetalExt/debug.jl new file mode 100644 index 00000000..934e2736 --- /dev/null +++ b/ext/AdaptiveArrayPoolsMetalExt/debug.jl @@ -0,0 +1,263 @@ +# ============================================================================== +# Metal Safety: Poisoning, Escape Detection, Borrow Tracking +# ============================================================================== +# Metal-specific safety implementations for MetalAdaptiveArrayPool{R,S}. +# +# Binary safety system (R=0 off, R=1 all checks): +# - R=0: Zero overhead (all branches dead-code-eliminated) +# - R=1: Poisoning + structural invalidation + escape detection + borrow tracking +# +# Key difference: CPU uses resize!(v, 0) at Level 1 to invalidate stale SubArrays. +# On Metal, resize!(MtlVector, 0) would free GPU memory, so we use +# _resize_to_fit!(vec, 0) instead — sets dims to (0,) while preserving +# the GPU allocation (maxsize). Poisoning fills sentinel data before the shrink. +# arr_wrappers are invalidated by setting wrapper dims to zeros (matches CPU pattern). + +using AdaptiveArrayPools: _runtime_check, _validate_pool_return, + _set_pending_callsite!, _maybe_record_borrow!, + _invalidate_released_slots!, _zero_dims_tuple, + _throw_pool_escape_error, + PoolRuntimeEscapeError + +# ============================================================================== +# Poisoning: Fill released MtlArrays with sentinel values (R=1) +# ============================================================================== + +_metal_poison_value(::Type{T}) where {T <: AbstractFloat} = T(NaN) +_metal_poison_value(::Type{T}) where {T <: Integer} = typemax(T) +_metal_poison_value(::Type{Complex{T}}) where {T} = Complex{T}(_metal_poison_value(T), _metal_poison_value(T)) +_metal_poison_value(::Type{Bool}) = true +_metal_poison_value(::Type{T}) where {T} = zero(T) # generic fallback + +""" + _metal_poison_fill!(v::MtlArray{T,1}) + +Fill a MtlArray with a detectable sentinel value (NaN for floats, typemax for ints). +@noinline to avoid inlining GPU kernel launch overhead into hot rewind paths. +""" +@noinline function _metal_poison_fill!(v::MtlArray{T, 1}) where {T} + length(v) > 0 && Metal.fill!(v, _metal_poison_value(T)) + return nothing +end + +# ============================================================================== +# _invalidate_released_slots! for MetalTypedPool (R=1) +# ============================================================================== +# +# Overrides the no-op fallback in base. On Metal: +# - R=0: no-op (base _rewind_typed_pool! gates with S >= 1, so never called) +# - R=1: poison released MtlArrays + invalidate arr_wrappers +# - NO resize!(mtl, 0) — would free GPU memory; use _resize_to_fit! instead + +@noinline function AdaptiveArrayPools._invalidate_released_slots!( + tp::MetalTypedPool{T, S}, old_n_active::Int, safety::Int + ) where {T, S} + new_n = tp.n_active + # Poison released MtlArrays + shrink logical length to 0 + for i in (new_n + 1):old_n_active + _metal_poison_fill!(@inbounds tp.vectors[i]) + # Shrink logical length to 0 (GPU memory preserved via _resize_to_fit!). + # Matches CPU behavior where resize!(vec, 0) invalidates SubArray references. + _resize_to_fit!(@inbounds(tp.vectors[i]), 0) + end + # Invalidate arr_wrappers for released slots (matches CPU pattern from src/state.jl) + for N_idx in 1:length(tp.arr_wrappers) + wrappers_for_N = @inbounds tp.arr_wrappers[N_idx] + wrappers_for_N === nothing && continue + wrappers = wrappers_for_N::Vector{Any} + for i in (new_n + 1):min(old_n_active, length(wrappers)) + wrapper = @inbounds wrappers[i] + wrapper === nothing && continue + setfield!(wrapper::MtlArray, :dims, _zero_dims_tuple(N_idx)) + end + end + return nothing +end + +# ============================================================================== +# Borrow Tracking: Call-site recording (R=1) +# ============================================================================== +# +# Overrides the no-op AbstractArrayPool fallbacks. +# The macro injects pool._pending_callsite = "file:line\nexpr" before acquire calls. +# These functions flush that pending info into the borrow log. + +"""Record pending callsite for borrow tracking (compiles to no-op when R=0).""" +@inline function AdaptiveArrayPools._set_pending_callsite!(pool::MetalAdaptiveArrayPool{R, S}, msg::String) where {R, S} + R >= 1 && isempty(pool._pending_callsite) && (pool._pending_callsite = msg) + return nothing +end + +"""Flush pending callsite into borrow log (compiles to no-op when R=0).""" +@inline function AdaptiveArrayPools._maybe_record_borrow!(pool::MetalAdaptiveArrayPool{R, S}, tp::AbstractTypedPool) where {R, S} + R >= 1 && _metal_record_borrow_from_pending!(pool, tp) + return nothing +end + +@noinline function _metal_record_borrow_from_pending!(pool::MetalAdaptiveArrayPool, tp::AbstractTypedPool) + callsite = pool._pending_callsite + isempty(callsite) && return nothing + log = pool._borrow_log + if log === nothing + log = IdDict{Any, String}() + pool._borrow_log = log + end + @inbounds log[tp.vectors[tp.n_active]] = callsite + pool._pending_callsite = "" # Clear for next acquire + return nothing +end + +@noinline function _metal_lookup_borrow_callsite(pool::MetalAdaptiveArrayPool, v)::Union{Nothing, String} + log = pool._borrow_log + log === nothing && return nothing + return get(log, v, nothing) +end + +# ============================================================================== +# Escape Detection: _validate_pool_return for MtlArrays (R=1) +# ============================================================================== +# +# MtlArray views share the same device buffer, so device pointer overlap +# detection works correctly. pointer(::MtlArray) returns MtlPointer{T}. + +function AdaptiveArrayPools._validate_pool_return(val, pool::MetalAdaptiveArrayPool{R, S}) where {R, S} + R >= 1 || return nothing + _validate_metal_return(val, pool) + return nothing +end + +function _validate_metal_return(val, pool::MetalAdaptiveArrayPool) + # Note: Container recursion (Tuple, NamedTuple, Pair, Dict, Set, AbstractArray) + # is duplicated from CPU's _validate_pool_return dispatch chain (src/debug.jl). + # CPU uses multiple dispatch on pool::AdaptiveArrayPool for each container type, + # which doesn't cover MetalAdaptiveArrayPool. We could add MetalAdaptiveArrayPool methods + # for each container, but that creates 6+ method definitions vs. this single function. + + # MtlArray (MtlVector, MtlMatrix, etc.) + if val isa MtlArray + _check_metal_overlap(val, pool) + return + end + + # SubArray / ReshapedArray of MtlArray — defensive code. + # Current Metal.jl: view(MtlVector, 1:n) returns MtlArray via GPUArrays derive(), + # NOT SubArray. These branches guard against future Metal.jl behavior changes. + if val isa SubArray + p = parent(val) + if p isa MtlArray + _check_metal_overlap(p, pool, val) + end + return + end + + if val isa Base.ReshapedArray + p = parent(val) + if p isa MtlArray + _check_metal_overlap(p, pool, val) + elseif p isa SubArray + pp = parent(p) + if pp isa MtlArray + _check_metal_overlap(pp, pool, val) + end + end + return + end + + # Tuple + if val isa Tuple + for x in val + _validate_metal_return(x, pool) + end + return + end + + # NamedTuple + if val isa NamedTuple + for x in values(val) + _validate_metal_return(x, pool) + end + return + end + + # Pair + if val isa Pair + _validate_metal_return(val.first, pool) + _validate_metal_return(val.second, pool) + return + end + + # AbstractDict + if val isa AbstractDict + for p in val + _validate_metal_return(p, pool) + end + return + end + + # AbstractSet + if val isa AbstractSet + for x in val + _validate_metal_return(x, pool) + end + return + end + + # Array of MtlArrays (element recursion for containers) + if val isa AbstractArray + ET = eltype(val) + if !(ET <: Number) && !(ET <: AbstractString) && ET !== Symbol && ET !== Char + for x in val + _validate_metal_return(x, pool) + end + end + end + + return +end + +""" + _check_metal_overlap(arr::MtlArray, pool, original_val=arr) + +Check if a MtlArray's device memory overlaps with any pool backing MtlArray. +Throws `PoolRuntimeEscapeError` on overlap. +""" +function _check_metal_overlap(arr::MtlArray, pool::MetalAdaptiveArrayPool, original_val = arr) + arr_ptr = pointer(arr) + arr_buf = arr_ptr.buffer + arr_off = Int(arr_ptr.offset) + arr_sz = length(arr) * sizeof(eltype(arr)) + arr_end = arr_off + arr_sz + + return_site = let rs = pool._pending_return_site + isempty(rs) ? nothing : rs + end + + # Check fixed slots + AdaptiveArrayPools.foreach_fixed_slot(pool) do tp + _check_tp_metal_overlap(tp, arr_buf, arr_off, arr_end, pool, return_site, original_val) + end + + # Check others + for tp in values(pool.others) + _check_tp_metal_overlap(tp, arr_buf, arr_off, arr_end, pool, return_site, original_val) + end + return +end + +@noinline function _check_tp_metal_overlap( + tp::AbstractTypedPool, abuf, aoff::Int, aend::Int, + pool::MetalAdaptiveArrayPool, return_site, original_val + ) + for v in tp.vectors + vptr = pointer(v) + vbuf = vptr.buffer + voff = Int(vptr.offset) + vsz = length(v) * sizeof(eltype(v)) + vend = voff + vsz + if abuf === vbuf && !(aend <= voff || vend <= aoff) + callsite = _metal_lookup_borrow_callsite(pool, v) + _throw_pool_escape_error(original_val, eltype(v), callsite, return_site) + end + end + return +end diff --git a/ext/AdaptiveArrayPoolsMetalExt/dispatch.jl b/ext/AdaptiveArrayPoolsMetalExt/dispatch.jl new file mode 100644 index 00000000..0c938f6a --- /dev/null +++ b/ext/AdaptiveArrayPoolsMetalExt/dispatch.jl @@ -0,0 +1,49 @@ +# ============================================================================== +# Metal Dispatch Methods +# ============================================================================== +# Key dispatch points for Metal-specific allocation and type routing. + +using AdaptiveArrayPools: allocate_vector, get_typed_pool! + +# ============================================================================== +# Allocation Dispatch +# ============================================================================== + +@inline function AdaptiveArrayPools.allocate_vector( + ::AbstractTypedPool{T, MtlArray{T, 1, S}}, n::Int + ) where {T, S} + return MtlArray{T, 1, S}(undef, n) +end + +# ============================================================================== +# get_typed_pool! Dispatches for MetalAdaptiveArrayPool +# ============================================================================== + +# Fast path: compile-time dispatch for fixed slots +@inline AdaptiveArrayPools.get_typed_pool!(p::MetalAdaptiveArrayPool, ::Type{Float32}) = p.float32 +@inline AdaptiveArrayPools.get_typed_pool!(p::MetalAdaptiveArrayPool, ::Type{Float16}) = p.float16 +@inline AdaptiveArrayPools.get_typed_pool!(p::MetalAdaptiveArrayPool, ::Type{Int32}) = p.int32 +@inline AdaptiveArrayPools.get_typed_pool!(p::MetalAdaptiveArrayPool, ::Type{Int64}) = p.int64 +@inline AdaptiveArrayPools.get_typed_pool!(p::MetalAdaptiveArrayPool, ::Type{ComplexF32}) = p.complexf32 +@inline AdaptiveArrayPools.get_typed_pool!(p::MetalAdaptiveArrayPool, ::Type{Bool}) = p.bool + +# Slow path: rare types via IdDict (with checkpoint correction!) +# Explicitly reject Float64 and ComplexF64 (unsupported by Metal hardware). +@inline function AdaptiveArrayPools.get_typed_pool!(p::MetalAdaptiveArrayPool, ::Type{T}) where {T} + if T === Float64 || T === ComplexF64 + throw(ArgumentError("Metal backend does not support $T")) + end + return get!(p.others, T) do + tp = MetalTypedPool{T, Metal.PrivateStorage}() + # CRITICAL: Match CPU behavior - auto-checkpoint new pool if inside @with_pool scope + # Without this, rewind! would corrupt state for dynamically-created pools + if p._current_depth > 1 + push!(tp._checkpoint_n_active, 0) # n_active starts at 0 + push!(tp._checkpoint_depths, p._current_depth) + # Signal that a fallback type was touched so lazy/typed-lazy rewind + # iterates pool.others (same fix as CPU get_typed_pool!) + @inbounds p._touched_has_others[p._current_depth] = true + end + tp + end::MetalTypedPool{T, Metal.PrivateStorage} +end diff --git a/ext/AdaptiveArrayPoolsMetalExt/macros.jl b/ext/AdaptiveArrayPoolsMetalExt/macros.jl new file mode 100644 index 00000000..ae04a75e --- /dev/null +++ b/ext/AdaptiveArrayPoolsMetalExt/macros.jl @@ -0,0 +1,27 @@ +# ============================================================================== +# Metal Macro Support +# ============================================================================== +# Enables @with_pool :metal syntax for GPU memory pooling. + +using AdaptiveArrayPools: _get_pool_for_backend, _pool_type_for_backend + +# ============================================================================== +# Backend Registration (Val dispatch - zero overhead) +# ============================================================================== + +""" +Register :metal backend for `@with_pool :metal` syntax. +Uses Val dispatch for compile-time resolution and full inlining. +""" +@inline AdaptiveArrayPools._get_pool_for_backend(::Val{:metal}) = get_task_local_metal_pool() + +# ============================================================================== +# Pool Type Registration for Compile-Time Type Assertion +# ============================================================================== +# +# `_pool_type_for_backend` is called at macro expansion time to determine the +# concrete pool type for direct type assertion in macro-generated code. +# This enables `@with_pool :metal` to generate `pool::MetalAdaptiveArrayPool{R,S}` +# where R is determined by the compile-time const `RUNTIME_CHECK`. + +AdaptiveArrayPools._pool_type_for_backend(::Val{:metal}) = MetalAdaptiveArrayPool diff --git a/ext/AdaptiveArrayPoolsMetalExt/state.jl b/ext/AdaptiveArrayPoolsMetalExt/state.jl new file mode 100644 index 00000000..0d8dffb2 --- /dev/null +++ b/ext/AdaptiveArrayPoolsMetalExt/state.jl @@ -0,0 +1,336 @@ +# ============================================================================== +# State Management for Metal Pools +# ============================================================================== +# checkpoint!, rewind!, reset!, empty! implementations for MetalAdaptiveArrayPool{R,S}. +# Note: _checkpoint_typed_pool! and _rewind_typed_pool! already work with +# AbstractTypedPool, so they work for MetalTypedPool automatically. +# R parameter is threaded through rewind paths for compile-time safety dispatch. + +using AdaptiveArrayPools: checkpoint!, rewind!, reset!, + _checkpoint_typed_pool!, _rewind_typed_pool!, _has_bit, + _LAZY_MODE_BIT, _TYPED_LAZY_BIT, _TYPE_BITS_MASK + +# ============================================================================== +# Metal Fixed Slot Iteration +# ============================================================================== + +""" + foreach_fixed_slot(f, pool::MetalAdaptiveArrayPool) + +Apply `f` to each fixed slot MetalTypedPool. Zero allocation via compile-time unrolling. +""" +@generated function AdaptiveArrayPools.foreach_fixed_slot(f::F, pool::MetalAdaptiveArrayPool{R, S}) where {F, R, S} + exprs = [:(f(getfield(pool, $(QuoteNode(field))))) for field in METAL_FIXED_SLOT_FIELDS] + return quote + Base.@_inline_meta + $(exprs...) + nothing + end +end + +# ============================================================================== +# checkpoint! for MetalAdaptiveArrayPool +# ============================================================================== + +function AdaptiveArrayPools.checkpoint!(pool::MetalAdaptiveArrayPool) + # Increment depth and initialize type-touch tracking state + pool._current_depth += 1 + push!(pool._touched_type_masks, UInt16(0)) + push!(pool._touched_has_others, false) + depth = pool._current_depth + + # Fixed slots - zero allocation via @generated iteration + AdaptiveArrayPools.foreach_fixed_slot(pool) do tp + _checkpoint_typed_pool!(tp, depth) + end + + # Others - iterate without allocation + for p in values(pool.others) + _checkpoint_typed_pool!(p, depth) + end + + return nothing +end + +# Type-specific checkpoint (single type) +@inline function AdaptiveArrayPools.checkpoint!(pool::MetalAdaptiveArrayPool, ::Type{T}) where {T} + pool._current_depth += 1 + push!(pool._touched_type_masks, UInt16(0)) + push!(pool._touched_has_others, AdaptiveArrayPools._fixed_slot_bit(T) == UInt16(0)) + _checkpoint_typed_pool!(AdaptiveArrayPools.get_typed_pool!(pool, T), pool._current_depth) + return nothing +end + +# Type-specific checkpoint (multiple types) +@generated function AdaptiveArrayPools.checkpoint!(pool::MetalAdaptiveArrayPool{R, S}, types::Type...) where {R, S} + seen = Set{Any}() + unique_indices = Int[] + for i in eachindex(types) + if !(types[i] in seen) + push!(seen, types[i]) + push!(unique_indices, i) + end + end + has_any_fallback = any(i -> AdaptiveArrayPools._fixed_slot_bit(types[i].parameters[1]) == UInt16(0), unique_indices) + checkpoint_exprs = [:(_checkpoint_typed_pool!(AdaptiveArrayPools.get_typed_pool!(pool, types[$i]), pool._current_depth)) for i in unique_indices] + return quote + pool._current_depth += 1 + push!(pool._touched_type_masks, UInt16(0)) + push!(pool._touched_has_others, $has_any_fallback) + $(checkpoint_exprs...) + nothing + end +end + +# ============================================================================== +# rewind! for MetalAdaptiveArrayPool +# ============================================================================== + +function AdaptiveArrayPools.rewind!(pool::MetalAdaptiveArrayPool{R, S}) where {R, S} + cur_depth = pool._current_depth + + # Safety guard: at global scope (depth=1), delegate to reset! + if cur_depth == 1 + reset!(pool) + return nothing + end + + # Fixed slots — pass R for compile-time safety dispatch + AdaptiveArrayPools.foreach_fixed_slot(pool) do tp + _rewind_typed_pool!(tp, cur_depth, R) + end + + # Others + for tp in values(pool.others) + _rewind_typed_pool!(tp, cur_depth, R) + end + + pop!(pool._touched_type_masks) + pop!(pool._touched_has_others) + pool._current_depth -= 1 + + return nothing +end + +# Type-specific rewind (single type) +@inline function AdaptiveArrayPools.rewind!(pool::MetalAdaptiveArrayPool{R, S}, ::Type{T}) where {R, S, T} + if pool._current_depth == 1 + reset!(AdaptiveArrayPools.get_typed_pool!(pool, T), R) + return nothing + end + _rewind_typed_pool!(AdaptiveArrayPools.get_typed_pool!(pool, T), pool._current_depth, R) + pop!(pool._touched_type_masks) + pop!(pool._touched_has_others) + pool._current_depth -= 1 + return nothing +end + +# Type-specific rewind (multiple types) +@generated function AdaptiveArrayPools.rewind!(pool::MetalAdaptiveArrayPool{R, S}, types::Type...) where {R, S} + seen = Set{Any}() + unique_indices = Int[] + for i in eachindex(types) + if !(types[i] in seen) + push!(seen, types[i]) + push!(unique_indices, i) + end + end + rewind_exprs = [:(_rewind_typed_pool!(AdaptiveArrayPools.get_typed_pool!(pool, types[$i]), pool._current_depth, R)) for i in reverse(unique_indices)] + reset_exprs = [:(reset!(AdaptiveArrayPools.get_typed_pool!(pool, types[$i]), R)) for i in unique_indices] + return quote + if pool._current_depth == 1 + $(reset_exprs...) + return nothing + end + $(rewind_exprs...) + pop!(pool._touched_type_masks) + pop!(pool._touched_has_others) + pool._current_depth -= 1 + nothing + end +end + +# ============================================================================== +# Lazy Mode for MetalAdaptiveArrayPool (use_typed=false path) +# ============================================================================== +# Mirrors CPU _lazy_checkpoint! / _lazy_rewind! in src/state.jl. +# +# Float16 on Metal: direct struct field (not in pool.others dict), but _fixed_slot_bit(Float16)=0. +# We reassign Float16 to bit 7 (unused on Metal; CPU uses bit 7 for Bit type which has no GPU equivalent). +# This gives Float16 the same lazy-first-touch checkpoint treatment as other fixed-slot types. + +# Bit 7 on Metal is reserved for Float16 (CPU uses it for Bit; Bit type does not exist on GPU). +@inline _metal_float16_bit() = UInt16(1) << 7 + +@inline function AdaptiveArrayPools._lazy_checkpoint!(pool::MetalAdaptiveArrayPool) + pool._current_depth += 1 + push!(pool._touched_type_masks, _LAZY_MODE_BIT) # lazy mode flag + push!(pool._touched_has_others, false) + depth = pool._current_depth + # Eagerly checkpoint pre-existing others entries — same as CPU _lazy_checkpoint!. + # New types created during the scope start at n_active=0 (sentinel covers them, Case B safe). + # Pre-existing types need their count saved now so Case A fires correctly at rewind. + for p in values(pool.others) + _checkpoint_typed_pool!(p, depth) + @inbounds pool._touched_has_others[depth] = true + end + # Float16 uses lazy first-touch via bit 7 in _record_type_touch! — no eager checkpoint needed. + return nothing +end + +@inline function AdaptiveArrayPools._lazy_rewind!(pool::MetalAdaptiveArrayPool{R, S}) where {R, S} + d = pool._current_depth + mask = @inbounds(pool._touched_type_masks[d]) & _TYPE_BITS_MASK + _has_bit(mask, Float32) && _rewind_typed_pool!(pool.float32, d, R) + _has_bit(mask, Int64) && _rewind_typed_pool!(pool.int64, d, R) + _has_bit(mask, Int32) && _rewind_typed_pool!(pool.int32, d, R) + _has_bit(mask, ComplexF32) && _rewind_typed_pool!(pool.complexf32, d, R) + _has_bit(mask, Bool) && _rewind_typed_pool!(pool.bool, d, R) + # Bit 7: Float16 (Metal reassignment — _fixed_slot_bit(Float16)==0, must use explicit bit check) + mask & _metal_float16_bit() != 0 && _rewind_typed_pool!(pool.float16, d, R) + if @inbounds(pool._touched_has_others[d]) + for tp in values(pool.others) + _rewind_typed_pool!(tp, d, R) + end + end + pop!(pool._touched_type_masks) + pop!(pool._touched_has_others) + pool._current_depth -= 1 + return nothing +end + +# ============================================================================== +# Typed-Fallback Helpers for MetalAdaptiveArrayPool (Phase 5 parity) +# ============================================================================== + +# _typed_lazy_checkpoint!: typed checkpoint + set bit 14 for lazy extra-type tracking. +# Also eagerly snapshots pre-existing others entries (mirrors CPU fix for Issue #3). +@inline function AdaptiveArrayPools._typed_lazy_checkpoint!(pool::MetalAdaptiveArrayPool, types::Type...) + checkpoint!(pool, types...) + d = pool._current_depth + @inbounds pool._touched_type_masks[d] |= _TYPED_LAZY_BIT + # Eagerly snapshot pre-existing others entries — same reasoning as _lazy_checkpoint!. + # Skip re-snapshot for entries already checkpointed at d by checkpoint!(pool, types...) + for p in values(pool.others) + if @inbounds(p._checkpoint_depths[end]) != d + _checkpoint_typed_pool!(p, d) + end + @inbounds pool._touched_has_others[d] = true + end + # Float16 uses lazy first-touch via bit 7 in _record_type_touch! — no eager checkpoint needed. + return nothing +end + +# _typed_lazy_rewind!: selective rewind of (tracked | touched) mask. +# Uses direct field access with bit checks — foreach_fixed_slot is single-argument (no bit yield). +# Bit 7: Float16 (Metal-specific; lazy-checkpointed on first touch by _record_type_touch!). +# has_others: genuine others types (UInt8, Int8, etc.) — eagerly checkpointed at scope entry. +@inline function AdaptiveArrayPools._typed_lazy_rewind!(pool::MetalAdaptiveArrayPool{R, S}, tracked_mask::UInt16) where {R, S} + d = pool._current_depth + touched = @inbounds(pool._touched_type_masks[d]) & _TYPE_BITS_MASK + combined = tracked_mask | touched + _has_bit(combined, Float32) && _rewind_typed_pool!(pool.float32, d, R) + _has_bit(combined, Int64) && _rewind_typed_pool!(pool.int64, d, R) + _has_bit(combined, Int32) && _rewind_typed_pool!(pool.int32, d, R) + _has_bit(combined, ComplexF32) && _rewind_typed_pool!(pool.complexf32, d, R) + _has_bit(combined, Bool) && _rewind_typed_pool!(pool.bool, d, R) + # Float16: bit 7 is set by _record_type_touch! on first touch (lazy first-touch). + # Also rewind when Float16 was a *tracked* type in the macro. + if combined & _metal_float16_bit() != 0 || @inbounds(pool.float16._checkpoint_depths[end]) == d + _rewind_typed_pool!(pool.float16, d, R) + end + if @inbounds(pool._touched_has_others[d]) + for tp in values(pool.others) + _rewind_typed_pool!(tp, d, R) + end + end + pop!(pool._touched_type_masks) + pop!(pool._touched_has_others) + pool._current_depth -= 1 + return nothing +end + +# ============================================================================== +# reset! for MetalAdaptiveArrayPool +# ============================================================================== + +function AdaptiveArrayPools.reset!(pool::MetalAdaptiveArrayPool{R, S}) where {R, S} + # Fixed slots + AdaptiveArrayPools.foreach_fixed_slot(pool) do tp + reset!(tp, R) + end + + # Others + for tp in values(pool.others) + reset!(tp, R) + end + + # Reset depth and bitmask sentinel state + pool._current_depth = 1 + empty!(pool._touched_type_masks) + push!(pool._touched_type_masks, UInt16(0)) # Sentinel: no bits set + empty!(pool._touched_has_others) + push!(pool._touched_has_others, false) # Sentinel: no others + + # Reset borrow tracking state + pool._pending_callsite = "" + pool._pending_return_site = "" + pool._borrow_log = nothing + + return pool +end + +# Type-specific reset +@inline function AdaptiveArrayPools.reset!(pool::MetalAdaptiveArrayPool{R, S}, ::Type{T}) where {R, S, T} + reset!(AdaptiveArrayPools.get_typed_pool!(pool, T), R) + return pool +end + +# ============================================================================== +# empty! for MetalTypedPool and MetalAdaptiveArrayPool +# ============================================================================== + +""" + empty!(tp::MetalTypedPool) + +Clear all GPU storage. Note: This removes Julia references to MtlArrays. +Actual VRAM release depends on GC + Metal.jl's memory pool. +""" +function Base.empty!(tp::MetalTypedPool) + empty!(tp.vectors) + empty!(tp.arr_wrappers) + tp.n_active = 0 + # Restore sentinel values + empty!(tp._checkpoint_n_active) + push!(tp._checkpoint_n_active, 0) + empty!(tp._checkpoint_depths) + push!(tp._checkpoint_depths, 0) + return tp +end + +function Base.empty!(pool::MetalAdaptiveArrayPool) + # Fixed slots + AdaptiveArrayPools.foreach_fixed_slot(pool) do tp + empty!(tp) + end + + # Others - clear all then the IdDict + for tp in values(pool.others) + empty!(tp) + end + empty!(pool.others) + + # Reset depth and bitmask sentinel state + pool._current_depth = 1 + empty!(pool._touched_type_masks) + push!(pool._touched_type_masks, UInt16(0)) # Sentinel: no bits set + empty!(pool._touched_has_others) + push!(pool._touched_has_others, false) # Sentinel: no others + + # Reset borrow tracking state + pool._pending_callsite = "" + pool._pending_return_site = "" + pool._borrow_log = nothing + + return pool +end diff --git a/ext/AdaptiveArrayPoolsMetalExt/task_local_pool.jl b/ext/AdaptiveArrayPoolsMetalExt/task_local_pool.jl new file mode 100644 index 00000000..c60c9430 --- /dev/null +++ b/ext/AdaptiveArrayPoolsMetalExt/task_local_pool.jl @@ -0,0 +1,59 @@ +# ============================================================================== +# Task-Local Metal Pool (Multi-Device Aware) +# ============================================================================== +# Each Task gets one pool per Metal device to prevent cross-device memory access. +# Pools are parameterized by R (0=off, 1=checks on) via MetalAdaptiveArrayPool{R,S}. + +const _METAL_POOL_KEY = :ADAPTIVE_ARRAY_POOL_METAL + +""" + get_task_local_metal_pool() -> MetalAdaptiveArrayPool{R,S} + +Retrieves (or creates) the `MetalAdaptiveArrayPool` for the current Task and current Metal device. + +## Multi-Device Safety +Each pool is bound to a specific Metal device. This function automatically manages +a dictionary of pools (one per device) in task-local storage, ensuring that: +- Device A's pool is never used on Device B +- Switching devices gets the correct pool + +## Implementation +Uses `Dict{UInt64, MetalAdaptiveArrayPool}` in task-local storage, keyed by device hash. +Values are `MetalAdaptiveArrayPool{R,S}` where R is determined by `RUNTIME_CHECK`. +""" +@inline function AdaptiveArrayPools.get_task_local_metal_pool() + # 1. Get or create the pools dictionary + pools = get(task_local_storage(), _METAL_POOL_KEY, nothing) + if pools === nothing + pools = Dict{UInt64, MetalAdaptiveArrayPool}() + task_local_storage(_METAL_POOL_KEY, pools) + end + + # 2. Get current device key + dev = Metal.device() + dev_key = objectid(dev) + + # 3. Get or create pool for this device + pool = get(pools, dev_key, nothing) + if pool === nothing + pool = MetalAdaptiveArrayPool() # Uses RUNTIME_CHECK for initial R + pools[dev_key] = pool + end + + return pool::MetalAdaptiveArrayPool +end + +""" + get_task_local_metal_pools() -> Dict{UInt64, MetalAdaptiveArrayPool} + +Returns the dictionary of all Metal pools for the current task (one per device). +Useful for diagnostics or bulk operations across all devices. +""" +@inline function AdaptiveArrayPools.get_task_local_metal_pools() + pools = get(task_local_storage(), _METAL_POOL_KEY, nothing) + if pools === nothing + pools = Dict{UInt64, MetalAdaptiveArrayPool}() + task_local_storage(_METAL_POOL_KEY, pools) + end + return pools +end diff --git a/ext/AdaptiveArrayPoolsMetalExt/types.jl b/ext/AdaptiveArrayPoolsMetalExt/types.jl new file mode 100644 index 00000000..461f91c0 --- /dev/null +++ b/ext/AdaptiveArrayPoolsMetalExt/types.jl @@ -0,0 +1,171 @@ +# ============================================================================== +# Metal Type Definitions +# ============================================================================== +# +# Note: Unlike CPU, view(MtlVector, 1:n) returns MtlVector (via GPUArrays derive()), +# NOT SubArray. GPU view/reshape creation allocates ~80 bytes on CPU heap for the +# MtlArray wrapper. We cache wrappers via arr_wrappers to achieve zero-allocation +# on cache hit (same approach as CPU's setfield!-based Array wrapper reuse). + +using AdaptiveArrayPools: RUNTIME_CHECK + +const METAL_STORAGE = Metal.PrivateStorage + +""" + MetalTypedPool{T, S} <: AbstractTypedPool{T, MtlArray{T, 1, S}} + +GPU memory pool for element type `T` with Metal storage mode `S`. +Uses `arr_wrappers`-based MtlArray reuse for zero-allocation acquire +(same design as CPU TypedPool on Julia 1.11+ and CUDA CuTypedPool). + +## Fields +- `vectors`: Backing `MtlArray{T,1,S}` storage (one per slot) +- `arr_wrappers`: `Vector{Union{Nothing, Vector{Any}}}` — indexed by N (dimensionality), + each entry is a per-slot cached `MtlArray{T,N,S}` wrapper. Uses `setfield!(wrapper, :dims, dims)` + for zero-allocation reuse of unlimited dimension patterns within the same N. + When the backing vector's GPU buffer changes (rare: only on grow beyond capacity), + the wrapper's `:data` field is updated via DataRef refcount management. +- State management fields (same as CPU) +""" +mutable struct MetalTypedPool{T, S} <: AbstractTypedPool{T, MtlArray{T, 1, S}} + # --- Storage --- + vectors::Vector{MtlArray{T, 1, S}} + + # --- N-D Wrapper Cache (setfield!-based reuse, matches CPU TypedPool) --- + arr_wrappers::Vector{Union{Nothing, Vector{Any}}} # index=N (dimensionality), value=per-slot MtlArray{T,N,S} + + # --- State Management (1-based sentinel pattern) --- + n_active::Int + _checkpoint_n_active::Vector{Int} + _checkpoint_depths::Vector{Int} +end + +function MetalTypedPool{T, S}() where {T, S} + return MetalTypedPool{T, S}( + MtlArray{T, 1, S}[], # vectors + Union{Nothing, Vector{Any}}[], # arr_wrappers (indexed by N) + 0, [0], [0], # State (1-based sentinel) + ) +end + +# ============================================================================== +# Metal Fixed Slot Configuration +# ============================================================================== + +""" +Metal-optimized fixed slots. Differs from CUDA: +- No Float64, ComplexF64 (unsupported by Metal hardware) +- Float32 first (GPU-preferred precision) +- Float16 added (ML/inference workloads) +""" +const METAL_FIXED_SLOT_FIELDS = ( + :float32, # Primary GPU type + :float16, # ML inference + :int32, # GPU-preferred indexing + :int64, # Large indices + :complexf32, # FFT, signal processing + :bool, # Masks +) + +# ============================================================================== +# MetalAdaptiveArrayPool +# ============================================================================== + +""" + MetalAdaptiveArrayPool{R, S} <: AbstractArrayPool + +Multi-type Metal GPU memory pool, parameterized by runtime check level `R` (binary: 0 or 1) +and Metal storage mode `S`. + +## Runtime Check Levels +- `R=0`: Zero overhead — all safety branches eliminated by dead-code elimination +- `R=1`: Full checks — poisoning + structural invalidation + escape detection + borrow tracking + +## Device Safety +Each pool is bound to a specific Metal device. Using a pool on the wrong device +causes undefined behavior. The `device_key` field tracks ownership. + +## Fields +- Fixed slots for common GPU types (Float32 priority, includes Float16, no Float64/ComplexF64) +- `others`: IdDict fallback for rare types +- `device_key`: The Metal device this pool belongs to +- Borrow tracking fields (required by macro-injected field access at all R levels) +""" +mutable struct MetalAdaptiveArrayPool{R, S} <: AbstractArrayPool + # Fixed Slots (Metal-optimized order, no Float64/ComplexF64) + float32::MetalTypedPool{Float32, S} + float16::MetalTypedPool{Float16, S} + int32::MetalTypedPool{Int32, S} + int64::MetalTypedPool{Int64, S} + complexf32::MetalTypedPool{ComplexF32, S} + bool::MetalTypedPool{Bool, S} + + # Fallback for rare types + others::IdDict{DataType, Any} + + # State management (same as CPU) + _current_depth::Int + _touched_type_masks::Vector{UInt16} # Per-depth: which fixed slots were touched + mode flags + _touched_has_others::Vector{Bool} # Per-depth: any non-fixed-slot type touched? + + # Device tracking (safety) + device_key::Any + + # Borrow tracking (required: macro injects pool._pending_callsite = "..." as raw AST) + _pending_callsite::String + _pending_return_site::String + _borrow_log::Union{Nothing, IdDict{Any, String}} +end + +function MetalAdaptiveArrayPool{R, S}() where {R, S} + return MetalAdaptiveArrayPool{R, S}( + MetalTypedPool{Float32, S}(), + MetalTypedPool{Float16, S}(), + MetalTypedPool{Int32, S}(), + MetalTypedPool{Int64, S}(), + MetalTypedPool{ComplexF32, S}(), + MetalTypedPool{Bool, S}(), + IdDict{DataType, Any}(), + 1, # _current_depth (1 = global scope) + [UInt16(0)], # _touched_type_masks: sentinel (no bits set) + [false], # _touched_has_others: sentinel (no others) + Metal.device(), + "", # _pending_callsite + "", # _pending_return_site + nothing, # _borrow_log: lazily created when R >= 1 + ) +end + +"""Create pool with the default `RUNTIME_CHECK` level and PrivateStorage.""" +MetalAdaptiveArrayPool() = MetalAdaptiveArrayPool{RUNTIME_CHECK, METAL_STORAGE}() + +# ============================================================================== +# Runtime Check Dispatch +# ============================================================================== + +""" + _runtime_check(pool::MetalAdaptiveArrayPool) -> Bool + +Return compile-time constant indicating whether runtime safety checks are enabled. +`R >= 1` enables checks; `R == 0` disables (dead-code-eliminated). +""" +@inline AdaptiveArrayPools._runtime_check(::MetalAdaptiveArrayPool{0}) = false +@inline AdaptiveArrayPools._runtime_check(::MetalAdaptiveArrayPool) = true # R >= 1 + +""" + _make_metal_pool(level) -> MetalAdaptiveArrayPool + +Function barrier: converts runtime check level to concrete `MetalAdaptiveArrayPool{R,S}`. +Accepts `Bool` (`true`->1, `false`->0) or `Int` (used directly as R). +""" +_make_metal_pool(runtime_check::Bool) = _make_metal_pool(Int(runtime_check)) +@noinline function _make_metal_pool(R::Int) + R == 0 && return MetalAdaptiveArrayPool{0, METAL_STORAGE}() + return MetalAdaptiveArrayPool{1, METAL_STORAGE}() +end + +"""Human-readable runtime check label.""" +function _metal_check_label(r::Int) + r <= 0 && return "off" + return "on" +end diff --git a/ext/AdaptiveArrayPoolsMetalExt/utils.jl b/ext/AdaptiveArrayPoolsMetalExt/utils.jl new file mode 100644 index 00000000..a4f8f442 --- /dev/null +++ b/ext/AdaptiveArrayPoolsMetalExt/utils.jl @@ -0,0 +1,158 @@ +# ============================================================================== +# Metal Pool Display & Statistics +# ============================================================================== + +using AdaptiveArrayPools: pool_stats, foreach_fixed_slot + +# ============================================================================== +# pool_stats for MetalTypedPool +# ============================================================================== + +""" + pool_stats(tp::MetalTypedPool{T,S}; io::IO=stdout, indent::Int=0, name::String="") + +Print statistics for a Metal typed pool. +""" +function AdaptiveArrayPools.pool_stats(tp::MetalTypedPool{T, S}; io::IO = stdout, indent::Int = 0, name::String = "") where {T, S} + prefix = " "^indent + type_name = isempty(name) ? string(T) : name + + n_arrays = length(tp.vectors) + if n_arrays == 0 + printstyled(io, prefix, type_name, color = :cyan) + printstyled(io, " (empty)\n", color = :dark_gray) + return + end + + # Calculate total elements and memory + total_elements = sum(length(v) for v in tp.vectors) + gpu_bytes = sum(sizeof(v) for v in tp.vectors) # sizeof(MtlArray) returns GPU data size + cpu_bytes = sum(Base.summarysize(v) for v in tp.vectors) + gpu_str = Base.format_bytes(gpu_bytes) + cpu_str = Base.format_bytes(cpu_bytes) + + # Header + printstyled(io, prefix, type_name, color = :cyan) + printstyled(io, " [Metal]", color = :magenta) + println(io) + + # Stats + printstyled(io, prefix, " slots: ", color = :dark_gray) + printstyled(io, n_arrays, color = :blue) + printstyled(io, " (active: ", color = :dark_gray) + printstyled(io, tp.n_active, color = :blue) + printstyled(io, ")\n", color = :dark_gray) + + printstyled(io, prefix, " elements: ", color = :dark_gray) + printstyled(io, total_elements, color = :blue) + printstyled(io, " ($gpu_str GPU + $cpu_str CPU)\n", color = :dark_gray) + return nothing +end + +# ============================================================================== +# pool_stats for MetalAdaptiveArrayPool +# ============================================================================== + +""" + pool_stats(pool::MetalAdaptiveArrayPool; io::IO=stdout) + +Print statistics for a Metal adaptive array pool. +""" +function AdaptiveArrayPools.pool_stats(pool::MetalAdaptiveArrayPool{R, S}; io::IO = stdout) where {R, S} + # Header with device info and runtime check level + printstyled(io, "MetalAdaptiveArrayPool", bold = true, color = :magenta) + printstyled(io, "{$R,$S}", color = :yellow) + dev_name = try + string(pool.device_key.name) + catch + string(nameof(typeof(pool.device_key))) + end + printstyled(io, " (device ", color = :dark_gray) + printstyled(io, dev_name, color = :blue) + printstyled(io, ", check=", color = :dark_gray) + printstyled(io, _metal_check_label(R), color = :yellow) + printstyled(io, ")\n", color = :dark_gray) + + has_content = false + + # Fixed slots + foreach_fixed_slot(pool) do tp + if !isempty(tp.vectors) + has_content = true + T = typeof(tp).parameters[1] + pool_stats(tp; io, indent = 2, name = "$T (fixed)") + end + end + + # Fallback types + for (T, tp) in pool.others + has_content = true + pool_stats(tp; io, indent = 2, name = "$T (fallback)") + end + + if !has_content + printstyled(io, " (empty)\n", color = :dark_gray) + end + return nothing +end + +# Backend dispatch +function AdaptiveArrayPools.pool_stats(::Val{:metal}; io::IO = stdout) + pools = get_task_local_metal_pools() + for pool in values(pools) + pool_stats(pool; io) + end + return nothing +end + +# ============================================================================== +# Base.show for MetalTypedPool +# ============================================================================== + +# Compact one-line show +function Base.show(io::IO, tp::MetalTypedPool{T, S}) where {T, S} + n_vectors = length(tp.vectors) + return if n_vectors == 0 + print(io, "MetalTypedPool{$T,$S}(empty)") + else + total = sum(length(v) for v in tp.vectors) + print(io, "MetalTypedPool{$T,$S}(slots=$n_vectors, active=$(tp.n_active), elements=$total)") + end +end + +# Multi-line show +function Base.show(io::IO, ::MIME"text/plain", tp::MetalTypedPool{T, S}) where {T, S} + return pool_stats(tp; io, name = "MetalTypedPool{$T,$S}") +end + +# ============================================================================== +# Base.show for MetalAdaptiveArrayPool +# ============================================================================== + +# Compact one-line show +function Base.show(io::IO, pool::MetalAdaptiveArrayPool{R, S}) where {R, S} + n_types = Ref(0) + total_vectors = Ref(0) + total_active = Ref(0) + + foreach_fixed_slot(pool) do tp + if !isempty(tp.vectors) + n_types[] += 1 + end + total_vectors[] += length(tp.vectors) + total_active[] += tp.n_active + end + + n_types[] += length(pool.others) + for tp in values(pool.others) + total_vectors[] += length(tp.vectors) + total_active[] += tp.n_active + end + + return print(io, "MetalAdaptiveArrayPool{$R,$S}(check=$(_metal_check_label(R)), types=$(n_types[]), slots=$(total_vectors[]), active=$(total_active[]))") +end + +# Multi-line show +function Base.show(io::IO, ::MIME"text/plain", pool::MetalAdaptiveArrayPool) + return pool_stats(pool; io) +end diff --git a/src/AdaptiveArrayPools.jl b/src/AdaptiveArrayPools.jl index 82f18828..32ebcd5c 100644 --- a/src/AdaptiveArrayPools.jl +++ b/src/AdaptiveArrayPools.jl @@ -12,6 +12,7 @@ export STATIC_POOLING, MAYBE_POOLING, RUNTIME_CHECK export PoolEscapeError, EscapePoint export checkpoint!, rewind!, reset! export get_task_local_cuda_pool, get_task_local_cuda_pools # CUDA (stubs, overridden by extension) +export get_task_local_metal_pool, get_task_local_metal_pools # Metal (stubs, overridden by extension) # Extension API (for GPU backends) export AbstractTypedPool, AbstractArrayPool # For subtyping diff --git a/src/task_local_pool.jl b/src/task_local_pool.jl index 6dc406e7..3f4f14c9 100644 --- a/src/task_local_pool.jl +++ b/src/task_local_pool.jl @@ -124,3 +124,27 @@ Returns the dictionary of all CUDA pools for the current task (one per device). Requires CUDA.jl to be loaded. Throws an error if CUDA extension is not available. """ function get_task_local_cuda_pools end + +# ============================================================================== +# Metal Pool Stubs (overridden by extension when Metal is loaded) +# ============================================================================== + +""" + get_task_local_metal_pool() -> MetalAdaptiveArrayPool + +Retrieves (or creates) the Metal pool for the current Task and current Metal device. + +Requires Metal.jl to be loaded. Throws an error if Metal extension is not available. + +See also: [`get_task_local_pool`](@ref) for CPU pools. +""" +function get_task_local_metal_pool end + +""" + get_task_local_metal_pools() -> Dict{UInt64, MetalAdaptiveArrayPool} + +Returns the dictionary of all Metal pools for the current task (one per device). + +Requires Metal.jl to be loaded. Throws an error if Metal extension is not available. +""" +function get_task_local_metal_pools end diff --git a/src/utils.jl b/src/utils.jl index a64de665..0110bc2f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -113,12 +113,14 @@ end """ function pool_stats(; io::IO = stdout) pool_stats(:cpu; io) - # Show CUDA pools if extension is loaded and pools exist - try - pool_stats(Val(:cuda); io) - catch e - e isa MethodError || rethrow() - # CUDA extension not loaded - silently skip + # Show GPU pools if extensions are loaded + for backend in (:cuda, :metal) + try + pool_stats(Val(backend); io) + catch e + e isa MethodError || rethrow() + # Extension not loaded - silently skip + end end return nothing end diff --git a/test/Project.toml b/test/Project.toml index 6c2b8828..fcdde041 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/metal/runtests.jl b/test/metal/runtests.jl new file mode 100644 index 00000000..fb034dbc --- /dev/null +++ b/test/metal/runtests.jl @@ -0,0 +1,52 @@ +# Metal Extension Test Suite +# ========================= +# Entry point for all Metal-related tests. +# +# Usage: +# - From main test suite: automatically included when Metal is available +# - Direct execution: julia --project test/metal/runtests.jl +# - Skip Metal tests: TEST_METAL=false julia --project -e 'using Pkg; Pkg.test()' + +using Test + +# GPU pooling requires Julia 1.11+ +@static if VERSION < v"1.11-" + @info "Metal extension tests skipped (requires Julia 1.11+)" + @testset "Metal (skipped — Julia < 1.11)" begin end + return +end + +# Check Metal availability (requires macOS + Apple Silicon) +const METAL_AVAILABLE = try + Sys.isapple() || error("Not macOS") + using Metal + Metal.functional() +catch + false +end + +if !METAL_AVAILABLE + @info "Metal not available or not functional, skipping Metal tests" +else + @info "Running Metal extension tests on device: $(Metal.device())" + + # Load dependencies + using AdaptiveArrayPools + using AdaptiveArrayPools: checkpoint!, rewind!, get_typed_pool!, get_view!, foreach_fixed_slot + + # Extension types (only needed for type checks in tests) + const ext = Base.get_extension(AdaptiveArrayPools, :AdaptiveArrayPoolsMetalExt) + const MetalTypedPool = ext.MetalTypedPool + const MetalAdaptiveArrayPool = ext.MetalAdaptiveArrayPool + const METAL_FIXED_SLOT_FIELDS = ext.METAL_FIXED_SLOT_FIELDS + + # Include all Metal test files + include("test_extension.jl") + include("test_allocation.jl") + include("test_display.jl") + include("test_convenience.jl") + include("test_disabled_pool.jl") + include("test_metal_safety.jl") + include("test_reshape.jl") + include("test_task_local_pool.jl") +end diff --git a/test/metal/test_allocation.jl b/test/metal/test_allocation.jl new file mode 100644 index 00000000..8db866ef --- /dev/null +++ b/test/metal/test_allocation.jl @@ -0,0 +1,261 @@ +# Metal Allocation Tests +# Verifies zero-allocation pooling behavior and GPU memory reuse + +@testset "GPU Allocation" begin + + @testset "Memory reuse (same size)" begin + pool = get_task_local_metal_pool() + reset!(pool) + + # First acquire - populates pool + @with_pool :metal p begin + v = acquire!(p, Float32, 100) + v .= 1.0f0 + end + + # Second acquire (same size) - should reuse GPU memory + alloc = Metal.@allocated begin + @with_pool :metal p begin + v = acquire!(p, Float32, 100) + v .= 2.0f0 + end + end + + # GPU allocation should be minimal on cache hit (kernel launch overhead only) + @test alloc < 200 + end + + @testset "Memory reuse (multiple arrays)" begin + pool = get_task_local_metal_pool() + reset!(pool) + + # Warmup with 3 arrays + @with_pool :metal p begin + acquire!(p, Float32, 100) + acquire!(p, Float32, 200) + acquire!(p, Float32, 300) + end + + # Second pass should reuse all GPU memory + alloc = Metal.@allocated begin + @with_pool :metal p begin + v1 = acquire!(p, Float32, 100) + v2 = acquire!(p, Float32, 200) + v3 = acquire!(p, Float32, 300) + v1 .= 1.0f0; v2 .= 2.0f0; v3 .= 3.0f0 + end + end + + # 3 kernel launches × ~56 bytes each ≈ ~168 bytes overhead + @test alloc < 200 + end + + @testset "Memory reuse (N-D arrays)" begin + pool = get_task_local_metal_pool() + reset!(pool) + + # Warmup with 2D array + @with_pool :metal p begin + A = acquire!(p, Float32, 10, 10) + A .= 1.0f0 + end + + # Reuse check — GPU allocation only + alloc = Metal.@allocated begin + @with_pool :metal p begin + A = acquire!(p, Float32, 10, 10) + A .= 2.0f0 + end + end + + @test alloc < 200 + end + + @testset "Memory reuse (3D arrays)" begin + pool = get_task_local_metal_pool() + reset!(pool) + + # Warmup with 3D array + @with_pool :metal p begin + T = acquire!(p, Float32, 5, 5, 4) + T .= 1.0f0 + end + + alloc = Metal.@allocated begin + @with_pool :metal p begin + T = acquire!(p, Float32, 5, 5, 4) + T .= 2.0f0 + end + end + + @test alloc < 200 + end + + @testset "Pointer reuse verification" begin + pool = get_task_local_metal_pool() + reset!(pool) + + ptr1 = Ref{UInt}(0) + ptr2 = Ref{UInt}(0) + + @with_pool :metal p begin + v = acquire!(p, Float32, 1000) + ptr1[] = UInt(pointer(v).offset) + end + + @with_pool :metal p begin + v = acquire!(p, Float32, 1000) + ptr2[] = UInt(pointer(v).offset) + end + + # Same GPU memory offset should be reused + @test ptr1[] == ptr2[] + end + +end + +@testset "_resize_to_fit! Metal memory preservation" begin + _resize_to_fit! = ext._resize_to_fit! + + @testset "Shrink preserves Metal pointer" begin + v = MtlArray(zeros(Float32, 1000)) + ptr = UInt(pointer(v).offset) + _resize_to_fit!(v, 100) + @test length(v) == 100 + @test UInt(pointer(v).offset) == ptr + end + + @testset "Grow-back within capacity: no realloc" begin + v = MtlArray(zeros(Float32, 1000)) + ptr = UInt(pointer(v).offset) + _resize_to_fit!(v, 100) + @test length(v) == 100 + @test UInt(pointer(v).offset) == ptr + _resize_to_fit!(v, 1000) + @test length(v) == 1000 + @test UInt(pointer(v).offset) == ptr + end + + @testset "Shrink to 0, grow back preserves pointer" begin + v = MtlArray(zeros(Float32, 500)) + ptr = UInt(pointer(v).offset) + _resize_to_fit!(v, 0) + @test length(v) == 0 + _resize_to_fit!(v, 500) + @test length(v) == 500 + @test UInt(pointer(v).offset) == ptr + end + + @testset "Grow within capacity after invalidation: no realloc" begin + v = MtlArray(zeros(Float32, 1000)) + ptr = UInt(pointer(v).offset) + _resize_to_fit!(v, 0) + @test length(v) == 0 + _resize_to_fit!(v, 200) + @test length(v) == 200 + @test UInt(pointer(v).offset) == ptr + end + + @testset "No-op when n == length" begin + v = MtlArray(zeros(Float32, 200)) + ptr = UInt(pointer(v).offset) + _resize_to_fit!(v, 200) + @test length(v) == 200 + @test UInt(pointer(v).offset) == ptr + end + + @testset "Grow beyond capacity delegates to resize!" begin + v = MtlArray(zeros(Float32, 100)) + _resize_to_fit!(v, 10_000) + @test length(v) == 10_000 + end +end + +@testset "CPU Allocation (MtlArray wrapper)" begin + + @testset "acquire! N-D has low CPU allocation (cache hit)" begin + pool = get_task_local_metal_pool() + reset!(pool) + + function _test_metal_nd_alloc!() + @with_pool :metal p begin + acquire!(p, Float32, 10, 10) + end + end + + # Warmup (JIT + cache) + _test_metal_nd_alloc!() + _test_metal_nd_alloc!() + + cpu_alloc = @allocated _test_metal_nd_alloc!() + @test cpu_alloc < 100 + end + + @testset "acquire! 1D has low CPU allocation" begin + pool = get_task_local_metal_pool() + reset!(pool) + + function _test_metal_1d_alloc!() + @with_pool :metal p begin + acquire!(p, Float32, 100) + end + end + + # Warmup (JIT + cache) + _test_metal_1d_alloc!() + _test_metal_1d_alloc!() + + cpu_alloc = @allocated _test_metal_1d_alloc!() + @test cpu_alloc < 200 + end + +end + +@testset "Mixed Type Allocation" begin + + @testset "Multiple types maintain separate pools" begin + pool = get_task_local_metal_pool() + reset!(pool) + + # Warmup all types + @with_pool :metal p begin + acquire!(p, Float32, 100) + acquire!(p, Int32, 100) + acquire!(p, Float16, 100) + end + + # Reuse all types — check GPU allocation only + alloc = Metal.@allocated begin + @with_pool :metal p begin + v32 = acquire!(p, Float32, 100) + vi32 = acquire!(p, Int32, 100) + v16 = acquire!(p, Float16, 100) + v32 .= 1.0f0; vi32 .= 3; v16 .= Float16(4.0) + end + end + + # 3 kernel launches overhead + @test alloc < 200 + end + + @testset "Float16 support (GPU ML type)" begin + pool = get_task_local_metal_pool() + reset!(pool) + + # Warmup + @with_pool :metal p begin + v = acquire!(p, Float16, 100) + v .= Float16(1.0) + end + + alloc = Metal.@allocated begin + @with_pool :metal p begin + v = acquire!(p, Float16, 100) + v .= Float16(2.0) + end + end + + @test alloc < 200 + end + +end diff --git a/test/metal/test_convenience.jl b/test/metal/test_convenience.jl new file mode 100644 index 00000000..2ac1b32a --- /dev/null +++ b/test/metal/test_convenience.jl @@ -0,0 +1,124 @@ +@testset "Metal Convenience Functions" begin + pool = get_task_local_metal_pool() + checkpoint!(pool) + + @testset "zeros! default type is Float32" begin + v = zeros!(pool, 10) + @test v isa MtlArray{Float32} + @test length(v) == 10 + @test all(v .== 0.0f0) + + m = zeros!(pool, 3, 4) + @test m isa MtlArray{Float32, 2} + @test size(m) == (3, 4) + @test all(m .== 0.0f0) + + # Tuple form + dims = (2, 3) + t = zeros!(pool, dims) + @test t isa MtlArray{Float32, 2} + @test size(t) == dims + end + + @testset "zeros! explicit type" begin + v16 = zeros!(pool, Float16, 5) + @test v16 isa MtlArray{Float16} + + vi = zeros!(pool, Int32, 8) + @test vi isa MtlArray{Int32} + @test all(vi .== 0) + end + + @testset "ones! default type is Float32" begin + v = ones!(pool, 10) + @test v isa MtlArray{Float32} + @test length(v) == 10 + @test all(v .== 1.0f0) + + m = ones!(pool, 3, 4) + @test m isa MtlArray{Float32, 2} + @test size(m) == (3, 4) + @test all(m .== 1.0f0) + + # Tuple form + dims = (2, 3) + t = ones!(pool, dims) + @test t isa MtlArray{Float32, 2} + @test size(t) == dims + end + + @testset "ones! explicit type" begin + vi = ones!(pool, Int32, 8) + @test vi isa MtlArray{Int32} + @test all(vi .== 1) + end + + @testset "similar!" begin + # Float32 template + template32 = MtlArray(rand(Float32, 5, 5)) + v = similar!(pool, template32) + @test v isa MtlArray{Float32, 2} + @test size(v) == (5, 5) + + # Different type + v_int = similar!(pool, template32, Int32) + @test v_int isa MtlArray{Int32, 2} + @test size(v_int) == (5, 5) + + # Different dims + v_dims = similar!(pool, template32, 10) + @test v_dims isa MtlArray{Float32, 1} + @test length(v_dims) == 10 + + # Different type and dims + v_both = similar!(pool, template32, Int32, 2, 3) + @test v_both isa MtlArray{Int32, 2} + @test size(v_both) == (2, 3) + end + + @testset "zeros! returns Array (not view)" begin + v = zeros!(pool, 10) + @test v isa MtlArray{Float32, 1} + @test !(v isa SubArray) + @test all(v .== 0.0f0) + + m = zeros!(pool, 3, 4) + @test m isa MtlArray{Float32, 2} + @test !(m isa SubArray) + @test all(m .== 0.0f0) + end + + @testset "ones! returns Array (not view)" begin + v = ones!(pool, 10) + @test v isa MtlArray{Float32, 1} + @test !(v isa SubArray) + @test all(v .== 1.0f0) + + m = ones!(pool, 3, 4) + @test m isa MtlArray{Float32, 2} + @test !(m isa SubArray) + @test all(m .== 1.0f0) + end + + @testset "similar! returns Array (not view)" begin + template32 = MtlArray(rand(Float32, 5, 5)) + v = similar!(pool, template32) + @test v isa MtlArray{Float32, 2} + @test !(v isa SubArray) + @test size(v) == (5, 5) + + v_int = similar!(pool, template32, Int32) + @test v_int isa MtlArray{Int32, 2} + @test !(v_int isa SubArray) + + v_dims = similar!(pool, template32, 10) + @test v_dims isa MtlArray{Float32, 1} + @test !(v_dims isa SubArray) + + v_both = similar!(pool, template32, Int32, 2, 3) + @test v_both isa MtlArray{Int32, 2} + @test !(v_both isa SubArray) + end + + rewind!(pool) +end diff --git a/test/metal/test_disabled_pool.jl b/test/metal/test_disabled_pool.jl new file mode 100644 index 00000000..8bd72644 --- /dev/null +++ b/test/metal/test_disabled_pool.jl @@ -0,0 +1,192 @@ +# Tests for DisabledPool{:metal} dispatch methods + +using AdaptiveArrayPools: DisabledPool, DISABLED_CPU, pooling_enabled, default_eltype + +@testset "DisabledPool{:metal}" begin + DISABLED_METAL = ext.DISABLED_METAL + + @testset "DISABLED_METAL singleton" begin + @test DISABLED_METAL isa DisabledPool{:metal} + @test !pooling_enabled(DISABLED_METAL) + end + + @testset "default_eltype" begin + @test default_eltype(DISABLED_METAL) === Float32 + end + + @testset "zeros!" begin + v1 = zeros!(DISABLED_METAL, Float32, 10) + @test v1 isa MtlVector{Float32} + @test length(v1) == 10 + @test all(v1 .== 0.0f0) + + v2 = zeros!(DISABLED_METAL, Float32, 5, 5) + @test v2 isa MtlArray{Float32, 2} + @test size(v2) == (5, 5) + @test all(v2 .== 0.0f0) + + # Without type (default Float32) + v3 = zeros!(DISABLED_METAL, 8) + @test v3 isa MtlVector{Float32} + @test length(v3) == 8 + + v4 = zeros!(DISABLED_METAL, 3, 4) + @test v4 isa MtlArray{Float32, 2} + @test size(v4) == (3, 4) + + # Tuple dims + v5 = zeros!(DISABLED_METAL, Float32, (2, 3, 4)) + @test v5 isa MtlArray{Float32, 3} + @test size(v5) == (2, 3, 4) + + v6 = zeros!(DISABLED_METAL, (5, 6)) + @test v6 isa MtlArray{Float32, 2} + @test size(v6) == (5, 6) + end + + @testset "ones!" begin + v1 = ones!(DISABLED_METAL, Float32, 10) + @test v1 isa MtlVector{Float32} + @test length(v1) == 10 + @test all(v1 .== 1.0f0) + + v2 = ones!(DISABLED_METAL, Float32, 5, 5) + @test v2 isa MtlArray{Float32, 2} + @test size(v2) == (5, 5) + @test all(v2 .== 1.0f0) + + # Without type (default Float32) + v3 = ones!(DISABLED_METAL, 8) + @test v3 isa MtlVector{Float32} + @test all(v3 .== 1.0f0) + + v4 = ones!(DISABLED_METAL, 3, 4) + @test v4 isa MtlArray{Float32, 2} + @test size(v4) == (3, 4) + + # Tuple dims + v5 = ones!(DISABLED_METAL, Float32, (2, 3)) + @test v5 isa MtlArray{Float32, 2} + @test size(v5) == (2, 3) + + v6 = ones!(DISABLED_METAL, (4, 5)) + @test v6 isa MtlArray{Float32, 2} + @test size(v6) == (4, 5) + end + + @testset "similar! with MtlArray input" begin + template = Metal.zeros(Float32, 10) + + v1 = similar!(DISABLED_METAL, template) + @test v1 isa MtlVector{Float32} + @test length(v1) == 10 + + v2 = similar!(DISABLED_METAL, template, Int32) + @test v2 isa MtlVector{Int32} + @test length(v2) == 10 + + v3 = similar!(DISABLED_METAL, template, 5, 5) + @test v3 isa MtlArray{Float32, 2} + @test size(v3) == (5, 5) + + v4 = similar!(DISABLED_METAL, template, Int32, 3, 4) + @test v4 isa MtlArray{Int32, 2} + @test size(v4) == (3, 4) + end + + @testset "similar! with AbstractArray input (CPU->GPU)" begin + cpu_template = zeros(Float32, 8) + + v1 = similar!(DISABLED_METAL, cpu_template) + @test v1 isa MtlArray{Float32} + @test length(v1) == 8 + + v2 = similar!(DISABLED_METAL, cpu_template, Int32) + @test v2 isa MtlArray{Int32} + @test length(v2) == 8 + + v3 = similar!(DISABLED_METAL, cpu_template, 4, 4) + @test v3 isa MtlArray{Float32, 2} + @test size(v3) == (4, 4) + + v4 = similar!(DISABLED_METAL, cpu_template, Int32, 2, 3) + @test v4 isa MtlArray{Int32, 2} + @test size(v4) == (2, 3) + end + + @testset "reshape!" begin + a = acquire!(DISABLED_METAL, Float32, 12) + r1 = reshape!(DISABLED_METAL, a, 3, 4) + @test r1 isa MtlArray{Float32, 2} + @test size(r1) == (3, 4) + + r2 = reshape!(DISABLED_METAL, a, (4, 3)) + @test r2 isa MtlArray{Float32, 2} + @test size(r2) == (4, 3) + end + + @testset "Sub-function passing" begin + function _metal_helper(pool, n) + return zeros!(pool, Float32, n) + end + + function _metal_helper_typed(pool::AbstractArrayPool, n) + return acquire!(pool, Float32, n) + end + + function _metal_outer(pool, n) + return _metal_inner(pool, n) + end + function _metal_inner(pool, n) + return ones!(pool, Float32, n) + end + + v1 = _metal_helper(DISABLED_METAL, 5) + @test v1 isa MtlVector{Float32} + @test all(v1 .== 0.0f0) + + v2 = _metal_helper_typed(DISABLED_METAL, 5) + @test v2 isa MtlVector{Float32} + + v3 = _metal_outer(DISABLED_METAL, 3) + @test v3 isa MtlVector{Float32} + @test all(v3 .== 1.0f0) + end + + @testset "State management no-ops" begin + @test checkpoint!(DISABLED_METAL) === nothing + @test rewind!(DISABLED_METAL) === nothing + @test reset!(DISABLED_METAL) === nothing + @test empty!(DISABLED_METAL) === nothing + end + + @testset "acquire!" begin + # Type + single dim + v1 = acquire!(DISABLED_METAL, Float32, 10) + @test v1 isa MtlVector{Float32} + @test length(v1) == 10 + + # Type + vararg dims + v2 = acquire!(DISABLED_METAL, Int32, 5, 5) + @test v2 isa MtlArray{Int32, 2} + @test size(v2) == (5, 5) + + # Type + tuple dims + v3 = acquire!(DISABLED_METAL, Float32, (3, 4, 5)) + @test v3 isa MtlArray{Float32, 3} + @test size(v3) == (3, 4, 5) + + # MtlArray template + template = Metal.zeros(Float32, 8) + v4 = acquire!(DISABLED_METAL, template) + @test v4 isa MtlVector{Float32} + @test length(v4) == 8 + + # AbstractArray template (CPU->GPU) + cpu_template = zeros(Float32, 6) + v5 = acquire!(DISABLED_METAL, cpu_template) + @test v5 isa MtlArray{Float32} + @test length(v5) == 6 + end + +end diff --git a/test/metal/test_display.jl b/test/metal/test_display.jl new file mode 100644 index 00000000..51a4b2b1 --- /dev/null +++ b/test/metal/test_display.jl @@ -0,0 +1,204 @@ +# Metal Display Tests +# Tests for pool_stats and Base.show methods for MetalTypedPool and MetalAdaptiveArrayPool + +# Helper macro to capture stdout +macro capture_out(expr) + return quote + local old_stdout = stdout + local rd, wr = redirect_stdout() + try + $(esc(expr)) + redirect_stdout(old_stdout) + close(wr) + read(rd, String) + catch e + redirect_stdout(old_stdout) + close(wr) + rethrow(e) + end + end +end + +@testset "Metal Display" begin + + @testset "pool_stats for MetalAdaptiveArrayPool" begin + pool = get_task_local_metal_pool() + empty!(pool) + + # Empty pool stats + output = @capture_out pool_stats(pool) + @test occursin("MetalAdaptiveArrayPool", output) + @test occursin("device", output) + @test occursin("empty", output) + + # Add some arrays + checkpoint!(pool) + acquire!(pool, Float32, 100) + acquire!(pool, Int32, 50) + acquire!(pool, Float16, 25) + + output = @capture_out pool_stats(pool) + @test occursin("Float32 (fixed)", output) + @test occursin("Int32 (fixed)", output) + @test occursin("Float16 (fixed)", output) + @test occursin("Metal", output) + @test occursin("slots:", output) + @test occursin("active:", output) + + rewind!(pool) + end + + @testset "pool_stats(:metal) dispatch" begin + pool = get_task_local_metal_pool() + reset!(pool) + + checkpoint!(pool) + acquire!(pool, Float32, 100) + + output = @capture_out pool_stats(:metal) + @test occursin("MetalAdaptiveArrayPool", output) + @test occursin("Float32", output) + + rewind!(pool) + end + + @testset "pool_stats output format" begin + pool = get_task_local_metal_pool() + reset!(pool) + + checkpoint!(pool) + acquire!(pool, Float32, 100) + + output = @capture_out pool_stats(pool) + + @test occursin("slots:", output) + @test occursin("elements:", output) + @test occursin("bytes", output) + + rewind!(pool) + end + + @testset "pool_stats for MetalTypedPool" begin + pool = get_task_local_metal_pool() + empty!(pool) + + # Empty MetalTypedPool + output = @capture_out pool_stats(pool.float32) + @test occursin("Float32", output) + @test occursin("empty", output) + + # Non-empty MetalTypedPool + checkpoint!(pool) + acquire!(pool, Float32, 100) + acquire!(pool, Float32, 200) + + output = @capture_out pool_stats(pool.float32) + @test occursin("Float32", output) + @test occursin("Metal", output) + @test occursin("slots:", output) + @test occursin("elements:", output) + + rewind!(pool) + end + + @testset "pool_stats with fallback types" begin + pool = get_task_local_metal_pool() + reset!(pool) + + checkpoint!(pool) + acquire!(pool, UInt8, 200) + + output = @capture_out pool_stats(pool) + @test occursin("UInt8 (fallback)", output) + @test occursin("elements: 200", output) + + rewind!(pool) + end + + @testset "Base.show for MetalTypedPool" begin + pool = get_task_local_metal_pool() + empty!(pool) + + # Empty - compact show + output = sprint(show, pool.float32) + @test occursin("MetalTypedPool", output) + @test occursin("empty", output) + + # Non-empty - compact show + checkpoint!(pool) + acquire!(pool, Float32, 100) + acquire!(pool, Float32, 50) + + output = sprint(show, pool.float32) + @test occursin("MetalTypedPool", output) + @test occursin("slots=2", output) + @test occursin("active=2", output) + @test occursin("elements=150", output) + + # Multi-line show (MIME"text/plain") + output = sprint(show, MIME("text/plain"), pool.float32) + @test occursin("MetalTypedPool", output) + @test occursin("slots:", output) + @test occursin("Metal", output) + + rewind!(pool) + end + + @testset "Base.show for MetalAdaptiveArrayPool" begin + pool = get_task_local_metal_pool() + empty!(pool) + + # Empty pool - compact show + output = sprint(show, pool) + @test occursin("MetalAdaptiveArrayPool", output) + @test occursin("types=0", output) + @test occursin("slots=0", output) + + # Non-empty pool - compact show + checkpoint!(pool) + acquire!(pool, Float32, 100) + acquire!(pool, Int32, 50) + acquire!(pool, UInt8, 25) # fallback + + output = sprint(show, pool) + @test occursin("MetalAdaptiveArrayPool", output) + @test occursin("types=3", output) + @test occursin("slots=3", output) + @test occursin("active=3", output) + + # Multi-line show (MIME"text/plain") + output = sprint(show, MIME("text/plain"), pool) + @test occursin("MetalAdaptiveArrayPool", output) + @test occursin("Float32 (fixed)", output) + @test occursin("Int32 (fixed)", output) + @test occursin("UInt8 (fallback)", output) + + rewind!(pool) + end + + @testset "pool_stats returns nothing" begin + pool = get_task_local_metal_pool() + reset!(pool) + + result = pool_stats(pool; io = devnull) + @test result === nothing + + result = pool_stats(:metal; io = devnull) + @test result === nothing + end + + @testset "Float16 display (GPU ML type)" begin + pool = get_task_local_metal_pool() + reset!(pool) + + checkpoint!(pool) + acquire!(pool, Float16, 100) + + output = @capture_out pool_stats(pool) + @test occursin("Float16 (fixed)", output) + @test occursin("Metal", output) + + rewind!(pool) + end + +end diff --git a/test/metal/test_extension.jl b/test/metal/test_extension.jl new file mode 100644 index 00000000..633835f2 --- /dev/null +++ b/test/metal/test_extension.jl @@ -0,0 +1,504 @@ +# Metal Extension Core Tests +# Tests for MetalTypedPool, MetalAdaptiveArrayPool, state management, and macros + +@testset "Extension Types" begin + @testset "MetalTypedPool structure" begin + tp_fields = fieldnames(MetalTypedPool) + @test :vectors in tp_fields + @test :n_active in tp_fields + @test :arr_wrappers in tp_fields + @test :_checkpoint_n_active in tp_fields + @test :_checkpoint_depths in tp_fields + end + + @testset "MetalAdaptiveArrayPool structure" begin + pool_fields = fieldnames(MetalAdaptiveArrayPool) + @test :float16 in pool_fields + @test :device_key in pool_fields + @test :others in pool_fields + # Metal does NOT have float64/complexf64 + @test !(:float64 in pool_fields) + @test !(:complexf64 in pool_fields) + end + + @testset "Type hierarchy" begin + @test MetalTypedPool <: AbstractTypedPool + @test MetalAdaptiveArrayPool <: AbstractArrayPool + end + + @testset "Instance creation" begin + tp = MetalTypedPool{Float32, Metal.PrivateStorage}() + @test tp.n_active == 0 + @test length(tp.vectors) == 0 + + pool = MetalAdaptiveArrayPool() + @test pool.device_key == Metal.device() + @test pool._current_depth == 1 + end + + @testset "METAL_FIXED_SLOT_FIELDS" begin + @test :float16 in METAL_FIXED_SLOT_FIELDS + @test first(METAL_FIXED_SLOT_FIELDS) == :float32 + @test length(METAL_FIXED_SLOT_FIELDS) == 6 + # No Float64/ComplexF64 + @test !(:float64 in METAL_FIXED_SLOT_FIELDS) + @test !(:complexf64 in METAL_FIXED_SLOT_FIELDS) + end +end + +@testset "Dispatch Methods" begin + @testset "allocate_vector" begin + tp = MetalTypedPool{Float32, Metal.PrivateStorage}() + vec = AdaptiveArrayPools.allocate_vector(tp, 100) + @test vec isa MtlVector{Float32} + @test length(vec) == 100 + end + + @testset "get_typed_pool! fixed slots" begin + pool = MetalAdaptiveArrayPool() + test_types = [Float32, Float16, Int32, Int64, ComplexF32, Bool] + for T in test_types + tp = get_typed_pool!(pool, T) + @test tp isa MetalTypedPool{T} + end + end + + @testset "get_typed_pool! rejects Float64/ComplexF64" begin + pool = MetalAdaptiveArrayPool() + @test_throws ArgumentError get_typed_pool!(pool, Float64) + @test_throws ArgumentError get_typed_pool!(pool, ComplexF64) + end + + @testset "get_typed_pool! fallback (rare types)" begin + pool = MetalAdaptiveArrayPool() + tp = get_typed_pool!(pool, UInt8) + @test tp isa MetalTypedPool{UInt8} + @test haskey(pool.others, UInt8) + end + + @testset "get_view!" begin + tp = MetalTypedPool{Float32, Metal.PrivateStorage}() + @test tp.n_active == 0 + + v1 = get_view!(tp, 100) + @test v1 isa MtlArray + @test length(v1) == 100 + @test tp.n_active == 1 + + v2 = get_view!(tp, 200) + @test v2 isa MtlArray + @test length(v2) == 200 + @test tp.n_active == 2 + end + + @testset "Checkpoint auto-init for dynamic types" begin + pool = MetalAdaptiveArrayPool() + checkpoint!(pool) # Properly enter depth 2 + + tp = get_typed_pool!(pool, UInt16) + @test tp._checkpoint_n_active == [0, 0] + @test tp._checkpoint_depths == [0, 2] + end +end + +@testset "State Management" begin + @testset "Basic checkpoint/rewind" begin + pool = get_task_local_metal_pool() + reset!(pool) + + @test pool._current_depth == 1 + @test pool.float32.n_active == 0 + + checkpoint!(pool) + @test pool._current_depth == 2 + + get_view!(pool.float32, 100) + get_view!(pool.float32, 200) + @test pool.float32.n_active == 2 + + rewind!(pool) + @test pool._current_depth == 1 + @test pool.float32.n_active == 0 + @test length(pool.float32.vectors) >= 2 # Memory preserved + end + + @testset "Nested checkpoint/rewind" begin + pool = get_task_local_metal_pool() + reset!(pool) + + # Outer + checkpoint!(pool) + @test pool._current_depth == 2 + get_view!(pool.float32, 50) + @test pool.float32.n_active == 1 + + # Inner + checkpoint!(pool) + @test pool._current_depth == 3 + get_view!(pool.float32, 100) + get_view!(pool.float32, 150) + @test pool.float32.n_active == 3 + + # Inner rewind + rewind!(pool) + @test pool._current_depth == 2 + @test pool.float32.n_active == 1 + + # Outer rewind + rewind!(pool) + @test pool._current_depth == 1 + @test pool.float32.n_active == 0 + end + + @testset "reset!" begin + pool = get_task_local_metal_pool() + get_view!(pool.float32, 100) + get_view!(pool.int32, 200) + vectors_count = length(pool.float32.vectors) + + reset!(pool) + @test pool.float32.n_active == 0 + @test pool.int32.n_active == 0 + @test pool._current_depth == 1 + @test length(pool.float32.vectors) == vectors_count # Memory preserved + end + + @testset "empty!" begin + pool = get_task_local_metal_pool() + get_view!(pool.float32, 100) + @test length(pool.float32.vectors) >= 1 + + empty!(pool) + @test pool.float32.n_active == 0 + @test length(pool.float32.vectors) == 0 # Memory cleared + end + + @testset "foreach_fixed_slot" begin + pool = get_task_local_metal_pool() + slot_count = Ref(0) + foreach_fixed_slot(pool) do tp + slot_count[] += 1 + end + @test slot_count[] == 6 + end + + @testset "Type-specific checkpoint/rewind" begin + pool = get_task_local_metal_pool() + reset!(pool) + + checkpoint!(pool, Float32) + get_view!(pool.float32, 100) + get_view!(pool.int32, 200) + @test pool.float32.n_active == 1 + @test pool.int32.n_active == 1 + + rewind!(pool, Float32) + @test pool.float32.n_active == 0 + end + + @testset "Multi-type checkpoint/rewind" begin + pool = get_task_local_metal_pool() + reset!(pool) + + checkpoint!(pool, Float32, Int32) + @test pool._current_depth == 2 + + get_view!(pool.float32, 100) + get_view!(pool.int32, 200) + @test pool.float32.n_active == 1 + @test pool.int32.n_active == 1 + + rewind!(pool, Float32, Int32) + @test pool._current_depth == 1 + @test pool.float32.n_active == 0 + @test pool.int32.n_active == 0 + end + + @testset "Type-specific reset" begin + pool = get_task_local_metal_pool() + reset!(pool) + + get_view!(pool.float32, 100) + get_view!(pool.int32, 200) + @test pool.float32.n_active == 1 + @test pool.int32.n_active == 1 + + reset!(pool, Float32) + @test pool.float32.n_active == 0 + @test pool.int32.n_active == 1 # Not affected + end + + @testset "Rewind at depth=1 (edge case)" begin + pool = get_task_local_metal_pool() + reset!(pool) + + @test pool._current_depth == 1 + get_view!(pool.float32, 100) + @test pool.float32.n_active == 1 + + rewind!(pool) + @test pool._current_depth == 1 + @test pool.float32.n_active == 0 + end + + @testset "Type-specific rewind at depth=1" begin + pool = get_task_local_metal_pool() + reset!(pool) + + @test pool._current_depth == 1 + get_view!(pool.float32, 100) + @test pool.float32.n_active == 1 + + rewind!(pool, Float32) + @test pool.float32.n_active == 0 + end + + @testset "Multi-type rewind at depth=1" begin + pool = get_task_local_metal_pool() + reset!(pool) + + @test pool._current_depth == 1 + get_view!(pool.float32, 100) + get_view!(pool.int32, 200) + + rewind!(pool, Float32, Int32) + @test pool.float32.n_active == 0 + @test pool.int32.n_active == 0 + end + + @testset "State operations with rare types (pool.others)" begin + pool = get_task_local_metal_pool() + reset!(pool) + + tp_uint8 = get_typed_pool!(pool, UInt8) + @test haskey(pool.others, UInt8) + + checkpoint!(pool) + get_view!(tp_uint8, 50) + @test tp_uint8.n_active == 1 + + rewind!(pool) + @test tp_uint8.n_active == 0 + + get_view!(tp_uint8, 100) + @test tp_uint8.n_active == 1 + reset!(pool) + @test tp_uint8.n_active == 0 + + get_view!(tp_uint8, 100) + @test length(tp_uint8.vectors) >= 1 + empty!(pool) + @test tp_uint8.n_active == 0 + @test length(tp_uint8.vectors) == 0 + end +end + +@testset "Macro Integration" begin + @testset "@with_pool :metal basic" begin + result = @with_pool :metal pool begin + @test pool isa MetalAdaptiveArrayPool + v = acquire!(pool, Float32, 100) + v .= 1.0f0 + sum(v) + end + @test result == 100.0f0 + @test get_task_local_metal_pool().float32.n_active == 0 + end + + @testset "@with_pool :metal without pool name" begin + result = @with_pool :metal begin + pool = get_task_local_metal_pool() + v = acquire!(pool, Float32, 50) + v .= 2.0f0 + sum(v) + end + @test result == 100.0f0 + end + + @testset "Nested CPU/Metal pools" begin + result = @with_pool cpu_pool begin + cpu_v = acquire!(cpu_pool, Float64, 10) + cpu_v .= 1.0 + + gpu_result = @with_pool :metal gpu_pool begin + gpu_v = acquire!(gpu_pool, Float32, 10) + gpu_v .= 2.0f0 + sum(gpu_v) + end + + sum(cpu_v) + gpu_result + end + @test result == 30.0 + end + + @testset "Rewind on normal exit" begin + pool = get_task_local_metal_pool() + reset!(pool) + + @with_pool :metal p begin + acquire!(p, Float32, 100) + acquire!(p, Float32, 200) + @test p.float32.n_active == 2 + end + + @test pool.float32.n_active == 0 + end + + @testset "Rewind on error" begin + pool = get_task_local_metal_pool() + reset!(pool) + + try + @safe_with_pool :metal p begin + acquire!(p, Float32, 100) + @test p.float32.n_active == 1 + error("Intentional error") + end + catch e + @test e isa ErrorException + end + + @test pool.float32.n_active == 0 + end + + @testset "Multi-dimensional acquire" begin + result = @with_pool :metal pool begin + A = acquire!(pool, Float32, 10, 10) + @test size(A) == (10, 10) + A .= 1.0f0 + sum(A) + end + @test result == 100.0f0 + end + + @testset "acquire! returns MtlArray" begin + result = @with_pool :metal pool begin + A = acquire!(pool, Float32, 100) + @test A isa MtlArray{Float32, 1} + A .= 2.0f0 + sum(A) + end + @test result == 200.0f0 + end + + @testset "Direct rewind: explicit return" begin + @with_pool :metal pool function metal_early_return(flag) + v = acquire!(pool, Float32, 10) + v .= 1.0f0 + if flag + return sum(v) + end + v .= 2.0f0 + sum(v) + end + + @test metal_early_return(true) == 10.0f0 + @test metal_early_return(false) == 20.0f0 + @test get_task_local_metal_pool()._current_depth == 1 + end + + @testset "Direct rewind: break/continue in loop" begin + pool = get_task_local_metal_pool() + reset!(pool) + + total = 0.0f0 + for i in 1:5 + @with_pool :metal p begin + v = acquire!(p, Float32, 3) + v .= Float32(i) + if i == 3 + continue + end + total += sum(v) + end + end + @test total == 3.0f0 * (1 + 2 + 4 + 5) + @test pool._current_depth == 1 + end + + @testset "Direct rewind: nested catch recovery (entry depth guard)" begin + reset!(get_task_local_metal_pool()) + + @with_pool :metal pool function metal_outer_catches() + v = acquire!(pool, Float32, 10) + v .= 1.0f0 + result = try + @with_pool :metal pool begin + acquire!(pool, Int32, 5) + error("boom") + end + catch + 42 + end + sum(v) + result + end + + @test metal_outer_catches() == 52.0f0 + @test get_task_local_metal_pool()._current_depth == 1 + end + + @testset "Uncaught exception corrupts Metal pool (documented)" begin + pool = get_task_local_metal_pool() + reset!(pool) + + try + @with_pool :metal p begin + acquire!(p, Float32, 10) + error("uncaught!") + end + catch + end + + @test pool._current_depth > 1 # corrupted — expected + reset!(pool) + @test pool._current_depth == 1 + end +end + +@testset "Acquire API" begin + @testset "acquire! with MetalAdaptiveArrayPool" begin + pool = MetalAdaptiveArrayPool() + v = acquire!(pool, Float32, 100) + @test v isa MtlArray + @test length(v) == 100 + end + + @testset "acquire! multi-dim" begin + pool = MetalAdaptiveArrayPool() + A = acquire!(pool, Float32, 10, 10) + @test size(A) == (10, 10) + end + + @testset "acquire! tuple dims" begin + pool = MetalAdaptiveArrayPool() + dims = (5, 5, 5) + A = acquire!(pool, Float32, dims) + @test size(A) == dims + end + + @testset "acquire! similar-style" begin + pool = MetalAdaptiveArrayPool() + original = MtlArray(rand(Float32, 10, 10)) + A = acquire!(pool, original) + @test size(A) == size(original) + @test eltype(A) == eltype(original) + end + + @testset "acquire! all dimensionalities" begin + pool = MetalAdaptiveArrayPool() + + v = acquire!(pool, Float32, 100) + @test v isa MtlArray{Float32, 1} + + A = acquire!(pool, Int32, 10, 10) + @test A isa MtlArray{Int32, 2} + + B = acquire!(pool, Int32, (5, 5)) + @test B isa MtlArray{Int32, 2} + end + + @testset "acquire! rejects Float64" begin + pool = MetalAdaptiveArrayPool() + @test_throws ArgumentError acquire!(pool, Float64, 10) + end +end diff --git a/test/metal/test_metal_safety.jl b/test/metal/test_metal_safety.jl new file mode 100644 index 00000000..8c4ae86d --- /dev/null +++ b/test/metal/test_metal_safety.jl @@ -0,0 +1,509 @@ +import AdaptiveArrayPools: PoolRuntimeEscapeError, PoolEscapeError, _runtime_check, + _validate_pool_return, _lazy_checkpoint!, _lazy_rewind! + +const _make_metal_pool = ext._make_metal_pool + +# Opaque identity — defeats compile-time escape analysis +_metal_test_leak(x) = x + +@testset "Metal Safety (MetalAdaptiveArrayPool{R}, Binary R=0/1)" begin + + # ============================================================================== + # Type parameterization basics + # ============================================================================== + + @testset "MetalAdaptiveArrayPool{R} construction and _runtime_check" begin + p0 = _make_metal_pool(0) + p1 = _make_metal_pool(1) + + @test p0 isa MetalAdaptiveArrayPool{0} + @test p1 isa MetalAdaptiveArrayPool{1} + + @test _runtime_check(p0) == false + @test _runtime_check(p1) == true + + # Borrow fields exist at all levels + @test hasfield(typeof(p0), :_pending_callsite) + @test hasfield(typeof(p0), :_pending_return_site) + @test hasfield(typeof(p0), :_borrow_log) + end + + # ============================================================================== + # R=0: No poisoning, no validation + # ============================================================================== + + @testset "R=0: no poisoning on rewind" begin + pool = _make_metal_pool(0) + checkpoint!(pool) + v = acquire!(pool, Float32, 10) + Metal.fill!(v, 42.0f0) + rewind!(pool) + + @test length(pool.float32.vectors[1]) >= 10 + checkpoint!(pool) + v2 = acquire!(pool, Float32, 10) + @test all(x -> x == 42.0f0, Array(v2)) + rewind!(pool) + end + + @testset "R=0: no poisoning (verify data survives rewind)" begin + pool = _make_metal_pool(0) + checkpoint!(pool) + v = acquire!(pool, Float32, 10) + Metal.fill!(v, 42.0f0) + rewind!(pool) + + cpu_data = Array(pool.float32.vectors[1]) + @test all(x -> x == 42.0f0, cpu_data[1:10]) + end + + @testset "R=0: no escape detection" begin + pool = _make_metal_pool(0) + checkpoint!(pool) + try + v = acquire!(pool, Float32, 10) + _validate_pool_return(_metal_test_leak(v), pool) + finally + rewind!(pool) + end + end + + # ============================================================================== + # R=1: Poisoning + structural invalidation + escape detection + borrow tracking + # ============================================================================== + + @testset "R=1: released vectors have length 0 after rewind" begin + pool = _make_metal_pool(1) + checkpoint!(pool) + v = acquire!(pool, Float32, 100) + Metal.fill!(v, 42.0f0) + rewind!(pool) + + @test length(pool.float32.vectors[1]) == 0 + end + + @testset "R=1: Float32 poisoned with NaN on rewind" begin + pool = _make_metal_pool(1) + checkpoint!(pool) + v = acquire!(pool, Float32, 10) + Metal.fill!(v, 42.0f0) + rewind!(pool) + + @test length(pool.float32.vectors[1]) == 0 + + checkpoint!(pool) + v2 = acquire!(pool, Float32, 10) + @test all(isnan, Array(v2)) + rewind!(pool) + end + + @testset "R=1: Int32 poisoned with typemax on rewind" begin + pool = _make_metal_pool(1) + checkpoint!(pool) + v = acquire!(pool, Int32, 8) + Metal.fill!(v, Int32(42)) + rewind!(pool) + + checkpoint!(pool) + v2 = acquire!(pool, Int32, 8) + @test all(==(typemax(Int32)), Array(v2)) + rewind!(pool) + end + + @testset "R=1: ComplexF32 poisoned with NaN on rewind" begin + pool = _make_metal_pool(1) + checkpoint!(pool) + v = acquire!(pool, ComplexF32, 8) + Metal.fill!(v, ComplexF32(1.0f0 + 2.0f0im)) + rewind!(pool) + + checkpoint!(pool) + v2 = acquire!(pool, ComplexF32, 8) + @test all(z -> isnan(real(z)) && isnan(imag(z)), Array(v2)) + rewind!(pool) + end + + @testset "R=1: Bool poisoned with true on rewind" begin + pool = _make_metal_pool(1) + checkpoint!(pool) + v = acquire!(pool, Bool, 16) + Metal.fill!(v, false) + rewind!(pool) + + checkpoint!(pool) + v2 = acquire!(pool, Bool, 16) + @test all(==(true), Array(v2)) + rewind!(pool) + end + + @testset "R=1: Float16 poisoned with NaN on rewind" begin + pool = _make_metal_pool(1) + checkpoint!(pool) + v = acquire!(pool, Float16, 10) + Metal.fill!(v, Float16(42.0)) + rewind!(pool) + + checkpoint!(pool) + v2 = acquire!(pool, Float16, 10) + @test all(isnan, Array(v2)) + rewind!(pool) + end + + @testset "R=1: arr_wrappers invalidated on rewind" begin + pool = _make_metal_pool(1) + checkpoint!(pool) + v = acquire!(pool, Float32, 10) + Metal.fill!(v, 1.0f0) + rewind!(pool) + + tp = pool.float32 + for N_idx in 1:length(tp.arr_wrappers) + wrappers_for_N = tp.arr_wrappers[N_idx] + wrappers_for_N === nothing && continue + for wrapper in wrappers_for_N + wrapper === nothing && continue + @test all(==(0), size(wrapper)) + end + end + end + + # ============================================================================== + # R=1: Escape detection + # ============================================================================== + + @testset "R=1: escape detection catches MtlArray leak" begin + pool = _make_metal_pool(1) + @test_throws PoolRuntimeEscapeError begin + checkpoint!(pool) + try + v = acquire!(pool, Float32, 10) + _validate_pool_return(_metal_test_leak(v), pool) + finally + rewind!(pool) + end + end + end + + @testset "R=1: safe scalar return does not throw" begin + pool = _make_metal_pool(1) + checkpoint!(pool) + try + v = acquire!(pool, Float32, 10) + Metal.fill!(v, 3.0f0) + result = sum(Array(v)) + _validate_pool_return(result, pool) + @test result == 30.0f0 + finally + rewind!(pool) + end + end + + @testset "R=1: escape detection with Tuple containing MtlArray" begin + pool = _make_metal_pool(1) + @test_throws PoolRuntimeEscapeError begin + checkpoint!(pool) + try + v = acquire!(pool, Float32, 10) + val = (42, _metal_test_leak(v)) + _validate_pool_return(val, pool) + finally + rewind!(pool) + end + end + end + + @testset "R=1: escape detection with Dict containing MtlArray" begin + pool = _make_metal_pool(1) + @test_throws PoolRuntimeEscapeError begin + checkpoint!(pool) + try + v = acquire!(pool, Float32, 10) + val = Dict(:data => _metal_test_leak(v)) + _validate_pool_return(val, pool) + finally + rewind!(pool) + end + end + end + + # ============================================================================== + # R=1: Borrow tracking + # ============================================================================== + + @testset "R=1: borrow fields functional" begin + pool = _make_metal_pool(1) + @test pool._pending_callsite == "" + @test pool._pending_return_site == "" + @test pool._borrow_log === nothing + end + + @testset "R=1: _set_pending_callsite! works" begin + pool = _make_metal_pool(1) + AdaptiveArrayPools._set_pending_callsite!(pool, "test.jl:42\nacquire!(pool, Float32, 10)") + @test pool._pending_callsite == "test.jl:42\nacquire!(pool, Float32, 10)" + + pool0 = _make_metal_pool(0) + AdaptiveArrayPools._set_pending_callsite!(pool0, "should not be set") + @test pool0._pending_callsite == "" + end + + @testset "R=1: _maybe_record_borrow! records callsite" begin + pool = _make_metal_pool(1) + checkpoint!(pool) + tp = get_typed_pool!(pool, Float32) + + AdaptiveArrayPools._set_pending_callsite!(pool, "test.jl:99\nacquire!(pool, Float32, 5)") + acquire!(pool, Float32, 5) + + @test pool._borrow_log !== nothing + @test length(pool._borrow_log) >= 1 + + rewind!(pool) + end + + @testset "R=0: does not create borrow log on Metal" begin + pool = _make_metal_pool(0) + checkpoint!(pool) + _ = acquire!(pool, Float32, 10) + @test pool._borrow_log === nothing + rewind!(pool) + end + + @testset "R=1: creates borrow log on Metal acquire" begin + pool = _make_metal_pool(1) + checkpoint!(pool) + _ = acquire!(pool, Float32, 10) + @test pool._borrow_log !== nothing + @test pool._borrow_log isa IdDict + rewind!(pool) + end + + # ============================================================================== + # Nested scopes: inner poisoned, outer valid + # ============================================================================== + + @testset "Nested scopes: inner poisoned, outer still valid" begin + pool = _make_metal_pool(1) + + checkpoint!(pool) + v_outer = acquire!(pool, Float32, 10) + Metal.fill!(v_outer, 1.0f0) + + # Inner scope + checkpoint!(pool) + v_inner = acquire!(pool, Float32, 20) + Metal.fill!(v_inner, 2.0f0) + rewind!(pool) + + # Inner should be invalidated + @test length(pool.float32.vectors[2]) == 0 + checkpoint!(pool) + v_inner2 = acquire!(pool, Float32, 20) + @test all(isnan, Array(v_inner2)) + rewind!(pool) + + # Outer should still be valid + cpu_outer = Array(v_outer) + @test all(x -> x == 1.0f0, cpu_outer) + + rewind!(pool) + @test length(pool.float32.vectors[1]) == 0 + checkpoint!(pool) + v_outer2 = acquire!(pool, Float32, 10) + @test all(isnan, Array(v_outer2)) + rewind!(pool) + end + + # ============================================================================== + # reset! with safety + # ============================================================================== + + @testset "reset! clears borrow tracking state" begin + pool = _make_metal_pool(1) + pool._pending_callsite = "test" + pool._pending_return_site = "test" + pool._borrow_log = IdDict{Any, String}() + + reset!(pool) + + @test pool._pending_callsite == "" + @test pool._pending_return_site == "" + @test pool._borrow_log === nothing + end + + # ============================================================================== + # Fallback types (pool.others) poisoning + # ============================================================================== + + @testset "Fallback type (UInt8) poisoned on rewind" begin + pool = _make_metal_pool(1) + checkpoint!(pool) + v = acquire!(pool, UInt8, 16) + Metal.fill!(v, UInt8(42)) + rewind!(pool) + + tp = pool.others[UInt8] + @test length(tp.vectors[1]) == 0 + + checkpoint!(pool) + v2 = acquire!(pool, UInt8, 16) + @test all(==(typemax(UInt8)), Array(v2)) + rewind!(pool) + end + + # ============================================================================== + # Display includes {R} and check label + # ============================================================================== + + @testset "show includes {R} and check label" begin + pool1 = _make_metal_pool(1) + s1 = sprint(show, pool1) + @test occursin("{1", s1) + @test occursin("check=on", s1) + + pool0 = _make_metal_pool(0) + s0 = sprint(show, pool0) + @test occursin("{0", s0) + @test occursin("check=off", s0) + end + + # ============================================================================== + # Compile-time escape detection (@with_pool :metal) + # ============================================================================== + + @testset "Compile-time: direct MtlArray escape caught at macro expansion" begin + @test_throws PoolEscapeError @macroexpand @with_pool :metal pool begin + v = acquire!(pool, Float32, 10) + v + end + end + + @testset "Compile-time: safe scalar return passes" begin + ex = @macroexpand @with_pool :metal pool begin + v = acquire!(pool, Float32, 10) + sum(Array(v)) + end + @test ex isa Expr + end + + @testset "Compile-time: zeros!/ones! escape caught" begin + @test_throws PoolEscapeError @macroexpand @with_pool :metal pool begin + v = zeros!(pool, Float32, 10) + v + end + end + + # ============================================================================== + # R=1 escape detection via direct checkpoint/validate/rewind + # ============================================================================== + + @testset "Pool{1} escape detection via direct validate" begin + pool = _make_metal_pool(1) + checkpoint!(pool) + err = try + v = acquire!(pool, Float32, 10) + _validate_pool_return(_metal_test_leak(v), pool) + nothing + catch e + e + finally + rewind!(pool) + end + + @test err isa PoolRuntimeEscapeError + end + + @testset "Pool{1} safe scalar via direct validate" begin + pool = _make_metal_pool(1) + checkpoint!(pool) + v = acquire!(pool, Float32, 10) + Metal.fill!(v, 5.0f0) + result = sum(Array(v)) + _validate_pool_return(result, pool) + rewind!(pool) + @test result == 50.0f0 + end + + # ============================================================================== + # R=1 borrow tracking: callsite in escape error + # ============================================================================== + + @testset "Pool{1} escape error includes callsite when set" begin + pool = _make_metal_pool(1) + checkpoint!(pool) + + pool._pending_callsite = "test_metal.jl:42\nacquire!(pool, Float32, 10)" + v = acquire!(pool, Float32, 10) + + err = try + _validate_pool_return(_metal_test_leak(v), pool) + nothing + catch e + e + end + rewind!(pool) + + @test err isa PoolRuntimeEscapeError + @test err.callsite !== nothing + @test contains(err.callsite, "test_metal.jl:42") + @test contains(err.callsite, "acquire!(pool, Float32, 10)") + end + + # ============================================================================== + # Error message content (showerror) + # ============================================================================== + + @testset "showerror: MtlArray escape error message format" begin + err = PoolRuntimeEscapeError("MtlArray{Float32, 1}", "Float32", nothing, nothing) + io = IOBuffer() + showerror(io, err) + msg = String(take!(io)) + + @test contains(msg, "PoolEscapeError") + @test contains(msg, "MtlArray{Float32, 1}") + @test contains(msg, "Float32") + @test contains(msg, "RUNTIME_CHECK") + end + + @testset "showerror: MtlArray with callsite" begin + err = PoolRuntimeEscapeError( + "MtlArray{Float32, 1}", "Float32", + "test_metal.jl:42\nacquire!(pool, Float32, 10)", nothing + ) + io = IOBuffer() + showerror(io, err) + msg = String(take!(io)) + + @test contains(msg, "acquired at") + @test contains(msg, "test_metal.jl:42") + @test contains(msg, "acquire!(pool, Float32, 10)") + end + + # ============================================================================== + # Function form: @with_pool :metal pool function ... + # ============================================================================== + + @testset "Function form: compile-time escape detection" begin + @test_throws PoolEscapeError @macroexpand @with_pool :metal pool function _metal_test_escape_fn() + v = acquire!(pool, Float32, 10) + return v + end + end + + @testset "Function form: safe scalar return compiles" begin + ex = @macroexpand @with_pool :metal pool function _metal_test_safe_fn() + v = acquire!(pool, Float32, 5) + return sum(Array(v)) + end + @test ex isa Expr + end + + @testset "Function form: bare return compiles" begin + ex = @macroexpand @with_pool :metal pool function _metal_test_bare_fn() + _ = acquire!(pool, Float32, 10) + return + end + @test ex isa Expr + end + +end # Metal Safety diff --git a/test/metal/test_reshape.jl b/test/metal/test_reshape.jl new file mode 100644 index 00000000..92343906 --- /dev/null +++ b/test/metal/test_reshape.jl @@ -0,0 +1,88 @@ +# Metal Reshape Tests +# Tests for reshape! with MtlArray + +@testset "Metal Reshape" begin + + @testset "Basic reshape 1D → 2D" begin + pool = get_task_local_metal_pool() + reset!(pool) + + result = @with_pool :metal p begin + v = acquire!(p, Float32, 12) + v .= 1.0f0 + A = reshape!(p, v, 3, 4) + @test size(A) == (3, 4) + @test A isa MtlArray{Float32, 2} + sum(A) + end + @test result == 12.0f0 + end + + @testset "Reshape 1D → 3D" begin + pool = get_task_local_metal_pool() + reset!(pool) + + result = @with_pool :metal p begin + v = acquire!(p, Float32, 24) + v .= 2.0f0 + T = reshape!(p, v, 2, 3, 4) + @test size(T) == (2, 3, 4) + sum(T) + end + @test result == 48.0f0 + end + + @testset "Reshape with tuple dims" begin + pool = get_task_local_metal_pool() + reset!(pool) + + @with_pool :metal p begin + v = acquire!(p, Float32, 20) + A = reshape!(p, v, (4, 5)) + @test size(A) == (4, 5) + end + end + + @testset "Same-dim reshape (no cross-dim)" begin + pool = get_task_local_metal_pool() + reset!(pool) + + @with_pool :metal p begin + A = acquire!(p, Float32, 3, 4) + B = reshape!(p, A, 4, 3) + @test size(B) == (4, 3) + # Same dimensionality: in-place setfield! + @test B === A + end + end + + @testset "DimensionMismatch on wrong element count" begin + pool = get_task_local_metal_pool() + reset!(pool) + + @with_pool :metal p begin + v = acquire!(p, Float32, 10) + @test_throws DimensionMismatch reshape!(p, v, 3, 4) + end + end + + @testset "Reshape reuse across scopes" begin + pool = get_task_local_metal_pool() + reset!(pool) + + # First scope: create reshape wrapper + @with_pool :metal p begin + v = acquire!(p, Float32, 12) + A = reshape!(p, v, 3, 4) + @test size(A) == (3, 4) + end + + # Second scope: should reuse cached wrapper + @with_pool :metal p begin + v = acquire!(p, Float32, 12) + A = reshape!(p, v, 3, 4) + @test size(A) == (3, 4) + end + end + +end diff --git a/test/metal/test_task_local_pool.jl b/test/metal/test_task_local_pool.jl new file mode 100644 index 00000000..9a8c104d --- /dev/null +++ b/test/metal/test_task_local_pool.jl @@ -0,0 +1,47 @@ +# Metal Task-Local Pool Tests + +@testset "Metal Task-Local Pool" begin + + @testset "get_task_local_metal_pool" begin + pool1 = get_task_local_metal_pool() + @test pool1 isa MetalAdaptiveArrayPool + @test pool1.device_key == Metal.device() + + pool2 = get_task_local_metal_pool() + @test pool1 === pool2 # Same pool on second call + end + + @testset "get_task_local_metal_pools" begin + pools_dict = get_task_local_metal_pools() + @test pools_dict isa Dict{UInt64, MetalAdaptiveArrayPool} + pool = get_task_local_metal_pool() + dev_key = objectid(Metal.device()) + @test haskey(pools_dict, dev_key) + end + + @testset "get_task_local_metal_pools before pool creation" begin + result = fetch( + Threads.@spawn begin + pools = get_task_local_metal_pools() + @test pools isa Dict{UInt64, MetalAdaptiveArrayPool} + @test isempty(pools) + true + end + ) + @test result == true + end + + @testset "Device key verification" begin + pool = get_task_local_metal_pool() + current_dev = Metal.device() + @test pool.device_key == current_dev + + pools = get_task_local_metal_pools() + dev_key = objectid(current_dev) + @test haskey(pools, dev_key) + @test pools[dev_key] === pool + + @test get_task_local_metal_pool() === pool + end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 7b5c653a..c45b7e49 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -77,4 +77,11 @@ else else @info "CUDA tests disabled via TEST_CUDA=false" end + + # Metal extension tests (auto-detect, skip with TEST_METAL=false) + if get(ENV, "TEST_METAL", "true") != "false" + include("metal/runtests.jl") + else + @info "Metal tests disabled via TEST_METAL=false" + end end From 488d2d5023c2b7f78812193f24ab3ae8088f28fa Mon Sep 17 00:00:00 2001 From: Min-Gu Yoo Date: Fri, 13 Mar 2026 23:29:08 -0700 Subject: [PATCH 2/4] gate CUDAExt behind Julia 1.11+ (matches MetalExt pattern) CUDA extension imports modern-path-only functions (_store_arr_wrapper!, _reshape_impl!, _check_pool_growth) that don't exist on Julia 1.10 LTS. Wrap all includes in @static if VERSION >= v"1.11-" to prevent precompilation failure on older Julia. --- .../AdaptiveArrayPoolsCUDAExt.jl | 55 +++++++++++-------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/ext/AdaptiveArrayPoolsCUDAExt/AdaptiveArrayPoolsCUDAExt.jl b/ext/AdaptiveArrayPoolsCUDAExt/AdaptiveArrayPoolsCUDAExt.jl index 03c2ecf7..49882dd0 100644 --- a/ext/AdaptiveArrayPoolsCUDAExt/AdaptiveArrayPoolsCUDAExt.jl +++ b/ext/AdaptiveArrayPoolsCUDAExt/AdaptiveArrayPoolsCUDAExt.jl @@ -9,39 +9,48 @@ Loaded automatically when `using CUDA` with AdaptiveArrayPools. module AdaptiveArrayPoolsCUDAExt using AdaptiveArrayPools -using AdaptiveArrayPools: AbstractTypedPool, AbstractArrayPool using CUDA -# Type definitions -include("types.jl") +# GPU pooling requires Julia 1.11+ (setfield!-based Array, arr_wrappers cache). +# On older Julia, the extension loads but provides no functionality. +@static if VERSION >= v"1.11-" -# Dispatch methods (allocate_vector, wrap_array, get_typed_pool!) -include("dispatch.jl") + using AdaptiveArrayPools: AbstractTypedPool, AbstractArrayPool -# GPU-specific acquire (arr_wrappers + setfield!, _resize_to_fit!, _reshape_impl!) -include("acquire.jl") + # Type definitions + include("types.jl") -# Task-local pool (multi-device aware) -include("task_local_pool.jl") + # Dispatch methods (allocate_vector, wrap_array, get_typed_pool!) + include("dispatch.jl") -# State management (checkpoint!, rewind!, reset!, empty!) -include("state.jl") + # GPU-specific acquire (arr_wrappers + setfield!, _resize_to_fit!, _reshape_impl!) + include("acquire.jl") -# Safety: poisoning, escape detection, borrow tracking -include("debug.jl") + # Task-local pool (multi-device aware) + include("task_local_pool.jl") -# Display & statistics (pool_stats, show) -include("utils.jl") + # State management (checkpoint!, rewind!, reset!, empty!) + include("state.jl") -# Macro support (@with_pool :cuda) -include("macros.jl") + # Safety: poisoning, escape detection, borrow tracking + include("debug.jl") -# Convenience functions (Float32 default for zeros!/ones!) -include("convenience.jl") + # Display & statistics (pool_stats, show) + include("utils.jl") -# Exports (types only - functions are exported from main module) -export CuTypedPool, CuAdaptiveArrayPool -export GPU_FIXED_SLOT_FIELDS -# get_task_local_cuda_pool, get_task_local_cuda_pools are exported from AdaptiveArrayPools + # Macro support (@with_pool :cuda) + include("macros.jl") + + # Convenience functions (Float32 default for zeros!/ones!) + include("convenience.jl") + + # Exports (types only - functions are exported from main module) + export CuTypedPool, CuAdaptiveArrayPool + export GPU_FIXED_SLOT_FIELDS + # get_task_local_cuda_pool, get_task_local_cuda_pools are exported from AdaptiveArrayPools + +else + @warn "AdaptiveArrayPoolsCUDAExt requires Julia 1.11+. GPU pooling is disabled." maxlog = 1 +end # @static if end # module From 733a41f53d6a61cb1ac1c40e4a840fd7aab46320 Mon Sep 17 00:00:00 2001 From: Min-Gu Yoo Date: Fri, 13 Mar 2026 23:33:50 -0700 Subject: [PATCH 3/4] fix: gate GPU test suites behind Julia 1.11+ with @static if-else Replace `return`-based early exit (unreliable inside @static if on 1.10) with if-else branching to properly skip GPU tests on Julia LTS. --- test/cuda/runtests.jl | 63 +++++++++++++++++++++++------------------- test/metal/runtests.jl | 63 +++++++++++++++++++++--------------------- 2 files changed, 65 insertions(+), 61 deletions(-) diff --git a/test/cuda/runtests.jl b/test/cuda/runtests.jl index 2c016a50..f4e3ce58 100644 --- a/test/cuda/runtests.jl +++ b/test/cuda/runtests.jl @@ -9,37 +9,42 @@ using Test -# Check CUDA availability (separate from test execution) -const CUDA_AVAILABLE = try - using CUDA - CUDA.functional() -catch - false -end - -if !CUDA_AVAILABLE - @info "CUDA not available or not functional, skipping CUDA tests" - # Return early - no tests to run +# GPU pooling requires Julia 1.11+ +@static if VERSION < v"1.11-" + @info "CUDA extension tests skipped (requires Julia 1.11+)" + @testset "CUDA (skipped — Julia < 1.11)" begin end else - @info "Running CUDA extension tests on device: $(CUDA.name(CUDA.device()))" + # Check CUDA availability (separate from test execution) + const CUDA_AVAILABLE = try + using CUDA + CUDA.functional() + catch + false + end + + if !CUDA_AVAILABLE + @info "CUDA not available or not functional, skipping CUDA tests" + else + @info "Running CUDA extension tests on device: $(CUDA.name(CUDA.device()))" - # Load dependencies - functions work via dispatch, no need to access extension directly - using AdaptiveArrayPools - using AdaptiveArrayPools: checkpoint!, rewind!, get_typed_pool!, get_view!, foreach_fixed_slot + # Load dependencies - functions work via dispatch, no need to access extension directly + using AdaptiveArrayPools + using AdaptiveArrayPools: checkpoint!, rewind!, get_typed_pool!, get_view!, foreach_fixed_slot - # Extension types (only needed for type checks in tests) - const ext = Base.get_extension(AdaptiveArrayPools, :AdaptiveArrayPoolsCUDAExt) - const CuTypedPool = ext.CuTypedPool - const CuAdaptiveArrayPool = ext.CuAdaptiveArrayPool - const GPU_FIXED_SLOT_FIELDS = ext.GPU_FIXED_SLOT_FIELDS - # get_task_local_cuda_pool, get_task_local_cuda_pools are exported from AdaptiveArrayPools + # Extension types (only needed for type checks in tests) + const ext = Base.get_extension(AdaptiveArrayPools, :AdaptiveArrayPoolsCUDAExt) + const CuTypedPool = ext.CuTypedPool + const CuAdaptiveArrayPool = ext.CuAdaptiveArrayPool + const GPU_FIXED_SLOT_FIELDS = ext.GPU_FIXED_SLOT_FIELDS + # get_task_local_cuda_pool, get_task_local_cuda_pools are exported from AdaptiveArrayPools - # Include all CUDA test files - include("test_extension.jl") - include("test_allocation.jl") - include("test_nway_cache.jl") - include("test_display.jl") - include("test_convenience.jl") - include("test_disabled_pool.jl") - include("test_cuda_safety.jl") + # Include all CUDA test files + include("test_extension.jl") + include("test_allocation.jl") + include("test_nway_cache.jl") + include("test_display.jl") + include("test_convenience.jl") + include("test_disabled_pool.jl") + include("test_cuda_safety.jl") + end end diff --git a/test/metal/runtests.jl b/test/metal/runtests.jl index fb034dbc..5d9932fe 100644 --- a/test/metal/runtests.jl +++ b/test/metal/runtests.jl @@ -13,40 +13,39 @@ using Test @static if VERSION < v"1.11-" @info "Metal extension tests skipped (requires Julia 1.11+)" @testset "Metal (skipped — Julia < 1.11)" begin end - return -end - -# Check Metal availability (requires macOS + Apple Silicon) -const METAL_AVAILABLE = try - Sys.isapple() || error("Not macOS") - using Metal - Metal.functional() -catch - false -end - -if !METAL_AVAILABLE - @info "Metal not available or not functional, skipping Metal tests" else - @info "Running Metal extension tests on device: $(Metal.device())" + # Check Metal availability (requires macOS + Apple Silicon) + const METAL_AVAILABLE = try + Sys.isapple() || error("Not macOS") + using Metal + Metal.functional() + catch + false + end + + if !METAL_AVAILABLE + @info "Metal not available or not functional, skipping Metal tests" + else + @info "Running Metal extension tests on device: $(Metal.device())" - # Load dependencies - using AdaptiveArrayPools - using AdaptiveArrayPools: checkpoint!, rewind!, get_typed_pool!, get_view!, foreach_fixed_slot + # Load dependencies + using AdaptiveArrayPools + using AdaptiveArrayPools: checkpoint!, rewind!, get_typed_pool!, get_view!, foreach_fixed_slot - # Extension types (only needed for type checks in tests) - const ext = Base.get_extension(AdaptiveArrayPools, :AdaptiveArrayPoolsMetalExt) - const MetalTypedPool = ext.MetalTypedPool - const MetalAdaptiveArrayPool = ext.MetalAdaptiveArrayPool - const METAL_FIXED_SLOT_FIELDS = ext.METAL_FIXED_SLOT_FIELDS + # Extension types (only needed for type checks in tests) + const ext = Base.get_extension(AdaptiveArrayPools, :AdaptiveArrayPoolsMetalExt) + const MetalTypedPool = ext.MetalTypedPool + const MetalAdaptiveArrayPool = ext.MetalAdaptiveArrayPool + const METAL_FIXED_SLOT_FIELDS = ext.METAL_FIXED_SLOT_FIELDS - # Include all Metal test files - include("test_extension.jl") - include("test_allocation.jl") - include("test_display.jl") - include("test_convenience.jl") - include("test_disabled_pool.jl") - include("test_metal_safety.jl") - include("test_reshape.jl") - include("test_task_local_pool.jl") + # Include all Metal test files + include("test_extension.jl") + include("test_allocation.jl") + include("test_display.jl") + include("test_convenience.jl") + include("test_disabled_pool.jl") + include("test_metal_safety.jl") + include("test_reshape.jl") + include("test_task_local_pool.jl") + end end From 13c2cf29feba4c3f392f63f3bdd7a411b27ee2fd Mon Sep 17 00:00:00 2001 From: Min-Gu Yoo Date: Fri, 13 Mar 2026 23:41:31 -0700 Subject: [PATCH 4/4] docs: add Metal backend documentation, fix GPU caching descriptions - Add docs/src/features/metal-support.md with full Metal API reference - Fix cuda-support.md: replace incorrect N-way cache description with actual arr_wrappers-based direct-index caching mechanism - Add Metal backend to README supported backends and doc links - Add Metal page to docs/make.jl navigation and path mappings --- README.md | 14 ++-- docs/make.jl | 2 + docs/src/features/cuda-support.md | 23 ++--- docs/src/features/metal-support.md | 129 +++++++++++++++++++++++++++++ 4 files changed, 144 insertions(+), 24 deletions(-) create mode 100644 docs/src/features/metal-support.md diff --git a/README.md b/README.md index 3a694313..7e02a254 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,7 @@ A lightweight library that lets you write natural, allocation-style code while a **Supported backends:** - **CPU** — `Array`, works out of the box - **CUDA** — `CuArray`, loads automatically when [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) is available +- **Metal** — `MtlArray`, loads automatically when [Metal.jl](https://github.com/JuliaGPU/Metal.jl) is available (Apple Silicon) ## The Problem @@ -68,7 +69,7 @@ end | Allocations | ⚠️ 90,000 (2.75 GiB) | ✅ **0** | 100% eliminated | | GC Time | ⚠️ 31% | ✅ **0%** | No GC pauses | -> **CUDA support**: Same API—just use `@with_pool :cuda pool`. See [CUDA Backend](https://projecttorreypines.github.io/AdaptiveArrayPools.jl/stable/usage/cuda). +> **GPU support**: Same API—just use `@with_pool :cuda pool` or `@with_pool :metal pool`. See [CUDA Backend](https://projecttorreypines.github.io/AdaptiveArrayPools.jl/stable/features/cuda-support) and [Metal Backend](https://projecttorreypines.github.io/AdaptiveArrayPools.jl/stable/features/metal-support). ## How It Works @@ -111,11 +112,12 @@ Pkg.add("AdaptiveArrayPools") | Guide | Description | |-------|-------------| -| [API Reference](https://projecttorreypines.github.io/AdaptiveArrayPools.jl/stable/usage/api) | Complete function and macro reference | -| [CUDA Backend](https://projecttorreypines.github.io/AdaptiveArrayPools.jl/stable/usage/cuda) | GPU-specific usage and examples | -| [Safety Guide](https://projecttorreypines.github.io/AdaptiveArrayPools.jl/stable/guide/safety) | Scope rules and best practices | -| [Multi-Threading](https://projecttorreypines.github.io/AdaptiveArrayPools.jl/stable/advanced/multi-threading) | Task/thread safety patterns | -| [Configuration](https://projecttorreypines.github.io/AdaptiveArrayPools.jl/stable/usage/configuration) | Preferences and cache tuning | +| [API Reference](https://projecttorreypines.github.io/AdaptiveArrayPools.jl/stable/reference/api) | Complete function and macro reference | +| [CUDA Backend](https://projecttorreypines.github.io/AdaptiveArrayPools.jl/stable/features/cuda-support) | NVIDIA GPU usage and examples | +| [Metal Backend](https://projecttorreypines.github.io/AdaptiveArrayPools.jl/stable/features/metal-support) | Apple Silicon GPU usage and examples | +| [Safety Guide](https://projecttorreypines.github.io/AdaptiveArrayPools.jl/stable/basics/safety-rules) | Scope rules and best practices | +| [Multi-Threading](https://projecttorreypines.github.io/AdaptiveArrayPools.jl/stable/features/multi-threading) | Task/thread safety patterns | +| [Configuration](https://projecttorreypines.github.io/AdaptiveArrayPools.jl/stable/features/configuration) | Preferences and cache tuning | ## License diff --git a/docs/make.jl b/docs/make.jl index 54b79e41..06d093e4 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -33,6 +33,7 @@ const README_PATH_MAPPINGS = [ (r"\(docs/configuration\.md(#[^)]+)?\)", s"(features/configuration.md\1)"), (r"\(docs/maybe_with_pool\.md(#[^)]+)?\)", s"(features/maybe-with-pool.md\1)"), (r"\(docs/multi-threading\.md(#[^)]+)?\)", s"(features/multi-threading.md\1)"), + (r"\(docs/metal\.md(#[^)]+)?\)", s"(features/metal-support.md\1)"), # Basics (r"\(docs/safety\.md(#[^)]+)?\)", s"(basics/safety-rules.md\1)"), @@ -130,6 +131,7 @@ makedocs( "`@maybe_with_pool`" => "features/maybe-with-pool.md", "Bit Arrays" => "features/bit-arrays.md", "CUDA Support" => "features/cuda-support.md", + "Metal Support" => "features/metal-support.md", "Configuration" => "features/configuration.md", ], "Reference" => [ diff --git a/docs/src/features/cuda-support.md b/docs/src/features/cuda-support.md index 95e374df..f97ed1fd 100644 --- a/docs/src/features/cuda-support.md +++ b/docs/src/features/cuda-support.md @@ -48,24 +48,10 @@ The CUDA backend uses the same API as CPU, with `:cuda` backend specifier: **GPU Memory**: Always 0 bytes allocation after warmup. The underlying `CuVector` is resized as needed and reused. **CPU-side Wrapper Memory** (for `acquire!` N-D on CUDA): -- The CUDA backend uses an N-way set-associative cache for `CuArray` wrapper reuse -- Cache hit (≤`CACHE_WAYS` dimension patterns per slot): 0 bytes -- Cache miss (>`CACHE_WAYS` patterns): ~100 bytes for wrapper metadata -- See [Configuration](configuration.md) for `CACHE_WAYS` tuning - -!!! note "CPU vs CUDA caching" - On CPU (Julia 1.11+), `acquire!` uses `setfield!`-based wrapper reuse with **zero allocation for any number of dimension patterns**. The CUDA backend does not yet support this optimization and still uses the N-way cache. - -```julia -# Example: 4 patterns fit in default 4-way cache → zero CPU-side allocation -dims_list = ((10, 10), (5, 20), (20, 5), (4, 25)) -for dims in dims_list - @with_pool :cuda p begin - A = acquire!(p, Float64, dims...) - # Use A... - end -end -``` +- The CUDA backend uses `arr_wrappers`-based direct-index caching for `CuArray` wrapper reuse +- Each dimensionality `N` has one cached wrapper per slot, reused via `setfield!(:dims)` +- After warmup: **zero CPU-side allocation for any number of dimension patterns** (same `N`) +- Different `N` values each get their own cached wrapper (also zero-alloc after first use) ## Fixed Slot Types @@ -86,6 +72,7 @@ Other types use the fallback dictionary (`.others`). ## Limitations +- **Julia 1.11+**: Required for `setfield!`-based Array internals used by GPU extensions - **No `@maybe_with_pool :cuda`**: Runtime toggle not supported for CUDA backend - **Task-local only**: Each Task gets its own CUDA pool, same as CPU - **Same device**: All arrays in a pool use the same CUDA device diff --git a/docs/src/features/metal-support.md b/docs/src/features/metal-support.md new file mode 100644 index 00000000..85d4492f --- /dev/null +++ b/docs/src/features/metal-support.md @@ -0,0 +1,129 @@ +# Metal Backend + +AdaptiveArrayPools provides native Apple Silicon GPU support through a package extension that loads automatically when [Metal.jl](https://github.com/JuliaGPU/Metal.jl) is available. Requires Julia 1.11+. + +## Quick Start + +```julia +using AdaptiveArrayPools, Metal + +# Use :metal backend for Apple Silicon GPU arrays +@with_pool :metal pool function gpu_computation(n) + A = acquire!(pool, Float32, n, n) # MtlArray + B = acquire!(pool, Float32, n, n) # MtlArray + + fill!(A, 1.0f0) + fill!(B, 2.0f0) + + return sum(A .+ B) +end + +# Zero GPU allocation in hot loops +for i in 1:1000 + gpu_computation(100) # GPU memory reused from pool +end +``` + +## API + +The Metal backend uses the same API as CPU and CUDA, with `:metal` backend specifier: + +| Macro/Function | Description | +|----------------|-------------| +| `@with_pool :metal pool expr` | GPU pool with automatic checkpoint/rewind | +| `acquire!(pool, T, dims...)` | Returns `MtlArray` (always 0 bytes GPU alloc) | +| `acquire_view!(pool, T, dims...)` | Returns `MtlArray` (same as `acquire!` on Metal) | +| `get_task_local_metal_pool()` | Returns the task-local Metal pool | +| `pool_stats(:metal)` | Print Metal pool statistics | + +## Return Types + +| Function | 1D Return | N-D Return | +|----------|-----------|------------| +| `acquire!` | `MtlArray{T,1}` | `MtlArray{T,N}` | +| `acquire_view!` | `MtlArray{T,1}` | `MtlArray{T,N}` | + +## Allocation Behavior + +**GPU Memory**: Always 0 bytes allocation after warmup. The underlying `MtlVector` is resized as needed and reused. + +**CPU-side Wrapper Memory** (for `acquire!` N-D on Metal): +- The Metal backend uses `arr_wrappers`-based direct-index caching for `MtlArray` wrapper reuse +- Each dimensionality `N` has one cached wrapper per slot, reused via `setfield!(:dims)` +- After warmup: **zero CPU-side allocation for any number of dimension patterns** (same `N`) +- Different `N` values each get their own cached wrapper (also zero-alloc after first use) + +## Fixed Slot Types + +Metal hardware does not support Float64 or ComplexF64. The following types have optimized pre-allocated slots: + +| Type | Field | +|------|-------| +| `Float32` | `.float32` | +| `Float16` | `.float16` | +| `Int64` | `.int64` | +| `Int32` | `.int32` | +| `ComplexF32` | `.complexf32` | +| `Bool` | `.bool` | + +Other types use the fallback dictionary (`.others`). + +!!! note "No Float64/ComplexF64" + Apple Silicon GPUs do not natively support 64-bit floating point. Use `Float32` or `Float16` instead. + +## Limitations + +- **No Float64/ComplexF64**: Apple Silicon GPUs do not natively support 64-bit floating point +- **No `@maybe_with_pool :metal`**: Runtime toggle not supported for Metal backend +- **Single-device only**: Tested on single Apple GPU (multi-device untested) +- **Julia 1.11+**: Required for `setfield!`-based Array internals used by GPU extensions +- **Task-local only**: Each Task gets its own Metal pool, same as CPU + +## Example: Matrix Computation + +```julia +using AdaptiveArrayPools, Metal + +@with_pool :metal pool function gpu_compute(n) + A = acquire!(pool, Float32, n, n) + B = acquire!(pool, Float32, n, n) + C = acquire!(pool, Float32, n, n) + + fill!(A, 1.0f0); fill!(B, 2.0f0) + C .= A .+ B + + return sum(C) +end + +# Warmup +gpu_compute(100) + +# Benchmark - zero GPU allocation +using BenchmarkTools +@benchmark gpu_compute(1000) +``` + +## Debugging + +```julia +# Check pool state +pool_stats(:metal) + +# Output: +# MetalAdaptiveArrayPool +# Float32 (fixed) [Metal] +# slots: 3 (active: 0) +# elements: 30000 (117.188 KiB) +``` + +## CUDA vs Metal + +| Feature | CUDA | Metal | +|---------|------|-------| +| Backend symbol | `:cuda` | `:metal` | +| Array type | `CuArray` | `MtlArray` | +| Float64 support | Yes | No | +| ComplexF64 support | Yes | No | +| Julia requirement | 1.11+ | 1.11+ | +| Safety features | Full | Full | +| Lazy mode | Yes | Yes |