From 998856b373e7beb7be349f6ce5bd4ab920d891c8 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Tue, 9 Dec 2025 07:44:00 +0000 Subject: [PATCH 01/11] Support KIMI K2 Thinking int4 checkpoint PTQ Signed-off-by: Chenjie Luo --- examples/llm_ptq/example_utils.py | 7 ++++ examples/llm_ptq/hf_ptq.py | 38 ++++++++++++------- .../llm_ptq/scripts/huggingface_example.sh | 4 +- .../torch/quantization/plugins/huggingface.py | 28 ++++++++++++++ 4 files changed, 61 insertions(+), 16 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index ce3fb0853..c13b9f897 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -327,6 +327,13 @@ def get_model( device_map=device_map, **model_kwargs, ) + elif hf_config.quantization_config.get("format", None) == "pack-quantized": + model = AutoModelForCausalLM.from_pretrained( + ckpt_path, + torch_dtype=torch.float16, + device_map="auto", + trust_remote_code=trust_remote_code, + ) else: architecture = hf_config.architectures[0] diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 57f0b5a89..9697adf08 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -14,6 +14,7 @@ # limitations under the License. import argparse +import gc import random import time import warnings @@ -510,19 +511,26 @@ def main(args): "input_features" if model_type == "whisper" else "input_ids" ][0:1] - # Generate preview before quantization - if is_nemotron_vl_model and tokenizer is not None: - generated_ids_before_ptq = run_nemotron_vl_preview( - full_model, - tokenizer, - input_ids, - args.pyt_ckpt_path, - "before quantization", - allow_fallback=True, - ) - else: - # Standard generation for non-Nemotron VL models - generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) + try: + # Generate preview before quantization + if is_nemotron_vl_model and tokenizer is not None: + generated_ids_before_ptq = run_nemotron_vl_preview( + full_model, + tokenizer, + input_ids, + args.pyt_ckpt_path, + "before quantization", + allow_fallback=True, + ) + else: + # Standard generation for non-Nemotron VL models + generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) + except torch.OutOfMemoryError: + print("Out of memory. Skipping preview generation.") + generated_ids_before_ptq = None + gc.collect() + torch.cuda.empty_cache() + if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": print("Applying nvfp4 quantization (MoE only) for gpt-oss") @@ -542,7 +550,9 @@ def main(args): # Run some samples torch.cuda.empty_cache() generated_ids_after_ptq = None - if model_type != "llama4" and not is_nemotron_vl_model: + if generated_ids_before_ptq is None: + pass + elif model_type != "llama4" and not is_nemotron_vl_model: # Our fake quantizer may not be fully compatible with torch.compile. generated_ids_after_ptq = full_model.generate(input_ids, max_new_tokens=100) elif is_nemotron_vl_model and tokenizer is not None: diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh index 043b690e5..3ea85de9e 100755 --- a/examples/llm_ptq/scripts/huggingface_example.sh +++ b/examples/llm_ptq/scripts/huggingface_example.sh @@ -53,9 +53,9 @@ esac IFS="," for qformat in $QFORMAT; do case $qformat in - fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;; + fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_mlp_only) ;; *) - echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2 + echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_mlp_only]" >&2 exit 1 ;; esac diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 31ac2bbbd..458c72bce 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -22,6 +22,8 @@ from typing import TYPE_CHECKING import torch +from torch import Tensor +from torch.nn.functional import linear try: from torch.distributed.tensor import Shard @@ -501,6 +503,22 @@ def top_k(self, value): self.router.moe_top_k = value +class _QuantCompressedLinear(QuantModule): + def _setup(self): + self.input_quantizer = TensorQuantizer() + self.weight_quantizer = TensorQuantizer() + + def forward(self, input: Tensor) -> Tensor: + from compressed_tensors.quantization import QuantizationStatus + + if self.quantization_status == QuantizationStatus.COMPRESSED: + weight_data = self.compressor.decompress_module(self) + else: + weight_data = self.weight + + return linear(self.input_quantizer(input), self.weight_quantizer(weight_data), self.bias) + + try: from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe @@ -576,6 +594,16 @@ def top_k(self, value): except ImportError: pass +try: + from compressed_tensors.linear.compressed_linear import CompressedLinear + + if CompressedLinear not in QuantModuleRegistry: + QuantModuleRegistry.register({CompressedLinear: "hf.CompressedLinear"})( + _QuantCompressedLinear + ) +except ImportError: + pass + class _QuantGptOssExperts(_QuantFunctionalMixin): """Quantized wrapper for `transformers.GptOssExperts`. From 3aebac77c72f83f29640d05e8fc70f92474f7f6f Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Tue, 9 Dec 2025 17:27:45 +0000 Subject: [PATCH 02/11] Fix Signed-off-by: Chenjie Luo --- examples/llm_ptq/hf_ptq.py | 36 +++++++++++++++++------------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 9697adf08..d34f9fdbb 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -14,7 +14,6 @@ # limitations under the License. import argparse -import gc import random import time import warnings @@ -511,25 +510,24 @@ def main(args): "input_features" if model_type == "whisper" else "input_ids" ][0:1] - try: - # Generate preview before quantization - if is_nemotron_vl_model and tokenizer is not None: - generated_ids_before_ptq = run_nemotron_vl_preview( - full_model, - tokenizer, - input_ids, - args.pyt_ckpt_path, - "before quantization", - allow_fallback=True, - ) - else: - # Standard generation for non-Nemotron VL models - generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) - except torch.OutOfMemoryError: - print("Out of memory. Skipping preview generation.") + # Generate preview before quantization + if model_type == "deepseek": + print( + "Deepseek model may hit OOM during preview generation. Skipping preview generation." + ) generated_ids_before_ptq = None - gc.collect() - torch.cuda.empty_cache() + elif is_nemotron_vl_model and tokenizer is not None: + generated_ids_before_ptq = run_nemotron_vl_preview( + full_model, + tokenizer, + input_ids, + args.pyt_ckpt_path, + "before quantization", + allow_fallback=True, + ) + else: + # Standard generation for non-Nemotron VL models + generated_ids_before_ptq = full_model.generate(input_ids, max_new_tokens=100) if model_type == "gptoss" and args.qformat == "nvfp4_mlp_only": print("Applying nvfp4 quantization (MoE only) for gpt-oss") From 95ee2752e4d2156f248c5898b92f8bee4bce740f Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Wed, 10 Dec 2025 00:12:04 +0000 Subject: [PATCH 03/11] Fix Signed-off-by: Chenjie Luo --- examples/llm_ptq/example_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index c13b9f897..f31d11b8f 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -330,7 +330,6 @@ def get_model( elif hf_config.quantization_config.get("format", None) == "pack-quantized": model = AutoModelForCausalLM.from_pretrained( ckpt_path, - torch_dtype=torch.float16, device_map="auto", trust_remote_code=trust_remote_code, ) @@ -353,9 +352,9 @@ def get_model( from_config = auto_model_module._from_config with init_empty_weights(): - # When computing the device_map, assuming half precision by default, + # When computing the device_map, assuming bfloat16 precision by default, # unless specified by the hf_config. - torch_dtype = getattr(hf_config, "torch_dtype", torch.float16) + torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) model_kwargs2 = model_kwargs.copy() if auto_model_module != AutoModelForCausalLM: model_kwargs2.pop("trust_remote_code", None) From a09d86f6802e77921db08b73d70bc0976471902d Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Mon, 5 Jan 2026 23:50:17 +0000 Subject: [PATCH 04/11] Fix export Signed-off-by: Chenjie Luo --- modelopt/torch/export/layer_utils.py | 3 ++- modelopt/torch/export/unified_export_hf.py | 4 ++++ modelopt/torch/quantization/plugins/huggingface.py | 8 ++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py index e35ee070f..9346e074b 100755 --- a/modelopt/torch/export/layer_utils.py +++ b/modelopt/torch/export/layer_utils.py @@ -345,7 +345,8 @@ def is_moe(module: nn.Module) -> bool: def is_quantlinear(module: nn.Module) -> bool: """Returns whether the module is a quantized linear layer.""" - return "QuantLinear" in type(module).__name__ and "lora" not in type(module).__name__.lower() + name = type(module).__name__ + return ("QuantLinear" in name or "QuantCompressedLinear" in name) and "lora" not in name.lower() def dup_kv_weight(v: torch.Tensor, head_size: int, num_head: int, tp_size: int) -> torch.Tensor: diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index a98e455db..2c648073e 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -390,6 +390,8 @@ def _export_quantized_weight( if weight_scale is not None: sub_module.register_buffer(quantizer_attrs.weight_scale, weight_scale) + torch.cuda.empty_cache() + def _export_hf_checkpoint( model: nn.Module, dtype: torch.dtype | None = None, is_modelopt_qlora: bool = False, **kwargs @@ -516,6 +518,8 @@ def _export_hf_checkpoint( if is_modelopt_qlora and (hasattr(sub_module, "base_layer")): continue + if hasattr(sub_module, "weight_packed"): + sub_module.unpack_weight() if get_quantization_format(sub_module) != QUANTIZATION_NONE: if is_quantlinear(sub_module): with fsdp2_aware_weight_update(model, sub_module, reshard=False): diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 458c72bce..ae7e596fc 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -518,6 +518,14 @@ def forward(self, input: Tensor) -> Tensor: return linear(self.input_quantizer(input), self.weight_quantizer(weight_data), self.bias) + def unpack_weight(self): + from compressed_tensors.quantization import QuantizationStatus + + if self.quantization_status == QuantizationStatus.COMPRESSED: + self.weight = nn.Parameter(self.compressor.decompress_module(self), requires_grad=False) + del self.weight_packed + del self.weight_scale + try: from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe From 09c12afaaf04db3f6ef7fd85814a3960bed10dd1 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Tue, 6 Jan 2026 21:23:28 +0000 Subject: [PATCH 05/11] Fix Signed-off-by: Chenjie Luo --- examples/llm_ptq/example_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index f31d11b8f..dba746e80 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -174,6 +174,11 @@ def build_quant_cfg( quant_cfg["quant_cfg"]["*image*"] = {"enable": False} quant_cfg["quant_cfg"]["*vision*"] = {"enable": False} + if model_type == "deepseek": + # Disable MLA quantization for accuracy. + quant_cfg["quant_cfg"]["*self_attn.q*"] = {"enable": False} + quant_cfg["quant_cfg"]["*self_attn.kv*"] = {"enable": False} + return quant_cfg @@ -328,10 +333,12 @@ def get_model( **model_kwargs, ) elif hf_config.quantization_config.get("format", None) == "pack-quantized": + torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) model = AutoModelForCausalLM.from_pretrained( ckpt_path, device_map="auto", trust_remote_code=trust_remote_code, + torch_dtype=torch_dtype, ) else: architecture = hf_config.architectures[0] From 295bbb7496a7eb4bd8959f6393014d037d4be20f Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Wed, 14 Jan 2026 01:21:54 +0000 Subject: [PATCH 06/11] Fix export Signed-off-by: Chenjie Luo --- modelopt/torch/export/quant_utils.py | 8 +++++++- modelopt/torch/export/unified_export_hf.py | 4 ++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index eee13dc51..fe139a07e 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -881,7 +881,13 @@ def postprocess_state_dict( "v_bmm_quantizer._bias_value": "v_proj.v_bias", "input_quantizer._pre_quant_scale": "pre_quant_scale", } - skip_keys = ["output_quantizer", "_amax", "_bias_value", "input_quantizer._pre_quant_scale"] + skip_keys = [ + "output_quantizer", + "_amax", + "_bias_value", + "input_quantizer._pre_quant_scale", + "weight_shape", + ] # For modelopt-trained LoRA models, we need to remove the base_layer prefix from the keys for deployment if is_modelopt_qlora: diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 2c648073e..36c2e332f 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -592,6 +592,10 @@ def export_hf_checkpoint( hf_quant_config = convert_hf_quant_config_format(hf_quant_config) + # Remove hf_quantizer from model so post_state_dict can be exported. + if getattr(model, "hf_quantizer", None) is not None: + model.hf_quantizer = None + # Save model model.save_pretrained( export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state From bfb9a141ea2347151d0d1fd56d84ae79bab83baf Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Thu, 15 Jan 2026 16:21:08 +0000 Subject: [PATCH 07/11] Update Signed-off-by: Chenjie Luo --- examples/llm_ptq/example_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 9eaa8c76a..2298b2b31 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -351,7 +351,10 @@ def get_model( device_map=device_map, **model_kwargs, ) - elif hf_config.quantization_config.get("format", None) == "pack-quantized": + elif ( + hasattr(hf_config, "quantization_config") + and hf_config.quantization_config.get("format", None) == "pack-quantized" + ): torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) model = AutoModelForCausalLM.from_pretrained( ckpt_path, From 55c72246aaf025e3d80668b893b1ca7058e7840f Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Thu, 15 Jan 2026 17:00:43 +0000 Subject: [PATCH 08/11] Fix Signed-off-by: Chenjie Luo --- examples/llm_ptq/hf_ptq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 2a4a28410..ddf6bfbdc 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -623,7 +623,7 @@ def post_quantize( generated_ids_after_ptq = None if generated_ids_before_ptq is None: pass - if model_type != "llama4" and not is_nemotron_vl_model: + elif model_type != "llama4" and not is_nemotron_vl_model: # Our fake quantizer may not be fully compatible with torch.compile. generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=100) elif is_nemotron_vl_model and tokenizer is not None: From be81219634fbe68624afc63c85013423afb222a0 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Thu, 15 Jan 2026 17:09:01 +0000 Subject: [PATCH 09/11] Fix Signed-off-by: Chenjie Luo --- examples/llm_ptq/example_utils.py | 8 ++++---- modelopt/torch/quantization/plugins/huggingface.py | 6 ++++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 2298b2b31..e213a7f25 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -187,10 +187,10 @@ def build_quant_cfg( quant_cfg["quant_cfg"]["model*.*attn*k_proj*"] = {"enable": False} quant_cfg["quant_cfg"]["model*.*attn*v_proj*"] = {"enable": False} - if model_type == "deepseek": - # Disable MLA quantization for accuracy. - quant_cfg["quant_cfg"]["*self_attn.q*"] = {"enable": False} - quant_cfg["quant_cfg"]["*self_attn.kv*"] = {"enable": False} + if model_type == "deepseek": + # Disable MLA quantization for accuracy. + quant_cfg["quant_cfg"]["*self_attn.q*"] = {"enable": False} + quant_cfg["quant_cfg"]["*self_attn.kv*"] = {"enable": False} return quant_cfg diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index ba32077d7..0ae3a4c0a 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -607,8 +607,10 @@ def unpack_weight(self): if self.quantization_status == QuantizationStatus.COMPRESSED: self.weight = nn.Parameter(self.compressor.decompress_module(self), requires_grad=False) - del self.weight_packed - del self.weight_scale + if hasattr(self, "weight_packed"): + del self.weight_packed + if hasattr(self, "weight_scale"): + del self.weight_scale try: From aa7f8be010b9a4311e65bba8d8f968c61299da1a Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Thu, 15 Jan 2026 17:12:16 +0000 Subject: [PATCH 10/11] Update doc Signed-off-by: Chenjie Luo --- CHANGELOG.rst | 7 +++++++ examples/llm_ptq/README.md | 1 + 2 files changed, 8 insertions(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 978ac209d..42f678b16 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,13 @@ NVIDIA Model Optimizer Changelog (Linux) ======================================== +0.42 (XXXX-XX-XX) +^^^^^^^^^^^^^^^^^ + +**New Features** + +- Add support for Kimi K2 Thinking model quantization from the original int4 checkpoint. + 0.41 (2026-01-19) ^^^^^^^^^^^^^^^^^ diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md index 1ceb213c7..6c75b134a 100755 --- a/examples/llm_ptq/README.md +++ b/examples/llm_ptq/README.md @@ -109,6 +109,7 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http | QWen3 MOE, Next 6 | ✅ | - | - | - | ✅ | | QwQ | ✅ | - | - | - | ✅ | | DeepSeek V3, R1, V3.1, V3.27 | - | - | - | - | ✅ | +| Kimi K2 | - | - | - | - | ✅ | | T5 | ✅ | ✅ | ✅ | ✅ | - | | Whisper | ✅ | ❌ | ❌ | ❌ | - | From 4200d3c9059df9692f1904db37dd3e5b0c562dff Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Wed, 21 Jan 2026 06:20:19 +0000 Subject: [PATCH 11/11] Add compressed tensors Signed-off-by: Chenjie Luo --- examples/llm_ptq/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/llm_ptq/requirements.txt b/examples/llm_ptq/requirements.txt index 3485f10e1..1469d5552 100644 --- a/examples/llm_ptq/requirements.txt +++ b/examples/llm_ptq/requirements.txt @@ -1,3 +1,4 @@ +compressed-tensors==0.12.0 fire flash-attn>=2.6.0 rouge_score>=0.1.2