[#3324][fix] Honor tie_word_embeddings when sharing lm_head and embedding weights#13664
[#3324][fix] Honor tie_word_embeddings when sharing lm_head and embedding weights#13664javierdejesusda wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
… embedding weights ModelWeightsLoader.update_key_mapping silently tied lm_head and transformer.vocab_embedding whenever exactly one of the two weights was missing from the checkpoint, regardless of model.config.tie_word_embeddings. That hides genuine load failures behind tied weights for models that do not tie embeddings. Gate both tying branches on getattr(model.config, 'tie_word_embeddings', False) so a missing weight surfaces as a regular load error when tying was not requested. This aligns the legacy TensorRT backend with the PyTorch backend, which already checks tie_word_embeddings in its weight mapper. CohereConfig.from_hugging_face previously read hf_config.tie_word_embeddings only to flip use_parallel_embedding and never propagated it to the TRT-LLM config. Without the propagation the new gate would refuse to tie for tied Cohere/CommandR checkpoints, where the silent auto-tying was load-bearing. Pass the flag through to keep tied Cohere checkpoints loading correctly. Signed-off-by: javierdejesusda <javier.dejesusj9@gmail.com>
📝 WalkthroughWalkthroughChanges ensure embedding and language model head weights are only tied during model loading when explicitly configured via Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Review rate limit: 9/10 reviews remaining, refill in 6 minutes. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tensorrt_llm/models/model_weights_loader.py (1)
338-357: ⚡ Quick winAdd explicit regression tests for both gate outcomes.
Please add focused coverage for:
tie_word_embeddings=False+ missinglm_head.weight⇒ load error, andtie_word_embeddings=True+ one side missing ⇒ remap succeeds.
This will lock in the intended behavior and prevent regressions inupdate_key_mapping.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/models/model_weights_loader.py` around lines 338 - 357, Add two unit tests targeting the behavior in model_weights_loader.py around the tie-word-embedding logic: one test should set model.config.tie_word_embeddings=False and ensure that when lm_head.weight is missing (load_tensor returns None for translate_to_external_key('lm_head.weight')) the loader raises/returns the expected load error (exercise the branch where lm_head_weights is None and tie_word_embeddings is False); the second test should set model.config.tie_word_embeddings=True and simulate one side missing (e.g., lm_head.weight present but transformer.vocab_embedding.weight missing, or vice versa) and assert that update_key_mapping/remapping occurs by verifying self.tllm_to_externel_key_dict is updated to point 'lm_head' to 'transformer.vocab_embedding' (or the inverse) as implemented in the branch that assigns self.tllm_to_externel_key_dict['lm_head'] = self.tllm_to_externel_key_dict['transformer'] + '.' + self.tllm_to_externel_key_dict['vocab_embedding']; use mocks for load_tensor/translate_to_external_key to control presence/absence of tensors.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tensorrt_llm/models/model_weights_loader.py`:
- Around line 338-357: Add two unit tests targeting the behavior in
model_weights_loader.py around the tie-word-embedding logic: one test should set
model.config.tie_word_embeddings=False and ensure that when lm_head.weight is
missing (load_tensor returns None for
translate_to_external_key('lm_head.weight')) the loader raises/returns the
expected load error (exercise the branch where lm_head_weights is None and
tie_word_embeddings is False); the second test should set
model.config.tie_word_embeddings=True and simulate one side missing (e.g.,
lm_head.weight present but transformer.vocab_embedding.weight missing, or vice
versa) and assert that update_key_mapping/remapping occurs by verifying
self.tllm_to_externel_key_dict is updated to point 'lm_head' to
'transformer.vocab_embedding' (or the inverse) as implemented in the branch that
assigns self.tllm_to_externel_key_dict['lm_head'] =
self.tllm_to_externel_key_dict['transformer'] + '.' +
self.tllm_to_externel_key_dict['vocab_embedding']; use mocks for
load_tensor/translate_to_external_key to control presence/absence of tensors.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 98383d22-3c73-4c61-88b5-87486b7a3630
📒 Files selected for processing (2)
tensorrt_llm/models/commandr/config.pytensorrt_llm/models/model_weights_loader.py
Description
ModelWeightsLoader.update_key_mappingsilently tiedlm_headandtransformer.vocab_embeddingwhenever exactly one of the two weights wasmissing from the checkpoint, regardless of
model.config.tie_word_embeddings.That hid genuine load failures behind tied weights for models that do not tie
embeddings, which is the issue poedator reported in #3324.
This PR gates both tying branches on
getattr(model.config, 'tie_word_embeddings', False)so a missing weightsurfaces as a regular load error when tying was not requested. The PyTorch
backend already does this check in
tensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.py(
HfWeightMapper.should_skip_module); this brings the legacy TensorRT backendin line.
CohereConfig.from_hugging_facepreviously readhf_config.tie_word_embeddingsonly to flip
use_parallel_embeddingand never propagated it to the TRT-LLMconfig. Without that propagation the new gate would refuse to tie for tied
Cohere/CommandR checkpoints, where the silent auto-tying was load-bearing. The
flag is now passed through so tied Cohere checkpoints continue to load.
Fixes #3324
Test Coverage
examples/test_commandr.py::test_llm_commandr_v01_single_gpu_summary(
l0_a30.yml:129-130) exercises the Cohere conversion path end-to-end witha real
c4ai-command-r-v01checkpoint, covering the propagation change.tie_word_embeddings=Falseand amissing
lm_head.weightnow raises a load error instead of silently tyingto
transformer.vocab_embedding.weight.tie_word_embeddings=Trueandlm_head.weightabsent still resolves to the embedding weight (tyingpreserved when explicitly requested).
PR Checklist
Summary by CodeRabbit
Bug Fixes