Skip to content

[#3324][fix] Honor tie_word_embeddings when sharing lm_head and embedding weights#13664

Open
javierdejesusda wants to merge 1 commit intoNVIDIA:mainfrom
javierdejesusda:fix/3324-tie-word-embeddings
Open

[#3324][fix] Honor tie_word_embeddings when sharing lm_head and embedding weights#13664
javierdejesusda wants to merge 1 commit intoNVIDIA:mainfrom
javierdejesusda:fix/3324-tie-word-embeddings

Conversation

@javierdejesusda
Copy link
Copy Markdown

@javierdejesusda javierdejesusda commented Apr 30, 2026

Description

ModelWeightsLoader.update_key_mapping silently tied lm_head and
transformer.vocab_embedding whenever exactly one of the two weights was
missing from the checkpoint, regardless of model.config.tie_word_embeddings.
That hid genuine load failures behind tied weights for models that do not tie
embeddings, which is the issue poedator reported in #3324.

This PR gates both tying branches on
getattr(model.config, 'tie_word_embeddings', False) so a missing weight
surfaces as a regular load error when tying was not requested. The PyTorch
backend already does this check in
tensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.py
(HfWeightMapper.should_skip_module); this brings the legacy TensorRT backend
in line.

CohereConfig.from_hugging_face previously read hf_config.tie_word_embeddings
only to flip use_parallel_embedding and never propagated it to the TRT-LLM
config. Without that propagation the new gate would refuse to tie for tied
Cohere/CommandR checkpoints, where the silent auto-tying was load-bearing. The
flag is now passed through so tied Cohere checkpoints continue to load.

Fixes #3324

Test Coverage

  • Existing integration test
    examples/test_commandr.py::test_llm_commandr_v01_single_gpu_summary
    (l0_a30.yml:129-130) exercises the Cohere conversion path end-to-end with
    a real c4ai-command-r-v01 checkpoint, covering the propagation change.
  • Manually verified: a HF checkpoint with tie_word_embeddings=False and a
    missing lm_head.weight now raises a load error instead of silently tying
    to transformer.vocab_embedding.weight.
  • Manually verified: a HF checkpoint with tie_word_embeddings=True and
    lm_head.weight absent still resolves to the embedding weight (tying
    preserved when explicitly requested).

PR Checklist

  • PR description clearly explains what and why
  • PR follows TRT-LLM coding guidelines
  • No new dependencies introduced
  • No ownership or architecture changes
  • Documentation changes not required for this fix

Summary by CodeRabbit

Bug Fixes

  • Improved handling of word embedding configuration. The model now properly defaults and propagates the tied embeddings setting across initialization and weight loading stages, ensuring consistent and predictable behavior when embedding weights are shared between the vocabulary embeddings and language model output head components.

… embedding weights

ModelWeightsLoader.update_key_mapping silently tied lm_head and
transformer.vocab_embedding whenever exactly one of the two weights was
missing from the checkpoint, regardless of model.config.tie_word_embeddings.
That hides genuine load failures behind tied weights for models that do not
tie embeddings.

Gate both tying branches on getattr(model.config, 'tie_word_embeddings', False)
so a missing weight surfaces as a regular load error when tying was not
requested. This aligns the legacy TensorRT backend with the PyTorch backend,
which already checks tie_word_embeddings in its weight mapper.

CohereConfig.from_hugging_face previously read hf_config.tie_word_embeddings
only to flip use_parallel_embedding and never propagated it to the TRT-LLM
config. Without the propagation the new gate would refuse to tie for tied
Cohere/CommandR checkpoints, where the silent auto-tying was load-bearing.
Pass the flag through to keep tied Cohere checkpoints loading correctly.

Signed-off-by: javierdejesusda <javier.dejesusj9@gmail.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 30, 2026

📝 Walkthrough

Walkthrough

Changes ensure embedding and language model head weights are only tied during model loading when explicitly configured via tie_word_embeddings. Updates include proper config attribute handling with safe defaults in the Cohere model config and conditional weight key-sharing logic based on the config value.

Changes

Cohort / File(s) Summary
Config handling for weight tying
tensorrt_llm/models/commandr/config.py
Added safe retrieval of tie_word_embeddings attribute with getattr defaulting to False, with the value explicitly propagated to the constructed CohereConfig.
Weight loader gating
tensorrt_llm/models/model_weights_loader.py
Modified embedding weight key-sharing logic to check model.config.tie_word_embeddings before remapping keys, preventing silent unintended weight tying when lm_head.weight is absent from the checkpoint.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely describes the main fix: honoring the tie_word_embeddings configuration when sharing lm_head and embedding weights.
Description check ✅ Passed The description comprehensively explains the issue, the solution, test coverage, and includes all checklist items properly reviewed.
Linked Issues check ✅ Passed The PR fully addresses issue #3324 by gating weight-tying logic on model.config.tie_word_embeddings and propagating the flag through CohereConfig, exactly as suggested in the linked issue.
Out of Scope Changes check ✅ Passed All changes are directly scoped to fixing #3324: gates tying logic in model_weights_loader.py and propagates tie_word_embeddings in CohereConfig.from_hugging_face.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Review rate limit: 9/10 reviews remaining, refill in 6 minutes.

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tensorrt_llm/models/model_weights_loader.py (1)

338-357: ⚡ Quick win

Add explicit regression tests for both gate outcomes.

Please add focused coverage for:

  1. tie_word_embeddings=False + missing lm_head.weight ⇒ load error, and
  2. tie_word_embeddings=True + one side missing ⇒ remap succeeds.
    This will lock in the intended behavior and prevent regressions in update_key_mapping.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/models/model_weights_loader.py` around lines 338 - 357, Add two
unit tests targeting the behavior in model_weights_loader.py around the
tie-word-embedding logic: one test should set
model.config.tie_word_embeddings=False and ensure that when lm_head.weight is
missing (load_tensor returns None for
translate_to_external_key('lm_head.weight')) the loader raises/returns the
expected load error (exercise the branch where lm_head_weights is None and
tie_word_embeddings is False); the second test should set
model.config.tie_word_embeddings=True and simulate one side missing (e.g.,
lm_head.weight present but transformer.vocab_embedding.weight missing, or vice
versa) and assert that update_key_mapping/remapping occurs by verifying
self.tllm_to_externel_key_dict is updated to point 'lm_head' to
'transformer.vocab_embedding' (or the inverse) as implemented in the branch that
assigns self.tllm_to_externel_key_dict['lm_head'] =
self.tllm_to_externel_key_dict['transformer'] + '.' +
self.tllm_to_externel_key_dict['vocab_embedding']; use mocks for
load_tensor/translate_to_external_key to control presence/absence of tensors.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tensorrt_llm/models/model_weights_loader.py`:
- Around line 338-357: Add two unit tests targeting the behavior in
model_weights_loader.py around the tie-word-embedding logic: one test should set
model.config.tie_word_embeddings=False and ensure that when lm_head.weight is
missing (load_tensor returns None for
translate_to_external_key('lm_head.weight')) the loader raises/returns the
expected load error (exercise the branch where lm_head_weights is None and
tie_word_embeddings is False); the second test should set
model.config.tie_word_embeddings=True and simulate one side missing (e.g.,
lm_head.weight present but transformer.vocab_embedding.weight missing, or vice
versa) and assert that update_key_mapping/remapping occurs by verifying
self.tllm_to_externel_key_dict is updated to point 'lm_head' to
'transformer.vocab_embedding' (or the inverse) as implemented in the branch that
assigns self.tllm_to_externel_key_dict['lm_head'] =
self.tllm_to_externel_key_dict['transformer'] + '.' +
self.tllm_to_externel_key_dict['vocab_embedding']; use mocks for
load_tensor/translate_to_external_key to control presence/absence of tensors.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 98383d22-3c73-4c61-88b5-87486b7a3630

📥 Commits

Reviewing files that changed from the base of the PR and between 2c99e52 and 1dfde6b.

📒 Files selected for processing (2)
  • tensorrt_llm/models/commandr/config.py
  • tensorrt_llm/models/model_weights_loader.py

@svc-trtllm-gh-bot svc-trtllm-gh-bot added the Community want to contribute PRs initiated from Community label Apr 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Community want to contribute PRs initiated from Community

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Risk of silent unwanted tying weights. Must check 'model.config.tie_word_embeddings'

2 participants