diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index 978ac209d..42f678b16 100755
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -1,6 +1,13 @@
NVIDIA Model Optimizer Changelog (Linux)
========================================
+0.42 (XXXX-XX-XX)
+^^^^^^^^^^^^^^^^^
+
+**New Features**
+
+- Add support for Kimi K2 Thinking model quantization from the original int4 checkpoint.
+
0.41 (2026-01-19)
^^^^^^^^^^^^^^^^^
diff --git a/examples/llm_ptq/README.md b/examples/llm_ptq/README.md
index 1ceb213c7..6c75b134a 100755
--- a/examples/llm_ptq/README.md
+++ b/examples/llm_ptq/README.md
@@ -109,6 +109,7 @@ Please reference our [framework scripts](#framework-scripts) and our [docs](http
| QWen3 MOE, Next 6 | ✅ | - | - | - | ✅ |
| QwQ | ✅ | - | - | - | ✅ |
| DeepSeek V3, R1, V3.1, V3.27 | - | - | - | - | ✅ |
+| Kimi K2 | - | - | - | - | ✅ |
| T5 | ✅ | ✅ | ✅ | ✅ | - |
| Whisper | ✅ | ❌ | ❌ | ❌ | - |
diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py
index 312267991..e213a7f25 100755
--- a/examples/llm_ptq/example_utils.py
+++ b/examples/llm_ptq/example_utils.py
@@ -187,6 +187,11 @@ def build_quant_cfg(
quant_cfg["quant_cfg"]["model*.*attn*k_proj*"] = {"enable": False}
quant_cfg["quant_cfg"]["model*.*attn*v_proj*"] = {"enable": False}
+ if model_type == "deepseek":
+ # Disable MLA quantization for accuracy.
+ quant_cfg["quant_cfg"]["*self_attn.q*"] = {"enable": False}
+ quant_cfg["quant_cfg"]["*self_attn.kv*"] = {"enable": False}
+
return quant_cfg
@@ -346,6 +351,17 @@ def get_model(
device_map=device_map,
**model_kwargs,
)
+ elif (
+ hasattr(hf_config, "quantization_config")
+ and hf_config.quantization_config.get("format", None) == "pack-quantized"
+ ):
+ torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16)
+ model = AutoModelForCausalLM.from_pretrained(
+ ckpt_path,
+ device_map="auto",
+ trust_remote_code=trust_remote_code,
+ torch_dtype=torch_dtype,
+ )
else:
architecture = hf_config.architectures[0]
@@ -366,9 +382,9 @@ def get_model(
from_config = auto_model_module._from_config
with init_empty_weights():
- # When computing the device_map, assuming half precision by default,
+ # When computing the device_map, assuming bfloat16 precision by default,
# unless specified by the hf_config.
- torch_dtype = getattr(hf_config, "torch_dtype", torch.float16)
+ torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16)
model_kwargs2 = model_kwargs.copy()
if auto_model_module != AutoModelForCausalLM:
model_kwargs2.pop("trust_remote_code", None)
diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py
index a9862a742..ddf6bfbdc 100755
--- a/examples/llm_ptq/hf_ptq.py
+++ b/examples/llm_ptq/hf_ptq.py
@@ -575,7 +575,10 @@ def pre_quantize(
][0:1]
# Generate preview before quantization
- if is_nemotron_vl_model and tokenizer is not None:
+ if model_type == "deepseek":
+ # DeepSeek generation may go OOM, so we skip it
+ generated_ids_before_ptq = None
+ elif is_nemotron_vl_model and tokenizer is not None:
generated_ids_before_ptq = run_nemotron_vl_preview(
full_model,
tokenizer,
@@ -618,7 +621,9 @@ def post_quantize(
# Run some samples
torch.cuda.empty_cache()
generated_ids_after_ptq = None
- if model_type != "llama4" and not is_nemotron_vl_model:
+ if generated_ids_before_ptq is None:
+ pass
+ elif model_type != "llama4" and not is_nemotron_vl_model:
# Our fake quantizer may not be fully compatible with torch.compile.
generated_ids_after_ptq = full_model.generate(preview_input_ids, max_new_tokens=100)
elif is_nemotron_vl_model and tokenizer is not None:
diff --git a/examples/llm_ptq/scripts/huggingface_example.sh b/examples/llm_ptq/scripts/huggingface_example.sh
index 043b690e5..3ea85de9e 100755
--- a/examples/llm_ptq/scripts/huggingface_example.sh
+++ b/examples/llm_ptq/scripts/huggingface_example.sh
@@ -53,9 +53,9 @@ esac
IFS=","
for qformat in $QFORMAT; do
case $qformat in
- fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8) ;;
+ fp8 | fp8_pc_pt | fp8_pb_wo | int8_wo | int8_sq | int4_awq | w4a8_awq | fp16 | bf16 | nvfp4 | nvfp4_awq | w4a8_nvfp4_fp8 | w4a8_mxfp4_fp8 | nvfp4_mlp_only) ;;
*)
- echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8]" >&2
+ echo "Unknown quant argument: Expected one of: [fp8, fp8_pc_pt, fp8_pb_wo, int8_wo, int8_sq, int4_awq, w4a8_awq, fp16, bf16, nvfp4, nvfp4_awq, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, nvfp4_mlp_only]" >&2
exit 1
;;
esac
diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py
index e35ee070f..9346e074b 100755
--- a/modelopt/torch/export/layer_utils.py
+++ b/modelopt/torch/export/layer_utils.py
@@ -345,7 +345,8 @@ def is_moe(module: nn.Module) -> bool:
def is_quantlinear(module: nn.Module) -> bool:
"""Returns whether the module is a quantized linear layer."""
- return "QuantLinear" in type(module).__name__ and "lora" not in type(module).__name__.lower()
+ name = type(module).__name__
+ return ("QuantLinear" in name or "QuantCompressedLinear" in name) and "lora" not in name.lower()
def dup_kv_weight(v: torch.Tensor, head_size: int, num_head: int, tp_size: int) -> torch.Tensor:
diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py
index eee13dc51..fe139a07e 100755
--- a/modelopt/torch/export/quant_utils.py
+++ b/modelopt/torch/export/quant_utils.py
@@ -881,7 +881,13 @@ def postprocess_state_dict(
"v_bmm_quantizer._bias_value": "v_proj.v_bias",
"input_quantizer._pre_quant_scale": "pre_quant_scale",
}
- skip_keys = ["output_quantizer", "_amax", "_bias_value", "input_quantizer._pre_quant_scale"]
+ skip_keys = [
+ "output_quantizer",
+ "_amax",
+ "_bias_value",
+ "input_quantizer._pre_quant_scale",
+ "weight_shape",
+ ]
# For modelopt-trained LoRA models, we need to remove the base_layer prefix from the keys for deployment
if is_modelopt_qlora:
diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py
index 1dd1c1822..abe7f814d 100644
--- a/modelopt/torch/export/unified_export_hf.py
+++ b/modelopt/torch/export/unified_export_hf.py
@@ -392,6 +392,8 @@ def _export_quantized_weight(
if weight_scale is not None:
sub_module.register_buffer(quantizer_attrs.weight_scale, weight_scale)
+ torch.cuda.empty_cache()
+
def _export_hf_checkpoint(
model: nn.Module, dtype: torch.dtype | None = None, is_modelopt_qlora: bool = False, **kwargs
@@ -518,6 +520,8 @@ def _export_hf_checkpoint(
if is_modelopt_qlora and (hasattr(sub_module, "base_layer")):
continue
+ if hasattr(sub_module, "weight_packed"):
+ sub_module.unpack_weight()
if get_quantization_format(sub_module) != QUANTIZATION_NONE:
if is_quantlinear(sub_module):
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
@@ -590,6 +594,10 @@ def export_hf_checkpoint(
hf_quant_config = convert_hf_quant_config_format(hf_quant_config)
+ # Remove hf_quantizer from model so post_state_dict can be exported.
+ if getattr(model, "hf_quantizer", None) is not None:
+ model.hf_quantizer = None
+
# Save model
model.save_pretrained(
export_dir, state_dict=post_state_dict, save_modelopt_state=save_modelopt_state
diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py
index 30fdc5244..0ae3a4c0a 100644
--- a/modelopt/torch/quantization/plugins/huggingface.py
+++ b/modelopt/torch/quantization/plugins/huggingface.py
@@ -22,6 +22,8 @@
from typing import TYPE_CHECKING
import torch
+from torch import Tensor
+from torch.nn.functional import linear
try:
from torch.distributed.tensor import Shard
@@ -585,6 +587,32 @@ def top_k(self, value):
self.router.moe_top_k = value
+class _QuantCompressedLinear(QuantModule):
+ def _setup(self):
+ self.input_quantizer = TensorQuantizer()
+ self.weight_quantizer = TensorQuantizer()
+
+ def forward(self, input: Tensor) -> Tensor:
+ from compressed_tensors.quantization import QuantizationStatus
+
+ if self.quantization_status == QuantizationStatus.COMPRESSED:
+ weight_data = self.compressor.decompress_module(self)
+ else:
+ weight_data = self.weight
+
+ return linear(self.input_quantizer(input), self.weight_quantizer(weight_data), self.bias)
+
+ def unpack_weight(self):
+ from compressed_tensors.quantization import QuantizationStatus
+
+ if self.quantization_status == QuantizationStatus.COMPRESSED:
+ self.weight = nn.Parameter(self.compressor.decompress_module(self), requires_grad=False)
+ if hasattr(self, "weight_packed"):
+ del self.weight_packed
+ if hasattr(self, "weight_scale"):
+ del self.weight_scale
+
+
try:
from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe
@@ -660,6 +688,16 @@ def top_k(self, value):
except ImportError:
pass
+try:
+ from compressed_tensors.linear.compressed_linear import CompressedLinear
+
+ if CompressedLinear not in QuantModuleRegistry:
+ QuantModuleRegistry.register({CompressedLinear: "hf.CompressedLinear"})(
+ _QuantCompressedLinear
+ )
+except ImportError:
+ pass
+
class _QuantGptOssExperts(_QuantFunctionalMixin):
"""Quantized wrapper for `transformers.GptOssExperts`.