Fix Flash Attention 3 API compatibility for window size parameters#2704
Fix Flash Attention 3 API compatibility for window size parameters#2704jhvmhg wants to merge 11 commits intoNVIDIA:mainfrom
Conversation
Replace single window_size parameter with window_size_left and window_size_right in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes. - Update function signature in flash_attn_interface - Maintain backward compatibility where possible - Ensure consistency with Flash Attention v2 implementation Signed-off-by: Chaoyang Mei <1192554423@qq.com> Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
Greptile SummaryThis PR fixes Flash Attention 3 API compatibility by replacing the single Key changes:
Confidence Score: 4/5
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Flash Attention Call Site] --> B{use_flash_attn_3?}
B -->|Yes| C[window_size_left / window_size_right\nis_causal for backward]
B -->|No| D{fa_utils.v2_7_0_plus?}
D -->|Yes| E[window_size_left / window_size_right\ncausal for backward]
D -->|No| F{fa_utils.v2_3_plus?}
F -->|Yes| G[window_size tuple\ncausal for backward]
F -->|No| H[No window_size param\ncausal for backward]
C --> I[flash_attn_fwd_v3 / flash_attn_bwd_v3]
E --> J[flash_attn_varlen_fwd/bwd]
G --> J
H --> J
|
Additional Comments (1)
|
Rename causal parameter to is_causal in flash_attn_bwd function to align with flash-attn v2.7.0+ API changes. This ensures consistency with the updated flash-attn library interface for backward pass operations. Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
a245229 to
f9752ca
Compare
jhvmhg
left a comment
There was a problem hiding this comment.
Fix Flash Attention 3 backward API parameter naming
Rename causal parameter to is_causal in flash_attn_bwd function to align
with flash-attn v2.7.0+ API changes. This ensures consistency with the
updated flash-attn library interface for backward pass operations.
Rename causal parameter to is_causal in flash_attn_bwd function to align with flash-attn v2.7.0+ API changes. This ensures consistency with the updated flash-attn library interface for backward pass operations. Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
for more information, see https://pre-commit.ci
… fix/flash_attn3_support_CP
Replace keyword arguments with positional arguments in flash_attn_fwd and flash_attn_bwd to abstract away parameter naming differences (causal vs is_causal) between flash-attn versions. This provides a more robust interface that is resilient to future API changes in the flash-attn library. - Convert window_size_left, window_size_right, and causal parameters to positional args in both forward and backward functions - Eliminate version-specific parameter naming dependencies - Simplify compatibility handling across flash-attn v2.7.0+ variants Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
| softmax_lse_per_step[i], | ||
| *fa_backward_args_thd, | ||
| causal="causal" in ctx.attn_mask_type, | ||
| ctx.attn_mask_type, |
There was a problem hiding this comment.
ctx.attn_mask_type is a string (e.g., "causal", "no_mask"), not a boolean. Should be "causal" in ctx.attn_mask_type to convert to boolean.
| ctx.attn_mask_type, | |
| "causal" in ctx.attn_mask_type, |
Rename causal parameter to is_causal in flash_attn_bwd function to align with flash-attn v3 API changes. This ensures consistency with the updated flash-attn library interface for backward pass operations. Signed-off-by: meichaoyang001 <meichaoyang001@ke.com>
|
@mk-61 I think the changes look good, but could you please follow through with the CI, especially the L3_FA_version tests, to make sure the new changes pass the SWA tests for FA3? Thanks! |
|
LGTM |
|
/te-ci pytorch L3 |
Additional Comments (1)
When The Suggested fix for this location and similar ones at lines 1192–1208 and 3213–3217: if use_flash_attn_3 or fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = -1
fa_forward_kwargs["window_size_right"] = -1
elif fa_utils.v2_3_plus:
fa_forward_kwargs["window_size"] = (-1, -1) |
|
/te-ci pytorch L3 |
1 similar comment
|
/te-ci pytorch L3 |
Replace single window_size parameter with window_size_left and window_size_right in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Motivation:
The flash-attn library v2.7.0+ introduced breaking API changes that cause compatibility issues with TransformerEngine's Flash Attention 3 integration. These updates ensure seamless operation with newer versions of the flash-attn library while maintaining correctness of both forward and backward attention computations.
Related API Changes:
flash-attn v2.7.0+ split window_size into window_size_left and window_size_right
flash-attn v3+ renamed causal parameter to is_causal in backward pass
Type of change
Changes
Please list the changes introduced in this PR:
in flash_attn_fwd function to align with flash-attn v2.7.0+ API changes.
with flash-attn v3
Checklist: