Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Changelog

**New Features**

- Add the **D-PACE** loss objective for DFlash speculative-decoding training (`arXiv:2605.18810 <https://arxiv.org/abs/2605.18810>`_) and make it the default (``dflash_loss_objective: dpace``). It replaces the static exponential position decay with dynamic, confidence-derived per-position weights that adapt to whichever block positions currently limit acceptance. Smoothing is controlled by ``dflash_dpace_alpha`` (default 0.5); set ``dflash_loss_objective: decay`` to restore the previous static schedule. Training-only and detached from the gradient (no architecture or inference change).
- Add the ``day0-release`` agent skill (``.agents/skills/day0-release/``), a deterministic end-to-end driver that chains the PTQ → evaluation → comparison skills (the evaluation stage deploys the checkpoint itself) with an enforced gate after each stage and returns a publish decision (ACCEPT / REGRESSION / ANOMALOUS / INFEASIBLE). Ships three GPU-free, unit-tested gate scripts (``gate_ptq.py``, ``gate_run.py``, ``gate_compare.py``) that validate checkpoint coverage, evaluation-run completeness, and baseline-vs-candidate accuracy threshold. v1 reports and stops on regression; the recipe-search loop is deferred.
- Add **streaming** speculative-decoding training (EAGLE3 / DFlash): the draft trains on base-model hidden states produced on the fly by a co-located ``vllm serve`` (no disk dump), moved trainer-side over NIXL RDMA, scaling to multi-node (dedicated serve replicas + DDP trainers). New launcher examples for NVFP4 Kimi-K2.5 / K2.6 on GB200/aarch64 under ``tools/launcher/examples/moonshotai/``.
- Add a fused Triton fast path for ``local_hessian`` NVFP4 weight-scale search (the Hessian-weighted FP8-E4M3 scale sweep). For each NVFP4 block it minimizes ``dwᵀ H dw`` over the 126 candidate scales using the per-cin-block local Hessian on tensor cores, replacing the per-weight Python reference sweep — roughly **34x** faster on a single 8192x4096 weight and bit-exact with the reference for fp32/fp16 weights. Used automatically during ``local_hessian`` calibration for both dense and fused-MoE expert weights; falls back to the reference sweep on CPU, when Triton is unavailable, or via ``MODELOPT_NVFP4_TRITON_SWEEP=0``.
Expand Down
31 changes: 31 additions & 0 deletions examples/speculative_decoding/doc/dflash.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ See [`modelopt_recipes/general/speculative_decoding/dflash.yaml`](../../../model
| `dflash.dflash_block_size` | 8 | Block size for parallel prediction |
| `dflash.dflash_num_anchors` | 512 | Random anchor positions per sample (see below) |
| `dflash.dflash_loss_decay_factor` | 4.0 | Exponential decay gamma (0 disables, see below) |
| `dflash.dflash_loss_objective` | `dpace` | Position weighting: `decay` (static) or `dpace` (dynamic, see below) |
| `dflash.dflash_dpace_alpha` | 0.5 | D-PACE smoothing factor in (0, 1]; only used when objective is `dpace` |
| `dflash.dflash_self_logit_distillation` | true | Use target model logits as soft labels (vs hard CE) |
| `dflash.dflash_mask_token_id` | auto | Token ID for masked positions (see note below) |
| `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers |
Expand Down Expand Up @@ -244,6 +246,35 @@ Note: this is different from EAGLE3's `eagle_loss_decay_factor` which multiplies
`alpha^step` across TTT steps. DFlash decay operates within a single block, weighting
early positions higher because they gate acceptance of all later positions.

### D-PACE (Dynamic Position-Aware Cross-Entropy)

**D-PACE** ([arXiv:2605.18810](https://arxiv.org/abs/2605.18810)) is the default position-weighting
objective (`dflash_loss_objective: dpace`). It derives per-position weights from a differentiable
surrogate of expected accepted block length. Where the static decay above uses a fixed schedule,
D-PACE adapts to the draft's own per-position confidence and shifts training signal toward
whichever positions currently limit acceptance as the drafter improves. Set
`dflash_loss_objective: decay` to fall back to the static schedule.

For each block, let `q_i = exp(-CE_i)` be the draft confidence on the target token at
predicted position `i`. D-PACE smooths it (Eq.7) and weights each position by the suffix-sum
of prefix products (Eq.8):

```text
q~_i = (1 - alpha) * q_i + alpha
w_j = sum_{m >= j} prod_{i <= m} q~_i # detached; multiplies the per-token CE
```

The weight factors into the prefix-acceptance probability (`prod_{i<=j} q~_i`) times the
remaining accepted-length value, so it directly targets expected accepted length. The
weights are detached from the gradient — D-PACE only reshapes credit assignment and adds
~2.3% training overhead with no change to the draft architecture or inference.

- `dflash_dpace_alpha` is the asymmetric smoothing floor (`q~_i >= alpha`) that keeps later
weights from vanishing. Stable in `[0.3, 0.7]`; `alpha=0` is rejected (cumulative product
collapses), and `alpha → 1` flattens toward uniform weighting. Default `0.5`.
- D-PACE is mutually exclusive with `dflash_loss_decay_factor`; when objective is `dpace`,
the decay factor is ignored.

### Checkpoint Resume

DFlash supports checkpoint resume transparently. Rotary embeddings are lazily
Expand Down
27 changes: 26 additions & 1 deletion modelopt/torch/speculative/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Configurations for speculative decoding modes."""

from copy import deepcopy
from typing import Literal

from pydantic import model_validator

Expand Down Expand Up @@ -103,7 +104,23 @@ class DFlashConfig(ModeloptBaseConfig):
dflash_loss_decay_factor: float = ModeloptField(
default=0.0,
description="Gamma for exponential loss decay weighting (paper Eq.4). "
"Suggested: 7 for block_size=16, 5 for 10, 4 for 8. 0 disables.",
"Suggested: 7 for block_size=16, 5 for 10, 4 for 8. 0 disables. "
"Only used when dflash_loss_objective='decay'.",
)

dflash_loss_objective: Literal["decay", "dpace"] = ModeloptField(
default="dpace",
description="Block-position loss weighting objective. 'decay' uses the static "
"exponential decay of dflash_loss_decay_factor (DFlash, arXiv:2602.06036 Eq.4). "
"'dpace' uses dynamic, confidence-derived per-position weights "
"(D-PACE, arXiv:2605.18810 Eq.8).",
)

dflash_dpace_alpha: float = ModeloptField(
default=0.5,
description="D-PACE asymmetric smoothing factor alpha in (0, 1] (paper Eq.7). Used only "
"when dflash_loss_objective='dpace'. Stable in [0.3, 0.7]; alpha=0 is degenerate "
"(cumulative product vanishes) and alpha->1 removes the adaptive signal.",
)

dflash_num_anchors: int = ModeloptField(
Expand Down Expand Up @@ -146,6 +163,14 @@ class DFlashConfig(ModeloptBaseConfig):
),
)

@model_validator(mode="after")
def _check_dpace_alpha(self) -> "DFlashConfig":
# Validate at construction regardless of the active objective, so a bad alpha
# is rejected even if it only becomes active after a later objective override.
if not 0.0 < self.dflash_dpace_alpha <= 1.0:
raise ValueError(f"dflash_dpace_alpha must be in (0, 1], got {self.dflash_dpace_alpha}")
return self


class MedusaConfig(ModeloptBaseConfig):
"""Medusa config."""
Expand Down
13 changes: 13 additions & 0 deletions modelopt/torch/speculative/dflash/dflash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@

"""DFlash model to support block-wise parallel speculative decoding."""

import logging

from modelopt.torch.opt.dynamic import DynamicModule

logger = logging.getLogger(__name__)


class DFlashModel(DynamicModule):
"""Base DFlash Model."""
Expand All @@ -31,6 +35,15 @@ def modify(self, config):
self.dflash_block_size = config.dflash_block_size
self.dflash_freeze_base_model = config.dflash_freeze_base_model
self.dflash_loss_decay_factor = config.dflash_loss_decay_factor
self.dflash_loss_objective = config.dflash_loss_objective
self.dflash_dpace_alpha = config.dflash_dpace_alpha
# dflash_dpace_alpha range is validated on DFlashConfig at construction time.
if self.dflash_loss_objective == "dpace" and self.dflash_loss_decay_factor > 0:
logger.warning(
"dflash_loss_decay_factor=%s is ignored when dflash_loss_objective='dpace'; "
"D-PACE derives per-position weights dynamically from draft confidence.",
self.dflash_loss_decay_factor,
)
self.dflash_self_logit_distillation = config.dflash_self_logit_distillation
self.dflash_num_anchors = config.dflash_num_anchors
self.dflash_report_acc = config.dflash_report_acc
Expand Down
84 changes: 79 additions & 5 deletions modelopt/torch/speculative/plugins/hf_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,55 @@
__all__ = ["HFDFlashModel"]


def _dpace_position_weights(
confidences: torch.Tensor, alpha: float, valid_mask: torch.Tensor | None = None
) -> torch.Tensor:
"""Compute detached D-PACE per-position weights from draft confidences.

Derived from D-PACE (arXiv:2605.18810). The paper factorizes the per-position
weight (Fig. 2 / Eq. 8) into a *cumulative confidence* times a *continuation
value*, which is equivalently the suffix sum of the cumulative confidences::

C_j = prod_{i<=j} q~_i # cumulative confidence (Eq. 8)
w_j = sum_{m>=j} C_m # = C_j * continuation value f~_j

Each confidence is asymmetrically smoothed toward 1 (Eq. 7)::

q~_i = (1 - alpha) * q_i + alpha, alpha in (0, 1],

so the floor ``q~_i >= alpha`` keeps every cumulative product (hence every
weight) strictly positive. We evaluate the suffix sum from its definition as
``total - exclusive_prefix_sum`` of ``C`` rather than reversing the tensor.
Positions with ``valid_mask == 0`` are multiplicative no-ops in ``C`` and
contribute nothing to the sum, matching the per-token loss mask. Weights are
detached (Eq. 9): they reweight the cross-entropy without adding gradient.

Args:
confidences: ``[..., L]`` draft confidence ``q_i = exp(-CE_i)`` per position.
alpha: smoothing factor in (0, 1]; raises if outside that range.
valid_mask: optional ``[..., L]`` 0/1 mask; ``None`` treats all positions valid.

Returns:
Detached weights with the same shape and dtype as ``confidences``.
"""
if not 0.0 < alpha <= 1.0:
raise ValueError(f"dflash_dpace_alpha must be in (0, 1], got {alpha}")

with torch.no_grad():
smoothed = alpha + (1.0 - alpha) * confidences.float() # Eq. 7
if valid_mask is not None:
keep = valid_mask.to(torch.bool)
smoothed = torch.where(keep, smoothed, torch.ones_like(smoothed))
cum_conf = torch.cumprod(smoothed, dim=-1) # Eq. 8 cumulative confidence C_j
if valid_mask is not None:
cum_conf = cum_conf * keep.to(cum_conf.dtype)
# Suffix sum w_j = sum_{m>=j} C_m, written as total minus the exclusive
# prefix sum so no axis reversal is needed (Eq. 8).
inclusive = torch.cumsum(cum_conf, dim=-1)
weights = inclusive[..., -1:] - inclusive + cum_conf
return weights.to(dtype=confidences.dtype)


@DFlashDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"})
class HFDFlashModel(DFlashModel):
"""DFlash Model for HuggingFace transformers."""
Expand Down Expand Up @@ -368,14 +417,40 @@ def _compute_loss(

binary_eval_mask = weight_mask.view(-1)

# Optional loss decay
if self.dflash_loss_decay_factor > 0:
flat_logits = logits.view(-1, logits.size(-1))
flat_targets = target_ids.view(-1)

# Non-KD loss is per-token cross-entropy; compute it once (grad enabled) so the
# D-PACE confidences below can reuse it instead of a second CE pass. The KD path
# (base_logits is not None) optimizes KL, so its confidences need a dedicated
# no_grad CE pass.
loss_per_token = None
if base_logits is None:
loss_per_token = F.cross_entropy(flat_logits, flat_targets, reduction="none")

# Block-position loss weighting: dynamic D-PACE weights or static exponential decay.
if self.dflash_loss_objective == "dpace" and block_size > 1:
# Draft confidence q_i = exp(-CE) on the target-selected token, over the
# predicted positions (slot 0 is the given anchor, already masked above).
# Weights are detached (paper Eq.9), so this adds the documented ~2.3%
# training overhead without altering the cross-entropy gradient.
with torch.no_grad():
conf_ce = (
loss_per_token.detach()
if loss_per_token is not None
else F.cross_entropy(flat_logits, flat_targets, reduction="none")
).view(bsz, n_blocks, block_size)
confidences = torch.exp(-conf_ce[..., 1:].float())
dpace = torch.ones_like(weight_mask)
dpace[..., 1:] = _dpace_position_weights(
confidences, self.dflash_dpace_alpha, valid_mask=weight_mask[..., 1:]
)
weight_mask = weight_mask * dpace
Comment thread
h-guo18 marked this conversation as resolved.
elif self.dflash_loss_decay_factor > 0:
k = torch.arange(block_size, device=device).view(1, 1, -1)
decay = torch.exp(-(k - 1).clamp(min=0).float() / self.dflash_loss_decay_factor)
weight_mask = weight_mask * decay
Comment thread
h-guo18 marked this conversation as resolved.

flat_logits = logits.view(-1, logits.size(-1))
flat_targets = target_ids.view(-1)
flat_weights = weight_mask.view(-1)
valid_count = flat_weights.sum() + 1e-6

Expand All @@ -394,7 +469,6 @@ def _compute_loss(
kd_loss = -(target_soft * draft_logsoft).sum(dim=-1)
loss = (kd_loss * flat_weights).sum() / valid_count
else:
loss_per_token = F.cross_entropy(flat_logits, flat_targets, reduction="none")
loss = (loss_per_token * flat_weights).sum() / valid_count

with torch.no_grad():
Expand Down
Loading
Loading