diff --git a/docs/src/architecture/macro-internals.md b/docs/src/architecture/macro-internals.md index d9832374..ebdf4e2b 100644 --- a/docs/src/architecture/macro-internals.md +++ b/docs/src/architecture/macro-internals.md @@ -6,10 +6,23 @@ This page explains the internal mechanics of the `@with_pool` macro for advanced The `@with_pool` macro provides automatic lifecycle management with three key optimizations: -1. **Try-Finally Safety** — Guarantees cleanup even on exceptions +1. **Direct Rewind (no `try-finally`)** — Enables compiler inlining for ~35-73% less overhead 2. **Typed Checkpoint/Rewind** — Only saves/restores used types (~77% faster) 3. **Untracked Acquire Detection** — Safely handles `acquire!` calls outside macro visibility +## Why No `try-finally`? + +Julia's compiler cannot inline functions containing `try-finally`. For `@inline @with_pool` +functions called in hot loops, this means every call pays an exception handler frame cost +(~20-40ns on modern hardware, worse on Julia 1.10 LTS). + +`@with_pool` avoids this by inserting `rewind!` directly at every exit point instead: + +| Macro | Strategy | Inlinable | Use case | +|-------|----------|:---------:|----------| +| `@with_pool` | Direct rewind at each exit point | Yes | Default — hot paths | +| `@safe_with_pool` | `try-finally` wrapper | No | Exception safety required | + ## Basic Lifecycle Flow ``` @@ -28,24 +41,101 @@ The `@with_pool` macro provides automatic lifecycle management with three key op ┌─────────────────────────────────────────────────────────────┐ │ function foo(x) │ │ pool = get_task_local_pool() │ +│ _entry_depth = pool._current_depth │ │ checkpoint!(pool, Float64) # ← Type-specific │ -│ try │ -│ A = _acquire_impl!(pool, Float64, 100) │ -│ B = _similar_impl!(pool, A) │ -│ return sum(A) + sum(B) │ -│ finally │ -│ rewind!(pool, Float64) # ← Type-specific │ +│ │ +│ A = _acquire_impl!(pool, Float64, 100) │ +│ B = _similar_impl!(pool, A) │ +│ _result = sum(A) + sum(B) │ +│ │ +│ # Entry depth guard (cleans up leaked inner scopes) │ +│ while pool._current_depth > _entry_depth + 1 │ +│ rewind!(pool) │ │ end │ +│ rewind!(pool, Float64) # ← Type-specific │ +│ return _result │ │ end │ └─────────────────────────────────────────────────────────────┘ ``` ### Key Points -- **`try-finally`** ensures `rewind!` executes even if an exception occurs +- `rewind!` is inserted at **every exit point**: implicit return, explicit `return`, `break`, `continue` - `acquire!` → `_acquire_impl!` transformation bypasses untracked marking overhead - Type-specific `checkpoint!(pool, Float64)` is ~77% faster than full checkpoint +### Exit Point Coverage + +| Exit type | Handling | +|-----------|----------| +| Implicit return (end of body) | `rewind!` appended before result | +| Explicit `return` | `rewind!` inserted before each `return` statement | +| `break` / `continue` | `rewind!` inserted before each (block form only) | +| `@goto` (internal) | Allowed — stays within pool scope | +| `@goto` (external) | Hard error at macro expansion time | +| Uncaught exception | **Not handled** — use `@safe_with_pool` or `reset!(pool)` | + +## Exception Behavior + +### `@with_pool` (direct rewind) + +Without `try-finally`, uncaught exceptions skip `rewind!`. This is an intentional trade-off: + +```julia +# Uncaught exception → pool state invalid +try + @with_pool pool begin + acquire!(pool, Float64, 10) + error("boom") # rewind! never called + end +catch +end +# pool._current_depth is wrong here → call reset!(pool) +``` + +### Entry Depth Guard (nested catch recovery) + +When an inner `@with_pool` throws and the outer scope catches, the outer's exit +automatically cleans up leaked inner scopes: + +```julia +@with_pool pool function outer() + v = acquire!(pool, Float64, 10) + result = try + @with_pool pool begin + acquire!(pool, UInt8, 5) + error("inner boom") # inner rewind! skipped + end + catch + 42 # pool depth is wrong HERE + end + sum(v) + result + # Entry depth guard runs here → cleans up leaked inner scope + # Own rewind! runs → outer scope cleaned up +end +``` + +!!! warning "Catch block limitation" + Between the inner throw and the outer scope's exit, pool depth is incorrect. + Do not use pool operations inside the `catch` block. + +### `@safe_with_pool` (try-finally) + +For code that may throw and needs guaranteed cleanup: + +```julia +@safe_with_pool pool begin + acquire!(pool, Float64, 10) + risky_operation() # if this throws, rewind! still runs +end +``` + +This prevents inlining but guarantees pool cleanup regardless of exceptions. +Use it when: +- The pool body calls functions that may throw +- You need the pool to remain valid after a caught exception +- A custom macro inside the body might generate hidden `return`/`break`/`continue` + ## Type Extraction: Static Analysis at Compile Time The macro analyzes the AST to extract types used in `acquire!` calls: @@ -60,11 +150,8 @@ end # Generated code uses typed checkpoint/rewind: checkpoint!(pool, Float64, ComplexF64) -try - ... -finally - rewind!(pool, Float64, ComplexF64) -end +# ... body with rewind! at each exit ... +rewind!(pool, Float64, ComplexF64) ``` ### Type Extraction Rules @@ -227,6 +314,7 @@ end # OUTPUT (simplified) function compute(data) pool = get_task_local_pool() + _entry_depth = pool._current_depth # Bitmask subset check: can typed path handle any untracked acquires? if _can_use_typed_path(pool, _tracked_mask_for_types(Float64)) @@ -235,17 +323,22 @@ function compute(data) checkpoint!(pool) # Full checkpoint (safe) end - try - A = _acquire_impl!(pool, Float64, length(data)) - result = helper!(pool, A) - return result - finally - if _can_use_typed_path(pool, _tracked_mask_for_types(Float64)) - rewind!(pool, Float64) # Typed rewind (fast) - else - rewind!(pool) # Full rewind (safe) - end + A = _acquire_impl!(pool, Float64, length(data)) + _result = helper!(pool, A) + + # Entry depth guard: clean up any leaked inner scopes + while pool._current_depth > _entry_depth + 1 + rewind!(pool) end + + # Own scope rewind + if _can_use_typed_path(pool, _tracked_mask_for_types(Float64)) + rewind!(pool, Float64) # Typed rewind (fast) + else + rewind!(pool) # Full rewind (safe) + end + + return _result end ``` @@ -256,6 +349,9 @@ end | `_extract_acquire_types(expr, pool_name)` | AST walk to find types | | `_filter_static_types(types, local_vars)` | Filter out locally-defined types | | `_transform_acquire_calls(expr, pool_name)` | Replace `acquire!` → `_acquire_impl!` | +| `_transform_return_stmts(expr, ...)` | Insert `rewind!` before each `return` | +| `_transform_break_continue(expr, ...)` | Insert `rewind!` before `break`/`continue` | +| `_check_unsafe_goto(expr)` | Hard error on `@goto` that exits pool scope | | `_record_type_touch!(pool, T)` | Record type touch in bitmask for current depth | | `_can_use_typed_path(pool, mask)` | Bitmask subset check for typed vs full path | | `_tracked_mask_for_types(T...)` | Compile-time bitmask for tracked types | diff --git a/src/AdaptiveArrayPools.jl b/src/AdaptiveArrayPools.jl index 575f4cc4..34ad90d5 100644 --- a/src/AdaptiveArrayPools.jl +++ b/src/AdaptiveArrayPools.jl @@ -8,7 +8,7 @@ export acquire_view!, acquire_array! # Explicit naming aliases export zeros!, ones!, trues!, falses!, similar!, reshape!, default_eltype # Convenience functions export unsafe_zeros!, unsafe_ones!, unsafe_similar! # Unsafe convenience functions export Bit # Sentinel type for BitArray (use with acquire!, trues!, falses!) -export @with_pool, @maybe_with_pool +export @with_pool, @maybe_with_pool, @safe_with_pool, @safe_maybe_with_pool export STATIC_POOLING, MAYBE_POOLING, RUNTIME_CHECK export PoolEscapeError, EscapePoint export checkpoint!, rewind!, reset! diff --git a/src/debug.jl b/src/debug.jl index 4469c164..956184d5 100644 --- a/src/debug.jl +++ b/src/debug.jl @@ -223,6 +223,25 @@ _validate_pool_return(val, ::DisabledPool) = nothing # No-op fallback for pool types without specific validation (overridden by CUDA extension) _validate_pool_return(val, ::AbstractArrayPool) = nothing +# ============================================================================== +# Leaked Scope Warning (direct-rewind path, RUNTIME_CHECK >= 1) +# ============================================================================== +# +# Detects when entry depth guard fires (inner scope didn't rewind properly). +# @noinline to keep it out of the inlined hot path — only called on error. + +@noinline function _warn_leaked_scope(pool::AbstractArrayPool, entry_depth::Int) + return @error( + "Leaked @with_pool scope detected! " * + "Pool depth is $(pool._current_depth), expected $(entry_depth + 1). " * + "A macro inside @with_pool may have generated an unseen `return`/`break`, " * + "or an inner scope threw without try-finally protection. " * + "Consider using @safe_with_pool for exception safety.", + current_depth = pool._current_depth, + expected_depth = entry_depth + 1, + ) +end + # ============================================================================== # Poisoning: Fill released vectors with sentinel values (S >= 1) # ============================================================================== diff --git a/src/macros.jl b/src/macros.jl index 9d3efed6..e2233b28 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -257,7 +257,8 @@ end @with_pool :backend expr Executes code within a pooling scope with automatic lifecycle management. -Calls `checkpoint!` on entry and `rewind!` on exit (even if errors occur). +Calls `checkpoint!` on entry and inserts `rewind!` at every exit point +(implicit return, explicit `return`, `break`, `continue`). If `pool_name` is omitted, a hidden variable is used (useful when you don't need to reference the pool directly). @@ -323,6 +324,19 @@ Nested `@with_pool` blocks work correctly - each maintains its own checkpoint. sum(v1) + inner end ``` + +## Exception Behavior + +`@with_pool` does **not** use `try-finally` (for inlining performance). Implications: + +1. **Uncaught exceptions**: If an exception propagates out of all `@with_pool` scopes, + pool state is invalid. Call `reset!(pool)` or use a fresh pool. +2. **Caught exceptions (nested)**: If an inner `@with_pool` throws and an outer scope + catches, the outer scope's exit will clean up leaked inner scopes automatically + (deferred recovery). Do not use pool operations inside the catch block. +3. **`PoolRuntimeEscapeError`**: After this error fires, the pool is poisoned. + Fix the bug in your code and restart. +4. For full exception safety (`try-finally` guarantee), use [`@safe_with_pool`](@ref). """ macro with_pool(pool_name, expr) return _generate_pool_code(pool_name, expr, true; source = __source__) @@ -353,6 +367,9 @@ and `acquire!` falls back to standard allocation. Useful for libraries that want to let users control pooling behavior at runtime. +Like `@with_pool`, does **not** use `try-finally` — see `@with_pool` for exception +behavior details. For exception safety, use [`@safe_maybe_with_pool`](@ref). + ## Function Definition Like `@with_pool`, wrap function definitions: @@ -391,6 +408,74 @@ macro maybe_with_pool(backend::QuoteNode, expr) return _generate_pool_code_with_backend(backend.value, pool_name, expr, false; source = __source__) end +# ============================================================================== +# @safe_with_pool / @safe_maybe_with_pool — Exception-Safe Variants +# ============================================================================== + +""" + @safe_with_pool pool_name expr + @safe_with_pool expr + @safe_with_pool :backend pool_name expr + @safe_with_pool :backend expr + +Like [`@with_pool`](@ref) but uses `try-finally` to guarantee pool cleanup even +when exceptions are thrown. Use this when code inside the pool scope may throw +and you need the pool to remain in a valid state afterward. + +Performance note: `try-finally` prevents Julia's compiler from inlining the pool +scope, resulting in ~35-73% overhead compared to `@with_pool`. Prefer `@with_pool` +for hot paths and use `@safe_with_pool` only when exception safety is required. + +See also: [`@with_pool`](@ref), [`@safe_maybe_with_pool`](@ref) +""" +macro safe_with_pool(pool_name, expr) + return _generate_pool_code(pool_name, expr, true; safe = true, source = __source__) +end + +macro safe_with_pool(expr) + pool_name = gensym(:pool) + return _generate_pool_code(pool_name, expr, true; safe = true, source = __source__) +end + +macro safe_with_pool(backend::QuoteNode, pool_name, expr) + return _generate_pool_code_with_backend(backend.value, pool_name, expr, true; safe = true, source = __source__) +end + +macro safe_with_pool(backend::QuoteNode, expr) + pool_name = gensym(:pool) + return _generate_pool_code_with_backend(backend.value, pool_name, expr, true; safe = true, source = __source__) +end + +""" + @safe_maybe_with_pool pool_name expr + @safe_maybe_with_pool expr + @safe_maybe_with_pool :backend pool_name expr + @safe_maybe_with_pool :backend expr + +Like [`@maybe_with_pool`](@ref) but uses `try-finally` for exception safety. +Combines the runtime pooling toggle of `@maybe_with_pool` with the exception +guarantees of `@safe_with_pool`. + +See also: [`@maybe_with_pool`](@ref), [`@safe_with_pool`](@ref) +""" +macro safe_maybe_with_pool(pool_name, expr) + return _generate_pool_code(pool_name, expr, false; safe = true, source = __source__) +end + +macro safe_maybe_with_pool(expr) + pool_name = gensym(:pool) + return _generate_pool_code(pool_name, expr, false; safe = true, source = __source__) +end + +macro safe_maybe_with_pool(backend::QuoteNode, pool_name, expr) + return _generate_pool_code_with_backend(backend.value, pool_name, expr, false; safe = true, source = __source__) +end + +macro safe_maybe_with_pool(backend::QuoteNode, expr) + pool_name = gensym(:pool) + return _generate_pool_code_with_backend(backend.value, pool_name, expr, false; safe = true, source = __source__) +end + # ============================================================================== # Internal: DisabledPool Expression Generator # ============================================================================== @@ -561,13 +646,13 @@ end # Internal: Code Generation # ============================================================================== -function _generate_pool_code(pool_name, expr, force_enable; source::Union{LineNumberNode, Nothing} = nothing) +function _generate_pool_code(pool_name, expr, force_enable; safe::Bool = false, source::Union{LineNumberNode, Nothing} = nothing) # Compile-time check: if pooling disabled, use DisabledPool to preserve backend context if !STATIC_POOLING disabled_pool = _disabled_pool_expr(:cpu) if Meta.isexpr(expr, [:function, :(=)]) && _is_function_def(expr) # Function definition: inject local pool = DisabledPool at start of body - return _generate_function_pool_code(pool_name, expr, force_enable, true, :cpu; source) + return _generate_function_pool_code(pool_name, expr, force_enable, true, :cpu; safe, source) else return quote local $(esc(pool_name)) = $disabled_pool @@ -578,74 +663,191 @@ function _generate_pool_code(pool_name, expr, force_enable; source::Union{LineNu # Check if function definition if Meta.isexpr(expr, [:function, :(=)]) && _is_function_def(expr) - return _generate_function_pool_code(pool_name, expr, force_enable, false; source) + return _generate_function_pool_code(pool_name, expr, force_enable, false; safe, source) end # Compile-time escape detection (zero runtime cost) _esc = _check_compile_time_escape(expr, pool_name, source) _esc !== nothing && return :(throw($_esc)) - # Block logic - # Extract types from acquire! calls for optimized checkpoint/rewind - # Only extract types for calls to the target pool (pool_name) + # Block logic — shared with backend-specific code generation + inner = _generate_block_inner(pool_name, expr, safe, source) + + if force_enable + return _wrap_with_dispatch(esc(pool_name), :(get_task_local_pool()), inner) + else + # Split branches completely to avoid Union boxing + enabled_branch = _wrap_with_dispatch(esc(pool_name), :(get_task_local_pool()), inner) + return quote + if $MAYBE_POOLING[] + $enabled_branch + else + # let block isolates scope — prevents user variables from being + # captured by the dispatch closure in the if-branch (Core.Box) + let $(esc(pool_name)) = $DISABLED_CPU + $(esc(expr)) + end + end + end + end +end + +# ============================================================================== +# Internal: Shared Block-Form Inner Body Generator +# ============================================================================== +# +# Shared between _generate_pool_code (CPU) and _generate_pool_code_with_backend. +# Produces the `inner` quote block containing checkpoint → body → validate → rewind. + +""" + _generate_block_inner(pool_name, expr, safe, source) -> Expr + +Generate the inner body for block-form `@with_pool`. Handles both safe (try-finally) +and direct-rewind paths. Used by both CPU and backend-specific code generators. + +Does NOT handle the outer dispatch wrapper or MAYBE_POOLING branching — callers +handle those after receiving the inner body. +""" +function _generate_block_inner(pool_name, expr, safe::Bool, source) + # @goto safety check (direct-rewind path only) + if !safe + _check_unsafe_goto(expr) + end + all_types = _extract_acquire_types(expr, pool_name) local_vars = _extract_local_assignments(expr) static_types, has_dynamic = _filter_static_types(all_types, local_vars) - - # Use typed checkpoint/rewind if all types are static, otherwise fallback to full use_typed = !has_dynamic && !isempty(static_types) - # For typed path: transform acquire! → _acquire_impl! (bypasses type touch recording) - # For dynamic path: keep acquire! untransformed so _record_type_touch! is called - transformed_expr = use_typed ? _transform_acquire_calls(expr, pool_name) : expr - - # Inject borrow callsite recording + return validation. - # Always injected — _runtime_check(pool) gates at runtime (dead-code-eliminated when false). - transformed_expr = _inject_pending_callsite(transformed_expr, pool_name, expr) - transformed_expr = _transform_return_stmts(transformed_expr, pool_name) - + # Generate checkpoint/rewind calls (esc'd, for inner body template) if use_typed checkpoint_call = _generate_typed_checkpoint_call(esc(pool_name), static_types) + rewind_call = _generate_typed_rewind_call(esc(pool_name), static_types) else checkpoint_call = _generate_lazy_checkpoint_call(esc(pool_name)) + rewind_call = _generate_lazy_rewind_call(esc(pool_name)) end - if use_typed - rewind_call = _generate_typed_rewind_call(esc(pool_name), static_types) + transformed_expr = use_typed ? _transform_acquire_calls(expr, pool_name) : expr + transformed_expr = _inject_pending_callsite(transformed_expr, pool_name, expr) + + if safe + transformed_expr = _transform_return_stmts(transformed_expr, pool_name) + return quote + $checkpoint_call + try + local _result = $(esc(transformed_expr)) + if $_RUNTIME_CHECK_REF($(esc(pool_name))) + $_validate_pool_return(_result, $(esc(pool_name))) + end + _result + finally + $rewind_call + end + end else - rewind_call = _generate_lazy_rewind_call(esc(pool_name)) - end + entry_depth_var = gensym(:_entry_depth) + raw_rewind = _generate_raw_rewind_call(pool_name, use_typed, static_types) + raw_guard = _generate_raw_entry_depth_guard(pool_name, entry_depth_var) + + transformed_expr = _transform_return_stmts( + transformed_expr, pool_name; + rewind_call = raw_rewind, + entry_depth_guard = raw_guard + ) + transformed_expr = _transform_break_continue(transformed_expr, raw_rewind, raw_guard) - # Build the inner body (runs inside let-block where pool has concrete type) - inner = quote - $checkpoint_call - try + return quote + local $(esc(entry_depth_var)) = $(esc(pool_name))._current_depth + $checkpoint_call local _result = $(esc(transformed_expr)) if $_RUNTIME_CHECK_REF($(esc(pool_name))) $_validate_pool_return(_result, $(esc(pool_name))) end - _result - finally + if $_RUNTIME_CHECK_REF($(esc(pool_name))) && $(esc(pool_name))._current_depth > $(esc(entry_depth_var)) + 1 + $_WARN_LEAKED_SCOPE_REF($(esc(pool_name)), $(esc(entry_depth_var))) + end + while $(esc(pool_name))._current_depth > $(esc(entry_depth_var)) + 1 + $_REWIND_REF($(esc(pool_name))) + end $rewind_call + _result end end +end - if force_enable - return _wrap_with_dispatch(esc(pool_name), :(get_task_local_pool()), inner) +""" + _generate_function_inner(pool_name, expr, safe, source) + +Shared helper for function-form code generation (both CPU and backend variants). +Like `_generate_block_inner` but does NOT apply `_transform_break_continue` — +`break`/`continue` cannot exit a function scope. +""" +function _generate_function_inner(pool_name, expr, safe::Bool, source) + # @goto safety check (direct-rewind path only) + if !safe + _check_unsafe_goto(expr) + end + + all_types = _extract_acquire_types(expr, pool_name) + local_vars = _extract_local_assignments(expr) + static_types, has_dynamic = _filter_static_types(all_types, local_vars) + use_typed = !has_dynamic && !isempty(static_types) + + # Generate checkpoint/rewind calls (esc'd, for inner body template) + if use_typed + checkpoint_call = _generate_typed_checkpoint_call(esc(pool_name), static_types) + rewind_call = _generate_typed_rewind_call(esc(pool_name), static_types) else - # Split branches completely to avoid Union boxing - enabled_branch = _wrap_with_dispatch(esc(pool_name), :(get_task_local_pool()), inner) + checkpoint_call = _generate_lazy_checkpoint_call(esc(pool_name)) + rewind_call = _generate_lazy_rewind_call(esc(pool_name)) + end + + transformed_expr = use_typed ? _transform_acquire_calls(expr, pool_name) : expr + transformed_expr = _inject_pending_callsite(transformed_expr, pool_name, expr) + + if safe + transformed_expr = _transform_return_stmts(transformed_expr, pool_name) return quote - if $MAYBE_POOLING[] - $enabled_branch - else - # let block isolates scope — prevents user variables from being - # captured by the dispatch closure in the if-branch (Core.Box) - let $(esc(pool_name)) = $DISABLED_CPU - $(esc(expr)) + $checkpoint_call + try + local _result = $(esc(transformed_expr)) + if $_RUNTIME_CHECK_REF($(esc(pool_name))) + $_validate_pool_return(_result, $(esc(pool_name))) end + _result + finally + $rewind_call end end + else + entry_depth_var = gensym(:_entry_depth) + raw_rewind = _generate_raw_rewind_call(pool_name, use_typed, static_types) + raw_guard = _generate_raw_entry_depth_guard(pool_name, entry_depth_var) + + # Function form: transform returns with rewind, but NO break/continue transform + transformed_expr = _transform_return_stmts( + transformed_expr, pool_name; + rewind_call = raw_rewind, + entry_depth_guard = raw_guard + ) + + return quote + local $(esc(entry_depth_var)) = $(esc(pool_name))._current_depth + $checkpoint_call + local _result = $(esc(transformed_expr)) + if $_RUNTIME_CHECK_REF($(esc(pool_name))) + $_validate_pool_return(_result, $(esc(pool_name))) + end + if $_RUNTIME_CHECK_REF($(esc(pool_name))) && $(esc(pool_name))._current_depth > $(esc(entry_depth_var)) + 1 + $_WARN_LEAKED_SCOPE_REF($(esc(pool_name)), $(esc(entry_depth_var))) + end + while $(esc(pool_name))._current_depth > $(esc(entry_depth_var)) + 1 + $_REWIND_REF($(esc(pool_name))) + end + $rewind_call + _result + end end end @@ -661,12 +863,12 @@ Uses `_get_pool_for_backend(Val{backend}())` for zero-overhead dispatch. Includes type-specific checkpoint/rewind optimization (same as regular @with_pool). """ -function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, force_enable::Bool; source::Union{LineNumberNode, Nothing} = nothing) +function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, force_enable::Bool; safe::Bool = false, source::Union{LineNumberNode, Nothing} = nothing) # Compile-time check: if pooling disabled, use DisabledPool to preserve backend context if !STATIC_POOLING disabled_pool = _disabled_pool_expr(backend) if Meta.isexpr(expr, [:function, :(=)]) && _is_function_def(expr) - return _generate_function_pool_code_with_backend(backend, pool_name, expr, force_enable, true; source) + return _generate_function_pool_code_with_backend(backend, pool_name, expr, force_enable, true; safe, source) else return quote local $(esc(pool_name)) = $disabled_pool @@ -680,7 +882,7 @@ function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, forc disabled_pool = _disabled_pool_expr(backend) # Check if function definition if Meta.isexpr(expr, [:function, :(=)]) && _is_function_def(expr) - return _generate_function_pool_code_with_backend(backend, pool_name, expr, false, false; source) + return _generate_function_pool_code_with_backend(backend, pool_name, expr, false, false; safe, source) end # Compile-time escape detection (zero runtime cost) @@ -688,37 +890,8 @@ function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, forc _esc !== nothing && return :(throw($_esc)) # Block logic with runtime check - all_types = _extract_acquire_types(expr, pool_name) - local_vars = _extract_local_assignments(expr) - static_types, has_dynamic = _filter_static_types(all_types, local_vars) - use_typed = !has_dynamic && !isempty(static_types) - # For typed path: transform acquire! → _acquire_impl! (bypasses type touch recording) - # For dynamic path: keep acquire! untransformed so _record_type_touch! is called - transformed_expr = use_typed ? _transform_acquire_calls(expr, pool_name) : expr - transformed_expr = _inject_pending_callsite(transformed_expr, pool_name, expr) - transformed_expr = _transform_return_stmts(transformed_expr, pool_name) + inner = _generate_block_inner(pool_name, expr, safe, source) pool_getter = :($_get_pool_for_backend($(Val{backend}()))) - - if use_typed - checkpoint_call = _generate_typed_checkpoint_call(esc(pool_name), static_types) - rewind_call = _generate_typed_rewind_call(esc(pool_name), static_types) - else - checkpoint_call = _generate_lazy_checkpoint_call(esc(pool_name)) - rewind_call = _generate_lazy_rewind_call(esc(pool_name)) - end - - inner = quote - $checkpoint_call - try - local _result = $(esc(transformed_expr)) - if $_RUNTIME_CHECK_REF($(esc(pool_name))) - $_validate_pool_return(_result, $(esc(pool_name))) - end - _result - finally - $rewind_call - end - end enabled_branch = _wrap_with_dispatch(esc(pool_name), pool_getter, inner; backend) return quote if $MAYBE_POOLING[] @@ -733,54 +906,16 @@ function _generate_pool_code_with_backend(backend::Symbol, pool_name, expr, forc # Check if function definition if Meta.isexpr(expr, [:function, :(=)]) && _is_function_def(expr) - return _generate_function_pool_code_with_backend(backend, pool_name, expr, true, false; source) + return _generate_function_pool_code_with_backend(backend, pool_name, expr, true, false; safe, source) end # Compile-time escape detection (zero runtime cost) _esc = _check_compile_time_escape(expr, pool_name, source) _esc !== nothing && return :(throw($_esc)) - # Block logic: Extract types from acquire! calls for optimized checkpoint/rewind - all_types = _extract_acquire_types(expr, pool_name) - local_vars = _extract_local_assignments(expr) - static_types, has_dynamic = _filter_static_types(all_types, local_vars) - - # Use typed checkpoint/rewind if all types are static, otherwise fallback to full - use_typed = !has_dynamic && !isempty(static_types) - - # For typed path: transform acquire! → _acquire_impl! (bypasses type touch recording) - # For dynamic path: keep acquire! untransformed so _record_type_touch! is called - transformed_expr = use_typed ? _transform_acquire_calls(expr, pool_name) : expr - transformed_expr = _inject_pending_callsite(transformed_expr, pool_name, expr) - transformed_expr = _transform_return_stmts(transformed_expr, pool_name) - - # Use Val{backend}() for compile-time dispatch - fully inlinable + # Block logic (force_enable=true path) + inner = _generate_block_inner(pool_name, expr, safe, source) pool_getter = :($_get_pool_for_backend($(Val{backend}()))) - - if use_typed - checkpoint_call = _generate_typed_checkpoint_call(esc(pool_name), static_types) - else - checkpoint_call = _generate_lazy_checkpoint_call(esc(pool_name)) - end - - if use_typed - rewind_call = _generate_typed_rewind_call(esc(pool_name), static_types) - else - rewind_call = _generate_lazy_rewind_call(esc(pool_name)) - end - - inner = quote - $checkpoint_call - try - local _result = $(esc(transformed_expr)) - if $_RUNTIME_CHECK_REF($(esc(pool_name))) - $_validate_pool_return(_result, $(esc(pool_name))) - end - _result - finally - $rewind_call - end - end return _wrap_with_dispatch(esc(pool_name), pool_getter, inner; backend) end @@ -788,13 +923,13 @@ end _generate_function_pool_code_with_backend(backend, pool_name, func_def, force_enable, disable_pooling) Generate function code for a specific backend (e.g., :cuda). -Wraps the function body with pool getter, checkpoint, try-finally, rewind. +Wraps the function body with pool getter, checkpoint, and rewind. When `disable_pooling=true` (STATIC_POOLING=false), generates DisabledPool binding. When `force_enable=true` (@with_pool), always uses the real pool. When `force_enable=false` (@maybe_with_pool), generates MAYBE_POOLING[] runtime check. """ -function _generate_function_pool_code_with_backend(backend::Symbol, pool_name, func_def, force_enable::Bool, disable_pooling::Bool; source::Union{LineNumberNode, Nothing} = nothing) +function _generate_function_pool_code_with_backend(backend::Symbol, pool_name, func_def, force_enable::Bool, disable_pooling::Bool; safe::Bool = false, source::Union{LineNumberNode, Nothing} = nothing) def_head = func_def.head call_expr = func_def.args[1] body = func_def.args[2] @@ -814,44 +949,12 @@ function _generate_function_pool_code_with_backend(backend::Symbol, pool_name, f _esc = _check_compile_time_escape(body, pool_name, source) _esc !== nothing && return :(throw($_esc)) - # Analyze body for types - all_types = _extract_acquire_types(body, pool_name) - local_vars = _extract_local_assignments(body) - static_types, has_dynamic = _filter_static_types(all_types, local_vars) - use_typed = !has_dynamic && !isempty(static_types) - - # For typed path: transform acquire! → _acquire_impl! (bypasses type touch recording) - # For dynamic path: keep acquire! untransformed so _record_type_touch! is called - transformed_body = use_typed ? _transform_acquire_calls(body, pool_name) : body - transformed_body = _inject_pending_callsite(transformed_body, pool_name, body) - transformed_body = _transform_return_stmts(transformed_body, pool_name) + # Function body inner — no break/continue transform (can't break out of a function) + inner = _generate_function_inner(pool_name, body, safe, source) # Use Val{backend}() for compile-time dispatch pool_getter = :($_get_pool_for_backend($(Val{backend}()))) - if use_typed - checkpoint_call = _generate_typed_checkpoint_call(esc(pool_name), static_types) - rewind_call = _generate_typed_rewind_call(esc(pool_name), static_types) - else - checkpoint_call = _generate_lazy_checkpoint_call(esc(pool_name)) - rewind_call = _generate_lazy_rewind_call(esc(pool_name)) - end - - inner = quote - $checkpoint_call - try - local _result = begin - $(esc(transformed_body)) - end - if $_RUNTIME_CHECK_REF($(esc(pool_name))) - $_validate_pool_return(_result, $(esc(pool_name))) - end - _result - finally - $rewind_call - end - end - if force_enable new_body = quote $(_wrap_with_dispatch(esc(pool_name), pool_getter, inner; backend)) @@ -876,7 +979,7 @@ function _generate_function_pool_code_with_backend(backend::Symbol, pool_name, f return Expr(def_head, esc(call_expr), new_body) end -function _generate_function_pool_code(pool_name, func_def, force_enable, disable_pooling, backend::Symbol = :cpu; source::Union{LineNumberNode, Nothing} = nothing) +function _generate_function_pool_code(pool_name, func_def, force_enable, disable_pooling, backend::Symbol = :cpu; safe::Bool = false, source::Union{LineNumberNode, Nothing} = nothing) def_head = func_def.head call_expr = func_def.args[1] body = func_def.args[2] @@ -896,63 +999,13 @@ function _generate_function_pool_code(pool_name, func_def, force_enable, disable _esc = _check_compile_time_escape(body, pool_name, source) _esc !== nothing && return :(throw($_esc)) - # Analyze body for types - all_types = _extract_acquire_types(body, pool_name) - local_vars = _extract_local_assignments(body) - static_types, has_dynamic = _filter_static_types(all_types, local_vars) - use_typed = !has_dynamic && !isempty(static_types) - - # For typed path: transform acquire! → _acquire_impl! (bypasses type touch recording) - # For dynamic path: keep acquire! untransformed so _record_type_touch! is called - transformed_body = use_typed ? _transform_acquire_calls(body, pool_name) : body - # Safety transforms — always inject; dead-code-eliminated at S=0 inside dispatch closure - transformed_body = _inject_pending_callsite(transformed_body, pool_name, body) - transformed_body = _transform_return_stmts(transformed_body, pool_name) - - if use_typed - checkpoint_call = _generate_typed_checkpoint_call(esc(pool_name), static_types) - else - checkpoint_call = _generate_lazy_checkpoint_call(esc(pool_name)) - end - - if use_typed - rewind_call = _generate_typed_rewind_call(esc(pool_name), static_types) - else - rewind_call = _generate_lazy_rewind_call(esc(pool_name)) - end + # Function body inner — no break/continue transform (can't break out of a function) + inner = _generate_function_inner(pool_name, body, safe, source) if force_enable - inner = quote - $checkpoint_call - try - local _result = begin - $(esc(transformed_body)) - end - if $_RUNTIME_CHECK_REF($(esc(pool_name))) - $_validate_pool_return(_result, $(esc(pool_name))) - end - _result - finally - $rewind_call - end - end new_body = _wrap_with_dispatch(esc(pool_name), :(get_task_local_pool()), inner) else disabled_pool = _disabled_pool_expr(backend) - inner = quote - $checkpoint_call - try - local _result = begin - $(esc(transformed_body)) - end - if $_RUNTIME_CHECK_REF($(esc(pool_name))) - $_validate_pool_return(_result, $(esc(pool_name))) - end - _result - finally - $rewind_call - end - end enabled_branch = _wrap_with_dispatch(esc(pool_name), :(get_task_local_pool()), inner) new_body = quote if $MAYBE_POOLING[] @@ -1317,6 +1370,48 @@ function _generate_lazy_rewind_call(pool_expr) return :($_lazy_rewind!($pool_expr)) end +# ============================================================================== +# Internal: Raw (Un-Escaped) Rewind/Guard Generators for Direct-Rewind Path +# ============================================================================== +# +# These generate Expr nodes using raw pool_name symbols (NOT esc'd) and GlobalRef +# function references. They are embedded inside the un-escaped AST processed by +# _transform_return_stmts and _transform_break_continue. The outer esc() applied +# to the full transformed_expr handles escaping for all embedded nodes at once. + +""" + _generate_raw_rewind_call(pool_name, use_typed, static_types) -> Expr + +Generate un-escaped rewind call for embedding in AST transforms. +Uses GlobalRef function references and raw pool_name symbol. +""" +function _generate_raw_rewind_call(pool_name, use_typed::Bool, static_types) + if !use_typed || isempty(static_types) + return Expr(:call, _LAZY_REWIND_REF, pool_name) + else + typed_call = Expr(:call, _REWIND_REF, pool_name, static_types...) + mask_call = Expr(:call, _TRACKED_MASK_REF, static_types...) + selective_call = Expr(:call, _TYPED_LAZY_REWIND_REF, pool_name, mask_call) + condition = Expr(:call, _CAN_USE_TYPED_PATH_REF, pool_name, mask_call) + return Expr(:if, condition, typed_call, selective_call) + end +end + +""" + _generate_raw_entry_depth_guard(pool_name, entry_depth_var) -> Expr + +Generate un-escaped entry depth guard for cleaning up leaked inner scopes. + +Produces: `while pool._current_depth > _entry_depth + 1; rewind!(pool); end` +Uses full `rewind!(pool)` (not typed/lazy) because leaked inner scope may have +touched types outside this scope's static type set. +""" +function _generate_raw_entry_depth_guard(pool_name, entry_depth_var) + depth_access = Expr(:., pool_name, QuoteNode(:_current_depth)) + condition = Expr(:call, :>, depth_access, Expr(:call, :+, entry_depth_var, 1)) + body = Expr(:call, _REWIND_REF, pool_name) + return Expr(:while, condition, body) +end # ============================================================================== # Internal: Acquire Call Transformation @@ -1438,6 +1533,16 @@ end const _RUNTIME_CHECK_REF = GlobalRef(@__MODULE__, :_runtime_check) +# GlobalRefs for direct-rewind path (no try-finally): +# Used by _transform_return_stmts and _transform_break_continue to inject +# rewind calls into the un-escaped AST (outer esc() handles escaping). +const _WARN_LEAKED_SCOPE_REF = GlobalRef(@__MODULE__, :_warn_leaked_scope) +const _REWIND_REF = GlobalRef(@__MODULE__, :rewind!) +const _LAZY_REWIND_REF = GlobalRef(@__MODULE__, :_lazy_rewind!) +const _TYPED_LAZY_REWIND_REF = GlobalRef(@__MODULE__, :_typed_lazy_rewind!) +const _CAN_USE_TYPED_PATH_REF = GlobalRef(@__MODULE__, :_can_use_typed_path) +const _TRACKED_MASK_REF = GlobalRef(@__MODULE__, :_tracked_mask_for_types) + """Set of all transformed `_*_impl!` function names (GlobalRef targets).""" const _IMPL_FUNC_NAMES = Set{Symbol}( [ @@ -1564,30 +1669,47 @@ end const _VALIDATE_POOL_RETURN_REF = GlobalRef(@__MODULE__, :_validate_pool_return) """ - _transform_return_stmts(expr, pool_name) -> Expr + _transform_return_stmts(expr, pool_name; rewind_call=nothing, entry_depth_guard=nothing) -> Expr Walk AST and wrap explicit `return value` statements with escape validation. Generates: `local _ret = value; if _runtime_check(pool) validate(_ret, pool); end; return _ret` +When `rewind_call` and `entry_depth_guard` are provided (direct-rewind path, +`safe=false`), they are inserted after validation but before `return`: + `local _ret = value; validate; entry_depth_guard; rewind_call; return _ret` + +When `nothing` (safe path / try-finally), behavior is unchanged — rewind +happens in the `finally` clause instead. + Does NOT recurse into nested `:function` or `:->` expressions (inner functions have their own `return` semantics). """ -function _transform_return_stmts(expr, pool_name, current_lnn = nothing) +function _transform_return_stmts( + expr, pool_name, current_lnn = nothing; + rewind_call = nothing, + entry_depth_guard = nothing + ) expr isa Expr || return expr - # Don't recurse into nested function definitions (return belongs to inner function) - if expr.head in (:function, :->) + # Don't recurse into nested function definitions or quoted AST + if expr.head in (:function, :->, :quote) return expr end if expr.head == :return && length(expr.args) >= 1 value_expr = expr.args[1] - # Bare return (return nothing) — skip validation + # Bare return (return nothing) — skip validation but still need rewind if value_expr === nothing + if rewind_call !== nothing + return Expr(:block, entry_depth_guard, rewind_call, expr) + end return expr end # Recurse into the value expression first (may contain nested returns in ternary etc.) - value_expr = _transform_return_stmts(value_expr, pool_name, current_lnn) + value_expr = _transform_return_stmts( + value_expr, pool_name, current_lnn; + rewind_call, entry_depth_guard + ) retvar = gensym(:_pool_ret) # Build return-site string for S=1 display (e.g. "file:line\nreturn v") @@ -1616,16 +1738,17 @@ function _transform_return_stmts(expr, pool_name, current_lnn = nothing) Expr(:call, _VALIDATE_POOL_RETURN_REF, retvar, pool_name) end - return Expr( - :block, + # Build statement list: validate → [guard → rewind] → return + stmts = Any[ Expr(:local, Expr(:(=), retvar, value_expr)), - Expr( - :if, - Expr(:call, _RUNTIME_CHECK_REF, pool_name), - validate_expr - ), - Expr(:return, retvar) - ) + Expr(:if, Expr(:call, _RUNTIME_CHECK_REF, pool_name), validate_expr), + ] + if rewind_call !== nothing + push!(stmts, entry_depth_guard) + push!(stmts, rewind_call) + end + push!(stmts, Expr(:return, retvar)) + return Expr(:block, stmts...) end # For blocks, track LineNumberNodes @@ -1637,17 +1760,133 @@ function _transform_return_stmts(expr, pool_name, current_lnn = nothing) lnn = arg push!(new_args, arg) else - push!(new_args, _transform_return_stmts(arg, pool_name, lnn)) + push!( + new_args, _transform_return_stmts( + arg, pool_name, lnn; + rewind_call, entry_depth_guard + ) + ) end end return Expr(:block, new_args...) end # Other expressions: recurse with current_lnn - new_args = Any[_transform_return_stmts(arg, pool_name, current_lnn) for arg in expr.args] + new_args = Any[ + _transform_return_stmts( + arg, pool_name, current_lnn; + rewind_call, entry_depth_guard + ) for arg in expr.args + ] + return Expr(expr.head, new_args...) +end + +# ============================================================================== +# Internal: Break/Continue Transformation (Direct-Rewind Path) +# ============================================================================== +# +# For block-form @with_pool (NOT function form), `break` and `continue` at the +# pool scope level exit the pool scope (the block is inside a loop). Without +# try-finally, we must insert rewind before these statements. +# +# The walker SKIPS :for/:while bodies — break/continue inside nested loops +# belong to those loops, not the pool scope. Also skips :function/:-> bodies. + +""" + _transform_break_continue(expr, rewind_call, entry_depth_guard) -> Expr + +Walk AST and insert entry depth guard + rewind before `break`/`continue` statements +that would exit the pool scope. Only used for block-form `@with_pool` (not function form). + +Skips `:for`, `:while` bodies (break/continue there are for those loops). +Skips `:function`, `:->` bodies (inner function scope boundary). +""" +function _transform_break_continue(expr, rewind_call, entry_depth_guard) + expr isa Expr || return expr + + # Don't recurse into nested functions or quoted AST + expr.head in (:function, :->, :quote) && return expr + + # Don't recurse into loop bodies — break/continue there are for those loops + expr.head in (:for, :while) && return expr + + # Transform bare break/continue at pool-block level + if expr.head in (:break, :continue) + return Expr(:block, entry_depth_guard, rewind_call, expr) + end + + # Recurse into other expressions (if, try, let, block, etc.) + new_args = Any[ + _transform_break_continue(arg, rewind_call, entry_depth_guard) + for arg in expr.args + ] return Expr(expr.head, new_args...) end +# ============================================================================== +# Internal: @goto Safety Check (Direct-Rewind Path) +# ============================================================================== + +""" + _collect_local_gotos_and_labels(expr) -> (gotos::Set{Symbol}, labels::Set{Symbol}) + +Walk the body AST and collect all `@goto` target symbols and `@label` names. +Skips `:function`/`:->` bodies (inner functions have their own scope). + +At macro expansion time, `@goto`/`@label` are `:macrocall` nodes, not `:symbolicgoto`. +""" +function _collect_local_gotos_and_labels(expr) + gotos = Set{Symbol}() + labels = Set{Symbol}() + + function walk(node) + node isa Expr || return + + if node.head === :macrocall && length(node.args) >= 3 + name = node.args[1] + target = node.args[3] + if name === Symbol("@goto") && target isa Symbol + push!(gotos, target) + elseif name === Symbol("@label") && target isa Symbol + push!(labels, target) + end + end + + # Skip nested function bodies (separate scope) and quoted AST (not executable here) + node.head in (:function, :->, :quote) && return + + for arg in node.args + walk(arg) + end + return + end + + walk(expr) + return gotos, labels +end + +""" + _check_unsafe_goto(expr) + +Hard error if the body contains any `@goto` that targets a label NOT defined +within the same body. Such jumps would bypass `rewind!` insertion. + +Internal jumps (`@goto label` where `@label label` exists in the body) are safe +and allowed — they don't exit the pool scope. +""" +function _check_unsafe_goto(expr) + gotos, labels = _collect_local_gotos_and_labels(expr) + unsafe = setdiff(gotos, labels) + return if !isempty(unsafe) + targets = join(unsafe, ", ") + error( + "Pool scope: @goto to external label(s) ($targets) detected. " * + "This would bypass rewind! and corrupt pool state. " * + "Use the @safe_* variant (e.g., @safe_with_pool) for @goto across pool boundaries." + ) + end +end + # ============================================================================== # Internal: Compile-Time Escape Detection # ============================================================================== diff --git a/test/cuda/test_extension.jl b/test/cuda/test_extension.jl index ba6c8df4..8b1b78a9 100644 --- a/test/cuda/test_extension.jl +++ b/test/cuda/test_extension.jl @@ -415,7 +415,7 @@ end reset!(pool) try - @with_pool :cuda p begin + @safe_with_pool :cuda p begin acquire!(p, Float32, 100) @test p.float32.n_active == 1 error("Intentional error") @@ -446,6 +446,84 @@ end end @test result == 200.0f0 end + + # ================================================================== + # Direct-rewind path: CUDA pool runtime verification + # (Macro AST logic tested on CPU; here we verify CUDA rewind!/checkpoint!) + # ================================================================== + + @testset "Direct rewind: explicit return" begin + @with_pool :cuda pool function cuda_early_return(flag) + v = acquire!(pool, Float32, 10) + v .= 1.0f0 + if flag + return sum(v) + end + v .= 2.0f0 + sum(v) + end + + @test cuda_early_return(true) == 10.0f0 + @test cuda_early_return(false) == 20.0f0 + @test get_task_local_cuda_pool()._current_depth == 1 + end + + @testset "Direct rewind: break/continue in loop" begin + pool = get_task_local_cuda_pool() + reset!(pool) + + total = 0.0f0 + for i in 1:5 + @with_pool :cuda p begin + v = acquire!(p, Float32, 3) + v .= Float32(i) + if i == 3 + continue + end + total += sum(v) + end + end + @test total == 3.0f0 * (1 + 2 + 4 + 5) # skip i=3 + @test pool._current_depth == 1 + end + + @testset "Direct rewind: nested catch recovery (entry depth guard)" begin + reset!(get_task_local_cuda_pool()) + + @with_pool :cuda pool function cuda_outer_catches() + v = acquire!(pool, Float32, 10) + v .= 1.0f0 + result = try + @with_pool :cuda pool begin + acquire!(pool, Float64, 5) + error("boom") + end + catch + 42 + end + sum(v) + result + end + + @test cuda_outer_catches() == 52.0f0 + @test get_task_local_cuda_pool()._current_depth == 1 + end + + @testset "Uncaught exception corrupts CUDA pool (documented)" begin + pool = get_task_local_cuda_pool() + reset!(pool) + + try + @with_pool :cuda p begin + acquire!(p, Float32, 10) + error("uncaught!") + end + catch + end + + @test pool._current_depth > 1 # corrupted — expected + reset!(pool) + @test pool._current_depth == 1 + end end @testset "Acquire API" begin diff --git a/test/test_backend_macro_expansion.jl b/test/test_backend_macro_expansion.jl index 4440761f..74e1431f 100644 --- a/test/test_backend_macro_expansion.jl +++ b/test/test_backend_macro_expansion.jl @@ -31,9 +31,9 @@ @test occursin("checkpoint!", expr_str) @test occursin("rewind!", expr_str) - # Should have try-finally - @test occursin("try", expr_str) - @test occursin("finally", expr_str) + # Direct-rewind path: NO try-finally, uses entry depth guard + @test !occursin("finally", expr_str) + @test occursin("_current_depth", expr_str) end @testset "Different backends" begin @@ -157,8 +157,9 @@ @test occursin("_get_pool_for_backend", body_str) @test occursin("checkpoint!", body_str) - @test occursin("try", body_str) - @test occursin("finally", body_str) + # Direct-rewind path: no try-finally, uses entry depth guard + @test !occursin("finally", body_str) + @test occursin("_current_depth", body_str) @test occursin("rewind!", body_str) end @@ -408,13 +409,13 @@ v = acquire!(pool, Float64, 10) end - # Both should have checkpoint/rewind/try-finally + # Both should have checkpoint/rewind with direct-rewind path (no try-finally) for expr in [expr_regular, expr_backend] expr_str = string(expr) @test occursin("checkpoint!", expr_str) @test occursin("rewind!", expr_str) - @test occursin("try", expr_str) - @test occursin("finally", expr_str) + @test !occursin("finally", expr_str) + @test occursin("_current_depth", expr_str) end end diff --git a/test/test_fallback_reclamation.jl b/test/test_fallback_reclamation.jl index 36e959e7..905b6a48 100644 --- a/test/test_fallback_reclamation.jl +++ b/test/test_fallback_reclamation.jl @@ -665,9 +665,9 @@ const Dual_f2_11 = FakeDual{FakeTag{:f2}, Float64, 11} @test others_n_active(pool, UInt8) == 0 end - @testset "13b. @with_pool exception safety with fallback types" begin + @testset "13b. @safe_with_pool exception safety with fallback types" begin try - @with_pool pool begin + @safe_with_pool pool begin acquire!(pool, UInt8, 10) acquire!(pool, Float16, 20) error("simulated failure") @@ -675,7 +675,7 @@ const Dual_f2_11 = FakeDual{FakeTag{:f2}, Float64, 11} catch end - # After exception + rewind via finally, pool should be clean + # After exception + rewind via try-finally, pool should be clean pool = AdaptiveArrayPools.get_task_local_pool() @test others_n_active(pool, UInt8) == 0 @test others_n_active(pool, Float16) == 0 diff --git a/test/test_macro_expansion.jl b/test/test_macro_expansion.jl index 3e1b7f78..f8b74bf7 100644 --- a/test/test_macro_expansion.jl +++ b/test/test_macro_expansion.jl @@ -26,9 +26,9 @@ @test occursin("checkpoint!", expr_str) @test occursin("rewind!", expr_str) - # Should have try-finally structure - @test occursin("try", expr_str) - @test occursin("finally", expr_str) + # Direct-rewind path: NO try-finally, uses entry depth guard instead + @test !occursin("finally", expr_str) + @test occursin("_current_depth", expr_str) end # Test @maybe_with_pool expansion (has MAYBE_POOLING branch) @@ -324,6 +324,19 @@ end + @testset "@safe_with_pool expansion retains try-finally" begin + expr = @macroexpand @safe_with_pool pool begin + v = acquire!(pool, Float64, 10) + sum(v) + end + + expr_str = string(expr) + + # Safe path must use try-finally (unlike @with_pool which uses direct rewind) + @test occursin("finally", expr_str) + @test !occursin("_current_depth", expr_str) # no entry depth guard + end + end # Macro Expansion Details # ============================================================================== @@ -866,8 +879,10 @@ end expr_str = string(expr) @test occursin("_lazy_rewind!", expr_str) - # Full rewind must NOT appear; selective rewind is the only rewind call - @test !occursin("AdaptiveArrayPools.rewind!", expr_str) + # Entry depth guard uses full rewind! (cold path for leaked inner scopes), + # but the hot-path own-scope rewind uses _lazy_rewind! + # Verify _lazy_rewind! is the primary rewind mechanism + @test count("_lazy_rewind!", expr_str) >= 1 end # ========================================================================= @@ -899,8 +914,8 @@ end # Phase 5: else-branch uses selective rewind @test occursin("_typed_lazy_rewind!", expr_str) - # Full no-arg rewind!(pool) must NOT appear - @test !occursin("AdaptiveArrayPools.rewind!(pool)", expr_str) + # Full rewind!(pool) appears ONLY in the entry depth guard, not as the main rewind path + @test count("AdaptiveArrayPools.rewind!(pool)", expr_str) == 1 # entry depth guard only end end # Dynamic selective mode expansion diff --git a/test/test_macros.jl b/test/test_macros.jl index 6f664580..b2cf0a20 100644 --- a/test/test_macros.jl +++ b/test/test_macros.jl @@ -256,4 +256,372 @@ import AdaptiveArrayPools: checkpoint!, rewind! MAYBE_POOLING[] = true end + # ============================================================================== + # Direct-rewind path tests (no try-finally) + # ============================================================================== + + @testset "Direct rewind: explicit return in @with_pool function" begin + @with_pool pool function early_return_test(flag) + v = acquire!(pool, Float64, 10) + v .= 1.0 + if flag + return sum(v) # rewind should happen before return + end + v .= 2.0 + sum(v) + end + + @test early_return_test(true) == 10.0 + @test early_return_test(false) == 20.0 + + # Pool should be clean after both paths + pool = get_task_local_pool() + @test pool._current_depth == 1 + end + + @testset "Direct rewind: break inside @with_pool block in loop" begin + result = 0.0 + for i in 1:10 + @with_pool pool begin + v = acquire!(pool, Float64, 5) + v .= Float64(i) + result = sum(v) + if i == 3 + break # rewind should happen before break + end + end + end + + @test result == 15.0 # 3 * 5 + pool = get_task_local_pool() + @test pool._current_depth == 1 + end + + @testset "Direct rewind: continue inside @with_pool block in loop" begin + total = 0.0 + for i in 1:5 + @with_pool pool begin + v = acquire!(pool, Float64, 3) + v .= Float64(i) + if i == 3 + continue # rewind should happen before continue + end + total += sum(v) + end + end + + # sum for i=1,2,4,5 → 3*(1+2+4+5) = 36 + @test total == 36.0 + pool = get_task_local_pool() + @test pool._current_depth == 1 + end + + @testset "Direct rewind: nested catch recovery (entry depth guard)" begin + @with_pool pool function outer_catches() + v = acquire!(pool, Float64, 10) + v .= 1.0 + result = try + @with_pool pool begin + w = acquire!(pool, UInt8, 5) + error("boom") # inner scope leaks + end + catch + 42 + end + sum(v) + result + end + + @test outer_catches() == 52.0 # 10.0 + 42 + pool = get_task_local_pool() + @test pool._current_depth == 1 + end + + @testset "@safe_with_pool preserves try-finally behavior" begin + reset!(get_task_local_pool()) # ensure clean state + try + @safe_with_pool pool begin + acquire!(pool, Float64, 10) + error("simulated failure") + end + catch + end + + # try-finally guarantees cleanup even after exception + pool = get_task_local_pool() + @test pool._current_depth == 1 + end + + @testset "@safe_maybe_with_pool preserves try-finally behavior" begin + reset!(get_task_local_pool()) # ensure clean state + try + @safe_maybe_with_pool pool begin + acquire!(pool, Float64, 10) + error("simulated failure") + end + catch + end + + pool = get_task_local_pool() + @test pool._current_depth == 1 + end + + # ============================================================================== + # @goto safety checks + # ============================================================================== + + @testset "@goto safety in @with_pool" begin + # Internal @goto/@label: allowed (both are inside the pool body) + @testset "Internal @goto is allowed" begin + result = @with_pool pool begin + x = acquire!(pool, Float64, 10) + x .= 1.0 + s = sum(x) + if s < 100 + @goto done + end + s *= 2 + @label done + s + end + @test result == 10.0 + pool = get_task_local_pool() + @test pool._current_depth == 1 + end + + # External @goto: hard error at macro expansion time + @testset "External @goto is a hard error" begin + @test_throws ErrorException @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + @goto outside + end + end + + # @safe_with_pool allows any @goto (try-finally protects) + @testset "@safe_with_pool allows @goto" begin + expr = @macroexpand @safe_with_pool pool begin + v = acquire!(pool, Float64, 10) + @goto outside + end + @test expr isa Expr # no error thrown + end + + # Multiple internal @goto to different labels + @testset "Multiple internal @goto targets" begin + result = @with_pool pool begin + v = acquire!(pool, Float64, 5) + v .= 1.0 + x = sum(v) + if x > 10.0 + @goto big + elseif x > 0.0 + @goto small + end + @label big + x *= 100 + @label small + x + end + @test result == 5.0 # falls through to @label small + @test get_task_local_pool()._current_depth == 1 + end + + # @goto in function form (not just block) + @testset "External @goto error in function form" begin + @test_throws ErrorException @macroexpand @with_pool pool function goto_func() + v = acquire!(pool, Float64, 10) + @goto escape + end + end + + # @goto inside inner lambda is ignored (separate scope) + @testset "@goto inside inner function is ignored" begin + expr = @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + f = () -> @goto somewhere # inner function — not our scope + sum(v) + end + @test expr isa Expr # no error — inner lambda @goto is skipped + end + + # Mix of internal and external @goto: external wins → error + @testset "Mixed internal+external @goto errors on external" begin + @test_throws ErrorException @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + @goto internal_label + @label internal_label + @goto external_label # this one has no matching @label + end + end + + # Quoted @label must NOT mask real external @goto + @testset "Quoted @label does not mask external @goto" begin + @test_throws ErrorException @macroexpand @with_pool pool begin + q = quote + @label escape # just AST data, not a real label + end + @goto escape # real goto — should be caught as external + end + end + + # Quoted @goto should NOT trigger false-positive error + @testset "Quoted @goto does not trigger false error" begin + expr = @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + q = quote + @goto somewhere # just AST data, harmless + end + sum(v) + end + @test expr isa Expr # no error — quoted @goto is ignored + end + end + + # ============================================================================== + # Exception edge cases (deferred recovery) + # ============================================================================== + + @testset "Exception edge cases" begin + # Multi-level nested throw: 2 inner scopes leak, outer catches + @testset "Multi-level nested leak recovery" begin + reset!(get_task_local_pool()) + @with_pool pool function multi_level_leak() + v = acquire!(pool, Float64, 10) + v .= 1.0 + result = try + @with_pool pool begin + acquire!(pool, UInt8, 5) + @with_pool pool begin + acquire!(pool, Int32, 3) + error("deep boom") # 2 inner scopes leak + end + end + catch + 99 + end + sum(v) + result + end + + @test multi_level_leak() == 109.0 # 10.0 + 99 + @test get_task_local_pool()._current_depth == 1 + end + + # Multi-type cross-scope throw: inner uses different types than outer + @testset "Cross-type throw recovery" begin + reset!(get_task_local_pool()) + @with_pool pool function cross_type_throw() + v = acquire!(pool, Float64, 10) + v .= 2.0 + result = try + @with_pool pool begin + w = acquire!(pool, Int64, 5) # different type from outer + w .= 1 + error("type mismatch boom") + end + catch + 0 + end + sum(v) + result + end + + @test cross_type_throw() == 20.0 # sum(v)=20 + 0 + pool = get_task_local_pool() + @test pool._current_depth == 1 + @test pool.float64.n_active == 0 + end + + # Uncaught exception → pool state is corrupted (documented limitation) + @testset "Uncaught exception corrupts pool (documented)" begin + reset!(get_task_local_pool()) + try + @with_pool pool begin + acquire!(pool, Float64, 10) + error("uncaught!") + end + catch + end + # Without try-finally, rewind! was never called + pool = get_task_local_pool() + @test pool._current_depth > 1 # corrupted — this is expected behavior + + # reset! recovers + reset!(pool) + @test pool._current_depth == 1 + end + + # @safe_with_pool handles uncaught exception correctly + @testset "@safe_with_pool handles uncaught exception" begin + reset!(get_task_local_pool()) + try + @safe_with_pool pool begin + acquire!(pool, Float64, 10) + error("caught by safe!") + end + catch + end + # try-finally guarantees cleanup + @test get_task_local_pool()._current_depth == 1 + end + end + + # ============================================================================== + # Leaked scope warning (RUNTIME_CHECK-gated) + # ============================================================================== + + @testset "Leaked scope warning" begin + import AdaptiveArrayPools: _warn_leaked_scope, _runtime_check + + # 1. Macro expansion includes _warn_leaked_scope call + @testset "Warning present in macro expansion" begin + expr = @macroexpand @with_pool pool begin + v = acquire!(pool, Float64, 10) + sum(v) + end + expr_str = string(expr) + @test occursin("_warn_leaked_scope", expr_str) + end + + # 2. Warning is gated by _runtime_check (returns false at RUNTIME_CHECK=0) + @testset "Warning gated by _runtime_check" begin + pool_s0 = AdaptiveArrayPool{0}() + @test _runtime_check(pool_s0) == false # guard is false → warning never fires + + pool_s1 = AdaptiveArrayPool{1}() + @test _runtime_check(pool_s1) == true # guard is true → warning can fire + end + + # 3. Warning fires on RUNTIME_CHECK=1 pool with simulated leak + @testset "Warning fires on leaked scope (RUNTIME_CHECK=1)" begin + pool = AdaptiveArrayPool{1}() + @test _runtime_check(pool) == true + + # Simulate: checkpoint without matching rewind (leak) + checkpoint!(pool) # depth 1→2 (outer scope) + checkpoint!(pool) # depth 2→3 (inner scope, will "leak") + # skip inner rewind — simulates leaked @with_pool + + entry_depth = 1 # outer scope's entry depth + @test pool._current_depth > entry_depth + 1 # guard condition is true + + # Verify _warn_leaked_scope fires @error log + @test_logs (:error, r"Leaked @with_pool scope") _warn_leaked_scope(pool, entry_depth) + + # Cleanup + reset!(pool) + end + + # 4. Warning does NOT fire when depth is correct + @testset "No warning on normal depth" begin + pool = AdaptiveArrayPool{1}() + checkpoint!(pool) # depth 1→2 + # No leak — depth is entry_depth + 1 + + entry_depth = 1 + @test pool._current_depth == entry_depth + 1 # guard condition is false + # _warn_leaked_scope would NOT be called (the if guard prevents it) + + rewind!(pool) + @test pool._current_depth == 1 + end + end + end # Macro System