diff --git a/CHANGELOG.rst b/CHANGELOG.rst index f50b0f6e848..41288dbdcbd 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -111,6 +111,7 @@ Changelog **Bug Fixes** +- Support non-gated fused MoE experts in unified HuggingFace export. Nemotron-H MoE models (transformers 5.x ``NemotronHExperts``) store experts as fused 3-D ``up_proj`` / ``down_proj`` parameters with no ``gate_up_proj``; the fused-experts detection previously keyed on ``gate_up_proj``, so these were never wrapped as ``_QuantFusedExperts`` and export raised ``NotImplementedError: MoE model with experts type 'NemotronHExperts' is not supported``. The fused-experts path now also recognizes the non-gated layout (new ``_QuantNonGatedFusedExperts``) and exports a single ``up_proj`` per expert; the gated path is unchanged. - In Megatron-Core only do EP amax sync for routed expert weights if ``sync_expert_weight_amax=True``. Previously EP amax sync would sync routed expert weights across EP ranks even when ``sync_expert_weight_amax`` was False. - Fix Megatron-Core HF importer to load fused ``TELayerNormColumnParallelLinear.layer_norm_weight`` from HF for GPT-family models (Qwen3 etc.) under ``--export-default-te-spec``. Importer now prefers per-context keys ``fused_input_layernorm`` / ``fused_pre_mlp_layernorm`` (fallback ``fused_norm`` for Nemotron-H backward compatibility); ``mcore_qwen.py`` provides the new rules. Without this fix, post-prune MMLU sat at chance. - Fix ONNX AutoCast ``keep_io_types=True`` sanity-check failure (``Unexpected type in I/O tensor ...``) when a network input/output is an empty tensor (a dimension of size 0). Such tensors were "fake-cast" (retyped in place) to the low precision type; because the value-info map aliases the ``graph.input``/``graph.output`` ``ValueInfoProto``, this silently changed the model's I/O type. AutoCast now inserts a real ``Cast`` for protected I/O tensors instead. diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py index e0c78f42def..1ae21e659d2 100755 --- a/modelopt/torch/export/layer_utils.py +++ b/modelopt/torch/export/layer_utils.py @@ -987,8 +987,10 @@ def module_match_name_list(module, name_list): # Structural detection: after _export_fused_experts, fused expert modules # have per-expert submodules with gate_proj/up_proj/down_proj. # Also handles models that originally used this naming (Qwen, DeepSeek, etc.). - if hasattr(module, "experts") and hasattr(module.experts, "gate_up_proj_weight_quantizers"): - return ["gate_up_proj", "down_proj"] + if hasattr(module, "experts"): + first_proj_attr = getattr(module.experts, "_first_proj_attr", "gate_up_proj") + if hasattr(module.experts, f"{first_proj_attr}_weight_quantizers"): + return [first_proj_attr, "down_proj"] if module_match_name_list( module, @@ -1004,7 +1006,7 @@ def module_match_name_list(module, name_list): elif module_match_name_list(module, ["MixtralSparseMoeBlock"]): # Old-style Mixtral (iterable experts) uses w1/w2/w3. # Fused Mixtral (transformers 5.0+) is already handled by the - # structural gate_up_proj_weight_quantizers check above. + # structural first-projection quantizer check above. return ["w1", "w2", "w3"] elif module_match_name_list(module, ["MixtralMoeSparseMoeBlock"]): # Older transformers naming for Mixtral diff --git a/modelopt/torch/export/moe_utils.py b/modelopt/torch/export/moe_utils.py index 574691bebd9..787e173959e 100644 --- a/modelopt/torch/export/moe_utils.py +++ b/modelopt/torch/export/moe_utils.py @@ -59,11 +59,14 @@ def _delete_fused_moe_source_attrs(module: nn.Module) -> None: aliases or via the full unpack/pack path) so the redundant fused form doesn't appear in the exported state_dict alongside the per-expert form. """ + first_proj_attr = getattr(module, "_first_proj_attr", "gate_up_proj") + first_proj_weight_quantizers_attr = f"{first_proj_attr}_weight_quantizers" + first_proj_input_quantizer_attr = f"{first_proj_attr}_input_quantizer" for attr in ( - "gate_up_proj", + first_proj_attr, + first_proj_weight_quantizers_attr, + first_proj_input_quantizer_attr, "down_proj", - "gate_up_proj_weight_quantizers", - "gate_up_proj_input_quantizer", "down_proj_weight_quantizers", "down_proj_input_quantizer", ): @@ -79,20 +82,21 @@ def _export_fused_experts( ) -> None: """Split fused MoE expert weights and export per-expert quantization scales. - Works with any module wrapped by ``_QuantFusedExperts`` — i.e. any HF - transformers 5.0+ fused expert container that stores ``gate_up_proj`` and - ``down_proj`` as 3-D ``nn.Parameter`` tensors with per-expert quantizer - ``nn.ModuleList`` s. + Works with any module wrapped by ``_QuantFusedExperts`` (gated, with a fused + ``gate_up_proj``) or ``_QuantNonGatedFusedExperts`` (non-gated, with a single + ``up_proj`` — e.g. NemotronH). Both store their projections as 3-D + ``nn.Parameter`` tensors with per-expert quantizer ``nn.ModuleList`` s. Steps: - 1. Handle amax fallback for uncalibrated expert input quantizers. - 2. Split fused 3-D weights into per-expert 2-D projections - (``gate_proj``, ``up_proj``, ``down_proj``). + 1. Handle amax fallback for uncalibrated expert weight quantizers. + 2. Split fused 3-D weights into per-expert 2-D projections — gated: + (``gate_proj``, ``up_proj``, ``down_proj``); non-gated: (``up_proj``, + ``down_proj``). 3. Call ``_export_quantized_weight`` on each projection. 4. Register results under the standard naming convention:: - {E}.gate_proj.weight, {E}.gate_proj.weight_scale, ... + {E}.gate_proj.weight, {E}.gate_proj.weight_scale, ... # gated only {E}.up_proj.weight, {E}.up_proj.weight_scale, ... {E}.down_proj.weight, {E}.down_proj.weight_scale, ... @@ -100,7 +104,7 @@ def _export_fused_experts( fused-expert modules share their 3-D source params via HF ``_tied_weights_keys``, the unpacking creates fresh per-expert tensors that break the tie. With ``_moe_tied_cache`` provided (tuple-keyed by - ``(gate_up_proj.data_ptr(), down_proj.data_ptr())``), the alias step + ``(.data_ptr(), down_proj.data_ptr())``), the alias step at the end re-points the per-expert ``weight`` / ``weight_scale`` / ``weight_scale_2`` / ``input_scale`` buffers at a previously-processed module sharing the same source memory. ``_tied_cache`` (int-keyed) is @@ -114,13 +118,20 @@ def _export_fused_experts( from modelopt.torch.quantization.plugins.huggingface import _get_fused_expert_intermediate_dim n = module.num_experts - expert_dim = _get_fused_expert_intermediate_dim(module) + # Gated experts fuse gate+up into ``gate_up_proj`` and must be split on export; + is_gated = getattr(module, "_is_gated", True) + first_proj_attr = getattr(module, "_first_proj_attr", "gate_up_proj") + # Only the gated split needs the per-expert intermediate dim (gate|up boundary). + expert_dim = _get_fused_expert_intermediate_dim(module) if is_gated else None # Capture source tensor identities BEFORE unpacking (the source # attrs are deleted at the end of this function). - _source_key = (module.gate_up_proj.data_ptr(), module.down_proj.data_ptr()) + _source_key = ( + getattr(module, first_proj_attr).data_ptr(), + module.down_proj.data_ptr(), + ) - # Tied-experts fast path: if this exact (gate_up, down) source-tensor pair + # Tied-experts fast path: if this exact (first_proj, down) source-tensor pair # has been processed before, alias all per-expert buffers directly from the # prior module — no unpacking, no per-expert packing, no transient buffers # thrown away. Cache miss falls through to the full unpack/pack below and @@ -133,14 +144,15 @@ def _export_fused_experts( return # 1. Shared input quantizers — one per projection type, shared across all experts. - gate_up_input_q = module.gate_up_proj_input_quantizer + first_proj_input_q = getattr(module, f"{first_proj_attr}_input_quantizer") + first_proj_weight_quantizers = getattr(module, f"{first_proj_attr}_weight_quantizers") down_input_q = module.down_proj_input_quantizer - gate_up = module.gate_up_proj.data + first_proj = getattr(module, first_proj_attr).data # gate_up_proj or up_proj down = module.down_proj.data # 2-3. Split + export each per-expert projection. - fused_dim0 = gate_up.shape[1] # 2 * expert_dim + fused_dim0 = first_proj.shape[1] # gated: 2 * expert_dim; non-gated: expert_dim for idx in range(n): expert = nn.Module() @@ -153,13 +165,19 @@ def _export_fused_experts( # fallback further down would otherwise compute amax independently from # each half — gate's max and up's max generally differ — producing # mismatched weight_scale_2 and garbled MoE output at inference. - gate_up_q = module.gate_up_proj_weight_quantizers[idx] - if getattr(gate_up_q, "is_enabled", False) and ( - not hasattr(gate_up_q, "_amax") - or gate_up_q._amax is None - or torch.all(gate_up_q._amax == 0) + # Non-gated experts have no gate/up fusion, so this shared-amax step is + # skipped — their single up_proj uses the generic per-projection fallback. + first_proj_q = first_proj_weight_quantizers[idx] + if ( + is_gated + and getattr(first_proj_q, "is_enabled", False) + and ( + not hasattr(first_proj_q, "_amax") + or first_proj_q._amax is None + or torch.all(first_proj_q._amax == 0) + ) ): - gate_up_q.amax = gate_up[idx].abs().amax().to(torch.float32) + first_proj_q.amax = first_proj[idx].abs().amax().to(torch.float32) warnings.warn( f"Expert {idx} gate_up_proj weight quantizer was not calibrated " f"(amax missing or zero). Using fused-tensor amax as fallback " @@ -168,22 +186,38 @@ def _export_fused_experts( stacklevel=2, ) - projections = [ - ("gate_proj", gate_up[idx, :expert_dim, :], 0, fused_dim0, True), - ("up_proj", gate_up[idx, expert_dim:, :], expert_dim, fused_dim0, True), - ("down_proj", down[idx], 0, down.shape[1], False), - ] - - for proj_name, weight_slice, fused_start, fused_total, is_gate_up in projections: + if is_gated: + projections = [ + ("gate_proj", first_proj[idx, :expert_dim, :], 0, fused_dim0, True), + ("up_proj", first_proj[idx, expert_dim:, :], expert_dim, fused_dim0, True), + ("down_proj", down[idx], 0, down.shape[1], False), + ] + else: + # Non-gated: the single up_proj maps 1:1 to its weight quantizer, so it + # is exported whole (no dim-0 split, no shared gate/up weight_scale_2). + projections = [ + ("up_proj", first_proj[idx], 0, fused_dim0, True), + ("down_proj", down[idx], 0, down.shape[1], False), + ] + + for ( + proj_name, + weight_slice, + fused_start, + fused_total, + uses_first_proj_quantizers, + ) in projections: w_quantizer_src = ( - module.gate_up_proj_weight_quantizers[idx] - if is_gate_up + first_proj_weight_quantizers[idx] + if uses_first_proj_quantizers else module.down_proj_weight_quantizers[idx] ) - i_quantizer = gate_up_input_q if is_gate_up else down_input_q + i_quantizer = first_proj_input_q if uses_first_proj_quantizers else down_input_q # gate/up share a weight quantizer — clone so each gets independent amax. - w_quantizer = copy.deepcopy(w_quantizer_src) if is_gate_up else w_quantizer_src + w_quantizer = ( + copy.deepcopy(w_quantizer_src) if uses_first_proj_quantizers else w_quantizer_src + ) # For per-channel amax (dim >= 1), proportionally slice dim-0 # to match the split weight. diff --git a/modelopt/torch/export/plugins/vllm_fakequant_hf.py b/modelopt/torch/export/plugins/vllm_fakequant_hf.py index ad0b88f2f78..8883964daac 100644 --- a/modelopt/torch/export/plugins/vllm_fakequant_hf.py +++ b/modelopt/torch/export/plugins/vllm_fakequant_hf.py @@ -161,8 +161,9 @@ def _fakequant_fused_experts_weights( expert) that the base loop skips, leaving the fused 3-D weight unquantized in the export and breaking weight-fold round-trips. """ + first_proj_attr = getattr(module, "_first_proj_attr", "gate_up_proj") for w_attr, q_attr in ( - ("gate_up_proj", "gate_up_proj_weight_quantizers"), + (first_proj_attr, f"{first_proj_attr}_weight_quantizers"), ("down_proj", "down_proj_weight_quantizers"), ): quantizers = getattr(module, q_attr, None) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index ba74480dd7d..2af5f6eab0b 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -1495,7 +1495,7 @@ def sync_tied_input_amax(model: nn.Module) -> int: YOCO-style models). Must run BEFORE per-module export so the merged amax flows into ``input_scale`` derivation. Handles both dense Linears (keyed by ``weight.data_ptr()``) and fused MoE (keyed by - ``(gate_up_proj, down_proj)`` data_ptr tuple). Returns the number of + ``(, down_proj)`` data_ptr tuple). Returns the number of tied groups merged. """ from collections import defaultdict @@ -1503,13 +1503,16 @@ def sync_tied_input_amax(model: nn.Module) -> int: by_dp: dict = defaultdict(list) for _, m in model.named_modules(): # Fused MoE: 3-D source tensors with shared input quantizers + first_proj_attr = getattr(m, "_first_proj_attr", "gate_up_proj") + first_proj = getattr(m, first_proj_attr, None) + first_proj_input_quantizer_attr = f"{first_proj_attr}_input_quantizer" if ( - hasattr(m, "gate_up_proj_input_quantizer") - and hasattr(m, "gate_up_proj") + hasattr(m, first_proj_input_quantizer_attr) + and first_proj is not None and hasattr(m, "down_proj") - and m.gate_up_proj.dim() == 3 + and first_proj.dim() == 3 ): - key = ("moe", m.gate_up_proj.data_ptr(), m.down_proj.data_ptr()) + key = ("moe", first_proj.data_ptr(), m.down_proj.data_ptr()) by_dp[key].append(m) # Dense quantized Linear with an input_quantizer elif ( @@ -1549,7 +1552,8 @@ def _merge(quantizers: list) -> bool: if len(modules) < 2: continue if key[0] == "moe": - for q_name in ("gate_up_proj_input_quantizer", "down_proj_input_quantizer"): + first_proj_attr = getattr(modules[0], "_first_proj_attr", "gate_up_proj") + for q_name in (f"{first_proj_attr}_input_quantizer", "down_proj_input_quantizer"): if _merge([getattr(m, q_name, None) for m in modules]): synced += 1 elif _merge([m.input_quantizer for m in modules]): diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 527c92e93f8..8bc92ed5eb9 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -830,10 +830,12 @@ def _process_quantized_modules( ): sub_module.unpack_weight() - if hasattr(sub_module, "gate_up_proj_weight_quantizers"): - # _QuantFusedExperts uses plural `gate_up_proj_weight_quantizers` (ModuleList), - # which get_quantization_format's singular-weight_quantizer check misses. Handle - # it explicitly before the format gate so fused-experts get split + quantized. + first_proj_attr = getattr(sub_module, "_first_proj_attr", "gate_up_proj") + if hasattr(sub_module, f"{first_proj_attr}_weight_quantizers"): + # _QuantFusedExperts uses plural `_weight_quantizers` + # (ModuleList), which get_quantization_format's singular-weight_quantizer + # check misses. Handle it explicitly before the format gate so fused-experts + # get split + quantized. with fsdp2_aware_weight_update(model, sub_module, reshard=False): _export_fused_experts( sub_module, @@ -937,6 +939,10 @@ def _export_transformers_checkpoint( for _, sub_module in model.named_modules(): if is_moe(sub_module) and hasattr(sub_module, "experts"): expert_linear_names = get_expert_linear_names(sub_module) + first_proj_attr = getattr(sub_module.experts, "_first_proj_attr", "gate_up_proj") + has_fused_experts_quantizers = hasattr( + sub_module.experts, f"{first_proj_attr}_weight_quantizers" + ) for linear_name in expert_linear_names: # Handle DBRX experts specifically if "QuantDbrxExperts" in type(sub_module.experts).__name__: @@ -949,7 +955,7 @@ def _export_transformers_checkpoint( modules=list(linear_modulelist), quantizer_attrs=["input_quantizer"], ) - elif hasattr(sub_module.experts, "gate_up_proj_weight_quantizers"): + elif has_fused_experts_quantizers: # _QuantFusedExperts: amax fallback is handled in _export_fused_experts break elif ( diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index 03bd801387c..b5284c87249 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -80,6 +80,12 @@ def _is_hf_quant_fused_experts_module(module: nn.Module) -> bool: "down_proj_input_quantizer", "down_proj_weight_quantizer", ) +_NON_GATED_FUSED_EXPERTS_REPLAY_QUANTIZER_ATTRS = ( + "up_proj_input_quantizer", + "up_proj_weight_quantizer", + "down_proj_input_quantizer", + "down_proj_weight_quantizer", +) def _get_replay_quantizer_attr(attr_name: str) -> str: @@ -97,7 +103,11 @@ def _get_quantizer_attrs(module: nn.Module) -> tuple[str, ...]: For standard Linear-derived QuantModules, returns the canonical trio. """ if _is_hf_quant_fused_experts_module(module): - return _FUSED_EXPERTS_QUANTIZER_ATTRS + try: + from .plugins.huggingface import _get_fused_experts_quantizer_attr_names + except ImportError: + return _FUSED_EXPERTS_QUANTIZER_ATTRS + return _get_fused_experts_quantizer_attr_names(module) return _STD_QUANTIZER_ATTRS @@ -1524,7 +1534,11 @@ def _cfg_to_dict(v): for pattern in _as_list(search_state.get("disabled_layers")) ) per_module_entries: list[dict] = [] - _per_module_attrs = (*_STD_QUANTIZER_ATTRS, *_FUSED_EXPERTS_REPLAY_QUANTIZER_ATTRS) + _per_module_attrs = ( + *_STD_QUANTIZER_ATTRS, + *_FUSED_EXPERTS_REPLAY_QUANTIZER_ATTRS, + *_NON_GATED_FUSED_EXPERTS_REPLAY_QUANTIZER_ATTRS, + ) # Track global (non per-module) recipe entries. Last recipe wins for each pattern. global_entries: dict[str, dict] = {} diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 53d497d43c9..24984179791 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -720,8 +720,9 @@ def _warn_local_hessian_fallback(name, weight, weight_quantizer, block_size, war def _is_quant_fused_experts(module: nn.Module) -> bool: """Whether ``module`` is a converted HF fused-MoE-experts wrapper with per-expert quantizers.""" + first_proj_attr = getattr(module, "_first_proj_attr", "gate_up_proj") return hasattr(module, "_current_expert_idx") and hasattr( - module, "gate_up_proj_weight_quantizers" + module, f"{first_proj_attr}_weight_quantizers" ) @@ -765,11 +766,12 @@ def _dense_hook(linear, args): handles.append(module.register_forward_pre_hook(_dense_hook)) elif _is_quant_fused_experts(module): with enable_weight_access_and_writeback(module, model, name_to_module): + first_proj_attr = getattr(module, "_first_proj_attr", "gate_up_proj") for weight_name, quantizers_name, input_q_name in ( ( - "gate_up_proj", - "gate_up_proj_weight_quantizers", - "gate_up_proj_input_quantizer", + first_proj_attr, + f"{first_proj_attr}_weight_quantizers", + f"{first_proj_attr}_input_quantizer", ), ("down_proj", "down_proj_weight_quantizers", "down_proj_input_quantizer"), ): diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 97e13f419f9..a6a4c292e68 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -867,20 +867,41 @@ class _QuantFusedExperts(_QuantFunctionalMixin): Limitation: only works when ``experts_implementation="eager"`` (default). ``batched_mm`` / ``grouped_mm`` backends use ``torch.bmm`` / ``torch._grouped_mm`` instead of ``F.linear`` and are not intercepted. + + The non-gated variant (``up_proj`` instead of ``gate_up_proj``, used by + NemotronH) is handled by :class:`_QuantNonGatedFusedExperts` via the + ``_first_proj_attr`` / ``_is_gated`` hooks below; each layout names the + first-projection quantizers after its backing parameter. """ - def _get_expert_idx_from_gate_up(self, weight: torch.Tensor) -> int: - """Recover expert index from a ``gate_up_proj`` weight slice's storage offset. + # Name of the 3-D weight parameter feeding the first ``F.linear`` per expert. + # Gated experts fuse gate+up into ``gate_up_proj``; non-gated experts use a + # single ``up_proj`` (see _QuantNonGatedFusedExperts). + _first_proj_attr = "gate_up_proj" + # Whether the first projection packs a gate half that must be split on export. + _is_gated = True + + @property + def _first_proj_input_quantizer_attr(self) -> str: + return f"{self._first_proj_attr}_input_quantizer" + + @property + def _first_proj_weight_quantizers_attr(self) -> str: + return f"{self._first_proj_attr}_weight_quantizers" + + def _get_expert_idx_from_first_proj(self, weight: torch.Tensor) -> int: + """Recover expert index from a first-projection weight slice's storage offset. - When HF indexes ``gate_up_proj[idx]``, the result is a view sharing the + When HF indexes ``[idx]``, the result is a view sharing the same underlying storage. The offset delta divided by the stride along dim-0 gives the expert index. The invariant breaks if the tensor is ``.contiguous()``-copied or redistributed by certain distributed wrappers (FSDP2, tensor parallel). """ - base_offset = self.gate_up_proj.storage_offset() - stride = self.gate_up_proj.stride(0) + first_proj = getattr(self, self._first_proj_attr) + base_offset = first_proj.storage_offset() + stride = first_proj.stride(0) if stride == 0: return 0 idx = (weight.storage_offset() - base_offset) // stride @@ -892,8 +913,12 @@ def _get_expert_idx_from_gate_up(self, weight: torch.Tensor) -> int: def _setup(self): n = self.num_experts - self.gate_up_proj_input_quantizer = TensorQuantizer() - self.gate_up_proj_weight_quantizers = nn.ModuleList([TensorQuantizer() for _ in range(n)]) + setattr(self, self._first_proj_input_quantizer_attr, TensorQuantizer()) + setattr( + self, + self._first_proj_weight_quantizers_attr, + nn.ModuleList([TensorQuantizer() for _ in range(n)]), + ) self.down_proj_input_quantizer = TensorQuantizer() self.down_proj_weight_quantizers = nn.ModuleList([TensorQuantizer() for _ in range(n)]) @@ -913,10 +938,10 @@ def _quantized_linear(input, weight, bias=None): input = self.down_proj_input_quantizer(input) weight = self.down_proj_weight_quantizers[idx](weight) else: - idx = self._get_expert_idx_from_gate_up(weight) + idx = self._get_expert_idx_from_first_proj(weight) self._current_expert_idx = idx - input = self.gate_up_proj_input_quantizer(input) - weight = self.gate_up_proj_weight_quantizers[idx](weight) + input = getattr(self, self._first_proj_input_quantizer_attr)(input) + weight = getattr(self, self._first_proj_weight_quantizers_attr)[idx](weight) self._down_proj_linear = not self._down_proj_linear return _orig_linear(input, weight, bias) @@ -936,7 +961,7 @@ def iter_weights_for_calibration(self): quantizers without this override. """ for weight_name, quantizers_name in ( - ("gate_up_proj", "gate_up_proj_weight_quantizers"), + (self._first_proj_attr, self._first_proj_weight_quantizers_attr), ("down_proj", "down_proj_weight_quantizers"), ): weight = getattr(self, weight_name, None) @@ -951,11 +976,11 @@ def fold_weight(self, keep_attrs: bool = False): The base ``fold_weight`` only handles singular ``*_weight_quantizer`` attributes. Fused experts use ``nn.ModuleList`` of per-expert quantizers - (``gate_up_proj_weight_quantizers``, ``down_proj_weight_quantizers``), + (``_weight_quantizers``, ``down_proj_weight_quantizers``), which would otherwise be skipped, leaving ``_amax`` on every quantizer. """ for weight_name, quantizers_name in ( - ("gate_up_proj", "gate_up_proj_weight_quantizers"), + (self._first_proj_attr, self._first_proj_weight_quantizers_attr), ("down_proj", "down_proj_weight_quantizers"), ): weight = getattr(self, weight_name, None) @@ -974,6 +999,29 @@ def fold_weight(self, keep_attrs: bool = False): delattr(q, attr_name) +class _QuantNonGatedFusedExperts(_QuantFusedExperts): + """Quantized wrapper for non-gated fused MoE experts. + + Used by NemotronH (transformers 5.5+ ``NemotronHExperts``), whose experts + are a *non-gated* MLP: a single ``up_proj`` (no gate half) and a ``down_proj``, + both stored as 3-D ``nn.Parameter`` s indexed per expert. + """ + + _first_proj_attr = "up_proj" + _is_gated = False + + +def _get_fused_experts_quantizer_attr_names(module): + """Return quantizer attribute names for a converted fused-experts module.""" + first_proj_attr = getattr(module, "_first_proj_attr", "gate_up_proj") + return ( + f"{first_proj_attr}_input_quantizer", + f"{first_proj_attr}_weight_quantizers", + "down_proj_input_quantizer", + "down_proj_weight_quantizers", + ) + + def _is_quant_fused_experts_module(module): """Return True for a converted HF fused-MoE-experts quantization wrapper.""" return isinstance(module, _QuantFusedExperts) @@ -1466,27 +1514,45 @@ def register_sparse_moe_on_the_fly(model): ) -def _is_fused_experts_module(module): - """Check if a module is a fused MoE expert container compatible with _QuantFusedExperts. +def _fused_experts_wrapper_class(module): + """Return the _QuantFusedExperts subclass for a fused MoE expert container, or None. - Detects the standardized HuggingFace transformers 5.0+ fused expert pattern: - ``gate_up_proj`` (3-D parameter), ``down_proj`` (3-D parameter), ``num_experts``, - and ``act_fn``. Matches ``MixtralExperts``, ``Qwen2MoeExperts``, - ``Qwen3MoeExperts``, ``Qwen3_5MoeExperts``, ``DeepseekV3NaiveMoe``, - ``JambaExperts``, ``OlmoeExperts``, etc. + Two 3-D fused layouts are recognized, both requiring ``num_experts`` + ``act_fn`` + and a 3-D ``down_proj`` parameter: - Returns ``False`` for non-standard layouts (DBRX, GptOss, GraniteMoE, + * gated (``_QuantFusedExperts``): a 3-D ``gate_up_proj`` fusing gate+up. Matches + ``MixtralExperts``, ``Qwen2MoeExperts``, ``Qwen3MoeExperts``, + ``Qwen3_5MoeExperts``, ``DeepseekV3NaiveMoe``, ``JambaExperts``, + ``OlmoeExperts``, etc. + * non-gated (``_QuantNonGatedFusedExperts``): a 3-D ``up_proj`` with no + ``gate_proj`` and no ``gate_up_proj``. Matches NemotronH ``NemotronHExperts``. + + Returns ``None`` for non-standard layouts (DBRX, GptOss, GraniteMoE, Llama4TextExperts) which have their own explicit registrations. """ - if not hasattr(module, "gate_up_proj") or not hasattr(module, "down_proj"): - return False if not hasattr(module, "num_experts") or not hasattr(module, "act_fn"): - return False - gate_up = getattr(module, "gate_up_proj") - down = getattr(module, "down_proj") - if not isinstance(gate_up, (nn.Parameter, Tensor)) or gate_up.dim() != 3: - return False - return isinstance(down, (nn.Parameter, Tensor)) and down.dim() == 3 + return None + down = getattr(module, "down_proj", None) + if not isinstance(down, (nn.Parameter, Tensor)) or down.dim() != 3: + return None + gate_up = getattr(module, "gate_up_proj", None) + if isinstance(gate_up, (nn.Parameter, Tensor)) and gate_up.dim() == 3: + return _QuantFusedExperts + up = getattr(module, "up_proj", None) + if isinstance(up, (nn.Parameter, Tensor)) and up.dim() == 3: + # Only claim non-gated experts that alternate up_proj then down_proj. + if getattr(module, "gate_proj", None) is None and gate_up is None: + return _QuantNonGatedFusedExperts + return None + + +def _is_fused_experts_module(module): + """Check if a module is a fused MoE expert container compatible with _QuantFusedExperts. + + See :func:`_fused_experts_wrapper_class` for the recognized layouts (gated + ``gate_up_proj`` and non-gated ``up_proj``). + """ + return _fused_experts_wrapper_class(module) is not None def register_fused_experts_on_the_fly(model): @@ -1508,12 +1574,13 @@ def register_fused_experts_on_the_fly(model): visited_types.add(mod_type) - if _is_fused_experts_module(module): + wrapper_cls = _fused_experts_wrapper_class(module) + if wrapper_cls is not None: print( f"\033[1mDetected fused MoE experts '{name}' of type {mod_type.__name__}, " - f"registering with _QuantFusedExperts.\033[0m" + f"registering with {wrapper_cls.__name__}.\033[0m" ) - QuantModuleRegistry.register({mod_type: f"hf.{mod_type.__name__}"})(_QuantFusedExperts) + QuantModuleRegistry.register({mod_type: f"hf.{mod_type.__name__}"})(wrapper_cls) def force_eager_experts_impl_on_the_fly(model): diff --git a/modelopt/torch/quantization/utils/core_utils.py b/modelopt/torch/quantization/utils/core_utils.py index f9ed19816e5..b0049b5a08d 100644 --- a/modelopt/torch/quantization/utils/core_utils.py +++ b/modelopt/torch/quantization/utils/core_utils.py @@ -238,7 +238,7 @@ def weight_attr_names(module: nn.Module) -> "Generator[str, None, None]": - custom per-weight quantizer (e.g. ``Llama4TextExperts`` with ``gate_up_proj`` + ``gate_up_proj_weight_quantizer``). - fused-experts ``nn.ModuleList`` quantizers (``_QuantFusedExperts`` with - ``gate_up_proj`` + ``gate_up_proj_weight_quantizers`` plural list). + ```` + ``_weight_quantizers`` plural list). """ # standard: "weight" + "weight_quantizer" (singular) or "weight_quantizers" (plural) if getattr(module, "weight", None) is not None: @@ -250,10 +250,17 @@ def weight_attr_names(module: nn.Module) -> "Generator[str, None, None]": if name == "weight": continue weight = getattr(module, name, None) - if ( - isinstance(weight, nn.Parameter) - and representative_weight_quantizer(module, name) is not None + if not isinstance(weight, nn.Parameter): + continue + if representative_weight_quantizer(module, name) is not None: + yield name + elif ( + name == getattr(module, "_first_proj_attr", None) + and name != "gate_up_proj" + and isinstance(getattr(module, "gate_up_proj_weight_quantizers", None), nn.ModuleList) ): + # Backward compatibility for older non-gated fused-experts wrappers that + # kept first-projection quantizers under the gate_up_proj sentinel name. yield name diff --git a/tests/unit/torch/quantization/plugins/test_fused_experts.py b/tests/unit/torch/quantization/plugins/test_fused_experts.py index 550c27c46fd..89267fb06e6 100644 --- a/tests/unit/torch/quantization/plugins/test_fused_experts.py +++ b/tests/unit/torch/quantization/plugins/test_fused_experts.py @@ -26,14 +26,16 @@ import modelopt.torch.quantization as mtq from modelopt.torch.export.moe_utils import _export_fused_experts -from modelopt.torch.export.quant_utils import get_quant_config +from modelopt.torch.export.quant_utils import get_quant_config, get_quantization_format from modelopt.torch.quantization.conversion import _normalize_fused_experts_quantizer_name from modelopt.torch.quantization.model_calib import local_hessian_calibrate -from modelopt.torch.quantization.nn import QuantModuleRegistry +from modelopt.torch.quantization.nn import QuantModuleRegistry, TensorQuantizer from modelopt.torch.quantization.plugins.huggingface import ( + _fused_experts_wrapper_class, _is_fused_experts_module, _is_sparse_sequaential_moe_block, _QuantFusedExperts, + _QuantNonGatedFusedExperts, force_eager_experts_impl_on_the_fly, register_fused_experts_on_the_fly, register_sparse_moe_on_the_fly, @@ -86,6 +88,45 @@ def forward(self, hidden_states, top_k_index, top_k_weights): return final_hidden_states +class _SyntheticNonGatedFusedExperts(nn.Module): + """Mimics NemotronHExperts (transformers 5.5+): non-gated fused experts. + + A single ``up_proj`` (no gate half) + ``down_proj``, both 3-D ``nn.Parameter`` s, + with the forward calling ``F.linear`` exactly twice per expert (up then down). + """ + + def __init__(self): + super().__init__() + self.num_experts = NUM_EXPERTS + self.hidden_dim = HIDDEN_DIM + self.intermediate_dim = INTERMEDIATE_DIM + self.up_proj = nn.Parameter(torch.randn(NUM_EXPERTS, INTERMEDIATE_DIM, HIDDEN_DIM) * 0.02) + self.down_proj = nn.Parameter(torch.randn(NUM_EXPERTS, HIDDEN_DIM, INTERMEDIATE_DIM) * 0.02) + self.act_fn = nn.SiLU() + + def forward(self, hidden_states, top_k_index, top_k_weights): + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = F.one_hot(top_k_index, num_classes=self.num_experts).permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + current_hidden_states = F.linear(current_state, self.up_proj[expert_idx]) + current_hidden_states = self.act_fn(current_hidden_states) + current_hidden_states = F.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = ( + current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + ) + final_hidden_states.index_add_( + 0, token_idx, current_hidden_states.to(final_hidden_states.dtype) + ) + return final_hidden_states + + class _SyntheticTopKRouter(nn.Module): def __init__(self): super().__init__() @@ -128,6 +169,43 @@ def forward(self, x): return self.moe(x) +class _SyntheticNonGatedSparseMoeBlock(nn.Module): + """Mimics NemotronHMoE: a router + non-gated fused experts.""" + + def __init__(self): + super().__init__() + self.gate = _SyntheticTopKRouter() + self.experts = _SyntheticNonGatedFusedExperts() + + def forward(self, hidden_states): + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + _, top_k_weights, top_k_index = self.gate(hidden_states) + hidden_states = self.experts(hidden_states, top_k_index, top_k_weights) + return hidden_states.reshape(batch_size, sequence_length, hidden_dim) + + +class _TinyNonGatedMoEModel(nn.Module): + """Minimal model containing a single non-gated MoE block.""" + + def __init__(self): + super().__init__() + self.moe = _SyntheticNonGatedSparseMoeBlock() + + def forward(self, x): + return self.moe(x) + + +def _route_once_to_each_expert(model): + """Call fused experts directly with deterministic routing that covers every expert.""" + assert NUM_EXPERTS % TOP_K == 0 + seq_len = NUM_EXPERTS // TOP_K + hidden_states = torch.randn(seq_len, HIDDEN_DIM) + top_k_index = torch.arange(NUM_EXPERTS, dtype=torch.long).reshape(seq_len, TOP_K) + top_k_weights = torch.ones(seq_len, TOP_K) / TOP_K + model.moe.experts(hidden_states, top_k_index, top_k_weights) + + # --------------------------------------------------------------------------- # Tests for _is_fused_experts_module # --------------------------------------------------------------------------- @@ -254,7 +332,7 @@ def test_expert_index_recovery(self): for idx in range(NUM_EXPERTS): weight_slice = converted.gate_up_proj[idx] - recovered_idx = converted._get_expert_idx_from_gate_up(weight_slice) + recovered_idx = converted._get_expert_idx_from_first_proj(weight_slice) assert recovered_idx == idx, f"Expected {idx}, got {recovered_idx}" self._cleanup_registry(expert_type) @@ -299,9 +377,7 @@ def test_export_creates_per_expert_submodules(self): def forward_loop(m): torch.manual_seed(0) - for _ in range(2): - x = torch.randn(1, 4, HIDDEN_DIM) - m(x) + _route_once_to_each_expert(m) mtq.quantize(model, quant_cfg, forward_loop=forward_loop) converted = model.moe.experts @@ -755,9 +831,7 @@ def test_calibration_populates_all_expert_quantizers(self): def forward_loop(m): torch.manual_seed(0) - for _ in range(2): - x = torch.randn(1, 4, HIDDEN_DIM) - m(x) + _route_once_to_each_expert(m) mtq.quantize(model, quant_cfg, forward_loop=forward_loop) @@ -1077,3 +1151,222 @@ def test_unrelated_dotted_number_unchanged(self): _normalize_fused_experts_quantizer_name("moe.layers.3.gate.weight") == "moe.layers.3.gate.weight" ) + + +# --------------------------------------------------------------------------- +# Tests for the non-gated fused-experts path (NemotronH NemotronHExperts): +# single up_proj (no gate half) + down_proj. Quantizers are named after the +# backing weights: up_proj_* and down_proj_*. +# --------------------------------------------------------------------------- +class TestNonGatedFusedExperts: + @staticmethod + def _cleanup_registry(mod_type): + if QuantModuleRegistry.get(mod_type) is not None: + QuantModuleRegistry.unregister(mod_type) + + def test_detected_and_picks_nongated_wrapper(self): + module = _SyntheticNonGatedFusedExperts() + assert _is_fused_experts_module(module) is True + assert _fused_experts_wrapper_class(module) is _QuantNonGatedFusedExperts + + def test_gated_still_picks_base_wrapper(self): + assert _fused_experts_wrapper_class(_SyntheticFusedExperts()) is _QuantFusedExperts + + def test_register_uses_nongated_wrapper(self): + model = _TinyNonGatedMoEModel() + expert_type = type(model.moe.experts) + self._cleanup_registry(expert_type) + register_fused_experts_on_the_fly(model) + try: + converted = QuantModuleRegistry.convert(model.moe.experts) + assert isinstance(converted, _QuantNonGatedFusedExperts) + assert converted._first_proj_attr == "up_proj" + assert converted._is_gated is False + assert hasattr(converted, "up_proj_input_quantizer") + assert hasattr(converted, "up_proj_weight_quantizers") + assert not hasattr(converted, "gate_up_proj_input_quantizer") + assert not hasattr(converted, "gate_up_proj_weight_quantizers") + assert len(converted.up_proj_weight_quantizers) == NUM_EXPERTS + assert len(converted.down_proj_weight_quantizers) == NUM_EXPERTS + finally: + self._cleanup_registry(expert_type) + + def test_forward_passthrough_matches(self): + model = _TinyNonGatedMoEModel() + expert_type = type(model.moe.experts) + self._cleanup_registry(expert_type) + + ref_experts = _SyntheticNonGatedFusedExperts() + ref_experts.load_state_dict(model.moe.experts.state_dict()) + + register_fused_experts_on_the_fly(model) + try: + converted = QuantModuleRegistry.convert(model.moe.experts) + # Disable quantizers to isolate the wrapper's structural forward + # (the F.linear interception / per-expert index routing) from + # dynamic-quant noise — this is a passthrough equivalence check. + for q in converted.modules(): + if isinstance(q, TensorQuantizer): + q.disable() + seq_len = 8 + hidden_states = torch.randn(seq_len, HIDDEN_DIM) + top_k_index = torch.randint(0, NUM_EXPERTS, (seq_len, TOP_K)) + top_k_weights = torch.softmax(torch.randn(seq_len, TOP_K), dim=-1) + with torch.no_grad(): + out_ref = ref_experts(hidden_states, top_k_index, top_k_weights) + out_test = converted(hidden_states, top_k_index, top_k_weights) + assert torch.allclose(out_ref, out_test, atol=1e-5), ( + f"Max diff: {(out_ref - out_test).abs().max().item()}" + ) + finally: + self._cleanup_registry(expert_type) + + def test_expert_index_recovery(self): + experts = _SyntheticNonGatedFusedExperts() + expert_type = type(experts) + self._cleanup_registry(expert_type) + register_fused_experts_on_the_fly(_TinyNonGatedMoEModel()) + try: + converted = QuantModuleRegistry.convert(experts) + for idx in range(NUM_EXPERTS): + weight_slice = converted.up_proj[idx] + assert converted._get_expert_idx_from_first_proj(weight_slice) == idx + finally: + self._cleanup_registry(expert_type) + + def _nongated_fp8_cfg(self): + return { + "quant_cfg": [ + {"quantizer_name": "*", "enable": False}, + { + "quantizer_name": "*up_proj_input_quantizer", + "cfg": {"num_bits": 8, "axis": None}, + }, + { + "quantizer_name": "*down_proj_input_quantizer", + "cfg": {"num_bits": 8, "axis": None}, + }, + { + "quantizer_name": "*up_proj_weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0}, + }, + { + "quantizer_name": "*down_proj_weight_quantizer", + "cfg": {"num_bits": 8, "axis": 0}, + }, + ], + "algorithm": "max", + } + + def test_calibration_populates_all_expert_quantizers(self): + model = _TinyNonGatedMoEModel() + expert_type = type(model.moe.experts) + self._cleanup_registry(expert_type) + + def forward_loop(m): + torch.manual_seed(0) + _route_once_to_each_expert(m) + + try: + mtq.quantize(model, self._nongated_fp8_cfg(), forward_loop=forward_loop) + experts = model.moe.experts + assert experts.up_proj_input_quantizer.amax is not None + assert experts.down_proj_input_quantizer.amax is not None + for idx in range(NUM_EXPERTS): + assert experts.up_proj_weight_quantizers[idx].amax is not None + assert experts.down_proj_weight_quantizers[idx].amax is not None + finally: + self._cleanup_registry(expert_type) + + def test_export_creates_per_expert_up_down_only(self): + model = _TinyNonGatedMoEModel() + expert_type = type(model.moe.experts) + self._cleanup_registry(expert_type) + + def forward_loop(m): + torch.manual_seed(0) + for _ in range(2): + m(torch.randn(1, 4, HIDDEN_DIM)) + + try: + mtq.quantize(model, self._nongated_fp8_cfg(), forward_loop=forward_loop) + converted = model.moe.experts + _export_fused_experts(converted, torch.float16) + + for idx in range(NUM_EXPERTS): + expert_mod = getattr(converted, str(idx), None) + assert expert_mod is not None, f"Missing expert submodule {idx}" + # Non-gated: up_proj + down_proj, but NO gate_proj. + assert hasattr(expert_mod, "up_proj"), f"Expert {idx} missing up_proj" + assert hasattr(expert_mod, "down_proj"), f"Expert {idx} missing down_proj" + assert not hasattr(expert_mod, "gate_proj"), ( + f"Expert {idx} should NOT have gate_proj (non-gated MLP)" + ) + assert expert_mod.up_proj.weight.shape == (INTERMEDIATE_DIM, HIDDEN_DIM) + assert expert_mod.down_proj.weight.shape == (HIDDEN_DIM, INTERMEDIATE_DIM) + + # Fused params and per-expert quantizer lists are removed. + assert not hasattr(converted, "up_proj") + assert not hasattr(converted, "down_proj") + assert not hasattr(converted, "up_proj_weight_quantizers") + assert not hasattr(converted, "down_proj_weight_quantizers") + finally: + self._cleanup_registry(expert_type) + + def test_enumeration_yields_up_and_down_proj(self): + """weight_attr_names must yield up_proj and down_proj for non-gated experts.""" + model = _TinyNonGatedMoEModel() + expert_type = type(model.moe.experts) + self._cleanup_registry(expert_type) + register_fused_experts_on_the_fly(model) + try: + converted = QuantModuleRegistry.convert(model.moe.experts) + assert set(weight_attr_names(converted)) == {"up_proj", "down_proj"} + finally: + self._cleanup_registry(expert_type) + + def test_split_gated_layout_not_claimed_as_nongated(self): + """A fused container with a separate 3-D gate_proj (split-gated: three + F.linear calls per expert) must NOT be claimed by the non-gated wrapper, + whose two-call toggle and up_proj-storage index recovery assume exactly + two projections. It is left unsupported (None) rather than mis-quantized.""" + + class _SplitGatedExperts(nn.Module): + def __init__(self): + super().__init__() + self.num_experts = NUM_EXPERTS + self.gate_proj = nn.Parameter( + torch.randn(NUM_EXPERTS, INTERMEDIATE_DIM, HIDDEN_DIM) * 0.02 + ) + self.up_proj = nn.Parameter( + torch.randn(NUM_EXPERTS, INTERMEDIATE_DIM, HIDDEN_DIM) * 0.02 + ) + self.down_proj = nn.Parameter( + torch.randn(NUM_EXPERTS, HIDDEN_DIM, INTERMEDIATE_DIM) * 0.02 + ) + self.act_fn = nn.SiLU() + + module = _SplitGatedExperts() + assert _fused_experts_wrapper_class(module) is None + assert _is_fused_experts_module(module) is False + + def test_get_quant_config_resolves_nongated_experts(self): + """get_quant_config must detect the non-gated experts as quantized.""" + model = _TinyNonGatedMoEModel() + expert_type = type(model.moe.experts) + self._cleanup_registry(expert_type) + + def forward_loop(m): + torch.manual_seed(0) + for _ in range(2): + m(torch.randn(1, 4, HIDDEN_DIM)) + + try: + mtq.quantize(model, self._nongated_fp8_cfg(), forward_loop=forward_loop) + # Format resolves (via down_proj) instead of QUANTIZATION_NONE (None). + assert get_quantization_format(model.moe.experts) is not None + # The non-gated experts are reflected in the produced quant config. + quant = get_quant_config(model)["quantization"] + assert quant.get("quant_algo") is not None + finally: + self._cleanup_registry(expert_type)