Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions src/mcore_bridge/bridge/gpt_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

from mcore_bridge.config import ModelConfig
from mcore_bridge.tuners import LoraParallelLinear
from mcore_bridge.utils import (MxFp4Dequantizer, SafetensorLazyLoader, StreamingSafetensorSaver, deep_getattr,
gc_collect, get_logger, is_master, unwrap_model)
from mcore_bridge.utils import (MxFp4Dequantizer, PackedDequantizer, SafetensorLazyLoader, StreamingSafetensorSaver,
deep_getattr, gc_collect, get_logger, is_master, unwrap_model)

logger = get_logger()

Expand Down Expand Up @@ -81,6 +81,10 @@ def __init__(self, config: ModelConfig):

self._fp8_quantizer = None
self.mxfp4_quantizer = MxFp4Dequantizer()
quantization_config = getattr(self.config.hf_config, 'quantization_config', None)
self.packed_quantizer = None
if isinstance(quantization_config, dict) and quantization_config.get('quant_method') == 'compressed-tensors':
self.packed_quantizer = PackedDequantizer(quantization_config)
Comment thread
Jintao-Huang marked this conversation as resolved.
Comment thread
Jintao-Huang marked this conversation as resolved.

dp_size = dist.get_world_size() // self.etp_size // self.ep_size // self.pp_size
expert_decoder_rank_generator = mpu.RankGenerator(
Expand Down Expand Up @@ -923,18 +927,28 @@ def _set_mlp_state(
weight_list = []
start_idx = ep_rank * num_local_experts
for i in range(num_local_experts):
gate_proj_weight = hf_state_dict[f'{start_idx + i}.gate_proj.weight'].load()
up_proj_weight = hf_state_dict[f'{start_idx + i}.up_proj.weight'].load()
weight_list.append(torch.stack([gate_proj_weight, up_proj_weight], dim=0))
if f'{start_idx + i}.gate_proj.weight_packed' in hf_state_dict:
weight = []
for key in ['gate_proj', 'up_proj']:
weight_packed = hf_state_dict[f'{start_idx + i}.{key}.weight_packed'].load()
weight_scale = hf_state_dict[f'{start_idx + i}.{key}.weight_scale'].load()
weight_shape = hf_state_dict[f'{start_idx + i}.{key}.weight_shape'].load()
weight.append(
self.packed_quantizer.convert(weight_packed, weight_scale, weight_shape))
Comment thread
Jintao-Huang marked this conversation as resolved.
else:
gate_proj_weight = hf_state_dict[f'{start_idx + i}.gate_proj.weight'].load()
up_proj_weight = hf_state_dict[f'{start_idx + i}.up_proj.weight'].load()
weight = [gate_proj_weight, up_proj_weight]
weight_list.append(torch.stack(weight, dim=0))
gate_up_proj_weight = torch.concat(weight_list, dim=0)
del weight_list
if has_scale_inv:
scale_inv_list = []
for i in range(num_local_experts):
gate_scale_inv = hf_state_dict[f'{start_idx + i}.gate_proj.weight_scale_inv'].load()
up_scale_inv = hf_state_dict[f'{start_idx + i}.up_proj.weight_scale_inv'].load()
scale_inv_list.append(torch.stack([gate_scale_inv, up_scale_inv], dim=0))
gate_up_scale_inv = torch.concat(scale_inv_list, dim=0)
del weight_list
else:
gate_proj_weight = hf_state_dict['gate_proj.weight'].load()
up_proj_weight = hf_state_dict['up_proj.weight'].load()
Expand Down Expand Up @@ -1162,11 +1176,19 @@ def _set_mlp_state(
down_proj_bias = down_proj_bias[ep_rank * num_local_experts:(ep_rank + 1)
* num_local_experts]
else:
down_proj_weight = torch.concat([
hf_state_dict[f'{i + ep_rank * num_local_experts}.down_proj.weight'].load()
for i in range(num_local_experts)
],
dim=0)
weight_list = []
start_idx = ep_rank * num_local_experts
for i in range(num_local_experts):
if f'{start_idx + i}.down_proj.weight_packed' in hf_state_dict:
weight_packed = hf_state_dict[f'{start_idx + i}.down_proj.weight_packed'].load()
weight_scale = hf_state_dict[f'{start_idx + i}.down_proj.weight_scale'].load()
weight_shape = hf_state_dict[f'{start_idx + i}.down_proj.weight_shape'].load()
weight_list.append(
self.packed_quantizer.convert(weight_packed, weight_scale, weight_shape))
Comment thread
Jintao-Huang marked this conversation as resolved.
else:
weight_list.append(hf_state_dict[f'{start_idx + i}.down_proj.weight'].load())
down_proj_weight = torch.concat(weight_list, dim=0)
del weight_list
if has_scale_inv:
down_scale_inv = torch.concat([
hf_state_dict[f'{i + ep_rank * num_local_experts}.down_proj.weight_scale_inv'].load()
Expand Down
26 changes: 21 additions & 5 deletions src/mcore_bridge/model/mm_gpts/kimi_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ class KimiK25Vit(HuggingFaceVit):
module_mapping = {'vision_tower': 'vision_tower', 'mm_projector': 'mm_projector'}
_vision_tower = ['vision_tower']
_aligner = ['mm_projector']
test_mm_type = 'text'

def prepare_model(self, hf_config: PretrainedConfig):
output = []
Expand All @@ -82,15 +81,32 @@ def prepare_model(self, hf_config: PretrainedConfig):
assert hf_config.vision_config.mm_projector_type == 'patchmerger'
vit_config = VisionTowerConfig(hf_config.vision_config)
proj_config = ProjectorConfig(hf_config.vision_config)
vit_config.torch_dtype = hf_config.torch_dtype
self.vision_tower = MoonViT3dPretrainedModel._from_config(vit_config)
self.mm_projector = PatchMergerMLP(proj_config).to(self.vision_tower.dtype)
self.mm_projector = PatchMergerMLP(proj_config).to(hf_config.torch_dtype)
self.model_cls = get_class_from_dynamic_module('modeling_kimi_k25.KimiK25ForConditionalGeneration',
hf_config.name_or_path)

def get_inputs_embeds(self, inputs_embeds, **kwargs):
pixel_values = kwargs.pop('pixel_values', None)
if pixel_values is not None:
raise NotImplementedError('Kimi-K25 currently only supports plain text training.')
pixel_values = kwargs.get('pixel_values', None)
input_ids = kwargs['input_ids']

if pixel_values is not None and pixel_values.size(0) > 0:
pixel_values = pixel_values.to(self.vision_tower.dtype)
image_features = self._extract_image_features(pixel_values, kwargs['grid_thws'])
if self.mm_projector:
image_features = self.mm_projector(image_features)
Comment thread
Jintao-Huang marked this conversation as resolved.
image_features = torch.cat(image_features, dim=0)
inputs_embeds = inputs_embeds.to(image_features.dtype)
image_mask = (
input_ids == self.config.hf_config.media_placeholder_token_id).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_features)
return inputs_embeds

def _extract_image_features(self, *args, **kwargs):
with self.patch_hf_config():
return self.model_cls._extract_image_features(self, *args, **kwargs)


register_model(ModelMeta(
ModelType.kimi_k25,
Expand Down
2 changes: 1 addition & 1 deletion src/mcore_bridge/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from .dequantizer import Fp8Dequantizer, MxFp4Dequantizer, fp4_to_fp8
from .dequantizer import Fp8Dequantizer, MxFp4Dequantizer, PackedDequantizer, fp4_to_fp8
from .env import get_dist_setting, get_node_setting, is_dist, is_last_rank, is_local_master, is_master
from .import_utils import _LazyModule, is_flash_attn_3_available
from .logger import get_logger
Expand Down
75 changes: 74 additions & 1 deletion src/mcore_bridge/utils/dequantizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import torch
from typing import Tuple
from typing import Optional, Sequence, Tuple, Union


class Fp8Dequantizer:
Expand Down Expand Up @@ -77,3 +77,76 @@ def fp4_to_fp8(packed: torch.Tensor) -> torch.Tensor:
unpacked = unpacked.reshape(*packed.shape[:-1], 2 * packed.shape[-1])

return unpacked.to(torch.float8_e4m3fn)


class PackedDequantizer:
"""Dequantize INT4/INT8 weights packed into int32 (compressed-tensors `pack_quantized` format).

Mirrors ``compressed_tensors.compressors.pack_quantized.PackedDequantizer.decompress``
but exposes a simple ``convert(...)`` API consistent with the other dequantizers in this module.

Quantization parameters (num_bits, symmetric, strategy) are extracted from
``quantization_config`` at init time (i.e. ``hf_config.quantization_config``).
"""

# Strategies that store the zero-point in a packed int32 layout.
_PACK_ZP_STRATEGIES = ('group', 'channel')

def __init__(self, quantization_config: dict):
# Extract settings from the first (and usually only) config_groups entry.
config_groups = quantization_config.get('config_groups', {})
if config_groups:
group_cfg = next(iter(config_groups.values()))
weights_cfg = group_cfg.get('weights', {})
else:
weights_cfg = {}

self.num_bits: int = weights_cfg.get('num_bits', 4)
self.symmetric: bool = weights_cfg.get('symmetric', True)
self.strategy: str = weights_cfg.get('strategy', 'group')

def convert(
self,
packed: torch.Tensor,
scale: torch.Tensor,
original_shape: Union[torch.Size, torch.Tensor, Sequence[int]],
zero_point: Optional[torch.Tensor] = None,
g_idx: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Unpack ``weight_packed`` and dequantize it back to a float weight tensor.

:param packed: int32 packed weight tensor (``weight_packed``).
:param scale: per-channel / per-group scale tensor (``weight_scale``).
:param original_shape: original (unpacked) weight shape (``weight_shape``).
:param zero_point: optional zero-point. For asymmetric GROUP/CHANNEL strategies it is
still packed in int32 and will be unpacked here; for symmetric quantization it is
ignored.
:param g_idx: optional group index mapping (``weight_g_idx``).
:return: dequantized float weight tensor with shape ``original_shape``.
"""
from compressed_tensors.compressors.pack_quantized.helpers import unpack_from_int32
from compressed_tensors.quantization.lifecycle.forward import dequantize

if isinstance(original_shape, torch.Tensor):
original_shape = tuple(int(x) for x in original_shape.tolist())
else:
original_shape = tuple(int(x) for x in original_shape)

num_bits = self.num_bits
symmetric = self.symmetric
strategy = self.strategy

# Unpack zero_point before dequantization if it was stored in packed int32 form.
if (not symmetric) and strategy in self._PACK_ZP_STRATEGIES:
assert zero_point is not None, 'Asymmetric quant requires zero-point values'
original_zp_shape = (*original_shape[:-1], scale.shape[-1])
zero_point = unpack_from_int32(zero_point, num_bits, original_zp_shape, packed_dim=0)

unpacked = unpack_from_int32(packed, num_bits, original_shape)
weight = dequantize(
x_q=unpacked,
scale=scale,
zero_point=zero_point,
g_idx=g_idx,
)
return weight
Loading