Skip to content

[ROCm] Allow bf16/bf16/fp32 in nvte_multi_tensor_gemm dispatcher#573

Open
lizamd wants to merge 1 commit into
ROCm:devfrom
lizamd:fix/ck-grouped-gemm-bf16-fp32-output
Open

[ROCm] Allow bf16/bf16/fp32 in nvte_multi_tensor_gemm dispatcher#573
lizamd wants to merge 1 commit into
ROCm:devfrom
lizamd:fix/ck-grouped-gemm-bf16-fp32-output

Conversation

@lizamd
Copy link
Copy Markdown

@lizamd lizamd commented May 4, 2026

The is_supported_dtype check in nvte_multi_tensor_gemm previously required A==B==D for the fp16/bf16 path, which rejected the common bf16/bf16/fp32 case where the GEMM output is fp32 for gradient accumulation. This forced a fallback to multi_stream_cublas_gemm (a per-expert hipblaslt loop), bypassing the CK grouped GEMM kernel entirely on ROCm.

The CK FP16 dispatcher (ck_tile_grouped_gemm_fp16_dispatch) already supports independent D dtype via TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY (fp32, fp16, bf16). The wrapper check is the only thing that prevents it from being reached.

Relaxed to require A==B in fp16/bf16 and D in {fp32, fp16, bf16}, which matches what the CK dispatcher actually accepts. Verified on Qwen3-30B-A3B MoE training on MI355X (gfx950): fallback warning rate drops from ~1040/step (every GEMM) to ~28/step (~3% of shapes that the CK kernel itself rejects via Kernel::IsSupportedArgument). Throughput is essentially unchanged in this workload because hipblaslt's per-shape autotuning happens to be competitive with the hardcoded CK tile configs for these MoE shapes; the gain will materialize once the CK dispatcher gains more tile configs (or shape-aware tile selection by aggregate M).

This is a CUDA path file; the same patch applies to the AMD path via hipify. No CUDA-side behavior change since cuBLAS/cutlass dispatch on NVIDIA still requires A==B==D in the cutlass fast-path pre-conditions.

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@lizamd lizamd force-pushed the fix/ck-grouped-gemm-bf16-fp32-output branch 2 times, most recently from 764cb65 to ff19241 Compare May 5, 2026 00:02
@matthiasdiener matthiasdiener added the ci-level 1 CI test level 1 label May 5, 2026
@wenchenvincent
Copy link
Copy Markdown
Collaborator

@matthiasdiener @aris134 Could you review this PR?

Copy link
Copy Markdown
Collaborator

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you edit an existing test or add a new test showing that with your change, bf16/fp16 input and fp32 outputs are going through the ck flow correctly now? Also paste some benchmarking data to this ticket for future reference

Comment on lines +1166 to +1171
// CK FP16/BF16 grouped GEMM dispatcher (ck_tile_grouped_gemm_fp16_dispatch)
// already supports independent D dtype via TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
// (fp32, fp16, bf16). The previous check required A==B==D, which incorrectly
// rejected the common bf16/bf16/fp32 case (training with fp32 gradient
// accumulation), forcing a fallback to the per-expert hipblaslt loop.
// Relaxed to require A==B in fp16/bf16 and D in {fp32, fp16, bf16}.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this explanation may be better suited for the PR description rather than an inline code comment.

Copy link
Copy Markdown
Contributor

@aris134 aris134 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that the CK dispatch logic supports bf16/f32 combination. I would remove the detailed history comment about the previous fallback behavior which is better suited to the PR itself.

@lizamd
Copy link
Copy Markdown
Author

lizamd commented May 6, 2026 via email

The is_supported_dtype check in nvte_multi_tensor_gemm previously required
A==B==D for the fp16/bf16 path, which rejected the common bf16/bf16/fp32
case where the GEMM output is fp32 for gradient accumulation. This forced
a fallback to multi_stream_cublas_gemm (a per-expert hipblaslt loop),
bypassing the CK grouped GEMM kernel entirely on ROCm.

The CK FP16 dispatcher (ck_tile_grouped_gemm_fp16_dispatch) already
supports independent D dtype via TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY
(fp32, fp16, bf16). The wrapper check is the only thing that prevents it
from being reached.

Relaxed to require A==B in fp16/bf16 and D in {fp32, fp16, bf16}, which
matches what the CK dispatcher actually accepts. Verified on Qwen3-30B-A3B
MoE training on MI355X (gfx950): fallback warning rate drops from
~1040/step (every GEMM) to ~28/step (~3% of shapes that the CK kernel
itself rejects via Kernel::IsSupportedArgument). Throughput is essentially
unchanged in this workload because hipblaslt's per-shape autotuning
happens to be competitive with the hardcoded CK tile configs for these
MoE shapes; the gain will materialize once the CK dispatcher gains more
tile configs (or shape-aware tile selection by aggregate M).

This is a CUDA path file; the same patch applies to the AMD path via
hipify. No CUDA-side behavior change since cuBLAS/cutlass dispatch on
NVIDIA still requires A==B==D in the cutlass fast-path pre-conditions.

Follow-ups (out of scope for this PR):

- Add more CK tile configs (e.g. TileCfg_64x256x64, TileCfg_128x256x64)
  and shape-aware tile selection by aggregate M per call. Currently
  throughput is unchanged on this workload because the existing hipblaslt
  fallback is well-tuned and the 3 hardcoded CK tile configs
  (TileCfg_256x256x64, TileCfg_256x128x64, TileCfg_256x128x64_padding)
  don't fit MoE shapes (highly variable per-expert M) optimally. Real
  CK-grouped-GEMM perf wins will materialize once tile selection adapts
  to M.
- Investigate the ~3% of GEMMs that hit Kernel::IsSupportedArgument
  rejection (likely small per-expert M values that fail tile-size
  constraints in the current TileCfg_256x* instantiations).
@lizamd lizamd force-pushed the fix/ck-grouped-gemm-bf16-fp32-output branch from ff19241 to d416572 Compare May 7, 2026 17:45
@lizamd
Copy link
Copy Markdown
Author

lizamd commented May 7, 2026

@wangye805 @aris134 could you check the new commit?

)
@pytest.mark.parametrize("input_dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("layout", ["TN", "NT"])
def test_grouped_gemm_fp32_output(input_dtype, layout):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be done by adding configs/parameters to test_grouped_gemm?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 1 CI test level 1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants