Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tensorrt_llm/models/commandr/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -60,7 +60,8 @@ def from_hugging_face(

dtype = infer_dtype(dtype, getattr(hf_config, 'torch_dtype', None))

if hf_config.tie_word_embeddings:
tie_word_embeddings = getattr(hf_config, 'tie_word_embeddings', False)
if tie_word_embeddings:
kwargs['use_parallel_embedding'] = True
kwargs['embedding_sharding_dim'] = 0

Expand All @@ -82,6 +83,7 @@ def from_hugging_face(
rotary_base=hf_config.rope_theta,
attn_bias=hf_config.attention_bias,
qk_layernorm=hf_config.use_qk_norm,
tie_word_embeddings=tie_word_embeddings,
mapping=mapping,
quantization=quant_config,
**kwargs)
11 changes: 8 additions & 3 deletions tensorrt_llm/models/model_weights_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ def update_key_mapping(self, model):
if self.tllm_to_externel_key_dict['layers'] != 'layers':
del self.tllm_to_externel_key_dict['layers']

# Share embedding; only applies to standard structure with lm_head and transformer.vocab_embedding
# Share embedding when config.tie_word_embeddings is set; only applies
# to the standard lm_head + transformer.vocab_embedding structure.
if hasattr(self.model, 'lm_head') and hasattr(
self.model, 'transformer') and hasattr(self.model.transformer,
'vocab_embedding'):
Expand All @@ -344,12 +345,16 @@ def update_key_mapping(self, model):
vocab_embed_weights = self.load_tensor(
self.translate_to_external_key(
'transformer.vocab_embedding.weight'))
if lm_head_weights is None and vocab_embed_weights is not None:
tie_word_embeddings = getattr(model.config, 'tie_word_embeddings',
False)
if (tie_word_embeddings and lm_head_weights is None
and vocab_embed_weights is not None):
self.tllm_to_externel_key_dict[
'lm_head'] = self.tllm_to_externel_key_dict[
'transformer'] + '.' + self.tllm_to_externel_key_dict[
'vocab_embedding']
elif lm_head_weights is not None and vocab_embed_weights is None:
elif (tie_word_embeddings and lm_head_weights is not None
and vocab_embed_weights is None):
self.tllm_to_externel_key_dict[
'vocab_embedding'] = self.tllm_to_externel_key_dict[
'lm_head']
Expand Down