diff --git a/tensorrt_llm/models/commandr/config.py b/tensorrt_llm/models/commandr/config.py index a2edca61fb78..7685d6188a32 100644 --- a/tensorrt_llm/models/commandr/config.py +++ b/tensorrt_llm/models/commandr/config.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -60,7 +60,8 @@ def from_hugging_face( dtype = infer_dtype(dtype, getattr(hf_config, 'torch_dtype', None)) - if hf_config.tie_word_embeddings: + tie_word_embeddings = getattr(hf_config, 'tie_word_embeddings', False) + if tie_word_embeddings: kwargs['use_parallel_embedding'] = True kwargs['embedding_sharding_dim'] = 0 @@ -82,6 +83,7 @@ def from_hugging_face( rotary_base=hf_config.rope_theta, attn_bias=hf_config.attention_bias, qk_layernorm=hf_config.use_qk_norm, + tie_word_embeddings=tie_word_embeddings, mapping=mapping, quantization=quant_config, **kwargs) diff --git a/tensorrt_llm/models/model_weights_loader.py b/tensorrt_llm/models/model_weights_loader.py index a130883e95c7..bd8ca87bf8a3 100644 --- a/tensorrt_llm/models/model_weights_loader.py +++ b/tensorrt_llm/models/model_weights_loader.py @@ -335,7 +335,8 @@ def update_key_mapping(self, model): if self.tllm_to_externel_key_dict['layers'] != 'layers': del self.tllm_to_externel_key_dict['layers'] - # Share embedding; only applies to standard structure with lm_head and transformer.vocab_embedding + # Share embedding when config.tie_word_embeddings is set; only applies + # to the standard lm_head + transformer.vocab_embedding structure. if hasattr(self.model, 'lm_head') and hasattr( self.model, 'transformer') and hasattr(self.model.transformer, 'vocab_embedding'): @@ -344,12 +345,16 @@ def update_key_mapping(self, model): vocab_embed_weights = self.load_tensor( self.translate_to_external_key( 'transformer.vocab_embedding.weight')) - if lm_head_weights is None and vocab_embed_weights is not None: + tie_word_embeddings = getattr(model.config, 'tie_word_embeddings', + False) + if (tie_word_embeddings and lm_head_weights is None + and vocab_embed_weights is not None): self.tllm_to_externel_key_dict[ 'lm_head'] = self.tllm_to_externel_key_dict[ 'transformer'] + '.' + self.tllm_to_externel_key_dict[ 'vocab_embedding'] - elif lm_head_weights is not None and vocab_embed_weights is None: + elif (tie_word_embeddings and lm_head_weights is not None + and vocab_embed_weights is None): self.tllm_to_externel_key_dict[ 'vocab_embedding'] = self.tllm_to_externel_key_dict[ 'lm_head']