Skip to content

Conversation

@ruisizhang123
Copy link
Contributor

@ruisizhang123 ruisizhang123 commented Nov 24, 2025

Validate DSV3 manual bucketing when EP/TP are enable. Tested on DSV3-16B model. Dependent on Pytorch PR

(Single Node: BS = 1)

Node Method Parallelism Memory TPS Trace
1-Node (8H100) SimpleFSDP (aot_eager) FSDP=4 EP=2 51.11GiB(53.80%) 5,136 Link
1-Node (8H100) FSDP2-eager FSDP=4 EP=2 59.54GiB(62.68%) 5,942 Link
1-Node (8H100) SimpleFSDP (aot_eager) FSDP=2 TP=2 EP=2 42.21GiB(44.43%) 2,285 Link
1-Node (8H100) FSDP2-eager FSDP=2 TP=2 EP=2 45.41GiB(47.80%) 2,349 Link
  1. Example Trace
Screenshot 2025-12-10 at 7 51 23 PM

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 24, 2025
@ruisizhang123 ruisizhang123 marked this pull request as draft November 24, 2025 17:19
@ruisizhang123 ruisizhang123 force-pushed the ruisi/fix_manual_bucketing_dsv3 branch from f931aa9 to 88b700b Compare December 11, 2025 05:23
@ruisizhang123 ruisizhang123 marked this pull request as ready for review December 11, 2025 05:24
),
"16B": DeepSeekV3ModelArgs(
vocab_size=102400,
dim=2048,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l Should we have another config to allow users to turn on/off flexattention? Currently, flexattention doesn't work well with AC here. cc. @soulitzer for AC issue follow up!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what was the symptom?

also if it doesn't work why do we add an entry for it -- is it for repro?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's 16B_flexatten doesn't work. But in current DSV3 implementation, flexatten by default in turned on. I want to have a model config that, by default, turns off flex attention.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason we want to enable flex attention instead of other attention implementations by default in DSV3?

for m in modules:
if isinstance(m, list):
result.append(convert_modules_to_fqns(m, module_to_fqn_mapping))
if fqn_list := convert_modules_to_fqns(m, module_to_fqn_mapping):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does the syntax mean -- assigning to fqn_list and check not None? It feels a bit unusual to read.

Also please add a comment on why we need this check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, added a comment for it.

),
"16B": DeepSeekV3ModelArgs(
vocab_size=102400,
dim=2048,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what was the symptom?

also if it doesn't work why do we add an entry for it -- is it for repro?

@ruisizhang123 ruisizhang123 force-pushed the ruisi/fix_manual_bucketing_dsv3 branch 2 times, most recently from 35ad842 to 3b0fdda Compare December 12, 2025 01:12
Comment on lines 69 to 75
VIEW_OPS = {
torch.ops.aten.slice.Tensor,
torch.ops.aten.view.default,
torch.ops.aten.reshape.default,
torch.ops.aten.transpose.int,
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh Following today's discussion, I updated reshard_after_fwd to enforce all VIEW_OPS after wait are recomputed.

In this fx-graph in tlparse (link), view_63-view_65 + transpose_8 is enforced to be recompute. Thus, I can successfully get correct reshard_after_fwd semantics.

However, in bwd, the _grouped_mm is recomputed, seems because we enforce this region to be MUST_RECOMPUTE. I feel like I'm in a rabbit hole that, if not RECOMPUTE transpose_8, I will not get correct FSDP semantics. But if I RECOMPUTE transpose_8, the follow up _grouped_mm is recomputed.

Wonder if you think I should fix this from simplefsdp side or partitioner side? 🤔

 _to_copy_32: "bf16[1, 256, 256][65536, 256, 1]cuda:0" = torch.ops.aten._to_copy.default(view_58, dtype = torch.bfloat16);  view_58 = None
        
all_gather_into_tensor_19: "bf16[4, 256, 256][65536, 256, 1]cuda:0" = torch.ops._c10d_functional.all_gather_into_tensor.default(_to_copy_32, 4, '1');  _to_copy_32 = None

wait_tensor_22: "bf16[4, 256, 256][65536, 256, 1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_19);  all_gather_into_tensor_19 = None
        
view_63: "bf16[4, 256, 256][65536, 256, 1]cuda:0" = torch.ops.aten.view.default(wait_tensor_22, [4, 256, 256]);  wait_tensor_22 = None
        
view_64: "bf16[4, 256, 256][65536, 256, 1]cuda:0" = torch.ops.aten.view.default(view_63, [4, 256, 256]);  view_63 = None
        
view_65: "bf16[4, 256, 256][65536, 256, 1]cuda:0" = torch.ops.aten.view.default(view_64, [4, 256, 256]);  view_64 = None

transpose_8: "bf16[4, 256, 256][65536, 1, 256]cuda:0" = torch.ops.aten.transpose.int(view_65, -2, -1);  view_65 = None
        
_grouped_mm: "bf16[8*(((u2 + u3 + 39)//8)), 256][256, 1]cuda:0" = torch.ops.aten._grouped_mm.default(index_1, transpose_8, cumsum_2);  transpose_8 = None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grouped_mm getting recomputed too definitely seems wrong - I would have expected that as part of running SAC (outside of the graph pass) we would have set it to MUST_SAVE

torch.ops.aten.slice.Tensor,
torch.ops.aten.view.default,
torch.ops.aten.reshape.default,
torch.ops.aten.transpose.int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumb nit: we can probably generalize this by just having your graph pass check for func.is_view (bool property on OpOverload that is true for all view ops)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

@ruisizhang123 ruisizhang123 force-pushed the ruisi/fix_manual_bucketing_dsv3 branch from 3b0fdda to d547ce2 Compare December 12, 2025 20:02
@ruisizhang123 ruisizhang123 force-pushed the ruisi/fix_manual_bucketing_dsv3 branch from d547ce2 to d26eb87 Compare December 12, 2025 20:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants