Skip to content

Fix Flash Attention 3 API compatibility for window size parameters#2704

Open
jhvmhg wants to merge 11 commits intoNVIDIA:mainfrom
jhvmhg:fix/flash_attn3_support_CP
Open

Fix Flash Attention 3 API compatibility for window size parameters#2704
jhvmhg wants to merge 11 commits intoNVIDIA:mainfrom
jhvmhg:fix/flash_attn3_support_CP

Conversation

@jhvmhg
Copy link

@jhvmhg jhvmhg commented Feb 25, 2026

Replace single window_size parameter with window_size_left and window_size_right in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes.

  • Update function signature in flash_attn_interface
  • Maintain backward compatibility where possible
  • Ensure consistency with Flash Attention v2 implementation

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

  1. Fix window size parameters in flash_attn_fwd - Replaces the single window_size parameter with separate window_size_left and window_size_right parameters to match the updated flash-attn v2.7.0+ API.
  2. Fix causal parameter naming in flash_attn_bwd - Renames causal to is_causal in the backward function signature for consistency with the latest flash-attn interface.

Motivation:

The flash-attn library v2.7.0+ introduced breaking API changes that cause compatibility issues with TransformerEngine's Flash Attention 3 integration. These updates ensure seamless operation with newer versions of the flash-attn library while maintaining correctness of both forward and backward attention computations.

Related API Changes:

flash-attn v2.7.0+ split window_size into window_size_left and window_size_right
flash-attn v3+ renamed causal parameter to is_causal in backward pass

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Replace single window_size parameter with window_size_left and window_size_right
    in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes.
  • Rename causal parameter to is_causal in flash_attn_bwd function to align
    with flash-attn v3

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Replace single window_size parameter with window_size_left and window_size_right
in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes.

- Update function signature in flash_attn_interface
- Maintain backward compatibility where possible
- Ensure consistency with Flash Attention v2 implementation

Signed-off-by: Chaoyang Mei <1192554423@qq.com>
Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 25, 2026

Greptile Summary

This PR fixes Flash Attention 3 API compatibility by replacing the single window_size tuple parameter with separate window_size_left / window_size_right parameters in all forward and backward call sites that use FA3 or FA2 v2.7.0+, and renames the backward causal keyword to is_causal when calling the FA3 backward interface. The changes are consistently applied across all three context-parallel attention classes (AttnFuncWithCPAndKVP2P, AttnFuncWithCPAndKVAllGather, AttnFuncWithCPAndQKVOA2A) and their helper functions (cp_p2p_fwd_flash_attn, cp_p2p_bwd_flash_attn), fitting naturally into the existing version-gated branching pattern already present in the file.

Key changes:

  • use_flash_attn_3 correctly moved out of the legacy window_size branch and into the window_size_left / window_size_right branch alongside fa_utils.v2_7_0_plus
  • FA3 backward calls now pass is_causal=causal_ via fa_backward_kwargs instead of a positional causal= kwarg
  • The version-gating logic is applied symmetrically in all affected locations, with API parameter splits/renames matching the documented flash-attn v2.7.0+/v3 changes. No regressions introduced to the FA2 path.

Confidence Score: 4/5

  • This PR is safe to merge; the API compatibility fixes are logically consistent across all affected call sites.
  • The version-gating logic is applied symmetrically in all six changed locations (3 forward, 3 backward) and the parameter splits/renames match the documented flash-attn v2.7.0+/v3 API changes. The one style issue (redundant per-iteration kwarg assignment in the AllGather backward loop) is non-functional. No regressions introduced to the FA2 path.
  • No files require special attention beyond the single changed file.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Flash Attention Call Site] --> B{use_flash_attn_3?}
    B -->|Yes| C[window_size_left / window_size_right\nis_causal for backward]
    B -->|No| D{fa_utils.v2_7_0_plus?}
    D -->|Yes| E[window_size_left / window_size_right\ncausal for backward]
    D -->|No| F{fa_utils.v2_3_plus?}
    F -->|Yes| G[window_size tuple\ncausal for backward]
    F -->|No| H[No window_size param\ncausal for backward]
    C --> I[flash_attn_fwd_v3 / flash_attn_bwd_v3]
    E --> J[flash_attn_varlen_fwd/bwd]
    G --> J
    H --> J
Loading

Comments Outside Diff (1)

  1. transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 3218-3221 (link)

    The is_causal/causal key is set inside the per-step loop (starting at line 3152) on every iteration despite the value being constant — it could be hoisted above the loop for clarity and efficiency.

    Consider moving this assignment outside the loop alongside other static fa_backward_kwargs initializations:

Last reviewed commit: 04f3f78

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 25, 2026

Additional Comments (1)

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
removed causal parameter but other flash_attn_bwd calls in this file (lines 3222, 3832) still pass it - verify this inconsistency is intentional

Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v2.7.0+ API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.

Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
@jhvmhg jhvmhg force-pushed the fix/flash_attn3_support_CP branch from a245229 to f9752ca Compare February 25, 2026 07:54
Copy link
Author

@jhvmhg jhvmhg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix Flash Attention 3 backward API parameter naming

Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v2.7.0+ API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

jhvmhg and others added 2 commits February 25, 2026 15:56
Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v2.7.0+ API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.

Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Replace keyword arguments with positional arguments in flash_attn_fwd and
flash_attn_bwd to abstract away parameter naming differences (causal vs
is_causal) between flash-attn versions. This provides a more robust
interface that is resilient to future API changes in the flash-attn library.

- Convert window_size_left, window_size_right, and causal parameters to
  positional args in both forward and backward functions
- Eliminate version-specific parameter naming dependencies
- Simplify compatibility handling across flash-attn v2.7.0+ variants

Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

softmax_lse_per_step[i],
*fa_backward_args_thd,
causal="causal" in ctx.attn_mask_type,
ctx.attn_mask_type,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx.attn_mask_type is a string (e.g., "causal", "no_mask"), not a boolean. Should be "causal" in ctx.attn_mask_type to convert to boolean.

Suggested change
ctx.attn_mask_type,
"causal" in ctx.attn_mask_type,

@jhvmhg jhvmhg closed this Feb 25, 2026
Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v3 API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.

Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
@jhvmhg jhvmhg reopened this Feb 25, 2026
@jhvmhg jhvmhg closed this Feb 25, 2026
@jhvmhg jhvmhg reopened this Feb 25, 2026
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ptrendx ptrendx requested a review from cyanguwa February 25, 2026 19:57
@cyanguwa cyanguwa requested a review from mk-61 February 26, 2026 00:04
@cyanguwa
Copy link
Collaborator

@mk-61 I think the changes look good, but could you please follow through with the CI, especially the L3_FA_version tests, to make sure the new changes pass the SWA tests for FA3? Thanks!

@vcherepanov-nv
Copy link
Collaborator

LGTM

@sudhakarsingh27
Copy link
Collaborator

/te-ci pytorch L3

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Additional Comments (1)

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 944
Condition-priority bug for FA3 + flash-attn v2.3–v2.6

When use_flash_attn_3=True AND flash-attn v2 is installed between versions 2.3 and 2.6 (fa_utils.v2_3_plus=True, fa_utils.v2_7_0_plus=False), the first condition fires and sets the legacy window_size tuple parameter. However, the actual function called is _flash_attn_fwd_v3, which requires window_size_left and window_size_right parameters, resulting in a runtime error.

The use_flash_attn_3 check should take priority to ensure FA3 always receives the new API regardless of which flash-attn v2 version is co-installed.

Suggested fix for this location and similar ones at lines 1192–1208 and 3213–3217:

if use_flash_attn_3 or fa_utils.v2_7_0_plus:
    fa_forward_kwargs["window_size_left"] = -1
    fa_forward_kwargs["window_size_right"] = -1
elif fa_utils.v2_3_plus:
    fa_forward_kwargs["window_size"] = (-1, -1)

@sudhakarsingh27
Copy link
Collaborator

/te-ci pytorch L3

1 similar comment
@sudhakarsingh27
Copy link
Collaborator

/te-ci pytorch L3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants