Skip to content

[NPU] Add group norm support on NPU#1144

Merged
Tcc0403 merged 2 commits intolinkedin:mainfrom
orangeH25:group-norm/1
Mar 17, 2026
Merged

[NPU] Add group norm support on NPU#1144
Tcc0403 merged 2 commits intolinkedin:mainfrom
orangeH25:group-norm/1

Conversation

@orangeH25
Copy link
Copy Markdown
Contributor

Summary

This PR introduces a functional GroupNorm operator for Ascend NPU.

Key improvements:

  • Fixes the runtime error grid should be less than 65536! and ub overflow that occurs when the original GPU-oriented liger-kernel GroupNorm implementation is executed on NPU.
  • Adjusts the kernel launch and tiling strategy to comply with Ascend NPU execution constraints.
  • Resolves numerical accuracy issues with PyTorch reference outputs.

While the current implementation is still slower than the HuggingFace implementation in end-to-end benchmarks, it provides a stable and functional GroupNorm path for Ascend NPU.

This PR mainly focuses on correctness and NPU compatibility. Further kernel-level optimizations will be explored in follow-up work.

Testing Done

image
  • Hardware Type: Atlas 800I A2
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@orangeH25
Copy link
Copy Markdown
Contributor Author

Hi @Tcc0403 , please take a look. Thanks!

Comment on lines +243 to +247
else:
dW_block = tl.where(mask, DY_block * x_hat, 0.0)
dB_block = tl.where(mask, DY_block, 0.0)
tl.atomic_add(DW_scratch_base + global_channel, dW_block, mask=mask)
tl.atomic_add(DB_scratch_base + global_channel, dB_block, mask=mask)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L377 says

# Placeholder buffers (unused in kernel when COMPUTE_PARAM_GRAD=False)

, which contradicts what this block does. Shouldn't it be no-op here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

I originally kept it to preserve the kernel structure for potential future experiments, but it’s not needed in the current implementation.

Comment on lines +377 to +379
# Placeholder buffers (unused in kernel when COMPUTE_PARAM_GRAD=False)
DW_scratch = torch.empty((1, 1), dtype=torch.float32, device=W.device)
DB_scratch = torch.empty((1, 1), dtype=torch.float32, device=W.device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can placeholder buffers set to None in triton-ascend? to avoid accidently access in device code.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed.

@orangeH25
Copy link
Copy Markdown
Contributor Author

Hi @Tcc0403, changes applied. Appreciate another review when you have time, thanks!

Copy link
Copy Markdown
Collaborator

@Tcc0403 Tcc0403 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I left a comment about a potential improvement, but it can be done in another PR!

Comment on lines +239 to +244
if COMPUTE_PARAM_GRAD:
if SINGLE_CHANNEL_TILE:
dW_partial = tl.sum(tl.where(mask, DY_block * x_hat, 0.0), axis=1)
dB_partial = tl.sum(tl.where(mask, DY_block, 0.0), axis=1)
tl.atomic_add(DW_scratch_base + global_channel, dW_partial, mask=row_mask)
tl.atomic_add(DB_scratch_base + global_channel, dB_partial, mask=row_mask)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can accumulate dw and db over grid loop and store it after, similar to

dW_row += tl.sum(dY_row * (X_row * rstd_row[:, None]), 0)

With this approach, we can avoid using atomic_add and potentially handle the scenario where num_col_blocks>1.

The solution is not trivial and not gauranteed to achieve better performance, leaving the comment here as a future works direction.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I’ll look into this.

@Tcc0403 Tcc0403 added this pull request to the merge queue Mar 17, 2026
Merged via the queue into linkedin:main with commit 68a7489 Mar 17, 2026
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants