Rewrite Triton normalization backward kernel_1 (#499)#546
Open
jlamypoirier wants to merge 3 commits into
Open
Conversation
The backward of `layer_norm`/`rms_norm` trailed apex and torch-compiled by 1.1-1.6x at most hidden sizes, worst on tall-narrow shapes. kernel_1 was the bottleneck and over-produced grad_weight/grad_bias partials. kernel_1: - Decouple the register tile from `n_cols`: a `block_size_row x block_size_col` tile grid-strides the columns, so occupancy no longer collapses as hidden size grows. Rows wider than one chunk use a two-pass scheme (reduce per-row corrections, then re-read to write grad_input and the partials); narrower rows stay single pass. - Bound the partial-reduction work like apex: single pass grid-strides the rows with a program count fixed at `multi_processor_count x 2`, folding many row tiles into one fp32-accumulated partial. The partial buffer kernel_2 reduces is then independent of the row count instead of growing with it (e.g. 4096 -> ~260 rows at 32768x1024), which was the dominant remaining cost. Two waves per SM is the measured knee: one starves grad_input latency-hiding, more only re-inflates kernel_2. Parameter-grad partials reduce in fp32 (the store casts to the buffer dtype); reducing in bf16 degraded the parameter gradients. Result (H100, bf16): tall-narrow shapes go from ~1.3-1.6x behind to parity or better against the fastest alternative (apex_fast / torch-compiled-max), and apex's general fused path is beaten across the board. Wide hidden sizes (two-pass) remain ~1.1-1.3x behind, bounded by the column re-read. Benchmark harness (tools/benchmark/triton_kernels): - Measure backward in isolation with a cold L2 (forward untimed, L2 flushed, then the backward timed), which is training-representative. The prior fwd_bwd-minus-fwd number had a warm-L2 confound: the forward left the saved output partly resident, flattering the backward in a way real training never sees. - Add a per-kernel device-time breakdown so kernel_1 and kernel_2 can be attributed separately. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Offline-sweep the kernel_1 config space (single-pass threshold, block_size_col, block_size_row, num_warps, num_stages) per shape, validated against the prior config, and fold the result into the launch heuristic: - Extend single-pass to wide rows. A wide row can stay single-pass when a warp-saturated one-row tile spanning the whole row fits in registers, which avoids the two-pass column re-read. It fits up to the block-size cap without bias and half of it with bias (bias roughly doubles live registers per element), and wins once there are enough rows to fill the SMs. - Tune the remaining two-pass path (wider column chunk, more warps). - Thread num_stages through the launch. Result (H100, bf16): wide hidden sizes improve up to 1.4x (e.g. rms_norm 4096x8192 121->87us) with no regression elsewhere, by removing the re-read where it can be afforded. The remaining sub-parity shapes are now the narrow ones, where apex's per-hidden-size kernel is hard to match. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Now that the Triton backward is competitive with or faster than apex, drop apex from the `auto` resolution: it picks Triton when Triton is enabled (or required, for zero-centered weights) and falls back to PyTorch otherwise. The apex `fast` and `fused` implementations remain available via explicit selection. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Closes the backward-pass gap in
layer_norm/rms_norm(issue #499). On H100 the Triton backward trailed apex and torch-compiled by ~1.1–1.6× at most hidden sizes — worst on tall-narrow shapes.kernel_1was both the bottleneck and the source of an oversized partial-reduction buffer forkernel_2.kernel_1 rewrite
n_cols. Ablock_size_row × block_size_coltile grid-strides the columns, so occupancy no longer collapses as the hidden size grows. Rows wider than one chunk use a two-pass scheme (reduce the per-row corrections, then re-read to writegrad_inputand the partials); narrower rows stay single-pass with no re-read.fusedpath and the hand-tunedfastpath) does not avoid a second reduction kernel — it bounds the number of partial rows to a small constant via row grid-striding. Previouslykernel_1emitted one partial row perblock_size_rowinput rows, so the bufferkernel_2reduces grew with the row count (4096 rows at 32768×1024 →kernel_2ran at ~10% of bandwidth). Single-pass now grid-strides the rows with a program count fixed atmulti_processor_count × 2, folding many row tiles into one fp32-accumulated partial. The buffer is then independent of the row count. Two waves per SM is the measured knee — one wave starvesgrad_inputlatency-hiding; more only re-inflateskernel_2.Config tuning
An offline sweep of the
kernel_1config space (single-pass threshold,block_size_col,block_size_row,num_warps,num_stages), validated per shape against the prior config, drives the launch heuristic:num_stagesis threaded through the launch.Parameter-grad partials reduce in fp32 (the store casts to the buffer dtype) — reducing in bf16 measurably degrades the parameter gradients.
Results (H100, bf16)
Backward µs vs. the fastest competitor (
apex_fastfor LN,torch-compiled-maxotherwise). Bold = match-or-beat. Triton match-or-beats the best alternative on 9/15 LN and 10/15 RMS shapes; apex's general fused path is beaten everywhere (1.5–3× slower on backward).The wide hidden sizes improved up to 1.4× from the config tuning (e.g. rms_norm 4096×8192 121→87 µs) by removing the re-read where it can be afforded, with no regression elsewhere.
kernel_2is no longer a factor (2–5 µs on single-pass shapes, was up to 47 µs).Remaining sub-parity shapes are a mix of (a) narrow shapes (n_cols ≤ 4096) where
apex_fast/compiled are very tight, and (b)layer_normat the widest hidden size (16384), where the bias term spills the wide single-pass tile so the kernel must fall back to the two-pass re-read. Closing (b) would need a bias-aware shared-memory single-pass; it is the natural follow-up.Forward is at parity across implementations and is unchanged.
Default implementation
Now that the Triton backward is competitive with or faster than apex,
NormalizationConfig.implementation = autoresolves to Triton when Triton is enabled (or required, for zero-centered weights) and falls back to PyTorch otherwise — apex is dropped from theautopath. The apexfast/fusedimplementations stay available via explicit selection.Benchmark harness
tools/benchmark/triton_kernels:fwd_bwd − fwdnumber had a warm-L2 confound: the forward left the savedoutputpartly resident in L2, flattering the backward in a way real training never sees.kernel_1andkernel_2can be attributed separately.Validation
tests/layers/andtests/tools/test_triton_benchmark.py: 733 passed, 27 skipped (H100). Parameter-grad precision is bit-equivalent to the previous kernel (grad_weightrel-rms ≈ 2.8–2.9e-3).Authored by Claude Opus 4.8 (Claude Code).
🤖 Generated with Claude Code