Skip to content

Performance optimization for full size deepseek v3 on XPK cluster#388

Draft
jialei777 wants to merge 2 commits intomainfrom
jialei/ds-gmm
Draft

Performance optimization for full size deepseek v3 on XPK cluster#388
jialei777 wants to merge 2 commits intomainfrom
jialei/ds-gmm

Conversation

@jialei777
Copy link
Collaborator

@jialei777 jialei777 commented Sep 9, 2025

Adding sharding and the related all to all/dispatching logic around GMM kernel at MoE layer.

Currently failing around sharding with all_to_all collective in MoE layer.

Error executing job with overrides: ['model=deepseek-v3', 'ici_mesh.fsdp=4', 'logging_steps=1', 'dataset.block_size=4096', 'task.global_batch_size=4', 'profile_start_step=1', 'task.lr_scheduler.type=constant', 'task.max_steps=6', 'model.num_hidden_layers=2', 'model.first_k_dense_replace=1']
Traceback (most recent call last):
  File "/home/jialeic_google_com/work/torchprime/torchprime/torch_xla_models/train.py", line 94, in main
    trainer.train_loop()
  File "/home/jialeic_google_com/work/torchprime/torchprime/torch_xla_models/trainer/base_trainer.py", line 263, in train_loop
    loss, grad_norm = self.train_step(batch)
                      ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jialeic_google_com/miniconda3/envs/xla0905-py312-2/lib/python3.12/contextlib.py", line 80, in inner
    with self._recreate_cm():
         ^^^^^^^^^^^^^^^^^^^
  File "/home/jialeic_google_com/miniconda3/envs/xla0905-py312-2/lib/python3.12/contextlib.py", line 144, in __exit__
    next(self.gen)
  File "/home/jialeic_google_com/miniconda3/envs/xla0905-py312-2/lib/python3.12/site-packages/torch_xla/torch_xla.py", line 214, in _compile
    sync()
  File "/home/jialeic_google_com/miniconda3/envs/xla0905-py312-2/lib/python3.12/site-packages/torch_xla/torch_xla.py", line 87, in sync
    torch_xla._XLAC._xla_step_marker(
ValueError: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc:5636) hlo->has_sharding() Side-effect HLO must have sharding: %all-to-all.921 = (bf16[1,131072,7168]{2,1,0}) all-to-all(%add.919), replica_groups={{0},{1},{2},{3}}, constrain_layout=true

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant