diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 9e512f32355..c9a26d1d722 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -158,7 +158,9 @@ class ForwardMeta: # for prefill exist_prefill: bool = False + # for mla & dsa position_ids: Optional[paddle.Tensor] = None + mask_encoder_batch: Optional[paddle.Tensor] = None def clear_caches(self): """Safely clean up the caches""" diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 61ccc4e16e7..d0cd2f10760 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -272,6 +272,16 @@ def __init__( self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP) self.num_heads: int = num_heads + self.heads_need_padding = False + if self.num_heads < 64 and fd_config.parallel_config.tensor_parallel_size > 1: + self.padding_num_heads = 64 - self.num_heads + self.heads_need_padding = True + logger.warning( + "MLA num_attention_heads is less than 64, force to use 64 num_heads. " + "current num_heads=%d, tp_size=%d", + self.num_heads, + fd_config.parallel_config.tensor_parallel_size, + ) self.head_dim: int = fd_config.model_config.head_dim self.num_layers: int = fd_config.model_config.num_hidden_layers @@ -280,7 +290,9 @@ def __init__( self.qk_rope_head_dim: int = fd_config.model_config.qk_rope_head_dim self.qk_head_dim: int = fd_config.model_config.qk_nope_head_dim + fd_config.model_config.qk_rope_head_dim self.attn_softmax_scale: float = self.qk_head_dim**-0.5 - if fd_config.model_config.rope_scaling: + self.rope_scaling = getattr(fd_config.model_config, "rope_scaling", None) + if self.rope_scaling and "factor" in self.rope_scaling: + # if fd_config.model_config.rope_scaling: mscale_all_dim = fd_config.model_config.rope_scaling.get("mscale_all_dim", False) # 1.0 scaling_factor = fd_config.model_config.rope_scaling["factor"] # 40 mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) @@ -604,6 +616,10 @@ def forward_mixed( if int(os.getenv("USE_FLASH_MLA", "0")) == 0: assert self.num_heads <= 64, "paddle mla attention support failed" + if self.heads_need_padding: + q = paddle.nn.functional.pad( + q, [0, (self.padding_num_heads) * (self.kv_lora_rank + self.qk_rope_head_dim)], value=0.0 + ).contiguous() # 多头潜在注意力计算 fmha_out = multi_head_latent_attention( q, @@ -646,6 +662,8 @@ def forward_mixed( True, # causal speculate_decoder, ) + if self.heads_need_padding: + fmha_out = fmha_out[:, : self.num_heads * self.kv_lora_rank].contiguous() return fmha_out else: @@ -661,6 +679,12 @@ def forward_mixed( tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata() token_num = q.shape[0] decoder_q.reshape_([-1, 1, self.num_heads, 576]) + if self.heads_need_padding: + padded_q = paddle.zeros( + [decoder_q.shape[0], decoder_q.shape[1], 64, decoder_q.shape[3]], dtype=decoder_q.dtype + ) + padded_q[:, :, : self.num_heads, :] = decoder_q + decoder_q = padded_q new_cache_shape = latent_cache.shape assert new_cache_shape[1] == 1 @@ -679,6 +703,8 @@ def forward_mixed( softmax_scale=self.attn_softmax_scale, causal=True, ) + if self.heads_need_padding: + decoder_res = decoder_res[:, :, : self.num_heads, :].contiguous() final_res = insert_decoder_result_back( decoder_res, diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 4e75ba1d90b..f6b7c417089 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -289,7 +289,7 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None v_head_dim=self.v_head_dim, ) self.rope_scaling = getattr(fd_config.model_config, "rope_scaling", None) - if self.rope_scaling: + if self.rope_scaling and "factor" in self.rope_scaling: mscale_all_dim = self.rope_scaling.get("mscale_all_dim", False) scaling_factor = self.rope_scaling["factor"] mscale = self.yarn_get_mscale(scaling_factor, float(mscale_all_dim)) @@ -344,8 +344,6 @@ def forward( self, forward_meta: ForwardMeta, hidden_states: paddle.Tensor, - position_ids: paddle.Tensor, - mask_encoder_batch: paddle.Tensor, ): """ """ @@ -363,7 +361,7 @@ def forward( query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) key_pe.reshape_([-1, 1, self.qk_rope_head_dim]) - query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe) + query_pe, key_pe = self.rotary_emb(forward_meta.position_ids, query_pe, key_pe) compressed_kv = self.kv_a_layernorm(compressed_kv)[0] @@ -400,7 +398,7 @@ def forward( fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp, self.qk_head_dim]) fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim] fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim]) - fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype) + fmha_out_prefill = fmha_out_prefill * forward_meta.mask_encoder_batch.cast(fmha_out_prefill.dtype) fmha_out = fmha_out_prefill if need_do_decode: # max_dec_len_this_time @@ -617,7 +615,7 @@ def __init__( # self.buffer = paddle.zeros([2048 * 2048], dtype=paddle.uint8) def forward( - self, forward_meta: ForwardMeta, hidden_states: paddle.Tensor, qr: paddle.Tensor, positions, rotary_emb + self, forward_meta: ForwardMeta, hidden_states: paddle.Tensor, qr: paddle.Tensor, rotary_emb ) -> paddle.Tensor: self.indexer_cache = forward_meta.caches[2 * self.layer_id + 1] @@ -629,7 +627,7 @@ def forward( k, _ = self.k_norm(k) k_pe, k_nope = paddle.split(k, [self.rope_dim, self.index_head_dim - self.rope_dim], axis=-1) - q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) + q_pe, k_pe = rotary_emb(forward_meta.position_ids, q_pe, k_pe.unsqueeze(1)) q_pe = q_pe.reshape(-1, self.index_n_heads, self.rope_dim) k_pe = k_pe.reshape(-1, 1, self.rope_dim) @@ -853,7 +851,7 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None v_head_dim=self.v_head_dim, ) self.rope_scaling = getattr(fd_config.model_config, "rope_scaling", None) - if self.rope_scaling: + if self.rope_scaling and "factor" in self.rope_scaling: mscale_all_dim = self.rope_scaling.get("mscale_all_dim", False) scaling_factor = self.rope_scaling["factor"] mscale = self.yarn_get_mscale(scaling_factor, float(mscale_all_dim)) @@ -926,8 +924,6 @@ def forward( self, forward_meta: ForwardMeta, hidden_states: paddle.Tensor, - position_ids: paddle.Tensor, - mask_encoder_batch: paddle.Tensor, ): """ """ qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states) @@ -940,15 +936,13 @@ def forward( query = self.q_a_layernorm(query)[0] # DSA indexer - indexer_top_k = self.indexer( - forward_meta, hidden_states, query, position_ids, rotary_emb=self.indexer_rotary_emb - ) + indexer_top_k = self.indexer(forward_meta, hidden_states, query, rotary_emb=self.indexer_rotary_emb) query = self.q_b_proj(query) query.reshape_([-1, self.num_attention_heads_tp, self.qk_head_dim]) query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) - query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe) + query_pe, key_pe = self.rotary_emb(forward_meta.position_ids, query_pe, key_pe) q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]).contiguous(), proj_type="k") q_input = paddle.concat([q_nope_out.transpose([1, 0, 2]).contiguous(), query_pe], axis=-1) @@ -1044,8 +1038,6 @@ def forward( forward_meta: ForwardMeta, hidden_states: paddle.Tensor, residual: paddle.Tensor, - position_ids: paddle.Tensor, - mask_encoder_batch: paddle.Tensor, ): """ """ if hidden_states.shape[0] > 0: @@ -1053,7 +1045,7 @@ def forward( hidden_states, residual_input=residual, forward_meta=forward_meta ) - hidden_states = self.self_attn(forward_meta, hidden_states, position_ids, mask_encoder_batch) + hidden_states = self.self_attn(forward_meta, hidden_states) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) else: @@ -1108,8 +1100,6 @@ def forward( self, ids_remove_padding: paddle.Tensor, forward_meta: ForwardMeta, - position_ids: paddle.Tensor, - mask_encoder_batch: paddle.Tensor, ): """ """ hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta) @@ -1120,8 +1110,6 @@ def forward( forward_meta, hidden_states, residual, - position_ids, - mask_encoder_batch, ) out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0] @@ -1297,12 +1285,10 @@ def forward( forward_meta: ForwardMeta, ): ids_remove_padding = inputs["ids_remove_padding"] - forward_meta.position_ids, mask_encoder_batch = self.pre_process(forward_meta) + forward_meta.position_ids, forward_meta.mask_encoder_batch = self.pre_process(forward_meta) hidden_states = self.model( ids_remove_padding=ids_remove_padding, forward_meta=forward_meta, - position_ids=forward_meta.position_ids, - mask_encoder_batch=mask_encoder_batch, ) return hidden_states @@ -1353,3 +1339,29 @@ class DeepSeekV32PretrainedModel(DeepSeekV3PretrainedModel): @classmethod def arch_name(self): return "DeepseekV32ForCausalLM" + + +@ModelRegistry.register_model_class( + architecture="Glm4MoeLiteForCausalLM", + module_name="deepseek_v3", + category=ModelCategory.TEXT_GENERATION, + primary_use=ModelCategory.TEXT_GENERATION, +) +class Glm4MoeLiteForCausalLM(DeepseekV3ForCausalLM): + """ + Glm4MoeLiteForCausalLM + """ + + @classmethod + def name(cls): + return "Glm4MoeLiteForCausalLM" + + +class Glm4MoeLitePretrainedModel(DeepSeekV3PretrainedModel): + """ + Glm4MoeLite + """ + + @classmethod + def arch_name(self): + return "Glm4MoeLiteForCausalLM"