Skip to content

[Feat]:Support DPace#1724

Open
h-guo18 wants to merge 3 commits into
mainfrom
haoguo/dpace
Open

[Feat]:Support DPace#1724
h-guo18 wants to merge 3 commits into
mainfrom
haoguo/dpace

Conversation

@h-guo18

@h-guo18 h-guo18 commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Type of change: New feature

Adds the D-PACE (Dynamic Position-Aware Cross-Entropy) loss objective for DFlash speculative-decoding training (arXiv:2605.18810). It replaces the static exponential position decay with per-position CE weights derived from the draft's own confidence q_i = exp(-CE_i): smoothed q̃_i = (1-α)q_i + α (Eq.7) and weighted by the suffix-sum of prefix products w_j = Σ_{m≥j} ∏_{i≤m} q̃_i (Eq.8), which directly targets expected accepted block length and shifts signal toward whichever positions currently limit acceptance.

Selected via dflash_loss_objectiveD-PACE is now the default (dpace); set dflash_loss_objective: decay to restore the previous static schedule. Smoothing via dflash_dpace_alpha (default 0.5). Weights are detached from the gradient — training-only, ~2.3% overhead, no architecture or inference change. Mutually exclusive with dflash_loss_decay_factor.

Usage

# DFlash recipe / training config
dflash:
  dflash_loss_objective: dpace   # default: decay
  dflash_dpace_alpha: 0.5        # smoothing in (0, 1]; stable in [0.3, 0.7]

Testing

CPU unit tests in tests/unit/torch/speculative/plugins/test_hf_dflash.py: weights match the paper closed form, are detached and non-increasing, the α smoothing floor keeps later weights non-zero, and convert wires/validates the new fields (rejects bad objective and degenerate α). Training validated on Qwen3-8B (curve below).

image

Before your PR is "Ready for review"

  • Is this change backward compatible?: ⚠️ Behavior change — D-PACE is now the default objective, so DFlash training loss weighting changes unless you set dflash_loss_objective=decay (which reproduces the previous static-decay behavior).
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A (no new dependency)
  • Did you write any new necessary tests?: ✅
  • Did you update Changelog?: ✅
  • Did you get Claude approval on this PR?: ❌

Additional Information

Reference: D-PACE, arXiv:2605.18810. See examples/speculative_decoding/doc/dflash.md for the math and tuning notes.

Summary by CodeRabbit

Release Notes

  • New Features
    • Added a D-PACE training loss objective for DFlash speculative decoding (dflash_loss_objective: dpace), configurable via dflash_dpace_alpha (default 0.5).
  • Documentation
    • Documented D-PACE’s confidence-derived, dynamically weighted per-position loss behavior (training-only) and noted that dflash_loss_decay_factor is ignored with D-PACE.
  • Bug Fixes
    • Updated DFlash loss to reuse the precomputed per-token cross-entropy in the non-KD path.
  • Tests
    • Added unit tests for D-PACE weight correctness, masking, gradient detachment, monotonicity, smoothing, and config/validation behavior.

@copy-pr-bot

copy-pr-bot Bot commented Jun 15, 2026

Copy link
Copy Markdown

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds D-PACE (Dynamic Position-Aware Cross-Entropy) as a new DFlash training loss objective. Two new DFlashConfig fields (dflash_loss_objective, dflash_dpace_alpha) gate the feature. A new helper _dpace_position_weights computes detached per-position weights; _compute_loss applies them when "dpace" is selected. Input validation is added in DFlashModel.modify(). Tests, documentation, and a changelog entry are included.

Changes

D-PACE Loss Objective for DFlash

Layer / File(s) Summary
DFlashConfig new fields
modelopt/torch/speculative/config.py
Adds dflash_loss_objective (str, default "dpace") and dflash_dpace_alpha (float, default 0.5) to DFlashConfig; extends dflash_loss_decay_factor docstring to note it is only used with the "decay" objective. Imports Literal for loss objective typing.
Validation in DFlashModel.modify()
modelopt/torch/speculative/dflash/dflash_model.py
Adds module-level logging. DFlashModel.modify() reads both new config fields, validates dflash_dpace_alpha range (0, 1] when objective is "dpace", raises ValueError for invalid values, and logs a warning when dflash_loss_decay_factor > 0 is set under the "dpace" objective.
D-PACE position weight computation
modelopt/torch/speculative/plugins/hf_dflash.py
Introduces _dpace_position_weights helper: validates alpha in (0, 1], computes detached per-position weights from draft confidences via asymmetric smoothing, cumulative products, and suffix-sum using cumsum-based formulation. Supports optional valid_mask to zero invalid positions.
Loss computation integration
modelopt/torch/speculative/plugins/hf_dflash.py
_compute_loss reorganizes to compute per-token CE once when base_logits is None. Adds a conditional branch for dflash_loss_objective == "dpace" and block_size > 1: derives draft confidences via exp(-CE) for non-anchor positions, calls _dpace_position_weights under no_grad, and multiplies detached weights into weight_mask. Falls back to exponential decay for "decay". Non-KD loss path reuses precomputed loss_per_token.
Tests, documentation, and changelog
tests/unit/torch/speculative/plugins/test_hf_dflash.py, examples/speculative_decoding/doc/dflash.md, CHANGELOG.rst
TestDPaceWeights class covers formula equivalence, gradient detachment, monotonic non-increasing ordering, valid_mask semantics, smoothing behavior, invalid-alpha errors, and mtsp.convert wiring. Documentation adds parameter table entries and a D-PACE subsection with weight derivation, smoothing equations, detached credit assignment, and dflash_loss_decay_factor interaction. Changelog records the new feature.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Suggested reviewers

  • chadvoegele
  • ChenhanYu
🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 inconclusive)

Check name Status Explanation Resolution
Title check ❓ Inconclusive The title '[Feat]:Support DPace' is vague and generic, using a non-descriptive abbreviation without conveying what D-PACE is or its purpose. Improve the title to be more descriptive, such as 'Add D-PACE dynamic loss objective for DFlash speculative decoding' or 'Implement D-PACE weighted loss for improved DFlash training'.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Docstring Coverage ✅ Passed Docstring coverage is 80.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns detected. PR adds D-PACE loss objective for DFlash training with safe PyTorch operations, proper input validation, no unsafe deserialization, and no new dependencies.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch haoguo/dpace

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions

github-actions Bot commented Jun 15, 2026

Copy link
Copy Markdown
Contributor
PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1724/

Built to branch gh-pages at 2026-06-19 22:33 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@codecov

codecov Bot commented Jun 15, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 65.11628% with 15 lines in your changes missing coverage. Please review.
✅ Project coverage is 57.58%. Comparing base (9f6e8fd) to head (1a79fb7).
⚠️ Report is 29 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/speculative/plugins/hf_dflash.py 57.57% 14 Missing ⚠️
modelopt/torch/speculative/dflash/dflash_model.py 87.50% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1724       +/-   ##
===========================================
- Coverage   77.12%   57.58%   -19.55%     
===========================================
  Files         511      510        -1     
  Lines       56247    57433     +1186     
===========================================
- Hits        43381    33070    -10311     
- Misses      12866    24363    +11497     
Flag Coverage Δ
examples 21.05% <13.88%> (-20.91%) ⬇️
gpu 20.59% <13.88%> (-37.78%) ⬇️
regression 14.71% <44.44%> (+0.08%) ⬆️
unit 54.35% <59.45%> (-0.06%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@h-guo18 h-guo18 marked this pull request as ready for review June 15, 2026 18:59
@h-guo18 h-guo18 requested a review from a team as a code owner June 15, 2026 18:59
@h-guo18 h-guo18 requested a review from yeyu-nvidia June 15, 2026 18:59
@h-guo18

h-guo18 commented Jun 15, 2026

Copy link
Copy Markdown
Contributor Author

/claude review

@ChenhanYu

Copy link
Copy Markdown
Collaborator

/claude review

1 similar comment
@h-guo18

h-guo18 commented Jun 16, 2026

Copy link
Copy Markdown
Contributor Author

/claude review

Comment thread modelopt/torch/speculative/plugins/hf_dflash.py
Comment thread modelopt/torch/speculative/plugins/hf_dflash.py Outdated
Comment thread modelopt/torch/speculative/dflash/dflash_model.py Outdated

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude review passed — no blocking issues found.

Summary (CRITICAL: 0, IMPORTANT: 0, SUGGESTION: 3)

The D-PACE implementation is correct and well-scoped:

  • Algorithm matches paper Eq.7-8: smoothing q~_i = (1-α)q_i + α, prefix-product, suffix-sum (reverse-cumsum-reverse), all under no_grad and explicitly detached.
  • Opt-in (default dflash_loss_objective='decay' preserves prior behavior) — no backward-compat concern.
  • Validation lives in DFlashModel.modify() so bad configs fail at convert-time, not deep in the training loop.
  • Tests cover formula correctness, detachment, monotonicity, smoothing floor, error paths, and mtsp.convert wiring.
  • Position-0 (anchor) is correctly excluded from D-PACE weights via [..., 1:] slicing.
  • No mode-state schema or export-path changes — purely a training-loss feature.

Inline suggestions (non-blocking):

  1. CE on the predicted positions is computed twice in the non-KD path (once for confidences, once for the loss). Reusing one computation would close most of the documented ~2.3% overhead.
  2. _dpace_position_weights accepts alpha=0 (silently zero weights) while its docstring claims (0, 1] — tighten the runtime check to match.
  3. When dflash_loss_objective='dpace' and dflash_loss_decay_factor is non-default, the decay value is silently ignored. A logger.warning would surface the misconfiguration since the default recipe already sets decay_factor: 4.0.

LGTM.

@h-guo18 h-guo18 requested a review from a team as a code owner June 16, 2026 18:15
@h-guo18

h-guo18 commented Jun 16, 2026

Copy link
Copy Markdown
Contributor Author

/claude review

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude review summary

Counts: CRITICAL: 0, IMPORTANT: 0, SUGGESTION: 3

Scope verified

  • Algorithm correctness: D-PACE Eq.7 smoothing (q~ = (1-α)q + α) and Eq.8 weighting (w_j = Σ_{m≥j} ∏_{i≤m} q~) are implemented correctly via flip(cumsum(flip(cumprod(...)))). Hand-checked: with conf=[0.8, 0.5], α=0.5, the helper returns [1.575, 0.675] as expected. q_i = exp(-CE) is the standard differentiable confidence surrogate. The [..., 1:] slice correctly excludes the anchor at slot 0, and the dpace weights are carved into dpace[..., 1:] so slot 0 retains weight_mask semantics. block_size > 1 guard prevents a degenerate empty cumprod.
  • Gradient flow: weights are computed under torch.no_grad() and detached at the helper boundary; the CE used for confidences is .detach()-ed when shared with the loss path. No spurious gradient through the weights (matches paper Eq.9).
  • fp32 promotion: smoothing and cumprod run in float() before being cast back, which is the right call for a chain of block_size-deep multiplications.
  • CE reuse: the new loss_per_token = F.cross_entropy(...) hoist saves one CE pass on the non-KD path. Correctly skipped when base_logits is not None (KD path) where a separate no_grad CE is used for confidences only.
  • Mode/state and backward compat: new fields are vanilla Pydantic fields on DFlashConfig; old modelopt_state checkpoints round-trip cleanly because Pydantic fills defaults for the missing keys, and dflash_loss_objective='decay' reproduces prior behaviour byte-for-byte. convert_to_dflash_model / restore_dflash_model already share a single path, so the new attributes flow through both.
  • Mutual exclusion: warn-then-ignore (rather than hard error) when dflash_loss_decay_factor > 0 is set alongside dpace is reasonable and documented.

Findings (all SUGGESTION, none blocking)

  1. _compute_loss — confidences at loss-masked slots (notably orig_loss_mask=0 on user-token positions under answer_only_loss=true) still flow into the dpace cumprod, slightly skewing the weights at neighbouring assistant slots. Mask the smoothed confidences before the cumprod to keep the dynamic weighting computed only over slots whose loss actually contributes.
  2. test_weights_match_paper_formula — tautological: it computes the expected output with the same expression as the implementation. Replace with a hand-computed example (e.g. [0.8, 0.5][1.575, 0.675]) so a sign-flip or cumprod/cumsum swap would actually fail.
  3. DFlashModel.modify — alpha range check is gated on dflash_loss_objective == "dpace", so a bad alpha set under the default decay objective passes silently and only bites if the objective is later flipped. Move the range check into a Pydantic model_validator on DFlashConfig (or unconditional in modify) so the value is validated at config-construction time.

Risk assessment

Low risk. Opt-in feature, default unchanged, weights are detached, save/restore is backward-compatible, and the algorithm matches the paper's closed form. The training curve in the PR description supports correctness in practice. The three suggestions are quality polish; none block approval.

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude review passed — no blocking issues found. LGTM

Comment thread modelopt/torch/speculative/plugins/hf_dflash.py
Comment thread tests/unit/torch/speculative/plugins/test_hf_dflash.py Outdated
Comment thread modelopt/torch/speculative/dflash/dflash_model.py Outdated
h-guo18 added 2 commits June 19, 2026 22:19
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@h-guo18 h-guo18 requested a review from a team as a code owner June 19, 2026 22:30
@h-guo18 h-guo18 requested a review from cjluo-nv June 19, 2026 22:30
@h-guo18

h-guo18 commented Jun 19, 2026

Copy link
Copy Markdown
Contributor Author

/claude review

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude review passed — no blocking issues found. LGTM

Counts: CRITICAL: 0, IMPORTANT: 0, SUGGESTION: 0

Re-reviewed at HEAD after "address comments". All three suggestions from the prior detailed review are resolved:

  1. Loss-masked confidences leaking into the cumprod — fixed. _compute_loss now passes valid_mask=weight_mask[..., 1:] into _dpace_position_weights; invalid slots are multiplicative no-ops in the cumprod and are zeroed out of the suffix sum. test_mask_makes_invalid_positions_noops confirms equivalence to dropping the slot entirely.
  2. Tautological formula test — fixed. test_weights_match_hand_computed pins the helper to an independent hand-worked example ([0.8, 0.5], α=0.5 → [1.575, 0.675]).
  3. alpha range check gated on objective — fixed. Validation moved to an unconditional Pydantic model_validator on DFlashConfig, so a bad alpha is rejected at config construction regardless of the active objective.

Verified this round:

  • Algorithm: Eq.7 smoothing q~ = (1-α)q + α and Eq.8 weighting w_j = Σ_{m≥j} ∏_{i≤m} q~_i, implemented as total - exclusive_prefix_sum + C_j over a cumprod/cumsum — math checks out, hand-verified against the test oracle.
  • Gradient flow: weights computed under no_grad and detached (Eq.9); shared CE is .detach()-ed when reused for confidences. fp32 promotion on the cumprod chain.
  • Mode/state & backward compat: new fields are plain Pydantic fields; old modelopt_state checkpoints round-trip via defaults. The default behavior change (decay → dpace) is intentional and documented in the PR body + CHANGELOG.
  • CE reuse: non-KD path computes per-token CE once and reuses it; KD path keeps a dedicated no_grad CE for confidences only.

Risk: Low. Self-contained training-loss feature, weights detached, save/restore backward-compatible, algorithm matches the paper's closed form.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants