Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/forward_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ class ForwardMeta:
# for prefill
exist_prefill: bool = False

# for mla & dsa
position_ids: Optional[paddle.Tensor] = None
mask_encoder_batch: Optional[paddle.Tensor] = None

def clear_caches(self):
"""Safely clean up the caches"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,16 @@ def __init__(
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)

self.num_heads: int = num_heads
self.heads_need_padding = False
if self.num_heads < 64 and fd_config.parallel_config.tensor_parallel_size > 1:
self.padding_num_heads = 64 - self.num_heads
self.heads_need_padding = True
Comment thread
chang-wenbin marked this conversation as resolved.
logger.warning(
"MLA num_attention_heads is less than 64, force to use 64 num_heads. "
"current num_heads=%d, tp_size=%d",
self.num_heads,
fd_config.parallel_config.tensor_parallel_size,
)
self.head_dim: int = fd_config.model_config.head_dim
self.num_layers: int = fd_config.model_config.num_hidden_layers

Expand All @@ -280,7 +290,9 @@ def __init__(
self.qk_rope_head_dim: int = fd_config.model_config.qk_rope_head_dim
self.qk_head_dim: int = fd_config.model_config.qk_nope_head_dim + fd_config.model_config.qk_rope_head_dim
self.attn_softmax_scale: float = self.qk_head_dim**-0.5
if fd_config.model_config.rope_scaling:
self.rope_scaling = getattr(fd_config.model_config, "rope_scaling", None)
if self.rope_scaling and "factor" in self.rope_scaling:
# if fd_config.model_config.rope_scaling:
mscale_all_dim = fd_config.model_config.rope_scaling.get("mscale_all_dim", False) # 1.0
scaling_factor = fd_config.model_config.rope_scaling["factor"] # 40
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
Expand Down Expand Up @@ -604,6 +616,10 @@ def forward_mixed(

if int(os.getenv("USE_FLASH_MLA", "0")) == 0:
assert self.num_heads <= 64, "paddle mla attention support failed"
if self.heads_need_padding:
q = paddle.nn.functional.pad(
q, [0, (self.padding_num_heads) * (self.kv_lora_rank + self.qk_rope_head_dim)], value=0.0
).contiguous()
Comment thread
chang-wenbin marked this conversation as resolved.
# 多头潜在注意力计算
fmha_out = multi_head_latent_attention(
q,
Expand Down Expand Up @@ -646,6 +662,8 @@ def forward_mixed(
True, # causal
speculate_decoder,
)
if self.heads_need_padding:
fmha_out = fmha_out[:, : self.num_heads * self.kv_lora_rank].contiguous()

return fmha_out
else:
Expand All @@ -661,6 +679,12 @@ def forward_mixed(
tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata()
token_num = q.shape[0]
decoder_q.reshape_([-1, 1, self.num_heads, 576])
if self.heads_need_padding:
padded_q = paddle.zeros(
[decoder_q.shape[0], decoder_q.shape[1], 64, decoder_q.shape[3]], dtype=decoder_q.dtype
)
padded_q[:, :, : self.num_heads, :] = decoder_q
decoder_q = padded_q

new_cache_shape = latent_cache.shape
assert new_cache_shape[1] == 1
Expand All @@ -679,6 +703,8 @@ def forward_mixed(
softmax_scale=self.attn_softmax_scale,
causal=True,
)
if self.heads_need_padding:
decoder_res = decoder_res[:, :, : self.num_heads, :].contiguous()

final_res = insert_decoder_result_back(
decoder_res,
Expand Down
60 changes: 36 additions & 24 deletions fastdeploy/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None
v_head_dim=self.v_head_dim,
)
self.rope_scaling = getattr(fd_config.model_config, "rope_scaling", None)
if self.rope_scaling:
if self.rope_scaling and "factor" in self.rope_scaling:
mscale_all_dim = self.rope_scaling.get("mscale_all_dim", False)
scaling_factor = self.rope_scaling["factor"]
mscale = self.yarn_get_mscale(scaling_factor, float(mscale_all_dim))
Expand Down Expand Up @@ -344,8 +344,6 @@ def forward(
self,
forward_meta: ForwardMeta,
hidden_states: paddle.Tensor,
position_ids: paddle.Tensor,
mask_encoder_batch: paddle.Tensor,
):
""" """

Expand All @@ -363,7 +361,7 @@ def forward(
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)

key_pe.reshape_([-1, 1, self.qk_rope_head_dim])
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
query_pe, key_pe = self.rotary_emb(forward_meta.position_ids, query_pe, key_pe)

compressed_kv = self.kv_a_layernorm(compressed_kv)[0]

Expand Down Expand Up @@ -400,7 +398,7 @@ def forward(
fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp, self.qk_head_dim])
fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim]
fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim])
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype)
fmha_out_prefill = fmha_out_prefill * forward_meta.mask_encoder_batch.cast(fmha_out_prefill.dtype)
fmha_out = fmha_out_prefill

if need_do_decode: # max_dec_len_this_time
Expand Down Expand Up @@ -617,7 +615,7 @@ def __init__(
# self.buffer = paddle.zeros([2048 * 2048], dtype=paddle.uint8)

def forward(
self, forward_meta: ForwardMeta, hidden_states: paddle.Tensor, qr: paddle.Tensor, positions, rotary_emb
self, forward_meta: ForwardMeta, hidden_states: paddle.Tensor, qr: paddle.Tensor, rotary_emb
) -> paddle.Tensor:
self.indexer_cache = forward_meta.caches[2 * self.layer_id + 1]

Expand All @@ -629,7 +627,7 @@ def forward(
k, _ = self.k_norm(k)
k_pe, k_nope = paddle.split(k, [self.rope_dim, self.index_head_dim - self.rope_dim], axis=-1)

q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
q_pe, k_pe = rotary_emb(forward_meta.position_ids, q_pe, k_pe.unsqueeze(1))
q_pe = q_pe.reshape(-1, self.index_n_heads, self.rope_dim)
k_pe = k_pe.reshape(-1, 1, self.rope_dim)

Expand Down Expand Up @@ -853,7 +851,7 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None
v_head_dim=self.v_head_dim,
)
self.rope_scaling = getattr(fd_config.model_config, "rope_scaling", None)
if self.rope_scaling:
if self.rope_scaling and "factor" in self.rope_scaling:
mscale_all_dim = self.rope_scaling.get("mscale_all_dim", False)
scaling_factor = self.rope_scaling["factor"]
mscale = self.yarn_get_mscale(scaling_factor, float(mscale_all_dim))
Expand Down Expand Up @@ -926,8 +924,6 @@ def forward(
self,
forward_meta: ForwardMeta,
hidden_states: paddle.Tensor,
position_ids: paddle.Tensor,
mask_encoder_batch: paddle.Tensor,
):
""" """
qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states)
Expand All @@ -940,15 +936,13 @@ def forward(
query = self.q_a_layernorm(query)[0]

# DSA indexer
indexer_top_k = self.indexer(
forward_meta, hidden_states, query, position_ids, rotary_emb=self.indexer_rotary_emb
)
indexer_top_k = self.indexer(forward_meta, hidden_states, query, rotary_emb=self.indexer_rotary_emb)

query = self.q_b_proj(query)
query.reshape_([-1, self.num_attention_heads_tp, self.qk_head_dim])
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)

query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
query_pe, key_pe = self.rotary_emb(forward_meta.position_ids, query_pe, key_pe)
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]).contiguous(), proj_type="k")
q_input = paddle.concat([q_nope_out.transpose([1, 0, 2]).contiguous(), query_pe], axis=-1)

Expand Down Expand Up @@ -1044,16 +1038,14 @@ def forward(
forward_meta: ForwardMeta,
hidden_states: paddle.Tensor,
residual: paddle.Tensor,
position_ids: paddle.Tensor,
mask_encoder_batch: paddle.Tensor,
):
""" """
if hidden_states.shape[0] > 0:
hidden_states, residual = self.input_layernorm(
hidden_states, residual_input=residual, forward_meta=forward_meta
)

hidden_states = self.self_attn(forward_meta, hidden_states, position_ids, mask_encoder_batch)
hidden_states = self.self_attn(forward_meta, hidden_states)

hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
else:
Expand Down Expand Up @@ -1108,8 +1100,6 @@ def forward(
self,
ids_remove_padding: paddle.Tensor,
forward_meta: ForwardMeta,
position_ids: paddle.Tensor,
mask_encoder_batch: paddle.Tensor,
):
""" """
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta)
Expand All @@ -1120,8 +1110,6 @@ def forward(
forward_meta,
hidden_states,
residual,
position_ids,
mask_encoder_batch,
)
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]

Expand Down Expand Up @@ -1297,12 +1285,10 @@ def forward(
forward_meta: ForwardMeta,
):
ids_remove_padding = inputs["ids_remove_padding"]
forward_meta.position_ids, mask_encoder_batch = self.pre_process(forward_meta)
forward_meta.position_ids, forward_meta.mask_encoder_batch = self.pre_process(forward_meta)
hidden_states = self.model(
ids_remove_padding=ids_remove_padding,
forward_meta=forward_meta,
position_ids=forward_meta.position_ids,
mask_encoder_batch=mask_encoder_batch,
)
return hidden_states

Expand Down Expand Up @@ -1353,3 +1339,29 @@ class DeepSeekV32PretrainedModel(DeepSeekV3PretrainedModel):
@classmethod
def arch_name(self):
return "DeepseekV32ForCausalLM"


@ModelRegistry.register_model_class(
architecture="Glm4MoeLiteForCausalLM",
module_name="deepseek_v3",
category=ModelCategory.TEXT_GENERATION,
primary_use=ModelCategory.TEXT_GENERATION,
)
class Glm4MoeLiteForCausalLM(DeepseekV3ForCausalLM):
"""
Glm4MoeLiteForCausalLM
"""

@classmethod
def name(cls):
return "Glm4MoeLiteForCausalLM"


class Glm4MoeLitePretrainedModel(DeepSeekV3PretrainedModel):
"""
Glm4MoeLite
"""

@classmethod
def arch_name(self):
return "Glm4MoeLiteForCausalLM"
Loading