diff --git a/src/AdaptiveArrayPools.jl b/src/AdaptiveArrayPools.jl index d5d49fcf..0ffb3e5b 100644 --- a/src/AdaptiveArrayPools.jl +++ b/src/AdaptiveArrayPools.jl @@ -9,7 +9,8 @@ export zeros!, ones!, trues!, falses!, similar!, reshape!, default_eltype # Con export unsafe_zeros!, unsafe_ones!, unsafe_similar! # Unsafe convenience functions export Bit # Sentinel type for BitArray (use with acquire!, trues!, falses!) export @with_pool, @maybe_with_pool -export STATIC_POOLING, MAYBE_POOLING, POOL_DEBUG +export STATIC_POOLING, MAYBE_POOLING, POOL_DEBUG, POOL_SAFETY_LV, STATIC_POOL_CHECKS +export PoolEscapeError, EscapePoint export USE_POOLING, MAYBE_POOLING_ENABLED # Deprecated aliases (backward compat) export checkpoint!, rewind!, reset! export get_task_local_cuda_pool, get_task_local_cuda_pools # CUDA (stubs, overridden by extension) @@ -28,6 +29,7 @@ export DisabledPool, DISABLED_CPU, pooling_enabled # Disabled pool support include("convenience.jl") include("state.jl") include("task_local_pool.jl") + include("debug.jl") include("macros.jl") else export CACHE_WAYS, set_cache_ways! # N-way cache configuration (legacy only) @@ -38,6 +40,7 @@ else include("convenience.jl") include("legacy/state.jl") include("task_local_pool.jl") + include("debug.jl") include("macros.jl") end diff --git a/src/acquire.jl b/src/acquire.jl index f2d4f9ff..5647aacc 100644 --- a/src/acquire.jl +++ b/src/acquire.jl @@ -327,12 +327,16 @@ Internal implementation of acquire!. Called directly by macro-transformed code """ @inline function _acquire_impl!(pool::AbstractArrayPool, ::Type{T}, n::Int) where {T} tp = get_typed_pool!(pool, T) - return get_view!(tp, n) + result = get_view!(tp, n) + _maybe_record_borrow!(pool, tp) + return result end @inline function _acquire_impl!(pool::AbstractArrayPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N} tp = get_typed_pool!(pool, T) - return get_view!(tp, dims) + result = get_view!(tp, dims) + _maybe_record_borrow!(pool, tp) + return result end @inline function _acquire_impl!(pool::AbstractArrayPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N} @@ -349,17 +353,23 @@ Internal implementation of unsafe_acquire!. Called directly by macro-transformed """ @inline function _unsafe_acquire_impl!(pool::AbstractArrayPool, ::Type{T}, n::Int) where {T} tp = get_typed_pool!(pool, T) - return get_array!(tp, (n,)) + result = get_array!(tp, (n,)) + _maybe_record_borrow!(pool, tp) + return result end @inline function _unsafe_acquire_impl!(pool::AbstractArrayPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N} tp = get_typed_pool!(pool, T) - return get_array!(tp, dims) + result = get_array!(tp, dims) + _maybe_record_borrow!(pool, tp) + return result end @inline function _unsafe_acquire_impl!(pool::AbstractArrayPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N} tp = get_typed_pool!(pool, T) - return get_array!(tp, dims) + result = get_array!(tp, dims) + _maybe_record_borrow!(pool, tp) + return result end # Similar-style @@ -403,18 +413,21 @@ See also: [`unsafe_acquire!`](@ref) for native array access. """ @inline function acquire!(pool::AbstractArrayPool, ::Type{T}, n::Int) where {T} _record_type_touch!(pool, T) + _set_pending_callsite!(pool, "") return _acquire_impl!(pool, T, n) end # Multi-dimensional support (zero-allocation with N-D cache) @inline function acquire!(pool::AbstractArrayPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N} _record_type_touch!(pool, T) + _set_pending_callsite!(pool, "") return _acquire_impl!(pool, T, dims...) end # Tuple support: allows acquire!(pool, T, size(A)) where size(A) returns NTuple{N,Int} @inline function acquire!(pool::AbstractArrayPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N} _record_type_touch!(pool, T) + _set_pending_callsite!(pool, "") return _acquire_impl!(pool, T, dims...) end @@ -435,6 +448,7 @@ end """ @inline function acquire!(pool::AbstractArrayPool, x::AbstractArray) _record_type_touch!(pool, eltype(x)) + _set_pending_callsite!(pool, "") return _acquire_impl!(pool, eltype(x), size(x)) end @@ -490,17 +504,20 @@ See also: [`acquire!`](@ref) for view-based access. """ @inline function unsafe_acquire!(pool::AbstractArrayPool, ::Type{T}, n::Int) where {T} _record_type_touch!(pool, T) + _set_pending_callsite!(pool, "") return _unsafe_acquire_impl!(pool, T, n) end @inline function unsafe_acquire!(pool::AbstractArrayPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N} _record_type_touch!(pool, T) + _set_pending_callsite!(pool, "") return _unsafe_acquire_impl!(pool, T, dims...) end # Tuple support @inline function unsafe_acquire!(pool::AbstractArrayPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N} _record_type_touch!(pool, T) + _set_pending_callsite!(pool, "") return _unsafe_acquire_impl!(pool, T, dims) end @@ -521,6 +538,7 @@ end """ @inline function unsafe_acquire!(pool::AbstractArrayPool, x::AbstractArray) _record_type_touch!(pool, eltype(x)) + _set_pending_callsite!(pool, "") return _unsafe_acquire_impl!(pool, eltype(x), size(x)) end diff --git a/src/bitarray.jl b/src/bitarray.jl index 8ec7c92f..387f35ac 100644 --- a/src/bitarray.jl +++ b/src/bitarray.jl @@ -139,17 +139,23 @@ end # Bit type: returns BitArray{N} with shared chunks (SIMD optimized, N-D cached) @inline function _unsafe_acquire_impl!(pool::AbstractArrayPool, ::Type{Bit}, n::Int) tp = get_typed_pool!(pool, Bit)::BitTypedPool - return get_bitarray!(tp, n) + result = get_bitarray!(tp, n) + _maybe_record_borrow!(pool, tp) + return result end @inline function _unsafe_acquire_impl!(pool::AbstractArrayPool, ::Type{Bit}, dims::Vararg{Int, N}) where {N} tp = get_typed_pool!(pool, Bit)::BitTypedPool - return get_bitarray!(tp, dims) + result = get_bitarray!(tp, dims) + _maybe_record_borrow!(pool, tp) + return result end @inline function _unsafe_acquire_impl!(pool::AbstractArrayPool, ::Type{Bit}, dims::NTuple{N, Int}) where {N} tp = get_typed_pool!(pool, Bit)::BitTypedPool - return get_bitarray!(tp, dims) + result = get_bitarray!(tp, dims) + _maybe_record_borrow!(pool, tp) + return result end # ============================================================================== @@ -193,19 +199,24 @@ end # ============================================================================== # Check if BitArray chunks overlap with the pool's BitTypedPool storage -function _check_bitchunks_overlap(arr::BitArray, pool::AdaptiveArrayPool) +function _check_bitchunks_overlap(arr::BitArray, pool::AdaptiveArrayPool, original_val = arr) arr_chunks = arr.chunks arr_ptr = UInt(pointer(arr_chunks)) arr_len = length(arr_chunks) * sizeof(UInt64) arr_end = arr_ptr + arr_len + return_site = let rs = pool._pending_return_site + isempty(rs) ? nothing : rs + end + for v in pool.bits.vectors v_chunks = v.chunks v_ptr = UInt(pointer(v_chunks)) v_len = length(v_chunks) * sizeof(UInt64) v_end = v_ptr + v_len if !(arr_end <= v_ptr || v_end <= arr_ptr) - error("Safety Violation: The function returned a BitArray backed by pool memory. This is unsafe as the memory will be reclaimed. Please return a copy (copy) or a scalar.") + callsite = _lookup_borrow_callsite(pool, v) + _throw_pool_escape_error(original_val, Bit, callsite, return_site) end end return nothing diff --git a/src/convenience.jl b/src/convenience.jl index 33ec099a..00267872 100644 --- a/src/convenience.jl +++ b/src/convenience.jl @@ -44,21 +44,25 @@ See also: [`ones!`](@ref), [`similar!`](@ref), [`acquire!`](@ref) """ @inline function zeros!(pool::AbstractArrayPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N} _record_type_touch!(pool, T) + _set_pending_callsite!(pool, "") return _zeros_impl!(pool, T, dims...) end @inline function zeros!(pool::AbstractArrayPool, dims::Vararg{Int, N}) where {N} _record_type_touch!(pool, default_eltype(pool)) + _set_pending_callsite!(pool, "") return _zeros_impl!(pool, default_eltype(pool), dims...) end @inline function zeros!(pool::AbstractArrayPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N} _record_type_touch!(pool, T) + _set_pending_callsite!(pool, "") return _zeros_impl!(pool, T, dims...) end @inline function zeros!(pool::AbstractArrayPool, dims::NTuple{N, Int}) where {N} _record_type_touch!(pool, default_eltype(pool)) + _set_pending_callsite!(pool, "") return _zeros_impl!(pool, default_eltype(pool), dims...) end @@ -117,21 +121,25 @@ See also: [`zeros!`](@ref), [`similar!`](@ref), [`acquire!`](@ref) """ @inline function ones!(pool::AbstractArrayPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N} _record_type_touch!(pool, T) + _set_pending_callsite!(pool, "") return _ones_impl!(pool, T, dims...) end @inline function ones!(pool::AbstractArrayPool, dims::Vararg{Int, N}) where {N} _record_type_touch!(pool, default_eltype(pool)) + _set_pending_callsite!(pool, "") return _ones_impl!(pool, default_eltype(pool), dims...) end @inline function ones!(pool::AbstractArrayPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N} _record_type_touch!(pool, T) + _set_pending_callsite!(pool, "") return _ones_impl!(pool, T, dims...) end @inline function ones!(pool::AbstractArrayPool, dims::NTuple{N, Int}) where {N} _record_type_touch!(pool, default_eltype(pool)) + _set_pending_callsite!(pool, "") return _ones_impl!(pool, default_eltype(pool), dims...) end @@ -187,10 +195,12 @@ See also: [`falses!`](@ref), [`ones!`](@ref), [`acquire!`](@ref) """ @inline function trues!(pool::AbstractArrayPool, dims::Vararg{Int, N}) where {N} _record_type_touch!(pool, Bit) + _set_pending_callsite!(pool, "") return _trues_impl!(pool, dims...) end @inline function trues!(pool::AbstractArrayPool, dims::NTuple{N, Int}) where {N} _record_type_touch!(pool, Bit) + _set_pending_callsite!(pool, "") return _trues_impl!(pool, dims...) end @@ -227,10 +237,12 @@ See also: [`trues!`](@ref), [`zeros!`](@ref), [`acquire!`](@ref) """ @inline function falses!(pool::AbstractArrayPool, dims::Vararg{Int, N}) where {N} _record_type_touch!(pool, Bit) + _set_pending_callsite!(pool, "") return _falses_impl!(pool, dims...) end @inline function falses!(pool::AbstractArrayPool, dims::NTuple{N, Int}) where {N} _record_type_touch!(pool, Bit) + _set_pending_callsite!(pool, "") return _falses_impl!(pool, dims...) end @@ -274,21 +286,25 @@ See also: [`zeros!`](@ref), [`ones!`](@ref), [`acquire!`](@ref) """ @inline function similar!(pool::AbstractArrayPool, x::AbstractArray) _record_type_touch!(pool, eltype(x)) + _set_pending_callsite!(pool, "") return _similar_impl!(pool, x) end @inline function similar!(pool::AbstractArrayPool, x::AbstractArray, ::Type{T}) where {T} _record_type_touch!(pool, T) + _set_pending_callsite!(pool, "") return _similar_impl!(pool, x, T) end @inline function similar!(pool::AbstractArrayPool, x::AbstractArray, dims::Vararg{Int, N}) where {N} _record_type_touch!(pool, eltype(x)) + _set_pending_callsite!(pool, "") return _similar_impl!(pool, x, dims...) end @inline function similar!(pool::AbstractArrayPool, x::AbstractArray, ::Type{T}, dims::Vararg{Int, N}) where {T, N} _record_type_touch!(pool, T) + _set_pending_callsite!(pool, "") return _similar_impl!(pool, x, T, dims...) end @@ -398,21 +414,25 @@ See also: [`unsafe_ones!`](@ref), [`zeros!`](@ref), [`unsafe_acquire!`](@ref) """ @inline function unsafe_zeros!(pool::AbstractArrayPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N} _record_type_touch!(pool, T) + _set_pending_callsite!(pool, "") return _unsafe_zeros_impl!(pool, T, dims...) end @inline function unsafe_zeros!(pool::AbstractArrayPool, dims::Vararg{Int, N}) where {N} _record_type_touch!(pool, default_eltype(pool)) + _set_pending_callsite!(pool, "") return _unsafe_zeros_impl!(pool, default_eltype(pool), dims...) end @inline function unsafe_zeros!(pool::AbstractArrayPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N} _record_type_touch!(pool, T) + _set_pending_callsite!(pool, "") return _unsafe_zeros_impl!(pool, T, dims...) end @inline function unsafe_zeros!(pool::AbstractArrayPool, dims::NTuple{N, Int}) where {N} _record_type_touch!(pool, default_eltype(pool)) + _set_pending_callsite!(pool, "") return _unsafe_zeros_impl!(pool, default_eltype(pool), dims...) end @@ -465,21 +485,25 @@ See also: [`unsafe_zeros!`](@ref), [`ones!`](@ref), [`unsafe_acquire!`](@ref) """ @inline function unsafe_ones!(pool::AbstractArrayPool, ::Type{T}, dims::Vararg{Int, N}) where {T, N} _record_type_touch!(pool, T) + _set_pending_callsite!(pool, "") return _unsafe_ones_impl!(pool, T, dims...) end @inline function unsafe_ones!(pool::AbstractArrayPool, dims::Vararg{Int, N}) where {N} _record_type_touch!(pool, default_eltype(pool)) + _set_pending_callsite!(pool, "") return _unsafe_ones_impl!(pool, default_eltype(pool), dims...) end @inline function unsafe_ones!(pool::AbstractArrayPool, ::Type{T}, dims::NTuple{N, Int}) where {T, N} _record_type_touch!(pool, T) + _set_pending_callsite!(pool, "") return _unsafe_ones_impl!(pool, T, dims...) end @inline function unsafe_ones!(pool::AbstractArrayPool, dims::NTuple{N, Int}) where {N} _record_type_touch!(pool, default_eltype(pool)) + _set_pending_callsite!(pool, "") return _unsafe_ones_impl!(pool, default_eltype(pool), dims...) end @@ -535,21 +559,25 @@ See also: [`similar!`](@ref), [`unsafe_acquire!`](@ref) """ @inline function unsafe_similar!(pool::AbstractArrayPool, x::AbstractArray) _record_type_touch!(pool, eltype(x)) + _set_pending_callsite!(pool, "") return _unsafe_similar_impl!(pool, x) end @inline function unsafe_similar!(pool::AbstractArrayPool, x::AbstractArray, ::Type{T}) where {T} _record_type_touch!(pool, T) + _set_pending_callsite!(pool, "") return _unsafe_similar_impl!(pool, x, T) end @inline function unsafe_similar!(pool::AbstractArrayPool, x::AbstractArray, dims::Vararg{Int, N}) where {N} _record_type_touch!(pool, eltype(x)) + _set_pending_callsite!(pool, "") return _unsafe_similar_impl!(pool, x, dims...) end @inline function unsafe_similar!(pool::AbstractArrayPool, x::AbstractArray, ::Type{T}, dims::Vararg{Int, N}) where {T, N} _record_type_touch!(pool, T) + _set_pending_callsite!(pool, "") return _unsafe_similar_impl!(pool, x, T, dims...) end diff --git a/src/debug.jl b/src/debug.jl new file mode 100644 index 00000000..9a2761a8 --- /dev/null +++ b/src/debug.jl @@ -0,0 +1,345 @@ +# ============================================================================== +# Debugging & Safety (POOL_DEBUG escape detection) +# ============================================================================== + +""" + POOL_DEBUG + +Legacy flag for escape detection. Superseded by [`POOL_SAFETY_LV`](@ref). + +Setting `POOL_DEBUG[] = true` enables escape detection at `@with_pool` scope exit +(equivalent to `POOL_SAFETY_LV[] >= 2` behavior). Both flags are checked independently. + +For new code, prefer `POOL_SAFETY_LV[] = 2`. + +Default: `false` +""" +const POOL_DEBUG = Ref(false) + +function _validate_pool_return(val, pool::AdaptiveArrayPool) + # 0. Check BitArray / BitVector (bit-packed storage) + if val isa BitArray + _check_bitchunks_overlap(val, pool) + return + end + + # 1. Check SubArray + if val isa SubArray + p = parent(val) + # Use pointer overlap check for ALL Array parents (Vector <: Array) + # This catches both: + # - acquire!() 1D returns: SubArray backed by pool's internal Vector + # - view(unsafe_acquire!()): SubArray backed by unsafe_wrap'd Array + if p isa Array + _check_pointer_overlap(p, pool, val) + elseif p isa BitArray + _check_bitchunks_overlap(p, pool, val) + end + return + end + + # 2. Check ReshapedArray (from acquire! N-D, wraps SubArray of pool Vector) + if val isa Base.ReshapedArray + p = parent(val) + # ReshapedArray wraps SubArray{T,1,Vector{T},...} + if p isa SubArray + pp = parent(p) + if pp isa Array + _check_pointer_overlap(pp, pool, val) + elseif pp isa BitArray + _check_bitchunks_overlap(pp, pool, val) + end + end + return + end + + # 3. Check raw Array (from unsafe_acquire!) + element recursion + return if val isa Array + # Pool vectors always have concrete eltypes — skip overlap check for abstract + if isconcretetype(eltype(val)) + _check_pointer_overlap(val, pool) + end + # Recurse into elements for containers like Vector{SubArray} + if _eltype_may_contain_arrays(eltype(val)) + for x in val + _validate_pool_return(x, pool) + end + end + end +end + +# Eltype guard: skip element iteration for leaf types (perf optimization in debug mode) +_eltype_may_contain_arrays(::Type{<:Number}) = false +_eltype_may_contain_arrays(::Type{<:AbstractString}) = false +_eltype_may_contain_arrays(::Type{Symbol}) = false +_eltype_may_contain_arrays(::Type{Char}) = false +_eltype_may_contain_arrays(::Type) = true + +# Check if array memory overlaps with any pool vector. +# `original_val` is the user-visible value (e.g., SubArray) for error reporting; +# `arr` may be its parent Array used for the actual pointer comparison. +function _check_pointer_overlap(arr::Array, pool::AdaptiveArrayPool, original_val = arr) + arr_ptr = UInt(pointer(arr)) + arr_len = length(arr) * sizeof(eltype(arr)) + arr_end = arr_ptr + arr_len + + return_site = let rs = pool._pending_return_site + isempty(rs) ? nothing : rs + end + + check_overlap = function (tp) + for v in tp.vectors + v isa Array || continue # Skip BitVector (no pointer(); checked via _check_bitchunks_overlap) + v_ptr = UInt(pointer(v)) + v_len = length(v) * sizeof(eltype(v)) + v_end = v_ptr + v_len + if !(arr_end <= v_ptr || v_end <= arr_ptr) + callsite = _lookup_borrow_callsite(pool, v) + _throw_pool_escape_error(original_val, eltype(v), callsite, return_site) + end + end + return + end + + # Check fixed slots + foreach_fixed_slot(pool) do tp + check_overlap(tp) + end + + # Check others + for tp in values(pool.others) + check_overlap(tp) + end + return +end + +""" + PoolRuntimeEscapeError <: Exception + +Thrown at runtime when `_validate_pool_return` detects a pool-backed array +escaping from an `@with_pool` scope (requires `POOL_SAFETY_LV[] >= 2`). + +This is the runtime counterpart of [`PoolEscapeError`](@ref) (compile-time). +""" +struct PoolRuntimeEscapeError <: Exception + val_summary::String + pool_eltype::String + callsite::Union{Nothing, String} # acquire location (LV ≥ 3) + return_site::Union{Nothing, String} # return location (LV ≥ 3) +end + +function Base.showerror(io::IO, e::PoolRuntimeEscapeError) + has_callsite = e.callsite !== nothing + lv_label = has_callsite ? "POOL_SAFETY_LV ≥ 3" : "POOL_SAFETY_LV ≥ 2" + + printstyled(io, "PoolEscapeError"; color = :red, bold = true) + printstyled(io, " (runtime, ", lv_label, ")"; color = :light_black) + println(io) + + println(io) + printstyled(io, " "; color = :normal) + printstyled(io, e.val_summary; color = :red, bold = true) + println(io) + printstyled(io, " ← backed by "; color = :light_black) + printstyled(io, e.pool_eltype; color = :yellow) + printstyled(io, " pool memory, will be reclaimed at scope exit\n"; color = :light_black) + + if has_callsite + # Parse callsite: "file:line" or "file:line\nexpr" + parts = split(e.callsite, '\n'; limit = 2) + location = String(parts[1]) + expr_text = length(parts) >= 2 ? String(parts[2]) : nothing + + # Shorten the file path (shorter of relpath vs ~/…-contracted) + location = _shorten_location(location) + + printstyled(io, " ← acquired at "; color = :light_black) + printstyled(io, location; color = :cyan, bold = true) + println(io) + + if expr_text !== nothing + printstyled(io, " "; color = :normal) + printstyled(io, expr_text; color = :cyan) + println(io) + end + end + + has_return_site = e.return_site !== nothing + if has_return_site + parts = split(e.return_site, '\n'; limit = 2) + location = _shorten_location(String(parts[1])) + expr_text = length(parts) >= 2 ? String(parts[2]) : nothing + + printstyled(io, " ← escapes at "; color = :light_black) + printstyled(io, location; color = :magenta, bold = true) + println(io) + + if expr_text !== nothing + printstyled(io, " "; color = :normal) + printstyled(io, expr_text; color = :magenta) + println(io) + end + end + + println(io) + printstyled(io, " Fix: "; bold = true) + printstyled(io, "Wrap with "; color = :light_black) + printstyled(io, "collect()"; bold = true) + printstyled(io, " to return an owned copy, or compute a scalar result.\n"; color = :light_black) + + return if !has_callsite + println(io) + printstyled(io, " Tip: "; bold = true) + printstyled(io, "set "; color = :light_black) + printstyled(io, "POOL_SAFETY_LV[] = 3"; bold = true) + printstyled(io, " for acquire!() call-site tracking.\n"; color = :light_black) + end +end + +Base.showerror(io::IO, e::PoolRuntimeEscapeError, ::Any; backtrace = true) = showerror(io, e) + +@noinline function _throw_pool_escape_error(val, pool_eltype, callsite::Union{Nothing, String} = nothing, return_site::Union{Nothing, String} = nothing) + throw(PoolRuntimeEscapeError(summary(val), string(pool_eltype), callsite, return_site)) +end + +# Recursive inspection of container types (Tuple, NamedTuple, Pair, Dict, Set). +# These are common wrapper types in Julia through which pool-backed arrays +# can escape undetected when hidden inside return values. +# Note: Array element recursion is handled in the main function via _eltype_may_contain_arrays. + +function _validate_pool_return(val::Tuple, pool::AdaptiveArrayPool) + for x in val + _validate_pool_return(x, pool) + end + return +end + +function _validate_pool_return(val::NamedTuple, pool::AdaptiveArrayPool) + for x in values(val) + _validate_pool_return(x, pool) + end + return +end + +function _validate_pool_return(val::Pair, pool::AdaptiveArrayPool) + _validate_pool_return(val.first, pool) + return _validate_pool_return(val.second, pool) +end + +function _validate_pool_return(val::AbstractDict, pool::AdaptiveArrayPool) + for p in val # each p is a Pair — reuses Pair dispatch + _validate_pool_return(p, pool) + end + return +end + +function _validate_pool_return(val::AbstractSet, pool::AdaptiveArrayPool) + for x in val + _validate_pool_return(x, pool) + end + return +end + +_validate_pool_return(val, ::DisabledPool) = nothing + +# ============================================================================== +# Poisoning: Fill released vectors with sentinel values (POOL_SAFETY_LV >= 2) +# ============================================================================== +# +# Poisons backing vectors with detectable values (NaN, typemax) before +# structural invalidation. This ensures stale references read obviously wrong +# data instead of silently valid old values — especially useful for +# unsafe_acquire! Array wrappers on Julia 1.10 where setfield!(:size) is +# unavailable and structural invalidation can't catch stale access. + +_poison_value(::Type{T}) where {T <: AbstractFloat} = T(NaN) +_poison_value(::Type{T}) where {T <: Integer} = typemax(T) +_poison_value(::Type{Complex{T}}) where {T} = Complex{T}(_poison_value(T), _poison_value(T)) +_poison_value(::Type{T}) where {T} = zero(T) # generic fallback + +_poison_fill!(v::Vector{T}) where {T} = fill!(v, _poison_value(T)) +_poison_fill!(v::BitVector) = fill!(v, true) + +""" + _poison_released_vectors!(tp::AbstractTypedPool, old_n_active) + +Fill released backing vectors (indices `n_active+1:old_n_active`) with sentinel +values. Called from `_invalidate_released_slots!` when `POOL_SAFETY_LV[] >= 2`, +before `resize!` zeroes the lengths. +""" +@noinline function _poison_released_vectors!(tp::AbstractTypedPool, old_n_active::Int) + new_n = tp.n_active + for i in (new_n + 1):old_n_active + _poison_fill!(@inbounds tp.vectors[i]) + end + return nothing +end + +# ============================================================================== +# Path Shortening (for readable callsite display) +# ============================================================================== +# +# Picks the shortest human-readable representation of a file path: +# relative to pwd, ~/…-contracted, or the original absolute path. +# Adapted from Infiltrator.jl (src/breakpoints.jl). + +function _short_path(f::String) + contracted = Base.contractuser(f) + try + rel = relpath(f) + return length(rel) < length(contracted) ? rel : contracted + catch + return contracted + end +end + +# Shorten "file:line" location string using _short_path +function _shorten_location(location::String) + colon_idx = findlast(':', location) + if colon_idx !== nothing + file = location[1:prevind(location, colon_idx)] + line_part = location[colon_idx:end] + return _short_path(file) * line_part + end + return location +end + +# ============================================================================== +# Borrow Registry: Call-site tracking for acquire! (POOL_SAFETY_LV >= 3) +# ============================================================================== +# +# Records where each acquire! call originated (file:line) so escape errors +# can point to the exact source location. The macro sets `_pending_callsite` +# before each acquire call; the _*_impl! functions call _record_borrow_from_pending! +# after claiming a slot. + +""" + _record_borrow_from_pending!(pool, tp) + +Record the pending callsite for the most recently claimed slot in `tp`. +Called from `_acquire_impl!` / `_unsafe_acquire_impl!` when `POOL_SAFETY_LV[] >= 3`. +""" +@noinline function _record_borrow_from_pending!(pool::AdaptiveArrayPool, tp::AbstractTypedPool) + callsite = pool._pending_callsite + isempty(callsite) && return nothing + log = pool._borrow_log + if log === nothing + log = IdDict{Any, String}() + pool._borrow_log = log + end + @inbounds log[tp.vectors[tp.n_active]] = callsite + pool._pending_callsite = "" # Clear so next _set_pending_callsite! can set a fresh value + return nothing +end + +""" + _lookup_borrow_callsite(pool, v) -> Union{Nothing, String} + +Look up the callsite string for a pool backing vector. Returns `nothing` if +no borrow was recorded (LV < 3 or non-macro path without callsite info). +""" +@noinline function _lookup_borrow_callsite(pool::AdaptiveArrayPool, v)::Union{Nothing, String} + log = pool._borrow_log + log === nothing && return nothing + return get(log, v, nothing) +end diff --git a/src/legacy/bitarray.jl b/src/legacy/bitarray.jl index a2bdbfa6..dcf3efb7 100644 --- a/src/legacy/bitarray.jl +++ b/src/legacy/bitarray.jl @@ -238,7 +238,7 @@ end # ============================================================================== # Check if BitArray chunks overlap with the pool's BitTypedPool storage -function _check_bitchunks_overlap(arr::BitArray, pool::AdaptiveArrayPool) +function _check_bitchunks_overlap(arr::BitArray, pool::AdaptiveArrayPool, original_val = arr) arr_chunks = arr.chunks arr_ptr = UInt(pointer(arr_chunks)) arr_len = length(arr_chunks) * sizeof(UInt64) @@ -250,7 +250,7 @@ function _check_bitchunks_overlap(arr::BitArray, pool::AdaptiveArrayPool) v_len = length(v_chunks) * sizeof(UInt64) v_end = v_ptr + v_len if !(arr_end <= v_ptr || v_end <= arr_ptr) - error("Safety Violation: The function returned a BitArray backed by pool memory. This is unsafe as the memory will be reclaimed. Please return a copy (copy) or a scalar.") + _throw_pool_escape_error(original_val, Bit) end end return nothing diff --git a/src/legacy/state.jl b/src/legacy/state.jl index 71848e26..9079d597 100644 --- a/src/legacy/state.jl +++ b/src/legacy/state.jl @@ -236,6 +236,79 @@ Decrements _current_depth once after all types are rewound. end end +# ============================================================================== +# Safety: Structural Invalidation on Rewind (POOL_SAFETY_LV >= 1) +# ============================================================================== +# +# When released, backing vectors are resize!'d to 0 and cached Array/BitArray +# wrappers have their size set to (0,...). This makes stale SubArrays and Arrays +# throw BoundsError on access instead of silently returning corrupted data. +# +# @noinline keeps invalidation code off the inlined hot path of _rewind_typed_pool!. + +# No-op fallback for extension types (e.g. CuTypedPool) +_invalidate_released_slots!(::AbstractTypedPool, ::Int) = nothing + +@noinline function _invalidate_released_slots!(tp::TypedPool{T}, old_n_active::Int) where {T} + new_n = tp.n_active + # Level 2+: poison vectors with NaN/sentinel before structural invalidation. + # Especially useful on legacy (1.10) where unsafe_acquire! Array wrappers + # can't be structurally invalidated (Array is a C struct, no setfield!). + if POOL_SAFETY_LV[] >= 2 + _poison_released_vectors!(tp, old_n_active) + end + # Level 1+: resize backing vectors to length 0 (invalidates SubArrays from acquire!) + # Note: Array wrapper invalidation (setfield! :size) requires Julia 1.11+. + # On legacy (1.10), only SubArray invalidation via resize! is available. + for i in (new_n + 1):old_n_active + @inbounds resize!(tp.vectors[i], 0) + end + # Invalidate view cache: reset view_lengths so get_view! forces cache miss + # instead of returning stale cached SubArrays pointing to empty vectors. + for i in (new_n + 1):min(old_n_active, length(tp.view_lengths)) + @inbounds tp.view_lengths[i] = 0 + end + return nothing +end + +@noinline function _invalidate_released_slots!(tp::BitTypedPool, old_n_active::Int) + new_n = tp.n_active + # Level 2+: poison BitVectors (all bits set to true) + if POOL_SAFETY_LV[] >= 2 + _poison_released_vectors!(tp, old_n_active) + end + # Level 1+: resize backing BitVectors to length 0 + for i in (new_n + 1):old_n_active + @inbounds resize!(tp.vectors[i], 0) + end + # Invalidate N-D BitArray wrappers (N-way cache layout) + # Also zero nd_ptrs to force cache misses on re-acquire: + # After resize!(bv, 0) → resize!(bv, n), the pointer may stay the same + # (capacity preserved), causing a stale cache hit that returns a + # BitArray with len=0/dims=(0,...) instead of creating a fresh one. + ways = CACHE_WAYS + for i in (new_n + 1):old_n_active + base = (i - 1) * ways + for w in 1:ways + idx = base + w + idx > length(tp.nd_arrays) && break + @inbounds tp.nd_ptrs[idx] = UInt(0) # force cache miss + ba = @inbounds tp.nd_arrays[idx] + ba === nothing && continue + dims = @inbounds tp.nd_dims[idx] + dims === nothing && continue + N = length(dims::Tuple) + setfield!(ba::BitArray, :len, 0) + setfield!(ba::BitArray, :dims, ntuple(_ -> 0, N)) + end + end + return nothing +end + +# ============================================================================== +# Internal: Rewind with Orphan Cleanup +# ============================================================================== + # Internal helper for rewind with orphan cleanup (works for any AbstractTypedPool) # Uses 1-based sentinel pattern: no isempty checks needed (sentinel [0] guarantees non-empty) @inline function _rewind_typed_pool!(tp::AbstractTypedPool, current_depth::Int) @@ -249,6 +322,11 @@ end pop!(tp._checkpoint_n_active) end + # Capture n_active before restore (for safety invalidation) + @static if STATIC_POOL_CHECKS + _old_n_active = tp.n_active + end + # 2. Normal Rewind Logic (Sentinel Pattern) # Now the stack top is guaranteed to be at depth <= current depth. if @inbounds tp._checkpoint_depths[end] == current_depth @@ -262,6 +340,14 @@ end # - If sentinel (_checkpoint_n_active=[0]), restores to n_active=0 tp.n_active = @inbounds tp._checkpoint_n_active[end] end + + # 3. Safety: invalidate released slots (Level 1+) + @static if STATIC_POOL_CHECKS + if POOL_SAFETY_LV[] >= 1 && _old_n_active > tp.n_active + _invalidate_released_slots!(tp, _old_n_active) + end + end + return nothing end @@ -459,12 +545,20 @@ Reset state without clearing allocated storage. Sets `n_active = 0` and restores checkpoint stacks to sentinel state. """ function reset!(tp::AbstractTypedPool) + @static if STATIC_POOL_CHECKS + _old_n_active = tp.n_active + end tp.n_active = 0 # Restore sentinel values (1-based sentinel pattern) empty!(tp._checkpoint_n_active) push!(tp._checkpoint_n_active, 0) # Sentinel: n_active=0 at depth=0 empty!(tp._checkpoint_depths) push!(tp._checkpoint_depths, 0) # Sentinel: depth=0 = no checkpoint + @static if STATIC_POOL_CHECKS + if POOL_SAFETY_LV[] >= 1 && _old_n_active > 0 + _invalidate_released_slots!(tp, _old_n_active) + end + end return tp end @@ -526,6 +620,11 @@ function reset!(pool::AdaptiveArrayPool) empty!(pool._touched_has_others) push!(pool._touched_has_others, false) # Sentinel: no others + # Clear borrow registry and return-site tracking + pool._pending_callsite = "" + pool._pending_return_site = "" + pool._borrow_log = nothing + return pool end diff --git a/src/legacy/types.jl b/src/legacy/types.jl index 23adf285..404c306b 100644 --- a/src/legacy/types.jl +++ b/src/legacy/types.jl @@ -61,6 +61,28 @@ function set_cache_ways!(n::Int) return n end +# ============================================================================== +# Safety Configuration (2-Tier Toggle) +# ============================================================================== +# Tier 1: STATIC_POOL_CHECKS (compile-time const) +# Tier 2: POOL_SAFETY_LV (runtime Ref{Int}, levels 0/1/2) + +const STATIC_POOL_CHECKS = @load_preference("pool_checks", true)::Bool + +""" + POOL_SAFETY_LV + +Runtime safety level for pool operations. Only effective when `STATIC_POOL_CHECKS` is `true`. + +- `0`: Off — no safety checks (Ref read only, ~1ns) +- `1`: Guard — structural invalidation on rewind (resize + setfield!, ~1ns/slot) +- `2`: Full — guard + escape detection on scope exit + poisoning +- `3`: Debug — full + borrow registry (acquire call-site tracking in error messages) + +Default: `1` (guard mode) +""" +const POOL_SAFETY_LV = Ref(1) + # ============================================================================== # Abstract Type Hierarchy (for extensibility) # ============================================================================== @@ -376,6 +398,11 @@ mutable struct AdaptiveArrayPool <: AbstractArrayPool _current_depth::Int # Current scope depth (1 = global scope) _touched_type_masks::Vector{UInt16} # Per-depth: which fixed slots were touched + mode flags _touched_has_others::Vector{Bool} # Per-depth: any non-fixed-slot type touched? + + # Borrow registry (POOL_SAFETY_LV >= 3 only, modern path only) + _pending_callsite::String # "" = no pending + _pending_return_site::String # "" = no pending + _borrow_log::Union{Nothing, IdDict{Any, String}} # vector_obj => callsite string end function AdaptiveArrayPool() @@ -391,7 +418,10 @@ function AdaptiveArrayPool() IdDict{DataType, Any}(), 1, # _current_depth: 1 = global scope (sentinel) [UInt16(0)], # _touched_type_masks: sentinel (no bits set) - [false] # _touched_has_others: sentinel (no others) + [false], # _touched_has_others: sentinel (no others) + "", # _pending_callsite: no pending + "", # _pending_return_site: no pending + nothing # _borrow_log: lazily created at LV >= 3 ) end @@ -445,3 +475,44 @@ Apply `f` to each fixed slot TypedPool. Zero allocation via compile-time unrolli nothing end end + +# ============================================================================== +# Safety Tag Dispatch (compile-time, zero-cost when STATIC_POOL_CHECKS=false) +# ============================================================================== +# +# Instead of `@static if STATIC_POOL_CHECKS` at every call site, we dispatch on +# a singleton tag. The compiler resolves `const _POOL_CHECK_TAG` at compile time, +# monomorphizes the call, and dead-code-eliminates the unused path entirely. + +"""Singleton tag: pool safety checks enabled.""" +struct _CheckOn end + +"""Singleton tag: pool safety checks disabled (all safety helpers become no-ops).""" +struct _CheckOff end + +"""Compile-time tag selected by `STATIC_POOL_CHECKS`.""" +const _POOL_CHECK_TAG = STATIC_POOL_CHECKS ? _CheckOn() : _CheckOff() + +# --- Active implementations (_CheckOn) --- + +@inline function _set_pending_callsite!(::_CheckOn, pool::AbstractArrayPool, msg::String) + POOL_SAFETY_LV[] >= 3 && isempty(pool._pending_callsite) && (pool._pending_callsite = msg) + return nothing +end + +@inline function _maybe_record_borrow!(::_CheckOn, pool::AbstractArrayPool, tp::AbstractTypedPool) + POOL_SAFETY_LV[] >= 3 && _record_borrow_from_pending!(pool, tp) + return nothing +end + +# --- No-op implementations (_CheckOff) --- + +@inline _set_pending_callsite!(::_CheckOff, ::AbstractArrayPool, ::String) = nothing +@inline _maybe_record_borrow!(::_CheckOff, ::AbstractArrayPool, ::AbstractTypedPool) = nothing + +# --- Convenience wrappers (auto-dispatch via const tag) --- + +@inline _set_pending_callsite!(pool::AbstractArrayPool, msg::String) = + _set_pending_callsite!(_POOL_CHECK_TAG, pool, msg) +@inline _maybe_record_borrow!(pool::AbstractArrayPool, tp::AbstractTypedPool) = + _maybe_record_borrow!(_POOL_CHECK_TAG, pool, tp) diff --git a/src/macros.jl b/src/macros.jl index 0946a28f..c6866859 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -2,6 +2,225 @@ # Macros for AdaptiveArrayPools # ============================================================================== +# ============================================================================== +# PoolEscapeError — Compile-time escape detection error +# ============================================================================== + +"""Per-return-point escape detail: which expression, at which line, leaks which vars.""" +struct EscapePoint + expr::Any + line::Union{Int, Nothing} + vars::Vector{Symbol} +end + +"""Per-variable declaration site: where an escaping variable was assigned.""" +struct DeclarationSite + var::Symbol + expr::Any + line::Union{Int, Nothing} + file::Union{Symbol, Nothing} +end + +""" + PoolEscapeError <: Exception + +Thrown at macro expansion time when pool-backed variables are detected in +return position of `@with_pool` / `@maybe_with_pool` blocks. + +This is a compile-time check with zero runtime cost. +""" +struct PoolEscapeError <: Exception + vars::Vector{Symbol} + file::Union{String, Nothing} + line::Union{Int, Nothing} + points::Vector{EscapePoint} + var_info::Dict{Symbol, Tuple{Symbol, Vector{Symbol}}} # var => (kind, source_vars) + declarations::Vector{DeclarationSite} +end + +PoolEscapeError(vars, file, line, points) = + PoolEscapeError(vars, file, line, points, Dict{Symbol, Tuple{Symbol, Vector{Symbol}}}(), DeclarationSite[]) + +PoolEscapeError(vars, file, line, points, var_info) = + PoolEscapeError(vars, file, line, points, var_info, DeclarationSite[]) + +"""Render an expression with escaped variable names highlighted in red. +Handles return, tuple, NamedTuple, array literal; falls back to print for others.""" +function _render_return_expr(io::IO, expr, escaped::Set{Symbol}) + return if expr isa Symbol + if expr in escaped + printstyled(io, string(expr); color = :red, bold = true) + else + print(io, expr) + end + elseif expr isa Expr + if expr.head == :return && !isempty(expr.args) + printstyled(io, "return "; color = :light_black) + _render_return_expr(io, expr.args[1], escaped) + elseif expr.head == :tuple + print(io, "(") + for (i, arg) in enumerate(expr.args) + i > 1 && print(io, ", ") + _render_return_expr(io, arg, escaped) + end + print(io, ")") + elseif expr.head == :(=) && length(expr.args) >= 2 + # NamedTuple key = value — only highlight value + print(io, expr.args[1], " = ") + _render_return_expr(io, expr.args[2], escaped) + elseif expr.head == :vect + print(io, "[") + for (i, arg) in enumerate(expr.args) + i > 1 && print(io, ", ") + _render_return_expr(io, arg, escaped) + end + print(io, "]") + else + print(io, expr) + end + else + print(io, expr) + end +end + +function Base.showerror(io::IO, e::PoolEscapeError) + # Header + printstyled(io, "PoolEscapeError"; color = :red, bold = true) + printstyled(io, " (compile-time)"; color = :light_black) + println(io) + + # Descriptive message + println(io) + n = length(e.vars) + if n == 1 + printstyled(io, " The following variable escapes the @with_pool scope:\n"; color = :light_black) + else + printstyled(io, " The following ", n, " variables escape the @with_pool scope:\n"; color = :light_black) + end + + # Escaped variables — one per line with classification + println(io) + for v in e.vars + printstyled(io, " "; color = :normal) + printstyled(io, string(v); color = :red, bold = true) + kind, sources = get(e.var_info, v, (:pool_buffer, Symbol[])) + if kind === :container + src_str = join(string.(sources), ", ") + printstyled(io, " ← wraps pool variable"; color = :light_black) + length(sources) > 1 && printstyled(io, "s"; color = :light_black) + printstyled(io, " (", src_str, ")\n"; color = :light_black) + elseif kind === :alias + printstyled(io, " ← alias of pool variable (", string(sources[1]), ")\n"; color = :light_black) + elseif kind === :pool_array + printstyled(io, " ← pool-acquired array\n"; color = :light_black) + elseif kind === :pool_bitarray + printstyled(io, " ← pool-acquired BitArray\n"; color = :light_black) + elseif kind === :pool_view + printstyled(io, " ← pool-acquired view\n"; color = :light_black) + else + printstyled(io, " ← pool-backed temporary\n"; color = :light_black) + end + end + + # Declaration sites — where each escaping variable was assigned + if !isempty(e.declarations) + println(io) + printstyled(io, " Declarations:\n"; bold = true) + for (idx, decl) in enumerate(e.declarations) + printstyled(io, " [", idx, "] "; color = :light_black) + printstyled(io, string(decl.expr); color = :cyan) + # Fall back to macro source file when body LineNumberNode has :none (REPL/eval) + decl_file = (decl.file !== nothing && decl.file !== :none) ? decl.file : e.file + loc = _format_location_str(decl_file, decl.line) + if loc !== nothing + printstyled(io, " ["; color = :cyan, bold = true) + printstyled(io, loc; color = :cyan, bold = true) + printstyled(io, "] "; color = :cyan, bold = true) + end + println(io) + end + end + + # Escaping return points with highlighted expressions + if !isempty(e.points) + println(io) + label = length(e.points) == 1 ? " Escaping return:" : " Escaping returns:" + printstyled(io, label, "\n"; bold = true) + escaped_set = Set{Symbol}(e.vars) + for (idx, pt) in enumerate(e.points) + printstyled(io, " [", idx, "] "; color = :light_black) + _render_return_expr(io, pt.expr, escaped_set) + loc = _format_point_location(e.file, pt.line) + if loc !== nothing + printstyled(io, " ["; color = :magenta, bold = true) + printstyled(io, loc; color = :magenta, bold = true) + printstyled(io, "] "; color = :magenta, bold = true) + end + println(io) + end + end + + # Suggestion 1: fix — collect targets are direct pool vars + container sources + println(io) + collect_targets = Symbol[] + has_containers = false + for v in e.vars + vkind, vsources = get(e.var_info, v, (:pool_buffer, Symbol[])) + if vkind === :container + append!(collect_targets, vsources) + has_containers = true + else + push!(collect_targets, v) + end + end + unique!(collect_targets) + sort!(collect_targets) + collects_str = join(["collect($v)" for v in collect_targets], ", ") + printstyled(io, " Fix: "; bold = true) + printstyled(io, "Use "; color = :light_black) + printstyled(io, collects_str; bold = true) + printstyled(io, " to return owned copies.\n"; color = :light_black) + if has_containers + printstyled(io, " Copy pool variables before wrapping in containers.\n"; color = :light_black) + end + printstyled(io, " Or use a regular Julia array (zeros()/Array{T}()) if it must outlive the pool scope.\n"; color = :light_black) + + # Suggestion 2: false positive → file issue + println(io) + printstyled(io, " False positive?\n"; bold = true) + printstyled(io, " Please file an issue at "; color = :light_black) + printstyled(io, "https://github.com/ProjectTorreyPines/AdaptiveArrayPools.jl/issues"; bold = true) + return printstyled(io, "\n with a minimal reproducer so we can improve the escape detector.\n"; color = :light_black) +end + +# Location formatting helpers (uses _short_path from debug.jl) +function _format_location_str(file, line) + file_str = file !== nothing ? string(file) : nothing + # Skip "none" — Julia's placeholder for REPL/eval contexts + if file_str !== nothing && file_str != "none" + short = _short_path(file_str) + return line !== nothing ? short * ":" * string(line) : short + elseif line !== nothing + return "line " * string(line) + end + return nothing +end + +function _format_point_location(file::Union{String, Nothing}, line::Union{Int, Nothing}) + # Skip "none" — Julia's placeholder for REPL/eval contexts + if file !== nothing && file != "none" + short = _short_path(file) + return line !== nothing ? short * ":" * string(line) : short + elseif line !== nothing + return "line " * string(line) + end + return nothing +end + +# Suppress stacktrace — LoadError delegates to this via showerror(io, ex.error, bt) +Base.showerror(io::IO, e::PoolEscapeError, ::Any; backtrace = true) = showerror(io, e) + + # ============================================================================== # Backend Dispatch (for extensibility) # ============================================================================== @@ -328,6 +547,10 @@ function _generate_pool_code(pool_name, expr, force_enable; source::Union{LineNu return _generate_function_pool_code(pool_name, expr, force_enable, false; source) end + # Compile-time escape detection (zero runtime cost) + _esc = _check_compile_time_escape(expr, pool_name, source) + _esc !== nothing && return :(throw($_esc)) + # Block logic # Extract types from acquire! calls for optimized checkpoint/rewind # Only extract types for calls to the target pool (pool_name) @@ -342,6 +565,12 @@ function _generate_pool_code(pool_name, expr, force_enable; source::Union{LineNu # For dynamic path: keep acquire! untransformed so _record_type_touch! is called transformed_expr = use_typed ? _transform_acquire_calls(expr, pool_name) : expr + # Inject borrow callsite recording (LV≥3 at runtime; gated by STATIC_POOL_CHECKS at compile time) + if STATIC_POOL_CHECKS + transformed_expr = _inject_pending_callsite(transformed_expr, pool_name, expr) + transformed_expr = _transform_return_stmts(transformed_expr, pool_name) + end + if use_typed checkpoint_call = _generate_typed_checkpoint_call(esc(pool_name), static_types) else @@ -360,7 +589,7 @@ function _generate_pool_code(pool_name, expr, force_enable; source::Union{LineNu $checkpoint_call try local _result = $(esc(transformed_expr)) - if $POOL_DEBUG[] + if ($POOL_SAFETY_LV[] >= 2 || $POOL_DEBUG[]) $_validate_pool_return(_result, $(esc(pool_name))) end _result @@ -376,7 +605,7 @@ function _generate_pool_code(pool_name, expr, force_enable; source::Union{LineNu $checkpoint_call try local _result = $(esc(transformed_expr)) - if $POOL_DEBUG[] + if ($POOL_SAFETY_LV[] >= 2 || $POOL_DEBUG[]) $_validate_pool_return(_result, $(esc(pool_name))) end _result @@ -425,6 +654,10 @@ function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, forc return _generate_function_pool_code_with_backend(backend, pool_name, expr, false, false; source) end + # Compile-time escape detection (zero runtime cost) + _esc = _check_compile_time_escape(expr, pool_name, source) + _esc !== nothing && return :(throw($_esc)) + # Block logic with runtime check all_types = _extract_acquire_types(expr, pool_name) local_vars = _extract_local_assignments(expr) @@ -433,6 +666,10 @@ function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, forc # For typed path: transform acquire! → _acquire_impl! (bypasses type touch recording) # For dynamic path: keep acquire! untransformed so _record_type_touch! is called transformed_expr = use_typed ? _transform_acquire_calls(expr, pool_name) : expr + if STATIC_POOL_CHECKS + transformed_expr = _inject_pending_callsite(transformed_expr, pool_name, expr) + transformed_expr = _transform_return_stmts(transformed_expr, pool_name) + end pool_getter = :($_get_pool_for_backend($(Val{backend}()))) if use_typed @@ -449,7 +686,7 @@ function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, forc $checkpoint_call try local _result = $(esc(transformed_expr)) - if $POOL_DEBUG[] + if ($POOL_SAFETY_LV[] >= 2 || $POOL_DEBUG[]) $_validate_pool_return(_result, $(esc(pool_name))) end _result @@ -468,6 +705,10 @@ function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, forc return _generate_function_pool_code_with_backend(backend, pool_name, expr, true, false; source) end + # Compile-time escape detection (zero runtime cost) + _esc = _check_compile_time_escape(expr, pool_name, source) + _esc !== nothing && return :(throw($_esc)) + # Block logic: Extract types from acquire! calls for optimized checkpoint/rewind all_types = _extract_acquire_types(expr, pool_name) local_vars = _extract_local_assignments(expr) @@ -479,6 +720,10 @@ function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, forc # For typed path: transform acquire! → _acquire_impl! (bypasses type touch recording) # For dynamic path: keep acquire! untransformed so _record_type_touch! is called transformed_expr = use_typed ? _transform_acquire_calls(expr, pool_name) : expr + if STATIC_POOL_CHECKS + transformed_expr = _inject_pending_callsite(transformed_expr, pool_name, expr) + transformed_expr = _transform_return_stmts(transformed_expr, pool_name) + end # Use Val{backend}() for compile-time dispatch - fully inlinable pool_getter = :($_get_pool_for_backend($(Val{backend}()))) @@ -500,7 +745,7 @@ function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, forc $checkpoint_call try local _result = $(esc(transformed_expr)) - if $POOL_DEBUG[] + if ($POOL_SAFETY_LV[] >= 2 || $POOL_DEBUG[]) $_validate_pool_return(_result, $(esc(pool_name))) end _result @@ -536,6 +781,10 @@ function _generate_function_pool_code_with_backend(backend::Symbol, pool_name, f return Expr(def_head, esc(call_expr), new_body) end + # Compile-time escape detection (zero runtime cost) + _esc = _check_compile_time_escape(body, pool_name, source) + _esc !== nothing && return :(throw($_esc)) + # Analyze body for types all_types = _extract_acquire_types(body, pool_name) local_vars = _extract_local_assignments(body) @@ -545,6 +794,10 @@ function _generate_function_pool_code_with_backend(backend::Symbol, pool_name, f # For typed path: transform acquire! → _acquire_impl! (bypasses type touch recording) # For dynamic path: keep acquire! untransformed so _record_type_touch! is called transformed_body = use_typed ? _transform_acquire_calls(body, pool_name) : body + if STATIC_POOL_CHECKS + transformed_body = _inject_pending_callsite(transformed_body, pool_name, body) + transformed_body = _transform_return_stmts(transformed_body, pool_name) + end # Use Val{backend}() for compile-time dispatch pool_getter = :($_get_pool_for_backend($(Val{backend}()))) @@ -562,7 +815,13 @@ function _generate_function_pool_code_with_backend(backend::Symbol, pool_name, f local $(esc(pool_name)) = $pool_getter $checkpoint_call try - $(esc(transformed_body)) + local _result = begin + $(esc(transformed_body)) + end + if ($POOL_SAFETY_LV[] >= 2 || $POOL_DEBUG[]) + $_validate_pool_return(_result, $(esc(pool_name))) + end + _result finally $rewind_call end @@ -574,7 +833,13 @@ function _generate_function_pool_code_with_backend(backend::Symbol, pool_name, f local $(esc(pool_name)) = $pool_getter $checkpoint_call try - $(esc(transformed_body)) + local _result = begin + $(esc(transformed_body)) + end + if ($POOL_SAFETY_LV[] >= 2 || $POOL_DEBUG[]) + $_validate_pool_return(_result, $(esc(pool_name))) + end + _result finally $rewind_call end @@ -607,6 +872,10 @@ function _generate_function_pool_code(pool_name, func_def, force_enable, disable return Expr(def_head, esc(call_expr), new_body) end + # Compile-time escape detection (zero runtime cost) + _esc = _check_compile_time_escape(body, pool_name, source) + _esc !== nothing && return :(throw($_esc)) + # Analyze body for types all_types = _extract_acquire_types(body, pool_name) local_vars = _extract_local_assignments(body) @@ -616,6 +885,10 @@ function _generate_function_pool_code(pool_name, func_def, force_enable, disable # For typed path: transform acquire! → _acquire_impl! (bypasses type touch recording) # For dynamic path: keep acquire! untransformed so _record_type_touch! is called transformed_body = use_typed ? _transform_acquire_calls(body, pool_name) : body + if STATIC_POOL_CHECKS + transformed_body = _inject_pending_callsite(transformed_body, pool_name, body) + transformed_body = _transform_return_stmts(transformed_body, pool_name) + end if use_typed checkpoint_call = _generate_typed_checkpoint_call(esc(pool_name), static_types) @@ -634,7 +907,13 @@ function _generate_function_pool_code(pool_name, func_def, force_enable, disable local $(esc(pool_name)) = get_task_local_pool() $checkpoint_call try - $(esc(transformed_body)) + local _result = begin + $(esc(transformed_body)) + end + if ($POOL_SAFETY_LV[] >= 2 || $POOL_DEBUG[]) + $_validate_pool_return(_result, $(esc(pool_name))) + end + _result finally $rewind_call end @@ -646,7 +925,13 @@ function _generate_function_pool_code(pool_name, func_def, force_enable, disable local $(esc(pool_name)) = get_task_local_pool() $checkpoint_call try - $(esc(transformed_body)) + local _result = begin + $(esc(transformed_body)) + end + if ($POOL_SAFETY_LV[] >= 2 || $POOL_DEBUG[]) + $_validate_pool_return(_result, $(esc(pool_name))) + end + _result finally $rewind_call end @@ -1115,3 +1400,796 @@ function _transform_acquire_calls(expr, pool_name) end return expr end + +# ============================================================================== +# Internal: Borrow Callsite Injection (POOL_SAFETY_LV >= 3) +# ============================================================================== +# +# Second-pass AST transformation that inserts `pool._pending_callsite = "file:line"` +# before each statement containing an acquire call. This enables borrow registry +# error messages to show WHERE the problematic acquire originated. +# +# Works with both typed path (_*_impl! GlobalRefs) and dynamic path (original +# acquire!/zeros!/etc. calls). The injection is gated behind STATIC_POOL_CHECKS +# at macro expansion time, and POOL_SAFETY_LV[] >= 3 at runtime. + +const _POOL_SAFETY_LV_REF = GlobalRef(@__MODULE__, :POOL_SAFETY_LV) + +"""Set of all transformed `_*_impl!` function names (GlobalRef targets).""" +const _IMPL_FUNC_NAMES = Set{Symbol}( + [ + :_acquire_impl!, :_unsafe_acquire_impl!, + :_zeros_impl!, :_ones_impl!, :_trues_impl!, :_falses_impl!, + :_similar_impl!, :_reshape_impl!, + :_unsafe_zeros_impl!, :_unsafe_ones_impl!, :_unsafe_similar_impl!, + ] +) + +""" + _contains_acquire_call(expr, pool_name) -> Bool + +Detect if `expr` (or any sub-expression) contains a pool acquire call. +Matches both transformed (`GlobalRef`-based `_*_impl!`) and original +(`acquire!`, `zeros!`, etc.) call forms. +""" +function _contains_acquire_call(expr, pool_name) + expr isa Expr || return false + if expr.head == :call && length(expr.args) >= 2 + fn = expr.args[1] + # Transformed _*_impl! calls (GlobalRef from typed path) + if fn isa GlobalRef && fn.name in _IMPL_FUNC_NAMES + return true + end + # Original acquire calls (dynamic path, or pre-transform) + if _is_acquire_call(expr, pool_name) + return true + end + end + return any(arg -> _contains_acquire_call(arg, pool_name), expr.args) +end + +""" + _find_acquire_call_expr(expr, pool_name) -> Union{Expr, Nothing} + +Find the first acquire call expression in `expr` targeting `pool_name`. +Returns the original call Expr (e.g., `:(zeros!(pool, Float64, 10))`) or `nothing`. +Used to capture the user's source expression for debug display. +""" +function _find_acquire_call_expr(expr, pool_name) + expr isa Expr || return nothing + if _is_acquire_call(expr, pool_name) + return expr + end + for arg in expr.args + result = _find_acquire_call_expr(arg, pool_name) + result !== nothing && return result + end + return nothing +end + +""" + _inject_pending_callsite(expr, pool_name, original_expr=expr) -> Expr + +Walk block-level statements, track `LineNumberNode`s, and insert +`POOL_SAFETY_LV[] >= 3 && (pool._pending_callsite = "file:line\\nexpr")` +before each statement containing a pool acquire call. + +When `original_expr` differs from `expr` (i.e., after `_transform_acquire_calls`), +the original untransformed AST is used to extract the user's source expression +(e.g., `zeros!(pool, Float64, 10)` instead of `_zeros_impl!(pool, Float64, 10)`). + +Only processes `:block` expressions. Non-block expressions are recursed +into to find nested blocks. +""" +function _inject_pending_callsite(expr, pool_name, original_expr = expr) + expr isa Expr || return expr + if expr.head == :block + new_args = Any[] + current_lnn = nothing + orig_args = (original_expr isa Expr && original_expr.head == :block) ? original_expr.args : nothing + for (i, arg) in enumerate(expr.args) + if arg isa LineNumberNode + current_lnn = arg + push!(new_args, arg) + else + orig_arg = (orig_args !== nothing && i <= length(orig_args)) ? orig_args[i] : arg + processed = _inject_pending_callsite(arg, pool_name, orig_arg) + if current_lnn !== nothing && _contains_acquire_call(processed, pool_name) + # Use the full original statement for debug display + expr_text = string(orig_arg) + callsite_str = isempty(expr_text) ? + "$(current_lnn.file):$(current_lnn.line)" : + "$(current_lnn.file):$(current_lnn.line)\n$(expr_text)" + inject = Expr( + :&&, + Expr(:call, :>=, Expr(:ref, _POOL_SAFETY_LV_REF), 3), + Expr( + :(=), + Expr(:., pool_name, QuoteNode(:_pending_callsite)), + callsite_str + ) + ) + push!(new_args, inject) + end + push!(new_args, processed) + end + end + return Expr(:block, new_args...) + else + orig_expr_args = (original_expr isa Expr) ? original_expr.args : nothing + new_args = Any[] + for (i, arg) in enumerate(expr.args) + orig_arg = (orig_expr_args !== nothing && i <= length(orig_expr_args)) ? orig_expr_args[i] : arg + push!(new_args, _inject_pending_callsite(arg, pool_name, orig_arg)) + end + return Expr(expr.head, new_args...) + end +end + +# ============================================================================== +# Internal: Return Statement Validation (POOL_SAFETY_LV >= 2) +# ============================================================================== +# +# Transforms `return expr` → `begin local _ret = expr; validate(_ret); return _ret end` +# so that explicit `return` statements in @with_pool function bodies are validated +# before exiting. Without this, `return` bypasses the post-body _validate_pool_return +# check because it exits the function before that line is reached. +# +# Stops recursion at :function and :-> boundaries (nested function return statements +# belong to the inner function, not the @with_pool scope). + +const _POOL_DEBUG_REF = GlobalRef(@__MODULE__, :POOL_DEBUG) +const _VALIDATE_POOL_RETURN_REF = GlobalRef(@__MODULE__, :_validate_pool_return) + +""" + _transform_return_stmts(expr, pool_name) -> Expr + +Walk AST and wrap explicit `return value` statements with escape validation. +Generates: `local _ret = value; if (LV≥2 || DEBUG) validate(_ret, pool); end; return _ret` + +Does NOT recurse into nested `:function` or `:->` expressions (inner functions +have their own `return` semantics). +""" +function _transform_return_stmts(expr, pool_name, current_lnn = nothing) + expr isa Expr || return expr + + # Don't recurse into nested function definitions (return belongs to inner function) + if expr.head in (:function, :->) + return expr + end + + if expr.head == :return && length(expr.args) >= 1 + value_expr = expr.args[1] + # Bare return (return nothing) — skip validation + if value_expr === nothing + return expr + end + # Recurse into the value expression first (may contain nested returns in ternary etc.) + value_expr = _transform_return_stmts(value_expr, pool_name, current_lnn) + retvar = gensym(:_pool_ret) + + # Build return-site string for LV ≥ 3 display (e.g. "file:line\nreturn v") + return_site_str = if current_lnn !== nothing + "$(current_lnn.file):$(current_lnn.line)\n$(string(expr))" + else + "" + end + + # Conditionally set _pending_return_site before validation + validate_expr = if !isempty(return_site_str) + Expr( + :block, + Expr( + :&&, + Expr(:call, :>=, Expr(:ref, _POOL_SAFETY_LV_REF), 3), + Expr( + :(=), + Expr(:., pool_name, QuoteNode(:_pending_return_site)), + return_site_str + ) + ), + Expr(:call, _VALIDATE_POOL_RETURN_REF, retvar, pool_name) + ) + else + Expr(:call, _VALIDATE_POOL_RETURN_REF, retvar, pool_name) + end + + return Expr( + :block, + Expr(:local, Expr(:(=), retvar, value_expr)), + Expr( + :if, + Expr( + :||, + Expr(:call, :>=, Expr(:ref, _POOL_SAFETY_LV_REF), 2), + Expr(:ref, _POOL_DEBUG_REF) + ), + validate_expr + ), + Expr(:return, retvar) + ) + end + + # For blocks, track LineNumberNodes + if expr.head == :block + new_args = Any[] + lnn = current_lnn + for arg in expr.args + if arg isa LineNumberNode + lnn = arg + push!(new_args, arg) + else + push!(new_args, _transform_return_stmts(arg, pool_name, lnn)) + end + end + return Expr(:block, new_args...) + end + + # Other expressions: recurse with current_lnn + new_args = Any[_transform_return_stmts(arg, pool_name, current_lnn) for arg in expr.args] + return Expr(expr.head, new_args...) +end + +# ============================================================================== +# Internal: Compile-Time Escape Detection +# ============================================================================== +# +# Detects common pool escape patterns at macro expansion time (zero runtime cost). +# - Error: bare acquired variable as last expression (100% escape) +# - Warning: acquired variable inside tuple/array literal (likely escape) +# +# This catches the most common beginner mistake — returning a pool-backed array +# from @with_pool — before the code even runs. + +""" + _ALL_ACQUIRE_NAMES + +Set of all function names that return pool-backed arrays. +Used by `_extract_acquired_vars` to identify assignments like `v = acquire!(pool, ...)`. +""" +const _ALL_ACQUIRE_NAMES = Set{Symbol}( + [ + :acquire!, :unsafe_acquire!, :acquire_view!, :acquire_array!, + :zeros!, :ones!, :similar!, :reshape!, + :unsafe_zeros!, :unsafe_ones!, :unsafe_similar!, + :trues!, :falses!, + ] +) + +"""Function names that return views (SubArray) from pool memory.""" +const _VIEW_ACQUIRE_NAMES = Set{Symbol}( + [ + :acquire!, :acquire_view!, + :zeros!, :ones!, :similar!, + :unsafe_zeros!, :unsafe_ones!, :unsafe_similar!, + :reshape!, + ] +) + +"""Function names that return raw Arrays backed by pool memory (unsafe_wrap).""" +const _ARRAY_ACQUIRE_NAMES = Set{Symbol}( + [ + :unsafe_acquire!, :acquire_array!, + ] +) + +"""Function names that return BitArrays from pool memory.""" +const _BITARRAY_ACQUIRE_NAMES = Set{Symbol}( + [ + :trues!, :falses!, + ] +) + +""" + _is_acquire_call(expr, target_pool) -> Bool + +Check if an expression is a call to any pool acquire/convenience function +targeting `target_pool`. +""" +function _is_acquire_call(expr, target_pool) + if !(expr isa Expr && expr.head == :call && length(expr.args) >= 2) + return false + end + fn = expr.args[1] + pool_arg = expr.args[2] + pool_arg == target_pool || return false + + # Direct name + if fn isa Symbol + return fn in _ALL_ACQUIRE_NAMES + end + # Qualified name: Module.acquire! + if fn isa Expr && fn.head == :. && length(fn.args) >= 2 + qn = fn.args[end] + if qn isa QuoteNode && qn.value isa Symbol + return qn.value in _ALL_ACQUIRE_NAMES + end + end + return false +end + +""" + _acquire_call_kind(expr, target_pool) -> Union{Symbol, Nothing} + +Return the classification of an acquire call: `:pool_view`, `:pool_array`, `:pool_bitarray`, +or `nothing` if not an acquire call. +""" +function _acquire_call_kind(expr, target_pool) + if !(expr isa Expr && expr.head == :call && length(expr.args) >= 2) + return nothing + end + fn = expr.args[1] + pool_arg = expr.args[2] + pool_arg == target_pool || return nothing + + fname = nothing + if fn isa Symbol + fname = fn + elseif fn isa Expr && fn.head == :. && length(fn.args) >= 2 + qn = fn.args[end] + if qn isa QuoteNode && qn.value isa Symbol + fname = qn.value + end + end + fname === nothing && return nothing + + fname in _VIEW_ACQUIRE_NAMES && return :pool_view + fname in _ARRAY_ACQUIRE_NAMES && return :pool_array + fname in _BITARRAY_ACQUIRE_NAMES && return :pool_bitarray + return nothing +end + +""" + _extract_acquired_vars(expr, target_pool) -> Set{Symbol} + +Walk AST to find variable names assigned from acquire/convenience calls. +Returns the set of symbols that hold pool-backed arrays. + +Only top-level assignments in a block are tracked (not inside branches). +Handles both simple assignment (`v = acquire!(...)`) and tuple destructuring +(`(v, w) = (acquire!(...), expr)`). +""" +function _extract_acquired_vars(expr, target_pool, vars = Set{Symbol}()) + if expr isa Expr + if expr.head == :block + # Walk top-level statements only (for flat reassignment tracking) + for arg in expr.args + _extract_acquired_vars(arg, target_pool, vars) + end + elseif expr.head == :(=) && length(expr.args) >= 2 + lhs = expr.args[1] + rhs = expr.args[2] + if lhs isa Symbol && _is_acquire_call(rhs, target_pool) + push!(vars, lhs) + elseif lhs isa Symbol && rhs isa Symbol && rhs in vars + # Simple alias: d = z where z is acquired + push!(vars, lhs) + elseif lhs isa Symbol && _literal_contains_acquired(rhs, vars) + # Container wrapping: d = (z,), d = [z, w], etc. + push!(vars, lhs) + elseif Meta.isexpr(lhs, :tuple) && Meta.isexpr(rhs, :tuple) + # Destructuring with tuple literal RHS: (v, w) = (acquire!(...), expr) + for (l, r) in zip(lhs.args, rhs.args) + if l isa Symbol && _is_acquire_call(r, target_pool) + push!(vars, l) + elseif l isa Symbol && r isa Symbol && r in vars + # Destructuring alias: (a, d) = (..., z) + push!(vars, l) + end + end + end + # Recurse into RHS (for nested blocks with acquire calls) + _extract_acquired_vars(rhs, target_pool, vars) + else + for arg in expr.args + _extract_acquired_vars(arg, target_pool, vars) + end + end + end + return vars +end + +""" + _get_last_expression(expr) -> Any + +Return the last non-LineNumberNode expression from a block. +For non-block expressions, returns the expression itself. +""" +function _get_last_expression(expr) + if expr isa Expr && expr.head == :block + for i in length(expr.args):-1:1 + arg = expr.args[i] + arg isa LineNumberNode && continue + return _get_last_expression(arg) + end + return nothing + end + return expr +end + +""" + _collect_all_return_values(expr) -> Vector{Tuple{Any, Union{Int,Nothing}}} + +Collect all (expression, line) pairs that could be returned from a block/function body: +- Explicit `return expr` statements anywhere in the body (recursive, skips nested functions) +- Implicit returns: the last expression, recursing into if/else/elseif branches +""" +function _collect_all_return_values(expr) + values = Tuple{Any, Union{Int, Nothing}}[] + _collect_explicit_returns!(values, expr, nothing) + last_expr, last_line = _get_last_expression_with_line(expr) + if last_expr !== nothing + _collect_implicit_return_values!(values, last_expr, last_line) + end + return values +end + +"""Walk AST to find all explicit `return expr` statements with line numbers. +Tracks LineNumberNodes through blocks. Skips nested function definitions.""" +function _collect_explicit_returns!(values, expr, current_line::Union{Int, Nothing}) + expr isa Expr || return + expr.head in (:function, :(->)) && return + if expr.head == :return + push!(values, (expr, current_line)) + return + end + return if expr.head == :block + line = current_line + for arg in expr.args + if arg isa LineNumberNode + line = arg.line + else + _collect_explicit_returns!(values, arg, line) + end + end + else + for arg in expr.args + _collect_explicit_returns!(values, arg, current_line) + end + end +end + +"""Return the last non-LineNumberNode expression from a block, together with its line. +Recurses into nested blocks.""" +function _get_last_expression_with_line(expr, default_line::Union{Int, Nothing} = nothing) + if !(expr isa Expr && expr.head == :block) + return (expr, default_line) + end + for i in length(expr.args):-1:1 + arg = expr.args[i] + arg isa LineNumberNode && continue + # Find the LineNumberNode preceding this expression + line = default_line + for j in (i - 1):-1:1 + if expr.args[j] isa LineNumberNode + line = expr.args[j].line + break + end + end + return _get_last_expression_with_line(arg, line) + end + return (nothing, default_line) +end + +"""Expand implicit return values by recursing into if/elseif/else branches. +Non-branch expressions are collected as (expr, line) pairs.""" +function _collect_implicit_return_values!(values, expr, current_line::Union{Int, Nothing}) + return if expr isa Expr && expr.head in (:if, :elseif) + for i in 2:length(expr.args) + branch = expr.args[i] + if branch isa Expr && branch.head in (:if, :elseif) + _collect_implicit_return_values!(values, branch, current_line) + else + last_expr, last_line = _get_last_expression_with_line(branch, current_line) + if last_expr !== nothing + _collect_implicit_return_values!(values, last_expr, last_line) + end + end + end + else + push!(values, (expr, current_line)) + end +end + +""" + _remove_flat_reassigned!(expr, acquired, target_pool) + +Walk top-level statements in order and remove variables from `acquired` +if they are reassigned to a non-acquire call. Handles both simple assignment +(`v = expr`) and tuple destructuring (`(a, v) = expr`). +Only handles flat (non-branching) reassignment — conditional is conservatively kept. +""" +function _remove_flat_reassigned!(expr, acquired, target_pool) + if !(expr isa Expr && expr.head == :block) + return + end + for arg in expr.args + arg isa LineNumberNode && continue + if arg isa Expr && arg.head == :(=) && length(arg.args) >= 2 + lhs = arg.args[1] + rhs = arg.args[2] + if lhs isa Symbol && lhs in acquired && !_is_acquire_call(rhs, target_pool) && + !(rhs isa Symbol && rhs in acquired) && # keep aliases + !_literal_contains_acquired(rhs, acquired) # keep container wrapping + delete!(acquired, lhs) + elseif Meta.isexpr(lhs, :tuple) + # Destructuring: (a, v, b) = expr + if Meta.isexpr(rhs, :tuple) && length(rhs.args) == length(lhs.args) + # RHS is tuple literal — check each element pair + for (l, r) in zip(lhs.args, rhs.args) + if l isa Symbol && l in acquired && !_is_acquire_call(r, target_pool) && + !(r isa Symbol && r in acquired) # keep aliases + delete!(acquired, l) + end + end + else + # RHS is function call or opaque expression — + # acquired var is now reassigned to unknown value, remove it + for l in lhs.args + if l isa Symbol && l in acquired + delete!(acquired, l) + end + end + end + end + end + end + return +end + +""" + _find_direct_exposure(expr, acquired) -> Set{Symbol} + +Check if the expression directly exposes any acquired variable. +Only catches high-confidence patterns: +- Bare Symbol: `v` +- Explicit return: `return v` +- Tuple/array literal containing a var: `(v, w)`, `[v, w]` +- NamedTuple-style kw: `(a=v,)` + +Does NOT recurse into function calls (can't know what `f(v)` returns). +""" + +"""Check if a literal expression (symbol, identity call, tuple/vect) transitively contains any acquired var.""" +function _literal_contains_acquired(expr, acquired) + expr isa Symbol && return expr in acquired + if expr isa Expr + # identity(x) — transparent, look through + if expr.head == :call && length(expr.args) >= 2 && expr.args[1] === :identity + return _literal_contains_acquired(expr.args[2], acquired) + end + if expr.head in (:tuple, :vect) + for arg in expr.args + if Meta.isexpr(arg, :(=)) && length(arg.args) >= 2 + _literal_contains_acquired(arg.args[2], acquired) && return true + elseif Meta.isexpr(arg, :kw) && length(arg.args) >= 2 + _literal_contains_acquired(arg.args[2], acquired) && return true + else + _literal_contains_acquired(arg, acquired) && return true + end + end + end + end + return false +end + +"""Check if a call target is `identity` or `Base.identity`.""" +_is_identity_call(x) = x === :identity || + (x isa Expr && x.head == :. && x.args == [:Base, QuoteNode(:identity)]) + +function _find_direct_exposure(expr, acquired) + found = Set{Symbol}() + if expr isa Symbol + # Bare variable: v + if expr in acquired + push!(found, expr) + end + elseif expr isa Expr + if expr.head == :return + # return v, return (v, w), etc. + for arg in expr.args + union!(found, _find_direct_exposure(arg, acquired)) + end + elseif expr.head in (:tuple, :vect) + # (v, w), [v, w], nested (a, (b, v)), (key=(a, v),) + for arg in expr.args + if Meta.isexpr(arg, :(=)) && length(arg.args) >= 2 + # NamedTuple: (a=v,) or (a=(b,v),) + union!(found, _find_direct_exposure(arg.args[2], acquired)) + else + union!(found, _find_direct_exposure(arg, acquired)) + end + end + elseif expr.head == :parameters + # (; a=v) style named tuple parameters + for arg in expr.args + if Meta.isexpr(arg, :kw) && length(arg.args) >= 2 + union!(found, _find_direct_exposure(arg.args[2], acquired)) + end + end + elseif expr.head == :call && length(expr.args) >= 2 && _is_identity_call(expr.args[1]) + # identity(x) / Base.identity(x) — transparent, look through + union!(found, _find_direct_exposure(expr.args[2], acquired)) + end + end + return found +end + + +"""Collect acquired variable names contained in a literal expression (symbol, tuple, vect).""" +function _collect_acquired_in_literal(expr, acquired_keys::Set{Symbol}) + found = Symbol[] + _collect_acquired_in_literal!(found, expr, acquired_keys) + return found +end + +function _collect_acquired_in_literal!(found, expr, acquired_keys) + return if expr isa Symbol + expr in acquired_keys && push!(found, expr) + elseif expr isa Expr + if expr.head == :call && length(expr.args) >= 2 && expr.args[1] === :identity + _collect_acquired_in_literal!(found, expr.args[2], acquired_keys) + elseif expr.head in (:tuple, :vect) + for arg in expr.args + if Meta.isexpr(arg, :(=)) && length(arg.args) >= 2 + _collect_acquired_in_literal!(found, arg.args[2], acquired_keys) + elseif Meta.isexpr(arg, :kw) && length(arg.args) >= 2 + _collect_acquired_in_literal!(found, arg.args[2], acquired_keys) + else + _collect_acquired_in_literal!(found, arg, acquired_keys) + end + end + end + end +end + +""" + _classify_escaped_vars(expr, target_pool, escaped, acquired) + +Classify each escaped variable by its origin for better error messages: +- `:pool_view` — from acquire!, zeros!, etc. (returns SubArray) +- `:pool_array` — from unsafe_acquire! (returns Array via unsafe_wrap) +- `:pool_bitarray` — from trues!, falses! (returns BitArray) +- `:alias` — alias of another acquired variable (e.g., `d = v`) +- `:container` — wraps acquired variables in a literal (e.g., `d = [v, 1]`) + +Returns `Dict{Symbol, Tuple{Symbol, Vector{Symbol}}}` mapping var → (kind, source_vars). +""" +function _classify_escaped_vars(expr, target_pool, escaped::Vector{Symbol}, acquired::Set{Symbol}) + info = Dict{Symbol, Tuple{Symbol, Vector{Symbol}}}() + escaped_set = Set(escaped) + _classify_walk!(info, expr, target_pool, escaped_set, acquired) + return info +end + +function _classify_walk!(info, expr, target_pool, escaped_set, acquired) + expr isa Expr || return + return if expr.head == :block + for arg in expr.args + _classify_walk!(info, arg, target_pool, escaped_set, acquired) + end + elseif expr.head == :(=) && length(expr.args) >= 2 + lhs = expr.args[1] + rhs = expr.args[2] + if lhs isa Symbol && lhs in escaped_set + kind = _acquire_call_kind(rhs, target_pool) + if kind !== nothing + info[lhs] = (kind, Symbol[]) + elseif rhs isa Symbol && rhs in acquired + info[lhs] = (:alias, [rhs]) + else + sources = _collect_acquired_in_literal(rhs, acquired) + if !isempty(sources) + info[lhs] = (:container, sort!(sources)) + end + end + end + # Recurse into RHS for nested blocks + _classify_walk!(info, rhs, target_pool, escaped_set, acquired) + else + for arg in expr.args + _classify_walk!(info, arg, target_pool, escaped_set, acquired) + end + end +end + +""" + _extract_declaration_sites(expr, escaped) + +Walk the AST to find assignment sites for escaped variables. +Returns a `Vector{DeclarationSite}` sorted by line number. +""" +function _extract_declaration_sites(expr, escaped::Set{Symbol}) + sites = DeclarationSite[] + seen = Set{Symbol}() + _collect_declaration_sites!(sites, seen, expr, escaped, nothing, nothing) + sort!(sites; by = s -> something(s.line, typemax(Int))) + return sites +end + +function _collect_declaration_sites!(sites, seen, expr, escaped, current_line, current_file) + expr isa Expr || return + return if expr.head == :block + line = current_line + file = current_file + for arg in expr.args + if arg isa LineNumberNode + line = arg.line + file = arg.file + else + _collect_declaration_sites!(sites, seen, arg, escaped, line, file) + end + end + elseif expr.head == :(=) && length(expr.args) >= 2 + lhs = expr.args[1] + if lhs isa Symbol && lhs in escaped && lhs ∉ seen + push!(sites, DeclarationSite(lhs, expr, current_line, current_file)) + push!(seen, lhs) + elseif Meta.isexpr(lhs, :tuple) + for l in lhs.args + if l isa Symbol && l in escaped && l ∉ seen + push!(sites, DeclarationSite(l, expr, current_line, current_file)) + push!(seen, l) + end + end + end + _collect_declaration_sites!(sites, seen, expr.args[2], escaped, current_line, current_file) + else + for arg in expr.args + _collect_declaration_sites!(sites, seen, arg, escaped, current_line, current_file) + end + end +end + +""" + _check_compile_time_escape(expr, pool_name, source) + +Compile-time (macro expansion time) escape detection. + +Checks if the block/function body's return expression directly contains +a pool-backed variable. This catches the most common beginner mistake +at zero runtime cost. + +All detected escapes are errors — bare symbol (`v`), `return v`, and +container patterns (`(v, w)`, `[v]`, `(key=v,)`). + +Skipped when `STATIC_POOLING = false` (pooling disabled, acquire returns normal arrays). +""" +function _check_compile_time_escape(expr, pool_name, source::Union{LineNumberNode, Nothing}) + # Extract variables assigned from acquire calls + acquired = _extract_acquired_vars(expr, pool_name) + isempty(acquired) && return + + # Remove vars that were unconditionally reassigned to non-acquire values + _remove_flat_reassigned!(expr, acquired, pool_name) + isempty(acquired) && return + + # Collect ALL return points: explicit returns + implicit (last expr / if-else branches) + return_values = _collect_all_return_values(expr) + isempty(return_values) && return + + # Check each return point for direct exposure of acquired vars + all_escaped = Set{Symbol}() + points = EscapePoint[] + seen_lines = Set{Int}() + for (ret_expr, ret_line) in return_values + # Deduplicate: explicit + implicit scanners can find the same return + if ret_line !== nothing && ret_line in seen_lines + continue + end + point_escaped = _find_direct_exposure(ret_expr, acquired) + if !isempty(point_escaped) + push!(points, EscapePoint(ret_expr, ret_line, sort!(collect(point_escaped)))) + union!(all_escaped, point_escaped) + ret_line !== nothing && push!(seen_lines, ret_line) + end + end + isempty(all_escaped) && return + + sorted = sort!(collect(all_escaped)) + var_info = _classify_escaped_vars(expr, pool_name, sorted, acquired) + declarations = _extract_declaration_sites(expr, all_escaped) + file = source !== nothing ? string(source.file) : nothing + line = source !== nothing ? source.line : nothing + throw(PoolEscapeError(sorted, file, line, points, var_info, declarations)) +end diff --git a/src/state.jl b/src/state.jl index f777ad45..60b3a84a 100644 --- a/src/state.jl +++ b/src/state.jl @@ -230,6 +230,73 @@ Decrements _current_depth once after all types are rewound. end end +# ============================================================================== +# Safety: Structural Invalidation on Rewind (POOL_SAFETY_LV >= 1) +# ============================================================================== +# +# When released, backing vectors are resize!'d to 0 and cached Array/BitArray +# wrappers have their size set to (0,...). This makes stale SubArrays and Arrays +# throw BoundsError on access instead of silently returning corrupted data. +# +# @noinline keeps invalidation code off the inlined hot path of _rewind_typed_pool!. + +# No-op fallback for extension types (e.g. CuTypedPool) +_invalidate_released_slots!(::AbstractTypedPool, ::Int) = nothing + +@noinline function _invalidate_released_slots!(tp::TypedPool{T}, old_n_active::Int) where {T} + new_n = tp.n_active + # Level 2+: poison vectors with NaN/sentinel before structural invalidation + if POOL_SAFETY_LV[] >= 2 + _poison_released_vectors!(tp, old_n_active) + end + # Level 1+: resize backing vectors to length 0 (invalidates SubArrays from acquire!) + for i in (new_n + 1):old_n_active + @inbounds resize!(tp.vectors[i], 0) + end + # Invalidate N-D Array wrappers from unsafe_acquire! (setfield! size to zeros) + for N_idx in 1:length(tp.arr_wrappers) + wrappers_for_N = @inbounds tp.arr_wrappers[N_idx] + wrappers_for_N === nothing && continue + wrappers = wrappers_for_N::Vector{Any} + for i in (new_n + 1):min(old_n_active, length(wrappers)) + wrapper = @inbounds wrappers[i] + wrapper === nothing && continue + setfield!(wrapper::Array, :size, ntuple(_ -> 0, N_idx)) + end + end + return nothing +end + +@noinline function _invalidate_released_slots!(tp::BitTypedPool, old_n_active::Int) + new_n = tp.n_active + # Level 2+: poison BitVectors (all bits set to true) + if POOL_SAFETY_LV[] >= 2 + _poison_released_vectors!(tp, old_n_active) + end + # Level 1+: resize backing BitVectors to length 0 (invalidates chunks) + for i in (new_n + 1):old_n_active + @inbounds resize!(tp.vectors[i], 0) + end + # Invalidate N-D BitArray wrappers (setfield! len and dims to zeros) + for N_idx in 1:length(tp.arr_wrappers) + wrappers_for_N = @inbounds tp.arr_wrappers[N_idx] + wrappers_for_N === nothing && continue + wrappers = wrappers_for_N::Vector{Any} + for i in (new_n + 1):min(old_n_active, length(wrappers)) + wrapper = @inbounds wrappers[i] + wrapper === nothing && continue + ba = wrapper::BitArray + setfield!(ba, :len, 0) + setfield!(ba, :dims, ntuple(_ -> 0, N_idx)) + end + end + return nothing +end + +# ============================================================================== +# Internal: Rewind with Orphan Cleanup +# ============================================================================== + # Internal helper for rewind with orphan cleanup (works for any AbstractTypedPool) # Uses 1-based sentinel pattern: no isempty checks needed (sentinel [0] guarantees non-empty) @inline function _rewind_typed_pool!(tp::AbstractTypedPool, current_depth::Int) @@ -243,6 +310,11 @@ end pop!(tp._checkpoint_n_active) end + # Capture n_active before restore (for safety invalidation) + @static if STATIC_POOL_CHECKS + _old_n_active = tp.n_active + end + # 2. Normal Rewind Logic (Sentinel Pattern) # Now the stack top is guaranteed to be at depth <= current depth. if @inbounds tp._checkpoint_depths[end] == current_depth @@ -256,6 +328,14 @@ end # - If sentinel (_checkpoint_n_active=[0]), restores to n_active=0 tp.n_active = @inbounds tp._checkpoint_n_active[end] end + + # 3. Safety: invalidate released slots (Level 1+) + @static if STATIC_POOL_CHECKS + if POOL_SAFETY_LV[] >= 1 && _old_n_active > tp.n_active + _invalidate_released_slots!(tp, _old_n_active) + end + end + return nothing end @@ -450,12 +530,20 @@ Reset state without clearing allocated storage. Sets `n_active = 0` and restores checkpoint stacks to sentinel state. """ function reset!(tp::AbstractTypedPool) + @static if STATIC_POOL_CHECKS + _old_n_active = tp.n_active + end tp.n_active = 0 # Restore sentinel values (1-based sentinel pattern) empty!(tp._checkpoint_n_active) push!(tp._checkpoint_n_active, 0) # Sentinel: n_active=0 at depth=0 empty!(tp._checkpoint_depths) push!(tp._checkpoint_depths, 0) # Sentinel: depth=0 = no checkpoint + @static if STATIC_POOL_CHECKS + if POOL_SAFETY_LV[] >= 1 && _old_n_active > 0 + _invalidate_released_slots!(tp, _old_n_active) + end + end return tp end @@ -517,6 +605,11 @@ function reset!(pool::AdaptiveArrayPool) empty!(pool._touched_has_others) push!(pool._touched_has_others, false) # Sentinel: no others + # Clear borrow registry and return-site tracking + pool._pending_callsite = "" + pool._pending_return_site = "" + pool._borrow_log = nothing + return pool end diff --git a/src/types.jl b/src/types.jl index 3a5494b1..e6c7d792 100644 --- a/src/types.jl +++ b/src/types.jl @@ -291,6 +291,36 @@ const _TYPE_BITS_MASK = UInt16(0x00FF) # bits 0-7: fixed-slot type bits # Check whether a type's bit is set in a bitmask (e.g. _touched_type_masks or combined). @inline _has_bit(mask::UInt16, ::Type{T}) where {T} = (mask & _fixed_slot_bit(T)) != 0 +# ============================================================================== +# Safety Configuration (2-Tier Toggle) +# ============================================================================== +# +# Tier 1: STATIC_POOL_CHECKS (compile-time const) +# Set via LocalPreferences.toml: pool_checks = true/false +# When false: all safety code elided at compile time (zero overhead) +# +# Tier 2: POOL_SAFETY_LV (runtime Ref{Int}, levels 0/1/2/3) +# 0 = off, 1 = guard, 2 = full, 3 = debug (borrow registry) +# Default: 1 (guard mode — safe by default) + +using Preferences: @load_preference + +const STATIC_POOL_CHECKS = @load_preference("pool_checks", true)::Bool + +""" + POOL_SAFETY_LV + +Runtime safety level for pool operations. Only effective when `STATIC_POOL_CHECKS` is `true`. + +- `0`: Off — no safety checks (Ref read only, ~1ns) +- `1`: Guard — structural invalidation on rewind (resize + setfield!, ~1ns/slot) +- `2`: Full — guard + escape detection on scope exit + poisoning +- `3`: Debug — full + borrow registry (acquire call-site tracking in error messages) + +Default: `1` (guard mode) +""" +const POOL_SAFETY_LV = Ref(1) + # ============================================================================== # AdaptiveArrayPool # ============================================================================== @@ -319,6 +349,11 @@ mutable struct AdaptiveArrayPool <: AbstractArrayPool _current_depth::Int # Current scope depth (1 = global scope) _touched_type_masks::Vector{UInt16} # Per-depth: which fixed slots were touched + mode flags _touched_has_others::Vector{Bool} # Per-depth: any non-fixed-slot type touched? + + # Borrow registry (POOL_SAFETY_LV >= 3 only) + _pending_callsite::String # "" = no pending; set by macro before acquire + _pending_return_site::String # "" = no pending; set by macro before validate + _borrow_log::Union{Nothing, IdDict{Any, String}} # vector_obj => callsite string end function AdaptiveArrayPool() @@ -334,7 +369,10 @@ function AdaptiveArrayPool() IdDict{DataType, Any}(), 1, # _current_depth: 1 = global scope (sentinel) [UInt16(0)], # _touched_type_masks: sentinel (no bits set) - [false] # _touched_has_others: sentinel (no others) + [false], # _touched_has_others: sentinel (no others) + "", # _pending_callsite: no pending + "", # _pending_return_site: no pending + nothing # _borrow_log: lazily created at LV >= 3 ) end @@ -388,3 +426,58 @@ Apply `f` to each fixed slot TypedPool. Zero allocation via compile-time unrolli nothing end end + +# ============================================================================== +# Safety Tag Dispatch (compile-time, zero-cost when STATIC_POOL_CHECKS=false) +# ============================================================================== +# +# Instead of `@static if STATIC_POOL_CHECKS` at every call site, we dispatch on +# a singleton tag. The compiler resolves `const _POOL_CHECK_TAG` at compile time, +# monomorphizes the call, and dead-code-eliminates the unused path entirely. + +"""Singleton tag: pool safety checks enabled.""" +struct _CheckOn end + +"""Singleton tag: pool safety checks disabled (all safety helpers become no-ops).""" +struct _CheckOff end + +"""Compile-time tag selected by `STATIC_POOL_CHECKS`.""" +const _POOL_CHECK_TAG = STATIC_POOL_CHECKS ? _CheckOn() : _CheckOff() + +# --- Active implementations (_CheckOn) --- + +""" + _set_pending_callsite!(pool, msg::String) + +Record a pending callsite string for borrow tracking (safety level ≥ 3). +Only sets the callsite if no prior callsite is pending (macro-injected ones take priority). +Dispatches to no-op when `STATIC_POOL_CHECKS` is `false`. +""" +@inline function _set_pending_callsite!(::_CheckOn, pool::AbstractArrayPool, msg::String) + POOL_SAFETY_LV[] >= 3 && isempty(pool._pending_callsite) && (pool._pending_callsite = msg) + return nothing +end + +""" + _maybe_record_borrow!(pool, tp::AbstractTypedPool) + +Flush the pending callsite into the borrow log (safety level ≥ 3). +Delegates to `_record_borrow_from_pending!` (defined in `debug.jl`). +Dispatches to no-op when `STATIC_POOL_CHECKS` is `false`. +""" +@inline function _maybe_record_borrow!(::_CheckOn, pool::AbstractArrayPool, tp::AbstractTypedPool) + POOL_SAFETY_LV[] >= 3 && _record_borrow_from_pending!(pool, tp) + return nothing +end + +# --- No-op implementations (_CheckOff) --- + +@inline _set_pending_callsite!(::_CheckOff, ::AbstractArrayPool, ::String) = nothing +@inline _maybe_record_borrow!(::_CheckOff, ::AbstractArrayPool, ::AbstractTypedPool) = nothing + +# --- Convenience wrappers (auto-dispatch via const tag) --- + +@inline _set_pending_callsite!(pool::AbstractArrayPool, msg::String) = + _set_pending_callsite!(_POOL_CHECK_TAG, pool, msg) +@inline _maybe_record_borrow!(pool::AbstractArrayPool, tp::AbstractTypedPool) = + _maybe_record_borrow!(_POOL_CHECK_TAG, pool, tp) diff --git a/src/utils.jl b/src/utils.jl index 4b77d21a..d9ee173d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,93 +1,3 @@ -# ============================================================================== -# Debugging & Safety -# ============================================================================== - -""" - POOL_DEBUG - -When `true`, `@with_pool` macros validate that returned values don't -reference pool memory (which would be unsafe). - -Default: `false` -""" -const POOL_DEBUG = Ref(false) - -function _validate_pool_return(val, pool::AdaptiveArrayPool) - # 0. Check BitArray / BitVector (bit-packed storage) - # Note: _check_bitchunks_overlap is defined in bitarray.jl / legacy/bitarray.jl (included after utils.jl) - if val isa BitArray - _check_bitchunks_overlap(val, pool) - return - end - - # 1. Check SubArray - if val isa SubArray - p = parent(val) - # Use pointer overlap check for ALL Array parents (Vector <: Array) - # This catches both: - # - acquire!() 1D returns: SubArray backed by pool's internal Vector - # - view(unsafe_acquire!()): SubArray backed by unsafe_wrap'd Array - if p isa Array - _check_pointer_overlap(p, pool) - elseif p isa BitArray - _check_bitchunks_overlap(p, pool) - end - return - end - - # 2. Check ReshapedArray (from acquire! N-D, wraps SubArray of pool Vector) - if val isa Base.ReshapedArray - p = parent(val) - # ReshapedArray wraps SubArray{T,1,Vector{T},...} - if p isa SubArray - pp = parent(p) - if pp isa Array - _check_pointer_overlap(pp, pool) - elseif pp isa BitArray - _check_bitchunks_overlap(pp, pool) - end - end - return - end - - # 3. Check raw Array (from unsafe_acquire!) - return if val isa Array - _check_pointer_overlap(val, pool) - end -end - -# Check if array memory overlaps with any pool vector -function _check_pointer_overlap(arr::Array, pool::AdaptiveArrayPool) - arr_ptr = UInt(pointer(arr)) - arr_len = length(arr) * sizeof(eltype(arr)) - arr_end = arr_ptr + arr_len - - check_overlap = function (tp) - for v in tp.vectors - v_ptr = UInt(pointer(v)) - v_len = length(v) * sizeof(eltype(v)) - v_end = v_ptr + v_len - if !(arr_end <= v_ptr || v_end <= arr_ptr) - error("Safety Violation: The function returned an Array backed by pool memory. This is unsafe as the memory will be reclaimed. Please return a copy (collect) or a scalar.") - end - end - return - end - - # Check fixed slots - foreach_fixed_slot(pool) do tp - check_overlap(tp) - end - - # Check others - for tp in values(pool.others) - check_overlap(tp) - end - return -end - -_validate_pool_return(val, ::DisabledPool) = nothing - # ============================================================================== # Statistics & Pretty Printing # ============================================================================== diff --git a/test/runtests.jl b/test/runtests.jl index ba9dbb0a..7b5c653a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,6 +26,10 @@ else include("test_macros.jl") include("test_task_local_pool.jl") include("test_utils.jl") + include("test_debug.jl") + include("test_borrow_registry.jl") + include("test_safety.jl") + include("test_compile_escape.jl") include("test_macro_expansion.jl") include("test_macro_internals.jl") include("test_zero_allocation.jl") @@ -48,6 +52,9 @@ else include("test_macros.jl") include("test_task_local_pool.jl") include("test_utils.jl") + include("test_debug.jl") + include("test_safety.jl") + include("test_compile_escape.jl") include("test_macro_expansion.jl") include("test_macro_internals.jl") include("test_zero_allocation.jl") diff --git a/test/test_allocation.jl b/test/test_allocation.jl index 3666da80..0617adeb 100644 --- a/test/test_allocation.jl +++ b/test/test_allocation.jl @@ -15,10 +15,15 @@ ff2 = zeros!(pool, Bit, 100) C = similar!(pool, tt1) + nothing # avoid compile-time escape error (C is pool-backed) end @testset "zero allocation on reuse" begin + # Disable safety invalidation: rewind-time resize!/setfield! forces cache misses + # (new SubArray views on legacy, new BitArray wrappers), breaking zero-alloc invariant. + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 0 # First call: JIT + initial cache miss (pool arrays + N-way bitarray cache) alloc1 = @allocated foo() @@ -35,4 +40,6 @@ end alloc3 = @allocated foo() @test alloc2 == 0 @test alloc3 == 0 + + POOL_SAFETY_LV[] = old_safety end diff --git a/test/test_backend_macro_expansion.jl b/test/test_backend_macro_expansion.jl index df0bbde8..ad292cc7 100644 --- a/test/test_backend_macro_expansion.jl +++ b/test/test_backend_macro_expansion.jl @@ -483,14 +483,16 @@ @test occursin("Val{:cuda}", expr_str) end - @testset "vs @with_pool :backend — no runtime toggle" begin + @testset "vs @with_pool :backend — no MAYBE_POOLING toggle" begin expr = @macroexpand @with_pool :cuda pool function with_backend_func(n) v = acquire!(pool, Float64, n) return sum(v) end body_str = string(expr.args[2]) - @test !occursin(refvalue_pattern, body_str) + # @with_pool has POOL_DEBUG check but NOT MAYBE_POOLING runtime toggle + @test occursin("_validate_pool_return", body_str) # POOL_DEBUG present + @test !occursin("DisabledPool", body_str) # No MAYBE_POOLING branch @test occursin("_get_pool_for_backend", body_str) end end diff --git a/test/test_basic.jl b/test/test_basic.jl index de1b5c72..091781c6 100644 --- a/test/test_basic.jl +++ b/test/test_basic.jl @@ -81,7 +81,10 @@ # Verify independence: writing to v1_big doesn't corrupt v2_reuse v1_big .= 99.0 @test all(v2_reuse .== 3.0) + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 0 # disable invalidation so backing vector length is preserved rewind!(pool) + POOL_SAFETY_LV[] = old_safety # Re-acquire slot 1 with SMALLER size — no resize needed, backing vec stays large checkpoint!(pool) diff --git a/test/test_borrow_registry.jl b/test/test_borrow_registry.jl new file mode 100644 index 00000000..a50462bb --- /dev/null +++ b/test/test_borrow_registry.jl @@ -0,0 +1,430 @@ +import AdaptiveArrayPools: _validate_pool_return, _lookup_borrow_callsite, + PoolRuntimeEscapeError, Bit + +_test_leak(x) = x + +@testset "Borrow Registry (POOL_SAFETY_LV=3)" begin + + # ============================================================================== + # Basic recording: LV=3 macro path → callsite in escape error + # ============================================================================== + + @testset "Macro path: escape error includes callsite" begin + old_lv = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 3 + + err = try + @with_pool pool begin + v = acquire!(pool, Float64, 10) + _test_leak(v) + end + nothing + catch e + e + end + + @test err isa PoolRuntimeEscapeError + @test err.callsite !== nothing + @test contains(err.callsite, ":") # "file:line" format + + POOL_SAFETY_LV[] = old_lv + end + + @testset "Macro path: unsafe_acquire! escape includes callsite" begin + old_lv = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 3 + + err = try + @with_pool pool begin + v = unsafe_acquire!(pool, Float64, 10) + _test_leak(v) + end + nothing + catch e + e + end + + @test err isa PoolRuntimeEscapeError + @test err.callsite !== nothing + @test contains(err.callsite, ":") + + POOL_SAFETY_LV[] = old_lv + end + + # ============================================================================== + # Non-macro path: direct acquire! → generic label + # ============================================================================== + + @testset "Direct acquire! shows generic callsite label" begin + old_lv = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 3 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + + v = acquire!(pool, Float64, 10) + err = try + _validate_pool_return(v, pool) + nothing + catch e + e + end + + @test err isa PoolRuntimeEscapeError + @test err.callsite == "" + + rewind!(pool) + POOL_SAFETY_LV[] = old_lv + end + + @testset "Direct unsafe_acquire! shows generic callsite label" begin + old_lv = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 3 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + + v = unsafe_acquire!(pool, Float64, 10) + err = try + _validate_pool_return(v, pool) + nothing + catch e + e + end + + @test err isa PoolRuntimeEscapeError + @test err.callsite == "" + + rewind!(pool) + POOL_SAFETY_LV[] = old_lv + end + + # ============================================================================== + # Convenience functions via macro → callsite + # ============================================================================== + + @testset "Macro path: zeros! escape includes callsite" begin + old_lv = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 3 + + err = try + @with_pool pool begin + v = zeros!(pool, Float64, 10) + _test_leak(v) + end + nothing + catch e + e + end + + @test err isa PoolRuntimeEscapeError + @test err.callsite !== nothing + @test contains(err.callsite, ":") + + POOL_SAFETY_LV[] = old_lv + end + + @testset "Macro path: callsite includes expression text" begin + old_lv = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 3 + + err = try + @with_pool pool begin + v = zeros!(pool, Float64, 10) + _test_leak(v) + end + nothing + catch e + e + end + + @test err isa PoolRuntimeEscapeError + @test err.callsite !== nothing + # Callsite should contain expression text after \n + @test contains(err.callsite, "\n") + @test contains(err.callsite, "zeros!(pool, Float64, 10)") + + POOL_SAFETY_LV[] = old_lv + end + + @testset "Direct zeros! shows generic callsite label" begin + old_lv = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 3 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + + v = zeros!(pool, Float64, 10) + err = try + _validate_pool_return(v, pool) + nothing + catch e + e + end + + @test err isa PoolRuntimeEscapeError + @test err.callsite == "" + + rewind!(pool) + POOL_SAFETY_LV[] = old_lv + end + + # ============================================================================== + # BitArray path → callsite + # ============================================================================== + + @testset "Macro path: BitArray acquire escape includes callsite" begin + old_lv = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 3 + + err = try + @with_pool pool begin + v = acquire!(pool, Bit, 100) + _test_leak(v) + end + nothing + catch e + e + end + + @test err isa PoolRuntimeEscapeError + @test err.callsite !== nothing + @test contains(err.callsite, ":") + + POOL_SAFETY_LV[] = old_lv + end + + # ============================================================================== + # LV<3: no borrow log overhead + # ============================================================================== + + @testset "LV<3 does not create borrow log" begin + for lv in (0, 1, 2) + old_lv = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = lv + + pool = AdaptiveArrayPool() + checkpoint!(pool) + _ = acquire!(pool, Float64, 10) + @test pool._borrow_log === nothing + rewind!(pool) + + POOL_SAFETY_LV[] = old_lv + end + end + + # ============================================================================== + # LV=3: borrow log IS created + # ============================================================================== + + @testset "LV=3 creates borrow log on acquire" begin + old_lv = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 3 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + _ = acquire!(pool, Float64, 10) + @test pool._borrow_log !== nothing + @test pool._borrow_log isa IdDict + rewind!(pool) + + POOL_SAFETY_LV[] = old_lv + end + + # ============================================================================== + # reset! clears borrow log + # ============================================================================== + + @testset "reset! clears borrow log and pending callsite" begin + old_lv = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 3 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + _ = acquire!(pool, Float64, 10) + @test pool._borrow_log !== nothing + + reset!(pool) + @test pool._borrow_log === nothing + @test pool._pending_callsite == "" + + POOL_SAFETY_LV[] = old_lv + end + + # ============================================================================== + # Error message format: showerror output + # ============================================================================== + + @testset "showerror: 'acquired at' shown when callsite present (LV≥3)" begin + err = PoolRuntimeEscapeError("SubArray{Float64, 1}", "Float64", "test.jl:42", nothing) + io = IOBuffer() + showerror(io, err) + msg = String(take!(io)) + + @test contains(msg, "acquired at") + @test contains(msg, "test.jl:42") + @test contains(msg, "POOL_SAFETY_LV ≥ 3") + @test !contains(msg, "Tip:") + end + + @testset "showerror: expression text shown when present in callsite" begin + err = PoolRuntimeEscapeError( + "SubArray{Float64, 1}", "Float64", + "test.jl:42\nzeros!(pool, Float64, 10)", nothing + ) + io = IOBuffer() + showerror(io, err) + msg = String(take!(io)) + + @test contains(msg, "acquired at") + @test contains(msg, "test.jl:42") + @test contains(msg, "zeros!(pool, Float64, 10)") + end + + @testset "showerror: short path used for absolute paths" begin + err = PoolRuntimeEscapeError( + "SubArray{Float64, 1}", "Float64", + "$(homedir())/.julia/dev/Foo/src/bar.jl:99\nacquire!(pool, Float64, 5)", nothing + ) + io = IOBuffer() + showerror(io, err) + msg = String(take!(io)) + + @test contains(msg, "acquired at") + # Should NOT contain the full absolute homedir path + @test !contains(msg, homedir()) + @test contains(msg, "bar.jl:99") + @test contains(msg, "acquire!(pool, Float64, 5)") + end + + @testset "showerror: 'Tip: set LV=3' shown when no callsite (LV=2)" begin + err = PoolRuntimeEscapeError("SubArray{Float64, 1}", "Float64", nothing, nothing) + io = IOBuffer() + showerror(io, err) + msg = String(take!(io)) + + @test !contains(msg, "acquired at") + @test contains(msg, "POOL_SAFETY_LV ≥ 2") + @test contains(msg, "Tip:") + @test contains(msg, "POOL_SAFETY_LV[] = 3") + end + + # ============================================================================== + # Multiple types: each gets correct callsite + # ============================================================================== + + @testset "Multiple types record independent callsites" begin + old_lv = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 3 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + + v_f64 = acquire!(pool, Float64, 10) + v_i32 = acquire!(pool, Int32, 5) + + # Both should have callsite recorded + tp_f64 = get_typed_pool!(pool, Float64) + tp_i32 = get_typed_pool!(pool, Int32) + + cs_f64 = _lookup_borrow_callsite(pool, tp_f64.vectors[1]) + cs_i32 = _lookup_borrow_callsite(pool, tp_i32.vectors[1]) + + @test cs_f64 !== nothing + @test cs_i32 !== nothing + # Both should be generic labels (direct calls) + @test cs_f64 == "" + @test cs_i32 == "" + + rewind!(pool) + POOL_SAFETY_LV[] = old_lv + end + + # ============================================================================== + # Return statement validation: explicit return in function form + # ============================================================================== + + @testset "Function form: explicit return triggers escape detection" begin + old_lv = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 2 + + # Function with explicit return of pool-backed array should throw + @with_pool pool function _test_return_escape() + v = acquire!(pool, Float64, 10) + return _test_leak(v) + end + + @test_throws PoolRuntimeEscapeError _test_return_escape() + + POOL_SAFETY_LV[] = old_lv + end + + @testset "Function form: safe return passes validation" begin + old_lv = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 2 + + @with_pool pool function _test_safe_return() + v = acquire!(pool, Float64, 5) + v .= 3.0 + return sum(v) # scalar — safe + end + + @test _test_safe_return() == 15.0 + + POOL_SAFETY_LV[] = old_lv + end + + @testset "Function form: bare return (nothing) passes" begin + old_lv = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 2 + + @with_pool pool function _test_bare_return() + _ = acquire!(pool, Float64, 10) + return + end + + @test _test_bare_return() === nothing + + POOL_SAFETY_LV[] = old_lv + end + + @testset "Function form: return with callsite at LV=3" begin + old_lv = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 3 + + @with_pool pool function _test_return_callsite() + v = acquire!(pool, Float64, 10) + return _test_leak(v) + end + + err = try + _test_return_callsite() + nothing + catch e + e + end + + @test err isa PoolRuntimeEscapeError + @test err.callsite !== nothing + @test contains(err.callsite, ":") + + POOL_SAFETY_LV[] = old_lv + end + + @testset "Block form: return in enclosing function triggers validation" begin + old_lv = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 2 + + function _test_block_return_escape() + @with_pool pool begin + v = acquire!(pool, Float64, 10) + return _test_leak(v) + end + end + + @test_throws PoolRuntimeEscapeError _test_block_return_escape() + + POOL_SAFETY_LV[] = old_lv + end + +end diff --git a/test/test_compile_escape.jl b/test/test_compile_escape.jl new file mode 100644 index 00000000..45f9687b --- /dev/null +++ b/test/test_compile_escape.jl @@ -0,0 +1,2046 @@ +import AdaptiveArrayPools: _extract_acquired_vars, _get_last_expression, + _find_direct_exposure, _remove_flat_reassigned!, _check_compile_time_escape, + _collect_all_return_values, _collect_explicit_returns!, _collect_implicit_return_values!, + _get_last_expression_with_line, _render_return_expr, + _acquire_call_kind, _classify_escaped_vars, _is_acquire_call, + DeclarationSite, _extract_declaration_sites, + _format_location_str, _format_point_location, + _find_acquire_call_expr, _literal_contains_acquired, + _collect_acquired_in_literal, + _find_first_lnn_index, _ensure_body_has_toplevel_lnn + +@testset "Compile-Time Escape Detection" begin + + # ============================================================================== + # _extract_acquired_vars: find variables assigned from acquire calls + # ============================================================================== + + @testset "_extract_acquired_vars" begin + # Basic acquire! + vars = _extract_acquired_vars( + :(v = acquire!(pool, Float64, 10)), + :pool + ) + @test :v in vars + + # Multiple acquire functions + vars = _extract_acquired_vars( + quote + v = acquire!(pool, Float64, 10) + w = zeros!(pool, 5) + x = ones!(pool, Int64, 3) + y = similar!(pool, some_array) + z = unsafe_acquire!(pool, Float64, 20) + bv = trues!(pool, 100) + bf = falses!(pool, 50) + r = reshape!(pool, some_array, 3, 4) + end, + :pool + ) + @test :v in vars + @test :w in vars + @test :x in vars + @test :y in vars + @test :z in vars + @test :bv in vars + @test :bf in vars + @test :r in vars + + # Non-acquire call → not tracked + vars = _extract_acquired_vars( + :(v = sum(data)), + :pool + ) + @test isempty(vars) + + # Different pool → not tracked + vars = _extract_acquired_vars( + :(v = acquire!(other_pool, Float64, 10)), + :pool + ) + @test isempty(vars) + + # Mixed: only acquire calls tracked + vars = _extract_acquired_vars( + quote + v = acquire!(pool, Float64, 10) + w = sum(v) + x = zeros!(pool, 5) + end, + :pool + ) + @test :v in vars + @test :x in vars + @test !(:w in vars) + + # Destructuring with tuple RHS: only acquire elements tracked + vars = _extract_acquired_vars( + :((v, w) = (acquire!(pool, Float64, 10), safe_func())), + :pool + ) + @test :v in vars + @test !(:w in vars) + + # Destructuring: both acquire calls tracked + vars = _extract_acquired_vars( + :((v, w) = (acquire!(pool, Float64, 10), zeros!(pool, 5))), + :pool + ) + @test :v in vars + @test :w in vars + + # Destructuring with function call RHS: can't determine → nothing tracked + vars = _extract_acquired_vars( + :((v, w) = foo()), + :pool + ) + @test isempty(vars) + + # Destructuring: mixed with regular assignment + vars = _extract_acquired_vars( + quote + (a, b) = (acquire!(pool, Float64, 10), safe()) + c = zeros!(pool, 5) + end, + :pool + ) + @test :a in vars + @test !(:b in vars) + @test :c in vars + end + + # ============================================================================== + # _get_last_expression: find the block's return value expression + # ============================================================================== + + @testset "_get_last_expression" begin + # Simple block + @test _get_last_expression( + quote + a = 1 + b = 2 + c + end + ) == :c + + # Block ending with LineNumberNode → skip it + block = Expr(:block, :a, LineNumberNode(1, :test)) + @test _get_last_expression(block) == :a + + # Non-block expression + @test _get_last_expression(:v) == :v + @test _get_last_expression(42) == 42 + + # Nested block + @test _get_last_expression( + quote + begin + x + end + end + ) == :x + + # Empty block + @test _get_last_expression(Expr(:block)) === nothing + end + + # ============================================================================== + # _find_direct_exposure: detect acquired vars in return expression + # ============================================================================== + + @testset "_find_direct_exposure" begin + acquired = Set{Symbol}([:v, :w]) + + # Bare symbol → detected + @test :v in _find_direct_exposure(:v, acquired) + @test :w in _find_direct_exposure(:w, acquired) + + # Non-acquired symbol → not detected + @test isempty(_find_direct_exposure(:x, acquired)) + + # Function call → NOT detected (can't know return type) + @test isempty(_find_direct_exposure(:(sum(v)), acquired)) + @test isempty(_find_direct_exposure(:(f(v, w)), acquired)) + + # Indexing → NOT detected (element access) + @test isempty(_find_direct_exposure(:(v[1]), acquired)) + + # Tuple containing acquired vars → detected + found = _find_direct_exposure(:(v, w), acquired) + @test :v in found + @test :w in found + + # Tuple with mix of acquired and non-acquired + found = _find_direct_exposure(:(v, 42, x), acquired) + @test :v in found + @test !(:x in found) + + # Array literal + found = _find_direct_exposure(:([v, w]), acquired) + @test :v in found + @test :w in found + + # NamedTuple with = syntax: (a=v,) + found = _find_direct_exposure( + Expr(:tuple, Expr(:(=), :a, :v)), + acquired + ) + @test :v in found + + # NamedTuple with kw syntax: (a=v,) + found = _find_direct_exposure( + Expr(:parameters, Expr(:kw, :a, :v)), + acquired + ) + @test :v in found + + # return v + found = _find_direct_exposure(Expr(:return, :v), acquired) + @test :v in found + + # return (v, w) + found = _find_direct_exposure( + Expr(:return, Expr(:tuple, :v, :w)), + acquired + ) + @test :v in found + @test :w in found + + # Scalar literal → not detected + @test isempty(_find_direct_exposure(42, acquired)) + @test isempty(_find_direct_exposure(3.14, acquired)) + end + + # ============================================================================== + # _remove_flat_reassigned!: handle v=acquire!() then v=other pattern + # ============================================================================== + + @testset "_remove_flat_reassigned!" begin + # Simple reassignment removes from set + acquired = Set{Symbol}([:v]) + _remove_flat_reassigned!( + quote + v = acquire!(pool, Float64, 10) + v = zeros(10) + end, + acquired, :pool + ) + @test isempty(acquired) + + # Reassignment to another acquire keeps it + acquired = Set{Symbol}([:v]) + _remove_flat_reassigned!( + quote + v = acquire!(pool, Float64, 10) + v = zeros!(pool, 10) + end, + acquired, :pool + ) + @test :v in acquired + + # Different variable reassigned → original stays + acquired = Set{Symbol}([:v]) + _remove_flat_reassigned!( + quote + v = acquire!(pool, Float64, 10) + w = zeros(10) + end, + acquired, :pool + ) + @test :v in acquired + + # Non-block expression → no change + acquired = Set{Symbol}([:v]) + _remove_flat_reassigned!(:(v = zeros(10)), acquired, :pool) + @test :v in acquired # only processes :block heads + + # Destructuring with function call RHS: v removed (reassigned to unknown) + acquired = Set{Symbol}([:v]) + _remove_flat_reassigned!( + quote + v = acquire!(pool, Float64, 10) + (result, v) = process(v) + end, + acquired, :pool + ) + @test isempty(acquired) + + # Destructuring with tuple RHS: element-wise check + acquired = Set{Symbol}([:v, :w]) + _remove_flat_reassigned!( + quote + v = acquire!(pool, Float64, 10) + w = acquire!(pool, Float64, 5) + (v, w) = (safe_func(), acquire!(pool, Float64, 3)) + end, + acquired, :pool + ) + @test !(:v in acquired) # reassigned to safe_func() → removed + @test :w in acquired # reassigned to acquire!() → stays + + # Comma destructuring (same AST as tuple): v removed + acquired = Set{Symbol}([:v]) + _remove_flat_reassigned!( + quote + v = acquire!(pool, Float64, 10) + a, v = foo() + end, + acquired, :pool + ) + @test isempty(acquired) + + # Destructuring doesn't affect vars not in acquired + acquired = Set{Symbol}([:v]) + _remove_flat_reassigned!( + quote + v = acquire!(pool, Float64, 10) + (x, y) = foo() + end, + acquired, :pool + ) + @test :v in acquired # v not in destructuring → untouched + end + + # ============================================================================== + # _check_compile_time_escape: integration tests + # ============================================================================== + + @testset "_check_compile_time_escape" begin + src = LineNumberNode(1, :test) + + # Bare variable return → error "Pool escape" + @test_throws PoolEscapeError _check_compile_time_escape( + quote + v = acquire!(pool, Float64, 10) + v + end, + :pool, src + ) + + # Tuple containing acquired var → error + @test_throws PoolEscapeError _check_compile_time_escape( + quote + v = acquire!(pool, Float64, 10) + w = acquire!(pool, Float64, 5) + (sum(v), w) + end, + :pool, src + ) + + # Safe: scalar return → no warning + @test_nowarn _check_compile_time_escape( + quote + v = acquire!(pool, Float64, 10) + sum(v) + end, + :pool, src + ) + + # Safe: collect return → no warning + @test_nowarn _check_compile_time_escape( + quote + v = acquire!(pool, Float64, 10) + collect(v) + end, + :pool, src + ) + + # Safe: literal return → no warning + @test_nowarn _check_compile_time_escape( + quote + v = acquire!(pool, Float64, 10) + v .= 1.0 + 42 + end, + :pool, src + ) + + # Safe: reassigned then returned → no warning + @test_nowarn _check_compile_time_escape( + quote + v = acquire!(pool, Float64, 10) + v .= data + v = collect(v) + v + end, + :pool, src + ) + + # Safe: no acquire calls → no warning + @test_nowarn _check_compile_time_escape( + quote + x = sum(data) + x + end, + :pool, src + ) + + # zeros!/ones!/similar! also detected + @test_throws PoolEscapeError _check_compile_time_escape( + quote + v = zeros!(pool, 10) + v + end, + :pool, src + ) + + @test_throws PoolEscapeError _check_compile_time_escape( + quote + v = ones!(pool, Float32, 10) + v + end, + :pool, src + ) + + @test_throws PoolEscapeError _check_compile_time_escape( + quote + v = similar!(pool, some_array) + v + end, + :pool, src + ) + + # unsafe_acquire! also detected + @test_throws PoolEscapeError _check_compile_time_escape( + quote + v = unsafe_acquire!(pool, Float64, 10) + v + end, + :pool, src + ) + + # trues!/falses! also detected + @test_throws PoolEscapeError _check_compile_time_escape( + quote + bv = trues!(pool, 100) + bv + end, + :pool, src + ) + + # Different pool name → no warning (not our pool) + @test_nowarn _check_compile_time_escape( + quote + v = acquire!(other_pool, Float64, 10) + v + end, + :pool, src + ) + + # source=nothing also works + @test_throws PoolEscapeError _check_compile_time_escape( + quote + v = acquire!(pool, Float64, 10) + v + end, + :pool, nothing + ) + + # `return v` is also a definite escape (error) + @test_throws PoolEscapeError _check_compile_time_escape( + quote + v = acquire!(pool, Float64, 10) + return v + end, + :pool, src + ) + + # `return (v, w)` is an escape (error) + @test_throws PoolEscapeError _check_compile_time_escape( + quote + v = acquire!(pool, Float64, 10) + return (v, sum(v)) + end, + :pool, src + ) + end + + # ============================================================================== + # _collect_all_return_values: explicit returns + implicit if/else branches + # ============================================================================== + + @testset "_collect_all_return_values (expr, line) pairs" begin + # Simple block: implicit return only + vals = _collect_all_return_values( + quote + x = 1 + x + end + ) + exprs = first.(vals) + @test any(e -> e isa Symbol && e == :x, exprs) + + # Explicit return inside if branch + vals = _collect_all_return_values( + quote + v = acquire!(pool, Float64, 10) + if cond + return v + end + sum(v) + end + ) + exprs = first.(vals) + @test any(e -> e isa Expr && e.head == :return, exprs) + @test any(e -> e isa Expr && e.head == :call, exprs) + + # Both branches have explicit return + vals = _collect_all_return_values( + quote + if a > 0.5 + return (v = 0.5, data = a) + else + return (v = v, data = z) + end + end + ) + returns = filter(((e, _),) -> e isa Expr && e.head == :return, vals) + @test length(returns) >= 2 + + # Implicit return from if/else branches (no explicit return keyword) + vals = _collect_all_return_values( + quote + if cond + v + else + sum(v) + end + end + ) + exprs = first.(vals) + @test any(e -> e isa Symbol && e == :v, exprs) + @test any(e -> e isa Expr && e.head == :call, exprs) + + # elseif branches + vals = _collect_all_return_values( + quote + if cond1 + v + elseif cond2 + w + else + sum(v) + end + end + ) + exprs = first.(vals) + @test any(e -> e isa Symbol && e == :v, exprs) + @test any(e -> e isa Symbol && e == :w, exprs) + + # Nested if inside branch + vals = _collect_all_return_values( + quote + if outer + if inner + v + else + w + end + else + sum(v) + end + end + ) + exprs = first.(vals) + @test any(e -> e isa Symbol && e == :v, exprs) + @test any(e -> e isa Symbol && e == :w, exprs) + + # Skips nested function definitions + vals = _collect_all_return_values( + quote + f = function () + return v # belongs to inner function, not our scope + end + sum(v) + end + ) + returns = filter(((e, _),) -> e isa Expr && e.head == :return, vals) + @test isempty(returns) + + # Ternary operator (same AST as if/else) + vals = _collect_all_return_values( + quote + cond ? v : sum(v) + end + ) + exprs = first.(vals) + @test any(e -> e isa Symbol && e == :v, exprs) + + # Line numbers are tracked + vals = _collect_all_return_values( + quote + if cond + return v # this will have a line number + end + sum(v) + end + ) + explicit_returns = filter(((e, _),) -> e isa Expr && e.head == :return, vals) + @test !isempty(explicit_returns) + @test last(explicit_returns[1]) !== nothing # line is captured + end + + # ============================================================================== + # _check_compile_time_escape: branch return detection + # ============================================================================== + + @testset "Escape detection through branches" begin + src = LineNumberNode(1, :test) + + # Explicit return inside if branch — caught + @test_throws PoolEscapeError _check_compile_time_escape( + quote + v = acquire!(pool, Float64, 10) + if cond + return v + end + sum(v) + end, + :pool, src + ) + + # Both branches return, one unsafe — caught + @test_throws PoolEscapeError _check_compile_time_escape( + quote + v = acquire!(pool, Float64, 10) + if a > 0.5 + return sum(v) + else + return v + end + end, + :pool, src + ) + + # NamedTuple in branch — caught + @test_throws PoolEscapeError _check_compile_time_escape( + quote + v = acquire!(pool, Float64, 10) + z = similar!(pool, v) + if a > 0.5 + return (v = 0.5, data = a) + else + return (v = v, data = z) + end + end, + :pool, src + ) + + # Both branches safe — no error + @test_nowarn _check_compile_time_escape( + quote + v = acquire!(pool, Float64, 10) + if cond + return sum(v) + else + return length(v) + end + end, + :pool, src + ) + + # Implicit return from if/else branches — caught + @test_throws PoolEscapeError _check_compile_time_escape( + quote + v = acquire!(pool, Float64, 10) + if cond + sum(v) + else + v + end + end, + :pool, src + ) + + # elseif branch — caught + @test_throws PoolEscapeError _check_compile_time_escape( + quote + v = acquire!(pool, Float64, 10) + if cond1 + sum(v) + elseif cond2 + v + else + length(v) + end + end, + :pool, src + ) + + # Ternary with escape — caught + @test_throws PoolEscapeError _check_compile_time_escape( + quote + v = acquire!(pool, Float64, 10) + cond ? v : sum(v) + end, + :pool, src + ) + + # Early return in loop — caught + @test_throws PoolEscapeError _check_compile_time_escape( + quote + v = acquire!(pool, Float64, 10) + for i in 1:10 + if cond + return v + end + end + sum(v) + end, + :pool, src + ) + + # Multi-variable across branches: reports all + err = try + _check_compile_time_escape( + quote + v = acquire!(pool, Float64, 10) + w = acquire!(pool, Float64, 5) + if cond + return v + else + return w + end + end, + :pool, src + ) + catch e + e + end + @test err isa PoolEscapeError + @test :v in err.vars + @test :w in err.vars + end + + # ============================================================================== + # Integration: compile-time error via @macroexpand (branch scenarios) + # ============================================================================== + + @testset "Branch escape detection through macro pipeline" begin + # if/else with explicit return in one branch — caught + @test_throws PoolEscapeError @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + if rand() > 0.5 + return sum(v) + else + return v + end + end + + # Function form with branch returns — caught + @test_throws PoolEscapeError @macroexpand @with_pool pool function branch_fn(n) + v = acquire!(pool, Float64, n) + z = similar!(pool, v) + if n > 0 + return (v = 0.5, data = 1.0) + else + return (v = v, data = z) + end + end + + # Ternary — caught + @test_throws PoolEscapeError @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + rand() > 0.5 ? sum(v) : v + end + + # All branches safe — no error + expanded = @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + if rand() > 0.5 + return sum(v) + else + return length(v) + end + end + @test expanded isa Expr + + # Function form, all branches safe + expanded = @macroexpand @with_pool pool function safe_branch(n) + v = acquire!(pool, Float64, n) + if n > 0 + sum(v) + else + 0.0 + end + end + @test expanded isa Expr + end + + # ============================================================================== + # Integration: compile-time error via @macroexpand + # ============================================================================== + + @testset "Compile-time error through macro pipeline" begin + # Bare variable: macro expansion itself throws + @test_throws PoolEscapeError @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + v + end + + # Safe return: macro expansion succeeds + expanded = @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + sum(v) + end + @test expanded isa Expr + + # Function form: bare return also caught + @test_throws PoolEscapeError @macroexpand @with_pool pool function test_fn(n) + v = acquire!(pool, Float64, n) + v + end + end + + # ============================================================================== + # Integration: block form — false positive prevention + # ============================================================================== + + @testset "Block form: safe patterns (no false positives)" begin + # Function call on acquired var → safe (can't determine return type) + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + sum(v) + end + + # Indexing → safe (element access, not array itself) + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + v[1] + end + + # Arithmetic expression → safe + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + sum(v) + 1.0 + end + + # Reassigned from collect, then returned → safe + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + v = collect(v) + v + end + + # Non-acquired variable returned → safe + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + w = sum(v) + w + end + + # nothing return → safe + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + v .= 1.0 + nothing + end + + # Multiple acquires, safe return + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + w = zeros!(pool, 5) + sum(v) + sum(w) + end + + # No acquire calls at all → safe + @test_nowarn @macroexpand @with_pool pool begin + x = [1, 2, 3] + sum(x) + end + + # Boolean comparison → safe + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + length(v) > 0 + end + + # NamedTuple: key name matches acquired var, but VALUE is safe + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + v .= 1.0 + (v = collect(v), total = sum(v)) + end + + # NamedTuple: key name coincidentally same, value is unrelated + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + v .= data + (v = zeros(10), u = sum(v)) + end + + # --- Tricky-but-safe edge cases (true negatives) --- + # Each pattern looks like it might escape pool memory, but is genuinely safe. + # The checker correctly does NOT flag these. + + # Safe: v is reassigned to non-pool arrays twice; final v is a fresh collect'd copy + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + v .= 1.0 + v = v .+ 1.0 + v = collect(v) + v + end + + # Safe: w is a plain view (not from acquire!), and return is scalar sum + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + w = view(v, 1:5) + sum(w) + end + + # Safe: similar() (no !) allocates a fresh independent array — w is not pool-backed + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + w = similar(v) + w .= 1.0 + w + end + + # Safe: copy() returns an independent deep copy — no pool memory escapes + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + v .= 1.0 + copy(v) + end + + # Safe: comprehension allocates a fresh Array from element values + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + v .= 1.0 + [v[i]^2 for i in 1:10] + end + + # Safe: result is a String, not an array + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + "total = $(sum(v))" + end + + # Safe: ternary returns scalar from both branches + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + length(v) > 5 ? sum(v) : 0.0 + end + + # Safe: pipe evaluates to sum(v) — a scalar, not the array + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + v .= 1.0 + v |> sum + end + + # Safe: broadcast allocates a fresh result array — neither v nor w escapes + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + w = acquire!(pool, Float64, 10) + v .+ w + end + + # Safe: Dict holds only scalars (sum, length) — no pool array reference + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + Dict(:sum => sum(v), :len => length(v)) + end + + # Safe: let block returns scalar s*2 — pool array v stays local + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + v .= 2.0 + let s = sum(v) + s * 2 + end + end + + # Safe: map() allocates a fresh Array with transformed elements + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + v .= 1.0 + map(v) do x + x^2 + end + end + + # Safe: destructuring reassigns v to a non-pool value + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + v .= 1.0 + (result, v) = process(v) + v + end + + # Safe: comma destructuring reassigns v + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + idx, v = findmax(some_array) + v + end + + # Safe: destructuring with tuple RHS — v gets safe value, w stays tracked + # but only v is returned + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + w = acquire!(pool, Float64, 5) + (v, w) = (collect(v), acquire!(pool, Float64, 3)) + sum(w) + sum(v) + end + + # Safe: swap pattern — v gets non-pool value after swap + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + x = zeros(10) + v, x = x, v + v + end + end + + @testset "Block form: additional escape scenarios" begin + # zeros! — definite escape + @test_throws PoolEscapeError @macroexpand @with_pool pool begin + v = zeros!(pool, 10) + v + end + + # trues! — definite escape + @test_throws PoolEscapeError @macroexpand @with_pool pool begin + bv = trues!(pool, 100) + bv + end + + # Explicit return — definite escape + @test_throws PoolEscapeError @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + return v + end + + # Tuple with acquired var → escape + @test_throws PoolEscapeError @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + (v, 42) + end + + # Array literal → escape + @test_throws PoolEscapeError @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + [v, nothing] + end + + # return (v, scalar) → escape + @test_throws PoolEscapeError @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + return (v, sum(v)) + end + + # Re-acquire reassignment: v still tracked after v = zeros!(pool, ...) + @test_throws PoolEscapeError @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + v = zeros!(pool, 20) + v + end + + # NamedTuple with acquired var as VALUE → escape + @test_throws PoolEscapeError @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + (result = v, n = 42) + end + + # NamedTuple shorthand (v = v) → value IS acquired → escape + # (key name coincidentally matches, but VALUE is the acquired var) + @test_throws PoolEscapeError @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + (v = v,) + end + + # Destructuring with acquire RHS: v still tracked → escape + @test_throws PoolEscapeError @macroexpand @with_pool pool begin + (v, w) = (acquire!(pool, Float64, 10), safe()) + v + end + + # Destructuring doesn't protect if RHS element IS acquire + @test_throws PoolEscapeError @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + (v, w) = (zeros!(pool, 5), safe()) + v + end + end + + # ============================================================================== + # Integration: @with_pool function definition form + # ============================================================================== + + @testset "Function form: escape detection" begin + # Definite — bare variable return + @test_throws PoolEscapeError @macroexpand @with_pool pool function fn_esc1(n) + v = acquire!(pool, Float64, n) + v + end + + # Definite — explicit return + @test_throws PoolEscapeError @macroexpand @with_pool pool function fn_esc2(n) + v = zeros!(pool, n) + v .= 1.0 + return v + end + + # Definite — trues! + @test_throws PoolEscapeError @macroexpand @with_pool pool function fn_esc3(n) + bv = trues!(pool, n) + bv + end + + # Tuple return → escape + @test_throws PoolEscapeError @macroexpand @with_pool pool function fn_warn1(n) + v = acquire!(pool, Float64, n) + (v, sum(v)) + end + + # return tuple → escape + @test_throws PoolEscapeError @macroexpand @with_pool pool function fn_warn2(n) + v = acquire!(pool, Float64, n) + return (v, n) + end + end + + @testset "Function form: safe patterns (no false positives)" begin + # Scalar return + @test_nowarn @macroexpand @with_pool pool function fn_safe1(n) + v = acquire!(pool, Float64, n) + sum(v) + end + + # collect() return + @test_nowarn @macroexpand @with_pool pool function fn_safe2(n) + v = acquire!(pool, Float64, n) + v .= 1.0 + collect(v) + end + + # nothing return (side-effect function) + @test_nowarn @macroexpand @with_pool pool function fn_safe3!(out, n) + v = acquire!(pool, Float64, n) + out .= v + nothing + end + + # Reassigned then returned + @test_nowarn @macroexpand @with_pool pool function fn_safe4(n) + v = acquire!(pool, Float64, n) + v = collect(v) + v + end + + # Function call wrapping acquired var + @test_nowarn @macroexpand @with_pool pool function fn_safe5(n) + v = acquire!(pool, Float64, n) + sum(v) + end + + # Multiple acquires, safe return + @test_nowarn @macroexpand @with_pool pool function fn_safe6(m, n) + v = acquire!(pool, Float64, m) + w = zeros!(pool, n) + sum(v) + sum(w) + end + + # Safe: pipe evaluates to scalar sum(v) + @test_nowarn @macroexpand @with_pool pool function fn_safe_pipe(n) + v = acquire!(pool, Float64, n) + v .= 1.0 + v |> sum + end + + # Safe: comprehension allocates fresh array from element values + @test_nowarn @macroexpand @with_pool pool function fn_safe_comp(n) + v = acquire!(pool, Float64, n) + v .= 1.0 + [v[i] for i in 1:n] + end + + # Safe: ternary returns scalar from both branches + @test_nowarn @macroexpand @with_pool pool function fn_safe_ternary(n) + v = acquire!(pool, Float64, n) + n > 5 ? sum(v) : 0.0 + end + + # Safe: v reassigned to fresh broadcast result — no longer pool-backed + @test_nowarn @macroexpand @with_pool pool function fn_safe_bcast(n) + v = acquire!(pool, Float64, n) + v = v .+ 1.0 + v + end + + # Safe: copy() returns independent deep copy + @test_nowarn @macroexpand @with_pool pool function fn_safe_copy(n) + v = acquire!(pool, Float64, n) + v .= 1.0 + copy(v) + end + + # Safe: destructuring reassigns v in function form + @test_nowarn @macroexpand @with_pool pool function fn_safe_destruct(n) + v = acquire!(pool, Float64, n) + (result, v) = process(v) + v + end + end + + # ============================================================================== + # Integration: @maybe_with_pool forms + # ============================================================================== + + @testset "@maybe_with_pool block form" begin + # Definite escape → error + @test_throws PoolEscapeError @macroexpand @maybe_with_pool pool begin + v = acquire!(pool, Float64, 10) + v + end + + # Container escape → error + @test_throws PoolEscapeError @macroexpand @maybe_with_pool pool begin + v = acquire!(pool, Float64, 10) + (v, sum(v)) + end + + # Safe → no warning + @test_nowarn @macroexpand @maybe_with_pool pool begin + v = acquire!(pool, Float64, 10) + sum(v) + end + end + + @testset "@maybe_with_pool function form" begin + # Definite escape → error + @test_throws PoolEscapeError @macroexpand @maybe_with_pool pool function mwp_esc(n) + v = acquire!(pool, Float64, n) + v + end + + # Safe → no warning + @test_nowarn @macroexpand @maybe_with_pool pool function mwp_safe(n) + v = acquire!(pool, Float64, n) + sum(v) + end + end + + # ============================================================================== + # Integration: backend forms (@with_pool :cpu, @maybe_with_pool :cpu) + # ============================================================================== + + @testset "@with_pool :cpu block form" begin + @test_throws PoolEscapeError @macroexpand @with_pool :cpu pool begin + v = acquire!(pool, Float64, 10) + v + end + + @test_nowarn @macroexpand @with_pool :cpu pool begin + v = acquire!(pool, Float64, 10) + sum(v) + end + end + + @testset "@with_pool :cpu function form" begin + @test_throws PoolEscapeError @macroexpand @with_pool :cpu pool function cpu_esc(n) + v = acquire!(pool, Float64, n) + v + end + + @test_nowarn @macroexpand @with_pool :cpu pool function cpu_safe(n) + v = acquire!(pool, Float64, n) + sum(v) + end + end + + @testset "@maybe_with_pool :cpu forms" begin + # Block — error + @test_throws PoolEscapeError @macroexpand @maybe_with_pool :cpu pool begin + v = acquire!(pool, Float64, 10) + v + end + + # Block — safe + @test_nowarn @macroexpand @maybe_with_pool :cpu pool begin + v = acquire!(pool, Float64, 10) + sum(v) + end + + # Function — error + @test_throws PoolEscapeError @macroexpand @maybe_with_pool :cpu pool function mcpu_esc(n) + v = acquire!(pool, Float64, n) + v + end + + # Function — safe + @test_nowarn @macroexpand @maybe_with_pool :cpu pool function mcpu_safe(n) + v = acquire!(pool, Float64, n) + sum(v) + end + end + + # ============================================================================== + # Integration: nested @with_pool scopes + # ============================================================================== + + @testset "Nested @with_pool scopes" begin + # Note: @macroexpand expands only the outermost macro; inner @with_pool + # inside esc() boundaries are not recursively expanded by macroexpand(). + # Inner scope escape detection runs when code is actually compiled. + + # Both scopes safe → no error + @test_nowarn @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + inner = @with_pool pool begin + w = acquire!(pool, Float64, 5) + sum(w) + end + sum(v) + inner + end + + # Outer scope escape → error from outer macro check + @test_throws PoolEscapeError @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + @with_pool pool begin + w = acquire!(pool, Float64, 5) + w .= 1.0 + nothing + end + v # ← outer definite escape + end + end + + # ============================================================================== + # Error/warning messages: verify variable names and suggestions + # ============================================================================== + + @testset "PoolEscapeError carries variable names, points, and formatted message" begin + # Single variable: bare return + err = try + @macroexpand( + @with_pool pool begin + v = acquire!(pool, Float64, 10) + v + end + ) + catch e + e + end + @test err isa PoolEscapeError + @test err.vars == [:v] + @test !isempty(err.points) + @test :v in err.points[1].vars + msg = sprint(showerror, err) + @test occursin("collect(v)", msg) + @test occursin("False positive?", msg) + @test occursin("Escaping return", msg) + @test occursin("escapes the @with_pool scope", msg) + # Declaration sites populated + @test !isempty(err.declarations) + @test any(d -> d.var === :v, err.declarations) + @test occursin("Declarations:", msg) + @test occursin("acquire!(pool, Float64, 10)", msg) + + # Different variable name + err = try + @macroexpand( + @with_pool pool begin + data = zeros!(pool, 10) + data + end + ) + catch e + e + end + @test err isa PoolEscapeError + @test err.vars == [:data] + @test any(d -> d.var === :data, err.declarations) + + # Function form + err = try + @macroexpand( + @with_pool pool function msg_fn(n) + result = acquire!(pool, Float64, n) + result + end + ) + catch e + e + end + @test err isa PoolEscapeError + @test err.vars == [:result] + + # Container: only w escapes, not sum(v) + err = try + @macroexpand( + @with_pool pool begin + v = acquire!(pool, Float64, 10) + w = acquire!(pool, Float64, 5) + (sum(v), w) + end + ) + catch e + e + end + @test err isa PoolEscapeError + @test err.vars == [:w] + @test :v ∉ err.vars + + # Multi-variable: both appear, sorted + err = try + @macroexpand( + @with_pool pool begin + v = acquire!(pool, Float64, 10) + w = acquire!(pool, Float64, 5) + (v, w) + end + ) + catch e + e + end + @test err isa PoolEscapeError + @test err.vars == [:v, :w] + msg = sprint(showerror, err) + @test occursin("collect(v)", msg) + @test occursin("collect(w)", msg) + # Multi-variable: both declarations present + @test length(err.declarations) == 2 + @test err.declarations[1].var === :v + @test err.declarations[2].var === :w + @test occursin("2 variables escape", msg) + + # Source location captured + @test err.file !== nothing + + # Branch escapes: points track per-return-point info + err = try + @macroexpand( + @with_pool pool function branch_msg(n) + v = acquire!(pool, Float64, n) + if n > 0 + return sum(v) + else + return v + end + end + ) + catch e + e + end + @test err isa PoolEscapeError + @test err.vars == [:v] + @test length(err.points) == 1 # only the unsafe return + @test :v in err.points[1].vars + @test err.points[1].line !== nothing + + # Multiple escape points across branches + err = try + @macroexpand( + @with_pool pool function multi_pt(n) + v = acquire!(pool, Float64, n) + w = acquire!(pool, Float64, n) + if n > 0 + return v + else + return w + end + end + ) + catch e + e + end + @test err isa PoolEscapeError + @test err.vars == [:v, :w] + @test length(err.points) == 2 + msg = sprint(showerror, err) + @test occursin("[1]", msg) + @test occursin("[2]", msg) + + # Rendered expression shows highlighted vars + err = try + @macroexpand( + @with_pool pool begin + v = acquire!(pool, Float64, 10) + z = similar!(pool, v) + if rand() > 0.5 + return (v = 0.5, data = 1.0) + else + return (v = v, data = z) + end + end + ) + catch e + e + end + @test err isa PoolEscapeError + @test err.vars == [:v, :z] + msg = sprint(showerror, err) + # The rendered return expression appears in the message + @test occursin("return", msg) + @test occursin("data", msg) + end + + # ============================================================================== + # Variable classification: view vs array vs bitarray vs container vs alias + # ============================================================================== + + @testset "_acquire_call_kind classification" begin + # View-returning functions + @test _acquire_call_kind(:(acquire!(pool, Float64, 10)), :pool) === :pool_view + @test _acquire_call_kind(:(zeros!(pool, 10)), :pool) === :pool_view + @test _acquire_call_kind(:(ones!(pool, Int64, 3)), :pool) === :pool_view + @test _acquire_call_kind(:(similar!(pool, arr)), :pool) === :pool_view + @test _acquire_call_kind(:(reshape!(pool, arr, 3, 4)), :pool) === :pool_view + + # Array-returning functions (unsafe_wrap) + @test _acquire_call_kind(:(unsafe_acquire!(pool, Float64, 10)), :pool) === :pool_array + @test _acquire_call_kind(:(acquire_array!(pool, Float64, 10)), :pool) === :pool_array + + # BitArray-returning functions + @test _acquire_call_kind(:(trues!(pool, 100)), :pool) === :pool_bitarray + @test _acquire_call_kind(:(falses!(pool, 50)), :pool) === :pool_bitarray + + # Non-acquire → nothing + @test _acquire_call_kind(:(sum(data)), :pool) === nothing + @test _acquire_call_kind(:(rand(10)), :pool) === nothing + + # Wrong pool → nothing + @test _acquire_call_kind(:(acquire!(other_pool, Float64, 10)), :pool) === nothing + end + + @testset "var_info classification in PoolEscapeError" begin + # Direct pool view + err = try + @macroexpand( + @with_pool pool begin + v = acquire!(pool, Float64, 10) + v + end + ) + catch e + e + end + @test err.var_info[:v] == (:pool_view, Symbol[]) + msg = sprint(showerror, err) + @test occursin("pool-acquired view", msg) + + # Direct pool array (unsafe_acquire!) + err = try + @macroexpand( + @with_pool pool begin + v = unsafe_acquire!(pool, Float64, 10) + v + end + ) + catch e + e + end + @test err.var_info[:v] == (:pool_array, Symbol[]) + msg = sprint(showerror, err) + @test occursin("pool-acquired array", msg) + + # Direct pool BitArray + err = try + @macroexpand( + @with_pool pool begin + bv = trues!(pool, 100) + bv + end + ) + catch e + e + end + @test err.var_info[:bv] == (:pool_bitarray, Symbol[]) + msg = sprint(showerror, err) + @test occursin("pool-acquired BitArray", msg) + + # Container wrapping pool variable + err = try + @macroexpand( + @with_pool pool begin + v = acquire!(pool, Float64, 10) + a = [v, 1] + a + end + ) + catch e + e + end + @test err.var_info[:a] == (:container, [:v]) + msg = sprint(showerror, err) + @test occursin("wraps pool variable (v)", msg) + # Fix suggests collect(v), not collect(a) + @test occursin("collect(v)", msg) + @test occursin("Copy pool variables before wrapping", msg) + + # Container with multiple pool vars + err = try + @macroexpand( + @with_pool pool begin + v = acquire!(pool, Float64, 10) + w = acquire!(pool, Float64, 5) + a = [v, w] + a + end + ) + catch e + e + end + @test err.var_info[:a] == (:container, [:v, :w]) + msg = sprint(showerror, err) + @test occursin("wraps pool variables (v, w)", msg) + + # Alias of pool variable + err = try + @macroexpand( + @with_pool pool begin + v = acquire!(pool, Float64, 10) + d = v + d + end + ) + catch e + e + end + @test err.var_info[:d] == (:alias, [:v]) + msg = sprint(showerror, err) + @test occursin("alias of pool variable (v)", msg) + + # Mixed: direct pool var + container in same return + err = try + @macroexpand( + @with_pool pool begin + v = acquire!(pool, Float64, 10) + a = [v, 1] + return (v, a) + end + ) + catch e + e + end + @test err.var_info[:v] == (:pool_view, Symbol[]) + @test err.var_info[:a] == (:container, [:v]) + msg = sprint(showerror, err) + @test occursin("pool-acquired view", msg) + @test occursin("wraps pool variable (v)", msg) + # Fix section deduplicates: only collect(v), not collect(a) + @test occursin("collect(v)", msg) + @test !occursin("collect(a)", msg) + + # zeros! classified as view + err = try + @macroexpand( + @with_pool pool begin + data = zeros!(pool, 10) + data + end + ) + catch e + e + end + @test err.var_info[:data] == (:pool_view, Symbol[]) + + # Tuple container + err = try + @macroexpand( + @with_pool pool begin + v = acquire!(pool, Float64, 10) + t = (v, 42) + t + end + ) + catch e + e + end + @test err.var_info[:t] == (:container, [:v]) + end + + # ============================================================================== + # Declaration site extraction + # ============================================================================== + + @testset "_extract_declaration_sites" begin + # Single acquire: captures var, expr, and line + sites = _extract_declaration_sites( + quote + v = acquire!(pool, Float64, 10) + v + end, + Set([:v]) + ) + @test length(sites) == 1 + @test sites[1].var === :v + @test sites[1].line !== nothing + @test string(sites[1].expr) == "v = acquire!(pool, Float64, 10)" + + # Multiple declarations sorted by line + sites = _extract_declaration_sites( + quote + v = acquire!(pool, Float64, 10) + w = zeros!(pool, 5) + (v, w) + end, + Set([:v, :w]) + ) + @test length(sites) == 2 + @test sites[1].var === :v + @test sites[2].var === :w + @test sites[1].line < sites[2].line + + # Container declaration captured + sites = _extract_declaration_sites( + quote + v = acquire!(pool, Float64, 10) + a = [v, 1] + a + end, + Set([:v, :a]) + ) + @test length(sites) == 2 + @test sites[1].var === :v + @test sites[2].var === :a + + # Alias declaration captured + sites = _extract_declaration_sites( + quote + v = acquire!(pool, Float64, 10) + d = v + d + end, + Set([:v, :d]) + ) + @test length(sites) == 2 + @test sites[1].var === :v + @test sites[2].var === :d + + # Only escaped vars captured (non-escaped ignored) + sites = _extract_declaration_sites( + quote + v = acquire!(pool, Float64, 10) + w = acquire!(pool, Float64, 5) + v + end, + Set([:v]) # only v escapes + ) + @test length(sites) == 1 + @test sites[1].var === :v + end + + # ============================================================================== + # Formatted message: declarations and escape points with locations + # ============================================================================== + + @testset "showerror shows declarations and escape locations" begin + # Container: declarations show both v and a + err = try + @macroexpand( + @with_pool pool begin + v = acquire!(pool, Float64, 10) + a = [v, 1] + return (v, a) + end + ) + catch e + e + end + msg = sprint(showerror, err) + @test occursin("Declarations:", msg) + @test occursin("Escaping return:", msg) + @test occursin("acquire!(pool, Float64, 10)", msg) + @test occursin("[v, 1]", msg) + @test occursin("return", msg) + end + + # ============================================================================== + # Coverage: PoolEscapeError convenience constructors + # ============================================================================== + + @testset "PoolEscapeError convenience constructors" begin + pt = EscapePoint(:v, 1, [:v]) + + # 4-arg constructor (no var_info, no declarations) + err4 = PoolEscapeError([:v], "test.jl", 1, [pt]) + @test isempty(err4.var_info) + @test isempty(err4.declarations) + + # 5-arg constructor (no declarations) + vi = Dict{Symbol, Tuple{Symbol, Vector{Symbol}}}(:v => (:pool_view, Symbol[])) + err5 = PoolEscapeError([:v], "test.jl", 1, [pt], vi) + @test err5.var_info[:v][1] === :pool_view + @test isempty(err5.declarations) + end + + # ============================================================================== + # Coverage: _render_return_expr branches + # ============================================================================== + + @testset "_render_return_expr branch coverage" begin + escaped = Set([:v]) + + # Non-escaped symbol → plain print (line 54) + buf = sprint() do io + _render_return_expr(io, :x, escaped) + end + @test buf == "x" + + # Array literal :vect (lines 71-77) + buf = sprint() do io + _render_return_expr(io, Expr(:vect, :v, :x), escaped) + end + @test occursin("[", buf) + @test occursin("]", buf) + @test occursin("x", buf) + + # Fallback Expr (line 79) — e.g. a :ref expression + buf = sprint() do io + _render_return_expr(io, :(v[1]), escaped) + end + @test occursin("v", buf) + + # Non-Expr, non-Symbol literal (line 82) + buf = sprint() do io + _render_return_expr(io, 42, escaped) + end + @test buf == "42" + end + + # ============================================================================== + # Coverage: showerror with unknown var_info kind (line 121) + # ============================================================================== + + @testset "showerror pool-backed temporary fallback" begin + vi = Dict{Symbol, Tuple{Symbol, Vector{Symbol}}}(:v => (:something_else, Symbol[])) + err = PoolEscapeError( + [:v], "test.jl", 1, + [EscapePoint(:v, 1, [:v])], + vi, DeclarationSite[] + ) + msg = sprint(showerror, err) + @test occursin("pool-backed temporary", msg) + end + + # ============================================================================== + # Coverage: _format_location_str / _format_point_location edge cases + # ============================================================================== + + @testset "location formatting edge cases" begin + # file="none", line present → "line N" (lines 203-204) + @test _format_location_str("none", 42) == "line 42" + # file=nothing, line=nothing → nothing (line 206) + @test _format_location_str(nothing, nothing) === nothing + # file=nothing, line present → "line N" + @test _format_location_str(nothing, 7) == "line 7" + + # Same for _format_point_location (lines 214-217) + @test _format_point_location("none", 42) == "line 42" + @test _format_point_location(nothing, nothing) === nothing + @test _format_point_location(nothing, 7) == "line 7" + end + + # ============================================================================== + # Coverage: showerror 3-arg (backtrace suppression, line 221) + # ============================================================================== + + @testset "showerror backtrace suppression" begin + err = try + @macroexpand( + @with_pool pool begin + v = acquire!(pool, Float64, 10) + v + end + ) + catch e + e + end + @test err isa PoolEscapeError + # 3-arg showerror should produce same output as 2-arg + msg2 = sprint(showerror, err) + msg3 = sprint() do io + showerror(io, err, nothing) + end + @test msg2 == msg3 + end + + # ============================================================================== + # Coverage: _is_acquire_call / _acquire_call_kind with qualified names + # ============================================================================== + + @testset "qualified acquire call detection" begin + # Module.acquire!(pool, ...) — qualified name (lines 1703-1705) + @test _is_acquire_call( + :(AdaptiveArrayPools.acquire!(pool, Float64, 10)), :pool + ) + @test _is_acquire_call( + :(SomeModule.zeros!(pool, 10)), :pool + ) + # Non-acquire qualified call + @test !_is_acquire_call( + :(Base.sum(pool, data)), :pool + ) + + # _acquire_call_kind with qualified names (lines 1728-1731) + @test _acquire_call_kind( + :(M.acquire!(pool, Float64, 10)), :pool + ) === :pool_view + @test _acquire_call_kind( + :(M.unsafe_acquire!(pool, Float64, 10)), :pool + ) === :pool_array + @test _acquire_call_kind( + :(M.trues!(pool, 10)), :pool + ) === :pool_bitarray + + # Qualified non-acquire → nothing (line 1739) + @test _acquire_call_kind( + :(M.sum(pool, data)), :pool + ) === nothing + end + + # ============================================================================== + # Coverage: _find_acquire_call_expr (lines 1458-1467) + # ============================================================================== + + @testset "_find_acquire_call_expr" begin + # Direct acquire call + expr = :(acquire!(pool, Float64, 10)) + @test _find_acquire_call_expr(expr, :pool) === expr + + # Nested inside assignment + outer = :(v = acquire!(pool, Float64, 10)) + result = _find_acquire_call_expr(outer, :pool) + @test result !== nothing + @test result.args[1] === :acquire! + + # No acquire call → nothing + @test _find_acquire_call_expr(:(sum(data)), :pool) === nothing + + # Non-Expr → nothing + @test _find_acquire_call_expr(:x, :pool) === nothing + @test _find_acquire_call_expr(42, :pool) === nothing + end + + # ============================================================================== + # Coverage: _literal_contains_acquired — identity / named tuple / kw + # ============================================================================== + + @testset "_literal_contains_acquired edge cases" begin + acquired = Set([:v, :w]) + + # identity(v) → detected (line 1959) + @test _literal_contains_acquired(:(identity(v)), acquired) + + # identity(x) → not detected + @test !_literal_contains_acquired(:(identity(x)), acquired) + + # Named tuple with = syntax: (a=v,) (line 1963-1964) + @test _literal_contains_acquired( + Expr(:tuple, Expr(:(=), :a, :v)), acquired + ) + + # Named tuple with kw syntax (line 1965-1966) + @test _literal_contains_acquired( + Expr(:tuple, Expr(:kw, :a, :v)), acquired + ) + + # Non-acquired kw + @test !_literal_contains_acquired( + Expr(:tuple, Expr(:kw, :a, :x)), acquired + ) + end + + # ============================================================================== + # Coverage: _collect_acquired_in_literal — identity / kw (lines 2031-2037) + # ============================================================================== + + @testset "_collect_acquired_in_literal edge cases" begin + acquired = Set([:v, :w]) + + # identity(v) + found = _collect_acquired_in_literal(:(identity(v)), acquired) + @test :v in found + + # Named tuple (a=v,) + found = _collect_acquired_in_literal( + Expr(:tuple, Expr(:(=), :a, :v)), acquired + ) + @test :v in found + + # kw syntax + found = _collect_acquired_in_literal( + Expr(:tuple, Expr(:kw, :a, :w)), acquired + ) + @test :w in found + + # Non-Expr/non-Symbol → empty + found = _collect_acquired_in_literal(42, acquired) + @test isempty(found) + end + + # ============================================================================== + # Coverage: _find_direct_exposure — identity (line 2012) + # ============================================================================== + + @testset "_find_direct_exposure identity" begin + acquired = Set([:v]) + + # identity(v) → detected + found = _find_direct_exposure(:(identity(v)), acquired) + @test :v in found + + # Base.identity(v) + found = _find_direct_exposure(:(Base.identity(v)), acquired) + @test :v in found + end + + # ============================================================================== + # Coverage: _find_first_lnn_index / _ensure_body_has_toplevel_lnn + # ============================================================================== + + @testset "LNN handling edge cases" begin + # _find_first_lnn_index: :meta then LNN (lines 434-435) + args = Any[Expr(:meta, :inline), LineNumberNode(1, :test)] + @test _find_first_lnn_index(args) == 2 + + # _find_first_lnn_index: non-meta before LNN → nothing (lines 437-440) + args2 = Any[:(x = 1), LineNumberNode(1, :test)] + @test _find_first_lnn_index(args2) === nothing + + # _find_first_lnn_index: empty → nothing + @test _find_first_lnn_index(Any[]) === nothing + + # _ensure_body_has_toplevel_lnn: source=nothing → identity + body = Expr(:block, :(x = 1)) + @test _ensure_body_has_toplevel_lnn(body, nothing) === body + + # source.file=:none → identity (line 458) + @test _ensure_body_has_toplevel_lnn(body, LineNumberNode(1, :none)) === body + + # LNN already points to user file → identity (line 467) + body_with_lnn = Expr(:block, LineNumberNode(5, :myfile), :(x = 1)) + result = _ensure_body_has_toplevel_lnn(body_with_lnn, LineNumberNode(5, :myfile)) + @test result === body_with_lnn + + # LNN points elsewhere → replaced (line 470-472) + body_wrong_lnn = Expr(:block, LineNumberNode(1, :macros), :(x = 1)) + result = _ensure_body_has_toplevel_lnn(body_wrong_lnn, LineNumberNode(10, :user)) + @test result.args[1] isa LineNumberNode + @test result.args[1].file === :user + + # No LNN in block → prepend (lines 476) + body_no_lnn = Expr(:block, :(x = 1)) + result = _ensure_body_has_toplevel_lnn(body_no_lnn, LineNumberNode(3, :src)) + @test result.args[1] isa LineNumberNode + @test result.args[1].file === :src + + # Empty block (lines 477-479) + empty_block = Expr(:block) + result = _ensure_body_has_toplevel_lnn(empty_block, LineNumberNode(1, :f)) + @test length(result.args) == 1 + @test result.args[1] isa LineNumberNode + + # Non-block body (lines 481-482) + scalar_body = :(x + 1) + result = _ensure_body_has_toplevel_lnn(scalar_body, LineNumberNode(1, :f)) + @test result.head === :block + @test result.args[1] isa LineNumberNode + @test result.args[2] == scalar_body + end + +end # Compile-Time Escape Detection diff --git a/test/test_debug.jl b/test/test_debug.jl new file mode 100644 index 00000000..bdc17181 --- /dev/null +++ b/test/test_debug.jl @@ -0,0 +1,825 @@ +import AdaptiveArrayPools: _validate_pool_return, _check_bitchunks_overlap, _eltype_may_contain_arrays, + PoolRuntimeEscapeError, _poison_value, _shorten_location +_test_leak(x) = x # opaque to compile-time escape checker (only identity() is transparent) + +@testset "POOL_DEBUG and Safety Validation" begin + + # ============================================================================== + # POOL_DEBUG flag toggle + # ============================================================================== + + @testset "POOL_DEBUG flag" begin + old_debug = POOL_DEBUG[] + + # Default is false + POOL_DEBUG[] = false + + # When debug is off, no validation happens even if SubArray escapes + result = @with_pool pool begin + v = acquire!(pool, Float64, 10) + _test_leak(v) # opaque to compile-time checker; runtime LV<2 won't catch + end + @test result isa SubArray # No error when debug is off + + POOL_DEBUG[] = old_debug + end + + @testset "POOL_DEBUG with safety violation" begin + old_debug = POOL_DEBUG[] + POOL_DEBUG[] = true + + # Should throw error when returning SubArray with debug on + @test_throws PoolRuntimeEscapeError @with_pool pool begin + v = acquire!(pool, Float64, 10) + _test_leak(v) # opaque to compile-time checker; caught by runtime LV2 + end + + # Safe returns should work fine + result = @with_pool pool begin + v = acquire!(pool, Float64, 10) + v .= 1.0 + sum(v) # Safe: returning scalar + end + @test result == 10.0 + + # Returning a copy is also safe + result = @with_pool pool begin + v = acquire!(pool, Float64, 5) + v .= 2.0 + collect(v) # Safe: returning a copy + end + @test result == [2.0, 2.0, 2.0, 2.0, 2.0] + + POOL_DEBUG[] = old_debug + end + + # ============================================================================== + # _validate_pool_return — direct tests + # ============================================================================== + + @testset "_validate_pool_return" begin + pool = AdaptiveArrayPool() + checkpoint!(pool) + + # Non-SubArray values pass validation + _validate_pool_return(42, pool) + _validate_pool_return([1, 2, 3], pool) + _validate_pool_return("hello", pool) + _validate_pool_return(nothing, pool) + + # SubArray not from pool passes validation + external_vec = [1.0, 2.0, 3.0] + external_view = view(external_vec, 1:2) + _validate_pool_return(external_view, pool) + + # SubArray from pool fails validation (fixed slot) + pool_view = acquire!(pool, Float64, 10) + @test_throws PoolRuntimeEscapeError _validate_pool_return(pool_view, pool) + + rewind!(pool) + + # Test with fallback type (others) + checkpoint!(pool) + pool_view_uint8 = acquire!(pool, UInt8, 10) + @test_throws PoolRuntimeEscapeError _validate_pool_return(pool_view_uint8, pool) + rewind!(pool) + + # DisabledPool always passes + _validate_pool_return(pool_view, DISABLED_CPU) + _validate_pool_return(42, DISABLED_CPU) + end + + @testset "_validate_pool_return with all fixed slots" begin + pool = AdaptiveArrayPool() + checkpoint!(pool) + + # Test each fixed slot type + v_f64 = acquire!(pool, Float64, 5) + v_f32 = acquire!(pool, Float32, 5) + v_i64 = acquire!(pool, Int64, 5) + v_i32 = acquire!(pool, Int32, 5) + v_c64 = acquire!(pool, ComplexF64, 5) + v_c32 = acquire!(pool, ComplexF32, 5) + v_bool = acquire!(pool, Bool, 5) + + @test_throws PoolRuntimeEscapeError _validate_pool_return(v_f64, pool) + @test_throws PoolRuntimeEscapeError _validate_pool_return(v_f32, pool) + @test_throws PoolRuntimeEscapeError _validate_pool_return(v_i64, pool) + @test_throws PoolRuntimeEscapeError _validate_pool_return(v_i32, pool) + @test_throws PoolRuntimeEscapeError _validate_pool_return(v_c64, pool) + @test_throws PoolRuntimeEscapeError _validate_pool_return(v_c32, pool) + @test_throws PoolRuntimeEscapeError _validate_pool_return(v_bool, pool) + + rewind!(pool) + end + + @testset "_validate_pool_return with N-D arrays" begin + pool = AdaptiveArrayPool() + checkpoint!(pool) + + # N-D ReshapedArray from pool should fail validation (pointer overlap check) + mat = acquire!(pool, Float64, 10, 10) + @test mat isa Base.ReshapedArray{Float64, 2} + @test_throws PoolRuntimeEscapeError _validate_pool_return(mat, pool) + + # 3D ReshapedArray should also fail + tensor = acquire!(pool, Float64, 5, 5, 5) + @test tensor isa Base.ReshapedArray{Float64, 3} + @test_throws PoolRuntimeEscapeError _validate_pool_return(tensor, pool) + + rewind!(pool) + end + + @testset "_validate_pool_return with unsafe_acquire!" begin + pool = AdaptiveArrayPool() + checkpoint!(pool) + + # Raw Vector from unsafe_acquire! should fail validation + v = unsafe_acquire!(pool, Float64, 100) + @test v isa Vector{Float64} + @test_throws PoolRuntimeEscapeError _validate_pool_return(v, pool) + + # Raw Matrix from unsafe_acquire! should fail validation + mat = unsafe_acquire!(pool, Float64, 10, 10) + @test mat isa Matrix{Float64} + @test_throws PoolRuntimeEscapeError _validate_pool_return(mat, pool) + + # Raw 3D Array from unsafe_acquire! should fail validation + tensor = unsafe_acquire!(pool, Float64, 5, 5, 5) + @test tensor isa Array{Float64, 3} + @test_throws PoolRuntimeEscapeError _validate_pool_return(tensor, pool) + + rewind!(pool) + end + + @testset "_validate_pool_return with view(unsafe_acquire!)" begin + # Bug fix test: view() wrapped around unsafe_acquire! result + # The parent Vector/Array is created by unsafe_wrap, not the pool's internal vector + # This requires pointer overlap check, not identity check + pool = AdaptiveArrayPool() + checkpoint!(pool) + + # 1D: view(unsafe_acquire!(...), :) should fail validation + v = unsafe_acquire!(pool, Float64, 100) + v_view = view(v, :) + @test v_view isa SubArray + @test parent(v_view) === v # Parent is unsafe_wrap'd Vector, not pool's internal vector + @test_throws PoolRuntimeEscapeError _validate_pool_return(v_view, pool) + + # Partial view should also fail + v_partial = view(v, 1:50) + @test_throws PoolRuntimeEscapeError _validate_pool_return(v_partial, pool) + + # 2D: view(unsafe_acquire!(...), :, :) should fail validation + mat = unsafe_acquire!(pool, Float64, 10, 10) + mat_view = view(mat, :, :) + @test mat_view isa SubArray + @test_throws PoolRuntimeEscapeError _validate_pool_return(mat_view, pool) + + rewind!(pool) + end + + @testset "_validate_pool_return external arrays pass" begin + pool = AdaptiveArrayPool() + checkpoint!(pool) + + # Acquire some memory to populate the pool + _ = acquire!(pool, Float64, 100) + + # External N-D arrays should pass validation + external_mat = zeros(Float64, 10, 10) + external_view = view(external_mat, :, :) + _validate_pool_return(external_view, pool) + _validate_pool_return(external_mat, pool) + + # External 3D array should pass + external_tensor = zeros(Float64, 5, 5, 5) + _validate_pool_return(external_tensor, pool) + + rewind!(pool) + end + + @testset "POOL_DEBUG with N-D arrays" begin + old_debug = POOL_DEBUG[] + POOL_DEBUG[] = true + + # N-D ReshapedArray should throw error when returned + @test_throws PoolRuntimeEscapeError @with_pool pool begin + mat = acquire!(pool, Float64, 10, 10) + _test_leak(mat) # opaque to compile-time checker; caught by runtime LV2 + end + + # Raw Array from unsafe_acquire! should throw error when returned + @test_throws PoolRuntimeEscapeError @with_pool pool begin + mat = unsafe_acquire!(pool, Float64, 10, 10) + _test_leak(mat) # opaque to compile-time checker; caught by runtime LV2 + end + + # Safe returns should work fine + result = @with_pool pool begin + mat = acquire!(pool, Float64, 10, 10) + mat .= 1.0 + sum(mat) # Safe: returning scalar + end + @test result == 100.0 + + # Returning a copy is also safe + result = @with_pool pool begin + mat = acquire!(pool, Float64, 3, 3) + mat .= 2.0 + collect(mat) # Safe: returning a copy + end + @test result == fill(2.0, 3, 3) + + POOL_DEBUG[] = old_debug + end + + # ============================================================================== + # BitArray overlap detection (_check_bitchunks_overlap) + # ============================================================================== + + @testset "_check_bitchunks_overlap - direct BitArray validation" begin + pool = AdaptiveArrayPool() + checkpoint!(pool) + + # 1D BitVector from pool - should detect overlap + bv = acquire!(pool, Bit, 100) + @test bv isa BitVector + @test_throws PoolRuntimeEscapeError _check_bitchunks_overlap(bv, pool) + + # N-D BitArray from pool - should detect overlap (shares chunks with pool) + ba = acquire!(pool, Bit, 10, 10) + @test ba isa BitMatrix + @test_throws PoolRuntimeEscapeError _check_bitchunks_overlap(ba, pool) + + # 3D BitArray from pool + ba3 = acquire!(pool, Bit, 4, 5, 3) + @test ba3 isa BitArray{3} + @test_throws PoolRuntimeEscapeError _check_bitchunks_overlap(ba3, pool) + + rewind!(pool) + end + + @testset "_check_bitchunks_overlap - external BitArray passes" begin + pool = AdaptiveArrayPool() + checkpoint!(pool) + + # Populate pool with some BitVectors + _ = acquire!(pool, Bit, 100) + _ = acquire!(pool, Bit, 200) + + # External BitVector (not from pool) should pass validation + external_bv = BitVector(undef, 50) + _check_bitchunks_overlap(external_bv, pool) # Should not throw + + # External BitMatrix should pass + external_ba = BitArray(undef, 10, 10) + _check_bitchunks_overlap(external_ba, pool) # Should not throw + + # External 3D BitArray should pass + external_ba3 = BitArray(undef, 5, 5, 5) + _check_bitchunks_overlap(external_ba3, pool) # Should not throw + + rewind!(pool) + end + + @testset "_validate_pool_return with BitArray (via _check_bitchunks_overlap)" begin + pool = AdaptiveArrayPool() + checkpoint!(pool) + + # Direct BitVector from pool fails validation + bv = acquire!(pool, Bit, 100) + @test_throws PoolRuntimeEscapeError _validate_pool_return(bv, pool) + + # Direct BitMatrix from pool fails validation + ba = acquire!(pool, Bit, 10, 10) + @test_throws PoolRuntimeEscapeError _validate_pool_return(ba, pool) + + # External BitArray passes validation + external_bv = BitVector(undef, 50) + _validate_pool_return(external_bv, pool) # Should not throw + + rewind!(pool) + end + + @testset "_validate_pool_return with SubArray{BitArray} parent" begin + pool = AdaptiveArrayPool() + checkpoint!(pool) + + # Create a view of a pool BitVector + bv = acquire!(pool, Bit, 100) + bv_view = view(bv, 1:50) + @test bv_view isa SubArray + @test parent(bv_view) isa BitVector + @test_throws PoolRuntimeEscapeError _validate_pool_return(bv_view, pool) + + # View of external BitVector should pass + external_bv = BitVector(undef, 100) + external_view = view(external_bv, 1:50) + _validate_pool_return(external_view, pool) # Should not throw + + rewind!(pool) + end + + @testset "POOL_DEBUG with BitArray" begin + old_debug = POOL_DEBUG[] + POOL_DEBUG[] = true + + # BitVector from pool should throw error when returned with debug on + @test_throws PoolRuntimeEscapeError @with_pool pool begin + bv = acquire!(pool, Bit, 100) + _test_leak(bv) # opaque to compile-time checker; caught by runtime LV2 + end + + # BitMatrix from pool should throw error when returned + @test_throws PoolRuntimeEscapeError @with_pool pool begin + ba = acquire!(pool, Bit, 10, 10) + _test_leak(ba) # opaque to compile-time checker; caught by runtime LV2 + end + + # Safe returns should work fine + result = @with_pool pool begin + bv = acquire!(pool, Bit, 100) + bv .= true + count(bv) # Safe: returning scalar + end + @test result == 100 + + # Returning a copy is also safe + result = @with_pool pool begin + bv = acquire!(pool, Bit, 5) + bv .= true + copy(bv) # Safe: returning a copy + end + @test result == trues(5) + + POOL_DEBUG[] = old_debug + end + + # ============================================================================== + # POOL_DEBUG with function definition forms + # ============================================================================== + + # ============================================================================== + # _validate_pool_return — recursive container inspection (Tuple, NamedTuple, Pair) + # ============================================================================== + + @testset "_validate_pool_return with Tuple" begin + pool = AdaptiveArrayPool() + checkpoint!(pool) + + v = acquire!(pool, Float64, 10) + + # Pool array inside tuple → caught + @test_throws PoolRuntimeEscapeError _validate_pool_return((42, v), pool) + @test_throws PoolRuntimeEscapeError _validate_pool_return((v,), pool) + + # Nested tuple: pool array deep inside → caught + @test_throws PoolRuntimeEscapeError _validate_pool_return((1, (2, v)), pool) + + # Safe tuple (no pool arrays) → passes + _validate_pool_return((1, 2, 3), pool) + _validate_pool_return((1, "hello", [1, 2, 3]), pool) + + rewind!(pool) + end + + @testset "_validate_pool_return with NamedTuple" begin + pool = AdaptiveArrayPool() + checkpoint!(pool) + + v = acquire!(pool, Float64, 10) + + # Pool array inside NamedTuple → caught + @test_throws PoolRuntimeEscapeError _validate_pool_return((data = v, n = 10), pool) + @test_throws PoolRuntimeEscapeError _validate_pool_return((result = 42, buffer = v), pool) + + # Nested: NamedTuple containing tuple with pool array + @test_throws PoolRuntimeEscapeError _validate_pool_return((meta = (v, 1),), pool) + + # Safe NamedTuple → passes + _validate_pool_return((a = 1, b = "hello"), pool) + + rewind!(pool) + end + + @testset "_validate_pool_return with Pair" begin + pool = AdaptiveArrayPool() + checkpoint!(pool) + + v = acquire!(pool, Float64, 10) + + # Pool array as Pair value → caught + @test_throws PoolRuntimeEscapeError _validate_pool_return(:data => v, pool) + + # Pool array as Pair key (unusual but possible) → caught + @test_throws PoolRuntimeEscapeError _validate_pool_return(v => :data, pool) + + # Safe Pair → passes + _validate_pool_return(:a => 42, pool) + + rewind!(pool) + end + + @testset "_validate_pool_return recursive with mixed containers" begin + pool = AdaptiveArrayPool() + checkpoint!(pool) + + v = acquire!(pool, Float64, 10) + bv = acquire!(pool, Bit, 50) + + # Tuple containing NamedTuple with pool array + @test_throws PoolRuntimeEscapeError _validate_pool_return((1, (data = v,)), pool) + + # Pair inside tuple + @test_throws PoolRuntimeEscapeError _validate_pool_return((:key => v, 42), pool) + + # BitVector inside tuple + @test_throws PoolRuntimeEscapeError _validate_pool_return((bv, 1), pool) + + # Multiple pool arrays in different container positions + @test_throws PoolRuntimeEscapeError _validate_pool_return((v, bv), pool) + + # N-D ReshapedArray inside NamedTuple + mat = acquire!(pool, Float64, 5, 5) + @test_throws PoolRuntimeEscapeError _validate_pool_return((matrix = mat, size = (5, 5)), pool) + + rewind!(pool) + end + + # ============================================================================== + # _validate_pool_return — Dict, Set, and Vector-of-arrays container inspection + # ============================================================================== + + @testset "_eltype_may_contain_arrays guard" begin + @test _eltype_may_contain_arrays(Float64) == false + @test _eltype_may_contain_arrays(Int32) == false + @test _eltype_may_contain_arrays(ComplexF64) == false + @test _eltype_may_contain_arrays(String) == false + @test _eltype_may_contain_arrays(Symbol) == false + @test _eltype_may_contain_arrays(Char) == false + @test _eltype_may_contain_arrays(Any) == true + @test _eltype_may_contain_arrays(SubArray) == true + @test _eltype_may_contain_arrays(AbstractArray) == true + @test _eltype_may_contain_arrays(Vector{Float64}) == true + end + + @testset "_validate_pool_return with Dict" begin + pool = AdaptiveArrayPool() + checkpoint!(pool) + + v = acquire!(pool, Float64, 10) + + # Pool array as Dict value → caught + @test_throws PoolRuntimeEscapeError _validate_pool_return(Dict(:data => v), pool) + + # Pool array as Dict key (unusual but possible) → caught + @test_throws PoolRuntimeEscapeError _validate_pool_return(Dict(v => :data), pool) + + # Multiple pool arrays in Dict values + w = acquire!(pool, Int64, 5) + @test_throws PoolRuntimeEscapeError _validate_pool_return(Dict(:a => v, :b => w), pool) + + # Safe Dict → passes + _validate_pool_return(Dict(:a => 1, :b => 2), pool) + _validate_pool_return(Dict{String, Float64}("x" => 1.0), pool) + + rewind!(pool) + end + + @testset "_validate_pool_return with nested Dict" begin + pool = AdaptiveArrayPool() + checkpoint!(pool) + + v = acquire!(pool, Float64, 10) + + # Dict inside Tuple → caught + @test_throws PoolRuntimeEscapeError _validate_pool_return((1, Dict(:data => v)), pool) + + # Dict inside NamedTuple → caught + @test_throws PoolRuntimeEscapeError _validate_pool_return((result = Dict(:buf => v),), pool) + + # Nested Dict (Dict of Dict) → caught + @test_throws PoolRuntimeEscapeError _validate_pool_return(Dict(:outer => Dict(:inner => v)), pool) + + rewind!(pool) + end + + @testset "_validate_pool_return with Set" begin + pool = AdaptiveArrayPool() + checkpoint!(pool) + + v = acquire!(pool, Float64, 10) + + # Pool array inside Set → caught + @test_throws PoolRuntimeEscapeError _validate_pool_return(Set([v]), pool) + + # Safe Set → passes + _validate_pool_return(Set([1, 2, 3]), pool) + + rewind!(pool) + end + + @testset "_validate_pool_return with Vector-of-arrays (element recursion)" begin + pool = AdaptiveArrayPool() + checkpoint!(pool) + + v = acquire!(pool, Float64, 10) + + # Vector{SubArray} — pool array as element → caught + external_container = Any[v] + @test_throws PoolRuntimeEscapeError _validate_pool_return(external_container, pool) + + # Multiple pool arrays in Vector + w = acquire!(pool, Int64, 5) + @test_throws PoolRuntimeEscapeError _validate_pool_return(Any[v, w], pool) + + # Nested: Vector inside Tuple + @test_throws PoolRuntimeEscapeError _validate_pool_return((42, Any[v]), pool) + + # Safe Vector{Float64} — passes (eltype guard skips element iteration) + _validate_pool_return([1.0, 2.0, 3.0], pool) + _validate_pool_return(zeros(1000), pool) # large but still fast (eltype guard) + + # Vector{Any} with safe values — passes + _validate_pool_return(Any[1, "hello", :sym], pool) + + rewind!(pool) + end + + @testset "_validate_pool_return Vector-of-arrays with unsafe_acquire!" begin + pool = AdaptiveArrayPool() + checkpoint!(pool) + + # unsafe_acquire! Array inside Vector → caught + raw = unsafe_acquire!(pool, Float64, 100) + @test_throws PoolRuntimeEscapeError _validate_pool_return(Any[raw], pool) + + # BitVector inside Vector → caught + bv = acquire!(pool, Bit, 50) + @test_throws PoolRuntimeEscapeError _validate_pool_return(Any[bv], pool) + + # ReshapedArray inside Vector → caught + mat = acquire!(pool, Float64, 5, 5) + @test_throws PoolRuntimeEscapeError _validate_pool_return(Any[mat], pool) + + rewind!(pool) + end + + @testset "_validate_pool_return containers via @with_pool macro (LV2)" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 2 + + # Tuple containing pool array — caught at runtime + @test_throws PoolRuntimeEscapeError @with_pool pool begin + v = acquire!(pool, Float64, 10) + _test_leak((sum(v), v)) # opaque to compile-time checker; runtime LV2 catches v inside tuple + end + + # NamedTuple containing pool array — caught at runtime + @test_throws PoolRuntimeEscapeError @with_pool pool begin + v = acquire!(pool, Float64, 10) + _test_leak((data = v, n = 10)) # opaque to compile-time checker; runtime LV2 catches v inside NamedTuple + end + + # Safe containers pass + result = @with_pool pool begin + v = acquire!(pool, Float64, 10) + v .= 3.0 + (sum(v), length(v)) + end + @test result == (30.0, 10) + + POOL_SAFETY_LV[] = old_safety + end + + # ============================================================================== + # Runtime LV2 escape detection through opaque function calls + # ============================================================================== + + @testset "Runtime LV2 catches escapes through opaque function calls" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 2 + + # Opaque function call bypasses compile-time PoolEscapeError, + # but runtime _validate_pool_return at LV2 still catches the escape. + @test_throws PoolRuntimeEscapeError @with_pool pool begin + v = acquire!(pool, Float64, 10) + _test_leak(v) # opaque to compile-time checker; runtime LV2 catches + end + + # Multiple vars: opaque call still caught at runtime + @test_throws PoolRuntimeEscapeError @with_pool pool begin + v = acquire!(pool, Float64, 10) + w = acquire!(pool, Float64, 5) + _test_leak(v) + end + + # Safe return works fine + result = @with_pool pool begin + v = acquire!(pool, Float64, 10) + v .= 1.0 + sum(v) # scalar — safe + end + @test result == 10.0 + + POOL_SAFETY_LV[] = old_safety + end + + @testset "LV1 does not perform runtime escape check" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 1 + + # At LV1, opaque call bypasses compile-time and runtime doesn't check escapes + # (only structural invalidation), so the SubArray escapes silently. + result = @with_pool pool begin + v = acquire!(pool, Float64, 10) + _test_leak(v) + end + @test result isa SubArray # Escapes — no runtime check at LV1 + + POOL_SAFETY_LV[] = old_safety + end + + # ============================================================================== + # POOL_DEBUG with function definition forms + # ============================================================================== + + @testset "POOL_DEBUG with @with_pool function definition" begin + old_debug = POOL_DEBUG[] + POOL_DEBUG[] = true + + # Unsafe: function returns pool-backed SubArray + @with_pool pool function _test_debug_func_unsafe(n) + v = acquire!(pool, Float64, n) + v .= 1.0 + _test_leak(v) # opaque to compile-time checker; caught by runtime LV2 + end + @test_throws PoolRuntimeEscapeError _test_debug_func_unsafe(10) + + # Safe: function returns scalar + @with_pool pool function _test_debug_func_safe(n) + v = acquire!(pool, Float64, n) + v .= 1.0 + sum(v) + end + @test _test_debug_func_safe(10) == 10.0 + + # Safe: function returns a copy + @with_pool pool function _test_debug_func_copy(n) + v = acquire!(pool, Float64, n) + v .= 2.0 + collect(v) + end + @test _test_debug_func_copy(5) == fill(2.0, 5) + + # Unsafe: N-D ReshapedArray from function + @with_pool pool function _test_debug_func_nd(m, n) + mat = acquire!(pool, Float64, m, n) + mat .= 1.0 + _test_leak(mat) # opaque to compile-time checker; caught by runtime LV2 + end + @test_throws PoolRuntimeEscapeError _test_debug_func_nd(3, 4) + + # Unsafe: BitVector from function + @with_pool pool function _test_debug_func_bit(n) + bv = acquire!(pool, Bit, n) + bv .= true + _test_leak(bv) # opaque to compile-time checker; caught by runtime LV2 + end + @test_throws PoolRuntimeEscapeError _test_debug_func_bit(100) + + POOL_DEBUG[] = old_debug + end + + @testset "POOL_DEBUG with @maybe_with_pool function definition" begin + old_debug = POOL_DEBUG[] + old_maybe = MAYBE_POOLING[] + POOL_DEBUG[] = true + MAYBE_POOLING[] = true + + # Unsafe: function returns pool-backed array + @maybe_with_pool pool function _test_maybe_debug_unsafe(n) + v = acquire!(pool, Float64, n) + v .= 1.0 + _test_leak(v) # opaque to compile-time checker; caught by runtime LV2 + end + @test_throws PoolRuntimeEscapeError _test_maybe_debug_unsafe(10) + + # Safe: function returns scalar + @maybe_with_pool pool function _test_maybe_debug_safe(n) + v = acquire!(pool, Float64, n) + v .= 1.0 + sum(v) + end + @test _test_maybe_debug_safe(10) == 10.0 + + # When pooling disabled, no validation needed (DisabledPool returns fresh arrays) + MAYBE_POOLING[] = false + @maybe_with_pool pool function _test_maybe_debug_disabled(n) + v = zeros!(pool, n) + _test_leak(v) # opaque to compile-time checker; disabled pool returns fresh arrays + end + result = _test_maybe_debug_disabled(5) + @test result == zeros(5) + + POOL_DEBUG[] = old_debug + MAYBE_POOLING[] = old_maybe + end + + @testset "POOL_DEBUG with @with_pool :cpu function definition" begin + old_debug = POOL_DEBUG[] + POOL_DEBUG[] = true + + # Unsafe: backend function returns pool-backed array + @with_pool :cpu pool function _test_backend_debug_unsafe(n) + v = acquire!(pool, Float64, n) + v .= 1.0 + _test_leak(v) # opaque to compile-time checker; caught by runtime LV2 + end + @test_throws PoolRuntimeEscapeError _test_backend_debug_unsafe(10) + + # Safe: returns scalar + @with_pool :cpu pool function _test_backend_debug_safe(n) + v = acquire!(pool, Float64, n) + v .= 1.0 + sum(v) + end + @test _test_backend_debug_safe(10) == 10.0 + + POOL_DEBUG[] = old_debug + end + + # ============================================================================== + # Coverage: PoolRuntimeEscapeError showerror with return_site (LV3) + # ============================================================================== + + @testset "PoolRuntimeEscapeError showerror with return_site" begin + # Construct error with both callsite and return_site to cover lines 169-180 + err = PoolRuntimeEscapeError( + "SubArray{Float64,1}", + "Float64", + "test.jl:10\nacquire!(pool, Float64, 10)", + "test.jl:15\nreturn v" + ) + msg = sprint(showerror, err) + @test occursin("escapes at", msg) + @test occursin("return v", msg) + @test occursin("acquired at", msg) + + # Return site without expression text (no \n) + err2 = PoolRuntimeEscapeError( + "SubArray{Float64,1}", + "Float64", + "test.jl:10", + "test.jl:15" + ) + msg2 = sprint(showerror, err2) + @test occursin("escapes at", msg2) + end + + # ============================================================================== + # Coverage: PoolRuntimeEscapeError 3-arg showerror (backtrace suppression) + # ============================================================================== + + @testset "PoolRuntimeEscapeError 3-arg showerror" begin + err = PoolRuntimeEscapeError("Vector{Float64}", "Float64", nothing, nothing) + msg2 = sprint(showerror, err) + msg3 = sprint() do io + showerror(io, err, nothing) + end + @test msg2 == msg3 + end + + # ============================================================================== + # Coverage: _poison_value generic fallback (line 258) + # ============================================================================== + + @testset "_poison_value generic fallback" begin + # Rational is not AbstractFloat, Integer, or Complex → hits generic fallback + @test _poison_value(Rational{Int}) == zero(Rational{Int}) + + # Exercise through actual pool rewind at LV≥2 with a non-fixed-slot type + old_lv = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 2 + pool = AdaptiveArrayPool() + checkpoint!(pool) + v = acquire!(pool, Rational{Int}, 5) + v .= 1 // 3 + rewind!(pool) # triggers _poison_fill! → _poison_value(Rational{Int}) → zero(Rational) + POOL_SAFETY_LV[] = old_lv + end + + # ============================================================================== + # Coverage: _shorten_location no-colon fallback (line 304) + # ============================================================================== + + @testset "_shorten_location edge cases" begin + # Location without colon → returned as-is + @test _shorten_location("nocolon") == "nocolon" + # Location with colon → shortened + loc = _shorten_location("somefile.jl:42") + @test occursin("42", loc) + end + +end # POOL_DEBUG and Safety Validation diff --git a/test/test_safety.jl b/test/test_safety.jl new file mode 100644 index 00000000..8fe6354e --- /dev/null +++ b/test/test_safety.jl @@ -0,0 +1,471 @@ +import AdaptiveArrayPools: _invalidate_released_slots!, PoolRuntimeEscapeError + +# Opaque identity — defeats compile-time escape analysis without @skip_check_vars +_test_leak(x) = x + +@testset "POOL_SAFETY_LV Guard-Level Invalidation" begin + + # ============================================================================== + # Default values + # ============================================================================== + + @testset "Default configuration" begin + @test STATIC_POOL_CHECKS == true + @test POOL_SAFETY_LV[] == 1 + end + + # ============================================================================== + # Level 1: acquire! SubArray invalidation + # ============================================================================== + + @testset "acquire! SubArray invalidated on rewind" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 1 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + v = acquire!(pool, Float64, 10) + v .= 42.0 # write to confirm it's valid before rewind + rewind!(pool) + + # Backing vector resized to 0 -> SubArray parent is length 0 + @test length(parent(v)) == 0 + + # Accessing stale SubArray should throw BoundsError + @test_throws BoundsError v[1] + + POOL_SAFETY_LV[] = old_safety + end + + @testset "acquire! N-D ReshapedArray invalidated on rewind" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 1 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + mat = acquire!(pool, Float64, 5, 5) + mat .= 1.0 + rewind!(pool) + + # Parent chain: ReshapedArray -> SubArray -> Vector (now length 0) + @test length(parent(parent(mat))) == 0 + @test_throws BoundsError mat[1, 1] + + POOL_SAFETY_LV[] = old_safety + end + + # ============================================================================== + # Level 1: unsafe_acquire! Array wrapper invalidation (Julia 1.11+ only) + # On Julia 1.10, Array is a C struct — setfield!(:size) is not available. + # ============================================================================== + + @static if VERSION >= v"1.11-" + @testset "unsafe_acquire! Array wrapper invalidated on rewind" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 1 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + arr = unsafe_acquire!(pool, Float64, 10) + arr .= 99.0 + @test size(arr) == (10,) + rewind!(pool) + + # Wrapper size set to (0,) via setfield! + @test size(arr) == (0,) + @test_throws BoundsError arr[1] + + POOL_SAFETY_LV[] = old_safety + end + + @testset "unsafe_acquire! N-D Array wrapper invalidated on rewind" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 1 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + mat = unsafe_acquire!(pool, Float64, 4, 3) + mat .= 1.0 + @test size(mat) == (4, 3) + rewind!(pool) + + @test size(mat) == (0, 0) + @test_throws BoundsError mat[1, 1] + + POOL_SAFETY_LV[] = old_safety + end + end + + # ============================================================================== + # Level 1: BitArray invalidation + # ============================================================================== + + @testset "acquire! BitVector invalidated on rewind" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 1 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + bv = acquire!(pool, Bit, 100) + bv .= true + rewind!(pool) + + # BitVector backing resized to 0 + @test length(pool.bits.vectors[1]) == 0 + # Accessing stale BitVector - len was set to 0 via setfield! + @test length(bv) == 0 + + POOL_SAFETY_LV[] = old_safety + end + + @testset "acquire! BitMatrix invalidated on rewind" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 1 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + ba = acquire!(pool, Bit, 8, 8) + ba .= true + @test size(ba) == (8, 8) + rewind!(pool) + + # BitArray dims set to (0, 0), len set to 0 + @test size(ba) == (0, 0) + @test length(ba) == 0 + + POOL_SAFETY_LV[] = old_safety + end + + # ============================================================================== + # Level 0: No invalidation + # ============================================================================== + + @testset "POOL_SAFETY_LV=0 bypasses invalidation" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 0 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + v = acquire!(pool, Float64, 10) + v .= 7.0 + rewind!(pool) + + # With safety off, backing vector still has length >= 10 + @test length(parent(v)) >= 10 + # Stale access works (this is the unsafe behavior we're protecting against) + @test v[1] == 7.0 + + POOL_SAFETY_LV[] = old_safety + end + + # ============================================================================== + # Re-acquire after invalidation (zero-alloc round-trip) + # ============================================================================== + + @testset "Re-acquire after invalidation restores vectors" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 1 + + pool = AdaptiveArrayPool() + + # First cycle: populate pool + checkpoint!(pool) + v1 = acquire!(pool, Float64, 100) + v1 .= 1.0 + rewind!(pool) + + # Vectors invalidated + @test length(pool.float64.vectors[1]) == 0 + + # Second cycle: re-acquire uses same slot, restores capacity + checkpoint!(pool) + v2 = acquire!(pool, Float64, 50) + v2 .= 2.0 + @test length(parent(v2)) >= 50 + @test v2[1] == 2.0 + # Same backing vector object (capacity preserved through resize round-trip) + @test parent(v2) === pool.float64.vectors[1] + rewind!(pool) + + POOL_SAFETY_LV[] = old_safety + end + + @static if VERSION >= v"1.11-" + @testset "Re-acquire unsafe_acquire! after invalidation" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 1 + + pool = AdaptiveArrayPool() + + # First cycle + checkpoint!(pool) + arr = unsafe_acquire!(pool, Float64, 20) + arr .= 3.0 + rewind!(pool) + @test size(arr) == (0,) + + # Second cycle: wrapper reused, size restored + checkpoint!(pool) + arr2 = unsafe_acquire!(pool, Float64, 15) + @test size(arr2) == (15,) + arr2 .= 4.0 + @test arr2[1] == 4.0 + rewind!(pool) + + POOL_SAFETY_LV[] = old_safety + end + end + + # ============================================================================== + # Nested scopes: inner invalidation doesn't affect outer + # ============================================================================== + + @testset "Nested checkpoint/rewind: inner invalidated, outer valid" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 1 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + v_outer = acquire!(pool, Float64, 10) + v_outer .= 1.0 + + # Inner scope + checkpoint!(pool) + v_inner = acquire!(pool, Float64, 20) + v_inner .= 2.0 + rewind!(pool) + + # Inner is invalidated (slot 2 released) + @test length(parent(v_inner)) == 0 + + # Outer is still valid (slot 1 not released) + @test length(parent(v_outer)) >= 10 + @test v_outer[1] == 1.0 + + rewind!(pool) + + # Now outer is also invalidated + @test length(parent(v_outer)) == 0 + + POOL_SAFETY_LV[] = old_safety + end + + # ============================================================================== + # reset! invalidation + # ============================================================================== + + @testset "reset! invalidates all active slots" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 1 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + v1 = acquire!(pool, Float64, 10) + v2 = acquire!(pool, Float64, 20) + v1 .= 1.0 + v2 .= 2.0 + + reset!(pool.float64) + + @test pool.float64.n_active == 0 + @test length(pool.float64.vectors[1]) == 0 + @test length(pool.float64.vectors[2]) == 0 + + POOL_SAFETY_LV[] = old_safety + end + + # ============================================================================== + # Fallback types (pool.others) + # ============================================================================== + + @testset "Fallback type invalidation" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 1 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + v = acquire!(pool, UInt8, 50) + v .= 0xff + @test length(parent(v)) >= 50 + rewind!(pool) + + # Fallback type also invalidated + tp = pool.others[UInt8] + @test length(tp.vectors[1]) == 0 + @test length(parent(v)) == 0 + + POOL_SAFETY_LV[] = old_safety + end + + # ============================================================================== + # POOL_DEBUG backward compatibility + # ============================================================================== + + @testset "POOL_DEBUG backward compat with POOL_SAFETY_LV" begin + old_debug = POOL_DEBUG[] + old_safety = POOL_SAFETY_LV[] + + # POOL_DEBUG=true still triggers escape detection (regardless of POOL_SAFETY_LV) + POOL_DEBUG[] = true + POOL_SAFETY_LV[] = 0 + @test_throws PoolRuntimeEscapeError @with_pool pool begin + v = acquire!(pool, Float64, 10) + _test_leak(v) # bypasses compile-time check; caught by runtime LV2 + end + + # POOL_SAFETY_LV=2 also triggers escape detection (without POOL_DEBUG) + POOL_DEBUG[] = false + POOL_SAFETY_LV[] = 2 + @test_throws PoolRuntimeEscapeError @with_pool pool begin + v = acquire!(pool, Float64, 10) + _test_leak(v) # bypasses compile-time check; caught by runtime LV2 + end + + # Neither flag -> no escape detection + POOL_DEBUG[] = false + POOL_SAFETY_LV[] = 1 + result = @with_pool pool begin + v = acquire!(pool, Float64, 10) + _test_leak(v) # bypasses compile-time check; runtime LV<2 won't catch + end + @test result isa SubArray + + POOL_DEBUG[] = old_debug + POOL_SAFETY_LV[] = old_safety + end + + # ============================================================================== + # Multiple types in same scope + # ============================================================================== + + @testset "Multiple types invalidated together" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 1 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + vf = acquire!(pool, Float64, 10) + vi = acquire!(pool, Int64, 20) + vb = acquire!(pool, Bit, 30) + vf .= 1.0 + vi .= 2 + vb .= true + rewind!(pool) + + @test length(parent(vf)) == 0 + @test length(parent(vi)) == 0 + @test length(vb) == 0 + + POOL_SAFETY_LV[] = old_safety + end + + # ============================================================================== + # @with_pool macro integration + # ============================================================================== + + @testset "@with_pool invalidates on scope exit" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 1 + + pool_ref = Ref{AdaptiveArrayPool}() + stale_ref = Ref{Any}() + + result = @with_pool pool begin + pool_ref[] = pool + v = acquire!(pool, Float64, 10) + v .= 5.0 + stale_ref[] = v + sum(v) # Safe scalar return + end + + @test result == 50.0 + # After @with_pool exits, the pool's vectors should be invalidated + v = stale_ref[] + @test length(parent(v)) == 0 + @test_throws BoundsError v[1] + + POOL_SAFETY_LV[] = old_safety + end + + # ============================================================================== + # Level 2: Poisoning (NaN/sentinel fill before structural invalidation) + # ============================================================================== + + @testset "Level 2: Float64 poisoned with NaN on rewind" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 2 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + v = acquire!(pool, Float64, 10) + v .= 42.0 + rewind!(pool) + + # Re-acquire: backing vector was poisoned with NaN before resize!(v,0). + # resize! round-trip (0→10) preserves capacity, NaN data survives. + checkpoint!(pool) + v2 = acquire!(pool, Float64, 10) + @test all(isnan, v2) + rewind!(pool) + + POOL_SAFETY_LV[] = old_safety + end + + @testset "Level 2: Int64 poisoned with typemax on rewind" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 2 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + v = acquire!(pool, Int64, 10) + v .= 42 + rewind!(pool) + + checkpoint!(pool) + v2 = acquire!(pool, Int64, 10) + @test all(==(typemax(Int64)), v2) + rewind!(pool) + + POOL_SAFETY_LV[] = old_safety + end + + @testset "Level 2: ComplexF64 poisoned with NaN+NaN*im on rewind" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 2 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + v = acquire!(pool, ComplexF64, 8) + v .= 1.0 + 2.0im + rewind!(pool) + + checkpoint!(pool) + v2 = acquire!(pool, ComplexF64, 8) + @test all(z -> isnan(real(z)) && isnan(imag(z)), v2) + rewind!(pool) + + POOL_SAFETY_LV[] = old_safety + end + + @testset "Level 1 does NOT poison" begin + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 1 + + pool = AdaptiveArrayPool() + checkpoint!(pool) + v = acquire!(pool, Float64, 10) + v .= 42.0 + rewind!(pool) + + # At level 1, only resize (no poison). Re-acquire restores length, + # data is whatever was in memory — should still be 42.0 (not NaN). + checkpoint!(pool) + v2 = acquire!(pool, Float64, 10) + @test !any(isnan, v2) + @test v2[1] == 42.0 + rewind!(pool) + + POOL_SAFETY_LV[] = old_safety + end + +end # POOL_SAFETY_LV Guard-Level Invalidation diff --git a/test/test_state.jl b/test/test_state.jl index 9f528b9f..2393a1ca 100644 --- a/test/test_state.jl +++ b/test/test_state.jl @@ -38,12 +38,16 @@ import AdaptiveArrayPools: _typed_lazy_checkpoint!, _typed_lazy_rewind!, _tracke acquire!(pool, Float64, 30) acquire!(pool, Float64, 7) end + # Disable safety invalidation to check capacity preservation + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 0 rewind!(pool) # After warm-up, vectors should be properly sized @test length(pool.float64.vectors[1]) >= 101 @test length(pool.float64.vectors[2]) >= 30 @test length(pool.float64.vectors[3]) >= 7 + POOL_SAFETY_LV[] = old_safety end @testset "checkpoint and rewind API" begin @@ -131,7 +135,10 @@ import AdaptiveArrayPools: _typed_lazy_checkpoint!, _typed_lazy_rewind!, _tracke v3 = acquire!(pool, Float64, 200) @test length(v3) == 200 @test length(parent(v3)) >= 200 # Backing vector was resized + old_safety = POOL_SAFETY_LV[] + POOL_SAFETY_LV[] = 0 # disable invalidation so backing vector length is preserved rewind!(pool) + POOL_SAFETY_LV[] = old_safety # Smaller size - cache miss, but no resize needed checkpoint!(pool) diff --git a/test/test_utils.jl b/test/test_utils.jl index 163adf86..b5e1d364 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -1,5 +1,3 @@ -import AdaptiveArrayPools: _validate_pool_return - # Helper macro to capture stdout (must be defined before use) macro capture_out(expr) return quote @@ -18,13 +16,12 @@ macro capture_out(expr) end end -@testset "Utilities and Debugging" begin +@testset "Statistics and Display" begin # ============================================================================== - # Tests for utils.jl: POOL_DEBUG, _validate_pool_return, pool_stats + # Tests for utils.jl: pool_stats, Base.show # ============================================================================== - @testset "pool_stats" begin pool = AdaptiveArrayPool() @@ -95,107 +92,6 @@ end rewind!(pool) end - @testset "POOL_DEBUG flag" begin - old_debug = POOL_DEBUG[] - - # Default is false - POOL_DEBUG[] = false - - # When debug is off, no validation happens even if SubArray is returned - result = @with_pool pool begin - v = acquire!(pool, Float64, 10) - v # Returning SubArray - would be unsafe in real code - end - @test result isa SubArray # No error when debug is off - - POOL_DEBUG[] = old_debug - end - - @testset "POOL_DEBUG with safety violation" begin - old_debug = POOL_DEBUG[] - POOL_DEBUG[] = true - - # Should throw error when returning SubArray with debug on - @test_throws ErrorException @with_pool pool begin - v = acquire!(pool, Float64, 10) - v # Unsafe: returning pool-backed SubArray - end - - # Safe returns should work fine - result = @with_pool pool begin - v = acquire!(pool, Float64, 10) - v .= 1.0 - sum(v) # Safe: returning scalar - end - @test result == 10.0 - - # Returning a copy is also safe - result = @with_pool pool begin - v = acquire!(pool, Float64, 5) - v .= 2.0 - collect(v) # Safe: returning a copy - end - @test result == [2.0, 2.0, 2.0, 2.0, 2.0] - - POOL_DEBUG[] = old_debug - end - - @testset "_validate_pool_return" begin - pool = AdaptiveArrayPool() - checkpoint!(pool) - - # Non-SubArray values pass validation - _validate_pool_return(42, pool) - _validate_pool_return([1, 2, 3], pool) - _validate_pool_return("hello", pool) - _validate_pool_return(nothing, pool) - - # SubArray not from pool passes validation - external_vec = [1.0, 2.0, 3.0] - external_view = view(external_vec, 1:2) - _validate_pool_return(external_view, pool) - - # SubArray from pool fails validation (fixed slot) - pool_view = acquire!(pool, Float64, 10) - @test_throws ErrorException _validate_pool_return(pool_view, pool) - - rewind!(pool) - - # Test with fallback type (others) - checkpoint!(pool) - pool_view_uint8 = acquire!(pool, UInt8, 10) - @test_throws ErrorException _validate_pool_return(pool_view_uint8, pool) - rewind!(pool) - - # DisabledPool always passes - _validate_pool_return(pool_view, DISABLED_CPU) - _validate_pool_return(42, DISABLED_CPU) - end - - @testset "_validate_pool_return with all fixed slots" begin - pool = AdaptiveArrayPool() - checkpoint!(pool) - - # Test each fixed slot type - v_f64 = acquire!(pool, Float64, 5) - v_f32 = acquire!(pool, Float32, 5) - v_i64 = acquire!(pool, Int64, 5) - v_i32 = acquire!(pool, Int32, 5) - v_c64 = acquire!(pool, ComplexF64, 5) - v_c32 = acquire!(pool, ComplexF32, 5) - v_bool = acquire!(pool, Bool, 5) - - @test_throws ErrorException _validate_pool_return(v_f64, pool) - @test_throws ErrorException _validate_pool_return(v_f32, pool) - @test_throws ErrorException _validate_pool_return(v_i64, pool) - @test_throws ErrorException _validate_pool_return(v_i32, pool) - @test_throws ErrorException _validate_pool_return(v_c64, pool) - @test_throws ErrorException _validate_pool_return(v_c32, pool) - @test_throws ErrorException _validate_pool_return(v_bool, pool) - - rewind!(pool) - end - @testset "Base.show for TypedPool & BitTypedPool" begin import AdaptiveArrayPools: TypedPool, BitTypedPool @@ -326,251 +222,4 @@ end @test _count_label(BitTypedPool()) == "bits" end - @testset "_validate_pool_return with N-D arrays" begin - pool = AdaptiveArrayPool() - checkpoint!(pool) - - # N-D ReshapedArray from pool should fail validation (pointer overlap check) - mat = acquire!(pool, Float64, 10, 10) - @test mat isa Base.ReshapedArray{Float64, 2} - @test_throws ErrorException _validate_pool_return(mat, pool) - - # 3D ReshapedArray should also fail - tensor = acquire!(pool, Float64, 5, 5, 5) - @test tensor isa Base.ReshapedArray{Float64, 3} - @test_throws ErrorException _validate_pool_return(tensor, pool) - - rewind!(pool) - end - - @testset "_validate_pool_return with unsafe_acquire!" begin - pool = AdaptiveArrayPool() - checkpoint!(pool) - - # Raw Vector from unsafe_acquire! should fail validation - v = unsafe_acquire!(pool, Float64, 100) - @test v isa Vector{Float64} - @test_throws ErrorException _validate_pool_return(v, pool) - - # Raw Matrix from unsafe_acquire! should fail validation - mat = unsafe_acquire!(pool, Float64, 10, 10) - @test mat isa Matrix{Float64} - @test_throws ErrorException _validate_pool_return(mat, pool) - - # Raw 3D Array from unsafe_acquire! should fail validation - tensor = unsafe_acquire!(pool, Float64, 5, 5, 5) - @test tensor isa Array{Float64, 3} - @test_throws ErrorException _validate_pool_return(tensor, pool) - - rewind!(pool) - end - - @testset "_validate_pool_return with view(unsafe_acquire!)" begin - # Bug fix test: view() wrapped around unsafe_acquire! result - # The parent Vector/Array is created by unsafe_wrap, not the pool's internal vector - # This requires pointer overlap check, not identity check - pool = AdaptiveArrayPool() - checkpoint!(pool) - - # 1D: view(unsafe_acquire!(...), :) should fail validation - v = unsafe_acquire!(pool, Float64, 100) - v_view = view(v, :) - @test v_view isa SubArray - @test parent(v_view) === v # Parent is unsafe_wrap'd Vector, not pool's internal vector - @test_throws ErrorException _validate_pool_return(v_view, pool) - - # Partial view should also fail - v_partial = view(v, 1:50) - @test_throws ErrorException _validate_pool_return(v_partial, pool) - - # 2D: view(unsafe_acquire!(...), :, :) should fail validation - mat = unsafe_acquire!(pool, Float64, 10, 10) - mat_view = view(mat, :, :) - @test mat_view isa SubArray - @test_throws ErrorException _validate_pool_return(mat_view, pool) - - rewind!(pool) - end - - @testset "_validate_pool_return external arrays pass" begin - pool = AdaptiveArrayPool() - checkpoint!(pool) - - # Acquire some memory to populate the pool - _ = acquire!(pool, Float64, 100) - - # External N-D arrays should pass validation - external_mat = zeros(Float64, 10, 10) - external_view = view(external_mat, :, :) - _validate_pool_return(external_view, pool) - _validate_pool_return(external_mat, pool) - - # External 3D array should pass - external_tensor = zeros(Float64, 5, 5, 5) - _validate_pool_return(external_tensor, pool) - - rewind!(pool) - end - - @testset "POOL_DEBUG with N-D arrays" begin - old_debug = POOL_DEBUG[] - POOL_DEBUG[] = true - - # N-D ReshapedArray should throw error when returned - @test_throws ErrorException @with_pool pool begin - mat = acquire!(pool, Float64, 10, 10) - mat # Unsafe: returning pool-backed N-D ReshapedArray - end - - # Raw Array from unsafe_acquire! should throw error when returned - @test_throws ErrorException @with_pool pool begin - mat = unsafe_acquire!(pool, Float64, 10, 10) - mat # Unsafe: returning raw Array backed by pool - end - - # Safe returns should work fine - result = @with_pool pool begin - mat = acquire!(pool, Float64, 10, 10) - mat .= 1.0 - sum(mat) # Safe: returning scalar - end - @test result == 100.0 - - # Returning a copy is also safe - result = @with_pool pool begin - mat = acquire!(pool, Float64, 3, 3) - mat .= 2.0 - collect(mat) # Safe: returning a copy - end - @test result == fill(2.0, 3, 3) - - POOL_DEBUG[] = old_debug - end - - # ============================================================================== - # Tests for _check_bitchunks_overlap (BitArray safety validation) - # ============================================================================== - - @testset "_check_bitchunks_overlap - direct BitArray validation" begin - import AdaptiveArrayPools: _check_bitchunks_overlap - - pool = AdaptiveArrayPool() - checkpoint!(pool) - - # 1D BitVector from pool - should detect overlap - bv = acquire!(pool, Bit, 100) - @test bv isa BitVector - @test_throws ErrorException _check_bitchunks_overlap(bv, pool) - - # N-D BitArray from pool - should detect overlap (shares chunks with pool) - ba = acquire!(pool, Bit, 10, 10) - @test ba isa BitMatrix - @test_throws ErrorException _check_bitchunks_overlap(ba, pool) - - # 3D BitArray from pool - ba3 = acquire!(pool, Bit, 4, 5, 3) - @test ba3 isa BitArray{3} - @test_throws ErrorException _check_bitchunks_overlap(ba3, pool) - - rewind!(pool) - end - - @testset "_check_bitchunks_overlap - external BitArray passes" begin - import AdaptiveArrayPools: _check_bitchunks_overlap - - pool = AdaptiveArrayPool() - checkpoint!(pool) - - # Populate pool with some BitVectors - _ = acquire!(pool, Bit, 100) - _ = acquire!(pool, Bit, 200) - - # External BitVector (not from pool) should pass validation - external_bv = BitVector(undef, 50) - _check_bitchunks_overlap(external_bv, pool) # Should not throw - - # External BitMatrix should pass - external_ba = BitArray(undef, 10, 10) - _check_bitchunks_overlap(external_ba, pool) # Should not throw - - # External 3D BitArray should pass - external_ba3 = BitArray(undef, 5, 5, 5) - _check_bitchunks_overlap(external_ba3, pool) # Should not throw - - rewind!(pool) - end - - @testset "_validate_pool_return with BitArray (via _check_bitchunks_overlap)" begin - pool = AdaptiveArrayPool() - checkpoint!(pool) - - # Direct BitVector from pool fails validation - bv = acquire!(pool, Bit, 100) - @test_throws ErrorException _validate_pool_return(bv, pool) - - # Direct BitMatrix from pool fails validation - ba = acquire!(pool, Bit, 10, 10) - @test_throws ErrorException _validate_pool_return(ba, pool) - - # External BitArray passes validation - external_bv = BitVector(undef, 50) - _validate_pool_return(external_bv, pool) # Should not throw - - rewind!(pool) - end - - @testset "_validate_pool_return with SubArray{BitArray} parent" begin - pool = AdaptiveArrayPool() - checkpoint!(pool) - - # Create a view of a pool BitVector - bv = acquire!(pool, Bit, 100) - bv_view = view(bv, 1:50) - @test bv_view isa SubArray - @test parent(bv_view) isa BitVector - @test_throws ErrorException _validate_pool_return(bv_view, pool) - - # View of external BitVector should pass - external_bv = BitVector(undef, 100) - external_view = view(external_bv, 1:50) - _validate_pool_return(external_view, pool) # Should not throw - - rewind!(pool) - end - - @testset "POOL_DEBUG with BitArray" begin - old_debug = POOL_DEBUG[] - POOL_DEBUG[] = true - - # BitVector from pool should throw error when returned with debug on - @test_throws ErrorException @with_pool pool begin - bv = acquire!(pool, Bit, 100) - bv # Unsafe: returning pool-backed BitVector - end - - # BitMatrix from pool should throw error when returned - @test_throws ErrorException @with_pool pool begin - ba = acquire!(pool, Bit, 10, 10) - ba # Unsafe: returning pool-backed BitMatrix - end - - # Safe returns should work fine - result = @with_pool pool begin - bv = acquire!(pool, Bit, 100) - bv .= true - count(bv) # Safe: returning scalar - end - @test result == 100 - - # Returning a copy is also safe - result = @with_pool pool begin - bv = acquire!(pool, Bit, 5) - bv .= true - copy(bv) # Safe: returning a copy - end - @test result == trues(5) - - POOL_DEBUG[] = old_debug - end - -end # Utilities and Debugging +end # Statistics and Display