[Feat]:Support DPace#1724
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds D-PACE (Dynamic Position-Aware Cross-Entropy) as a new DFlash training loss objective. Two new ChangesD-PACE Loss Objective for DFlash
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 inconclusive)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1724 +/- ##
===========================================
- Coverage 77.12% 57.58% -19.55%
===========================================
Files 511 510 -1
Lines 56247 57433 +1186
===========================================
- Hits 43381 33070 -10311
- Misses 12866 24363 +11497
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Harness. 🚀 New features to boost your workflow:
|
|
/claude review |
|
/claude review |
1 similar comment
|
/claude review |
There was a problem hiding this comment.
Claude review passed — no blocking issues found.
Summary (CRITICAL: 0, IMPORTANT: 0, SUGGESTION: 3)
The D-PACE implementation is correct and well-scoped:
- Algorithm matches paper Eq.7-8: smoothing
q~_i = (1-α)q_i + α, prefix-product, suffix-sum (reverse-cumsum-reverse), all underno_gradand explicitly detached. - Opt-in (default
dflash_loss_objective='decay'preserves prior behavior) — no backward-compat concern. - Validation lives in
DFlashModel.modify()so bad configs fail at convert-time, not deep in the training loop. - Tests cover formula correctness, detachment, monotonicity, smoothing floor, error paths, and
mtsp.convertwiring. - Position-0 (anchor) is correctly excluded from D-PACE weights via
[..., 1:]slicing. - No mode-state schema or export-path changes — purely a training-loss feature.
Inline suggestions (non-blocking):
- CE on the predicted positions is computed twice in the non-KD path (once for
confidences, once for the loss). Reusing one computation would close most of the documented ~2.3% overhead. _dpace_position_weightsacceptsalpha=0(silently zero weights) while its docstring claims(0, 1]— tighten the runtime check to match.- When
dflash_loss_objective='dpace'anddflash_loss_decay_factoris non-default, the decay value is silently ignored. Alogger.warningwould surface the misconfiguration since the default recipe already setsdecay_factor: 4.0.
LGTM.
|
/claude review |
There was a problem hiding this comment.
Claude review summary
Counts: CRITICAL: 0, IMPORTANT: 0, SUGGESTION: 3
Scope verified
- Algorithm correctness: D-PACE Eq.7 smoothing (
q~ = (1-α)q + α) and Eq.8 weighting (w_j = Σ_{m≥j} ∏_{i≤m} q~) are implemented correctly viaflip(cumsum(flip(cumprod(...)))). Hand-checked: withconf=[0.8, 0.5],α=0.5, the helper returns[1.575, 0.675]as expected.q_i = exp(-CE)is the standard differentiable confidence surrogate. The[..., 1:]slice correctly excludes the anchor at slot 0, and the dpace weights are carved intodpace[..., 1:]so slot 0 retainsweight_masksemantics.block_size > 1guard prevents a degenerate empty cumprod. - Gradient flow: weights are computed under
torch.no_grad()and detached at the helper boundary; the CE used for confidences is.detach()-ed when shared with the loss path. No spurious gradient through the weights (matches paper Eq.9). - fp32 promotion: smoothing and cumprod run in
float()before being cast back, which is the right call for a chain ofblock_size-deep multiplications. - CE reuse: the new
loss_per_token = F.cross_entropy(...)hoist saves one CE pass on the non-KD path. Correctly skipped whenbase_logits is not None(KD path) where a separateno_gradCE is used for confidences only. - Mode/state and backward compat: new fields are vanilla Pydantic fields on
DFlashConfig; oldmodelopt_statecheckpoints round-trip cleanly because Pydantic fills defaults for the missing keys, anddflash_loss_objective='decay'reproduces prior behaviour byte-for-byte.convert_to_dflash_model/restore_dflash_modelalready share a single path, so the new attributes flow through both. - Mutual exclusion: warn-then-ignore (rather than hard error) when
dflash_loss_decay_factor > 0is set alongsidedpaceis reasonable and documented.
Findings (all SUGGESTION, none blocking)
_compute_loss— confidences at loss-masked slots (notablyorig_loss_mask=0on user-token positions underanswer_only_loss=true) still flow into the dpacecumprod, slightly skewing the weights at neighbouring assistant slots. Mask the smoothed confidences before the cumprod to keep the dynamic weighting computed only over slots whose loss actually contributes.test_weights_match_paper_formula— tautological: it computes the expected output with the same expression as the implementation. Replace with a hand-computed example (e.g.[0.8, 0.5]→[1.575, 0.675]) so a sign-flip orcumprod/cumsumswap would actually fail.DFlashModel.modify— alpha range check is gated ondflash_loss_objective == "dpace", so a badalphaset under the defaultdecayobjective passes silently and only bites if the objective is later flipped. Move the range check into a Pydanticmodel_validatoronDFlashConfig(or unconditional inmodify) so the value is validated at config-construction time.
Risk assessment
Low risk. Opt-in feature, default unchanged, weights are detached, save/restore is backward-compatible, and the algorithm matches the paper's closed form. The training curve in the PR description supports correctness in practice. The three suggestions are quality polish; none block approval.
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
|
/claude review |
There was a problem hiding this comment.
Claude review passed — no blocking issues found. LGTM
Counts: CRITICAL: 0, IMPORTANT: 0, SUGGESTION: 0
Re-reviewed at HEAD after "address comments". All three suggestions from the prior detailed review are resolved:
- Loss-masked confidences leaking into the cumprod — fixed.
_compute_lossnow passesvalid_mask=weight_mask[..., 1:]into_dpace_position_weights; invalid slots are multiplicative no-ops in the cumprod and are zeroed out of the suffix sum.test_mask_makes_invalid_positions_noopsconfirms equivalence to dropping the slot entirely. - Tautological formula test — fixed.
test_weights_match_hand_computedpins the helper to an independent hand-worked example ([0.8, 0.5], α=0.5 →[1.575, 0.675]). - alpha range check gated on objective — fixed. Validation moved to an unconditional Pydantic
model_validatoronDFlashConfig, so a bad alpha is rejected at config construction regardless of the active objective.
Verified this round:
- Algorithm: Eq.7 smoothing
q~ = (1-α)q + αand Eq.8 weightingw_j = Σ_{m≥j} ∏_{i≤m} q~_i, implemented astotal - exclusive_prefix_sum + C_jover acumprod/cumsum— math checks out, hand-verified against the test oracle. - Gradient flow: weights computed under
no_gradand detached (Eq.9); shared CE is.detach()-ed when reused for confidences. fp32 promotion on the cumprod chain. - Mode/state & backward compat: new fields are plain Pydantic fields; old
modelopt_statecheckpoints round-trip via defaults. The default behavior change (decay → dpace) is intentional and documented in the PR body + CHANGELOG. - CE reuse: non-KD path computes per-token CE once and reuses it; KD path keeps a dedicated
no_gradCE for confidences only.
Risk: Low. Self-contained training-loss feature, weights detached, save/restore backward-compatible, algorithm matches the paper's closed form.
What does this PR do?
Type of change: New feature
Adds the D-PACE (Dynamic Position-Aware Cross-Entropy) loss objective for DFlash speculative-decoding training (arXiv:2605.18810). It replaces the static exponential position decay with per-position CE weights derived from the draft's own confidence
q_i = exp(-CE_i): smoothedq̃_i = (1-α)q_i + α(Eq.7) and weighted by the suffix-sum of prefix productsw_j = Σ_{m≥j} ∏_{i≤m} q̃_i(Eq.8), which directly targets expected accepted block length and shifts signal toward whichever positions currently limit acceptance.Selected via
dflash_loss_objective— D-PACE is now the default (dpace); setdflash_loss_objective: decayto restore the previous static schedule. Smoothing viadflash_dpace_alpha(default 0.5). Weights are detached from the gradient — training-only, ~2.3% overhead, no architecture or inference change. Mutually exclusive withdflash_loss_decay_factor.Usage
Testing
CPU unit tests in
tests/unit/torch/speculative/plugins/test_hf_dflash.py: weights match the paper closed form, are detached and non-increasing, the α smoothing floor keeps later weights non-zero, and convert wires/validates the new fields (rejects bad objective and degenerate α). Training validated on Qwen3-8B (curve below).Before your PR is "Ready for review"
dflash_loss_objective=decay(which reproduces the previous static-decay behavior).CONTRIBUTING.md: N/A (no new dependency)Additional Information
Reference: D-PACE, arXiv:2605.18810. See
examples/speculative_decoding/doc/dflash.mdfor the math and tuning notes.Summary by CodeRabbit
Release Notes
dflash_loss_objective: dpace), configurable viadflash_dpace_alpha(default0.5).dflash_loss_decay_factoris ignored with D-PACE.