Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torchtitan/experiments/simple_fsdp/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import torch
import torch._functorch.config as functorch_config
from functorch.compile import min_cut_rematerialization_partition
from torchtitan.tools.logging import logger

from .job_config import Compile as CompileConfig
Expand Down Expand Up @@ -60,6 +61,7 @@ def aot_eager_autobucketing_reordering_pass(
backend = aot_autograd_backend(
fw_compiler=aot_eager_autobucketing_reordering_pass,
bw_compiler=aot_eager_autobucketing_reordering_pass,
partition_fn=min_cut_rematerialization_partition,
keep_inference_input_mutations=True,
)
elif compile_config.backend == "inductor":
Expand Down Expand Up @@ -108,6 +110,7 @@ def aot_eager_transformer_block_bucketing_reordering_pass(
backend = aot_autograd_backend(
fw_compiler=aot_eager_transformer_block_bucketing_reordering_pass,
bw_compiler=aot_eager_transformer_block_bucketing_reordering_pass,
partition_fn=min_cut_rematerialization_partition,
keep_inference_input_mutations=True,
)
elif compile_config.backend == "inductor":
Expand Down
9 changes: 6 additions & 3 deletions torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,20 @@ def get_transformer_block_buckets(model) -> list[list[str] | str]:
[model.norm, model.output],
]
for layer_id, transformer_block in model.layers.items():
# [TODO](ruisizhang123) add EP support for transformer block bucketing
module_list.append(transformer_block)

def convert_modules_to_fqns(modules, module_to_fqn_mapping):
"""Convert a (possibly nested) list of modules to FQN strings."""
result = []
for m in modules:
if isinstance(m, list):
result.append(convert_modules_to_fqns(m, module_to_fqn_mapping))
# check if fqn_list is valid. In PP, bucketed module may
# not be in the current rank, and fqn_list is None.
if fqn_list := convert_modules_to_fqns(m, module_to_fqn_mapping):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does the syntax mean -- assigning to fqn_list and check not None? It feels a bit unusual to read.

Also please add a comment on why we need this check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, added a comment for it.

result.append(fqn_list)
else:
result.append(module_to_fqn_mapping.get(m, None))
if fqn := module_to_fqn_mapping.get(m):
result.append(fqn)
return result

module_to_name = {m: n for n, m in model.named_modules()}
Expand Down
2 changes: 2 additions & 0 deletions torchtitan/experiments/simple_fsdp/llama3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def convert_modules_to_fqns(modules, module_to_fqn_mapping):
result = []
for m in modules:
if isinstance(m, list):
# check if fqn_list is valid. In PP, bucketed module may
# not be in the current rank, and fqn_list is None.
if fqn_list := convert_modules_to_fqns(m, module_to_fqn_mapping):
result.append(fqn_list)
else:
Expand Down
34 changes: 27 additions & 7 deletions torchtitan/experiments/simple_fsdp/reshard_after_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,40 @@ def force_recompute_node(node):
# nearby ac_graph_id values
node.meta["ac_graph_id"] = 100000

def tag_view_recompute_nodes(start_node, seen=None):
"""
Recursively walk only *single‑user* paths from start_node,
tagging view nodes for recompute.
"""
if seen is None:
seen = set()
if start_node in seen:
return
seen.add(start_node)

# If current node has multiple users, stop descending
if len(start_node.users) != 1:
return

user = next(iter(start_node.users))

# Only continue if this user is a view op
if user.op == "call_function" and user.target.is_view:
force_recompute_node(user)
# Recurse deeper only on this one user
tag_view_recompute_nodes(user, seen)

# Make all-gather nodes (and related nodes) recomputable, to circumvent
# https://github.com/pytorch/pytorch/issues/136433
for node in graph.nodes:
if is_wait_tensor_from_fsdp(node):
ag_node = node.args[0]
force_recompute_node(ag_node) # all_gather
force_recompute_node(node) # wait_tensor
# Force-recompute slice that comes after wait
for user in node.users:
if (
user.op == "call_function"
and user.target == torch.ops.aten.slice.Tensor
):
force_recompute_node(user)

# Recursively tag view ops from all_gather_wait
tag_view_recompute_nodes(node)

# Force-recompute potential dtype casts from all_gather
if (
ag_node.all_input_nodes[0].op == "call_function"
Expand Down
2 changes: 0 additions & 2 deletions torchtitan/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@
qk_rope_head_dim=64,
v_head_dim=128,
mscale=0.70,
attn_type="flex",
attn_mask_type="block_causal",
),
"236B": DeepSeekV3ModelArgs(
vocab_size=102400,
Expand Down
Loading