Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions autoparallel/_testing/models/dsv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1584,6 +1584,12 @@ def forward(
return output


def dsv3_loss_fn(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.cross_entropy(
pred.flatten(0, 1).float(), labels.flatten(0, 1)
)


########################
# Pipeline stuff start #
########################
Expand Down
80 changes: 67 additions & 13 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import warnings
from contextlib import ExitStack, contextmanager
from types import MethodType
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union

import torch
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
Expand Down Expand Up @@ -159,17 +159,35 @@ def enable_local_map_wrapping():
yield


def _export(model: torch.nn.Module, inputs: tuple[Any]) -> torch.nn.Module:
def _export(
model: torch.nn.Module, model_wrapper: Callable, inputs: tuple[Any, ...]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: model_wrapper can be optional? if not provided, just use the model itself as the tracing entrypoint?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model wrapper is never None or Optional. Even if we don't use loss function we do construct a model wrapper. I feel this would be a good api even for future use when we decide to use an optimizer.

) -> torch.fx.GraphModule:
"""
Thin wrapper around graph capture output that restores the
original calling convention and attribute fqn. TODO:
1) Use bytecode for calling convention instead of pytree for more
seamless UX.
Capture a model graph via Dynamo and restore parameter/buffer metadata.

We need both `model` and `model_wrapper` because:
- `model_wrapper` is the actual callable that gets traced by Dynamo. It may wrap
the model with additional logic (e.g., adding a loss function on top of the model's
forward pass, or preparing inputs in a specific way).
- `model` is the original nn.Module needed to restore the correct fully-qualified
names (FQNs) for parameters and buffers in the traced graph. Without this, the
captured graph would lose the original parameter naming structure.

Args:
model: Original nn.Module with parameter/buffer metadata to restore
model_wrapper: Callable to trace (may wrap model with additional logic)
inputs: Input tensors for tracing

Returns:
GraphModule with restored parameter FQNs and calling convention

TODO:
1) Use bytecode for calling convention instead of pytree for more seamless UX
2) Attach guards
3) Be more careful about tensor constants names.
3) Be more careful about tensor constants names
"""
with torch._dynamo.config.patch(install_free_tensors=True):
gm = _dynamo_graph_capture_for_export(model)(*inputs)
gm = _dynamo_graph_capture_for_export(model_wrapper)(*inputs)
_restore_state_dict(model, gm)
return gm

Expand All @@ -193,6 +211,7 @@ def __init__(
ac_stage_size_in_GiB: Optional[Union[float, str]] = "auto",
reshard_after_forward: bool = True,
dynamic: bool = False,
loss_fn: Optional[Callable] = None,
**kwargs,
):
self.stack = ExitStack()
Expand Down Expand Up @@ -224,6 +243,7 @@ def __init__(
self.enable_ac = enable_ac
self.ac_stage_size_in_GiB = ac_stage_size_in_GiB
self.reshard_after_forward = reshard_after_forward
self.loss_fn = loss_fn

if dynamic:
self.fake_mode.shape_env = ShapeEnv()
Expand Down Expand Up @@ -291,20 +311,54 @@ def build_model_graph(self):
decomp_table = _get_decomp_table()

with self.fake_mode:
inputs = self.input_fn()
if not isinstance(inputs, tuple):
inputs = (inputs,)
raw_inputs = self.input_fn()

# Define model wrapper based on whether loss_fn is provided
model_wrapper: Callable
# Parse inputs based on whether loss_fn is provided
# If loss_fn is not None: expected format ((inp1, inp2,...), target)
# If loss_fn is None: expected format (inp1, inp2, ...)
if self.loss_fn is not None:
if isinstance(raw_inputs, tuple) and len(raw_inputs) == 2:
model_inputs, target = raw_inputs
# Normalize inputs to always be a tuple
if not isinstance(model_inputs, tuple):
model_inputs = (model_inputs,)
formatted_inputs = (model_inputs, target)

def model_with_loss(model_inputs, target) -> Any:
output = self.model(*model_inputs)
loss = self.loss_fn(output, target) # type: ignore[misc]
return loss

model_wrapper = model_with_loss
else:
raise ValueError(
"When loss_fn is provided, input_fn must return (inputs, target) "
"where inputs can be a single tensor or tuple of tensors"
)
else:
# No loss function, inputs are just model inputs
formatted_inputs = (
raw_inputs if isinstance(raw_inputs, tuple) else (raw_inputs,)
)

def model_wo_loss(*model_inputs) -> Any:
output = self.model(*model_inputs)
return output

model_wrapper = model_wo_loss

with set_dtype_cast(
True
), enable_local_map_wrapping(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
torch_ir_with_fqn = _export(self.model, inputs)
torch_ir_with_fqn = _export(self.model, model_wrapper, formatted_inputs)
# TODO Cna't use fake mode here because it clashes with the user level
# fake mode. Ideally dynamo should reuse the user level fake mode.
self.joint_with_descriptors = aot_export_joint_with_descriptors(
self.stack,
torch_ir_with_fqn,
inputs,
formatted_inputs,
decompositions=decomp_table,
)
gm = self.joint_with_descriptors.graph_module
Expand Down
19 changes: 12 additions & 7 deletions autoparallel/apply_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,18 @@ def call_function(self, target, args, kwargs):

# apply sharding to constructor functions as well
if target in TENSOR_FACTORY_OPS:
val = list(new_args[0])
spec = self.sharding_placement[node].output_specs
for mesh_size, placement in zip(spec.mesh.shape, spec.placements):
if placement.is_shard():
# TODO: fix uneven cases ?
val[placement.dim] //= mesh_size
new_args[0] = tuple(val)
# scalar_tensor has a scalar as first arg, not a shape
if target == torch.ops.aten.scalar_tensor.default:
# scalar tensors can't be sharded, so no transformation needed
pass
else:
val = list(new_args[0])
spec = self.sharding_placement[node].output_specs
for mesh_size, placement in zip(spec.mesh.shape, spec.placements):
if placement.is_shard():
# TODO: fix uneven cases ?
val[placement.dim] //= mesh_size
new_args[0] = tuple(val)

# use DTensor machinery to ensure the view ops are valid
# otherwise we would end-up forcing global shapes on local tensors
Expand Down
47 changes: 34 additions & 13 deletions autoparallel/graph_pp_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,20 @@ def _run_reduce_grad_module(
return sharded_grads


def _accumulate_stage_grads(
unsharded_grads: list[Union[torch.Tensor, None]],
grads_to_accumulate: list[Union[torch.Tensor, None]],
) -> None:
assert len(unsharded_grads) == len(grads_to_accumulate)
assert not all(grad is None for grad in grads_to_accumulate), "All grads are None"
for unsharded_grad, grad_to_accumulate in zip(unsharded_grads, grads_to_accumulate):
if grad_to_accumulate is not None:
if unsharded_grad is None:
unsharded_grad = grad_to_accumulate
else:
unsharded_grad += grad_to_accumulate


def _run_forward_microbatch(
stage: GraphPipelineStage, *args, numerics_logs: Optional[list[str]] = None
) -> tuple[Any, Any]:
Expand Down Expand Up @@ -216,17 +230,9 @@ def _run_backward_microbatch(
)

unsharded_grads = backward_stage.state["unsharded_grads"]
grads_to_accumulate = param_buffer_grads[
: len(backward_stage.state["sharded_params"])
]
assert len(unsharded_grads) == len(grads_to_accumulate)
assert not all(grad is None for grad in grads_to_accumulate), "All grads are None"
for unsharded_grad, grad_to_accumulate in zip(unsharded_grads, grads_to_accumulate):
if grad_to_accumulate is not None:
if unsharded_grad is None:
unsharded_grad = grad_to_accumulate
else:
unsharded_grad += grad_to_accumulate
grads_to_accumulate = param_buffer_grads[: backward_stage.graph_meta.num_params]
_accumulate_stage_grads(unsharded_grads, grads_to_accumulate)

return input_grads


Expand Down Expand Up @@ -275,6 +281,11 @@ def stage_forward(
# Receive activations for this chunk
# Activations only come in args form
composite_args = stage._retrieve_recv_activations(mb_index)
if stage.is_last and ctx.target_mbs is not None:
assert isinstance(
composite_args, tuple
), f"Expected composite args to be a tuple but got {type(composite_args)}"
composite_args = composite_args + (ctx.target_mbs[mb_index],) # type: ignore[index]

# stage._validate_fwd_input(args, kwargs) Maybe need to validate composite args?
logger.debug(
Expand All @@ -292,6 +303,8 @@ def stage_forward(
# Output chunks is only used for the last stage since we only merge the output of the last stage
if stage.is_last:
stage.output_chunks.append(output)
if ctx.target_mbs is not None:
ctx.schedule_ref._internal_losses.append(output)

stage.fwd_cache[mb_index] = (
output_tuple, # stage_output
Expand Down Expand Up @@ -360,7 +373,7 @@ def stage_full_backward(
# HACK till we have loss function, we populate the tangents here manually
bwd_kwargs = {
"stage_output": loss,
"tangents": [torch.randn_like(stage_output[0])],
"tangents": [torch.ones_like(stage_output[0])],
"saved_intermediates": saved_intermediates,
}
else:
Expand Down Expand Up @@ -525,7 +538,9 @@ def _accumulate_stage_grads_and_clear_states(
stage.state.clear()

def step(self, *args, **kwargs) -> None:

has_targets_and_loss = (
"losses" in kwargs and "targets" in kwargs if kwargs else False
)
for stage in self.schedule._stages:
assert isinstance(stage, GraphPipelineStage)
self._populate_stage_states(stage)
Expand All @@ -535,3 +550,9 @@ def step(self, *args, **kwargs) -> None:
for stage in self.schedule._stages:
assert isinstance(stage, GraphPipelineStage)
self._accumulate_stage_grads_and_clear_states(stage)

if has_targets_and_loss:
losses = kwargs["losses"]
assert len(self.schedule._internal_losses) == self.schedule._n_microbatches
losses.extend(self.schedule._internal_losses)
self.schedule._internal_losses.clear()
50 changes: 44 additions & 6 deletions autoparallel/propagation_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,8 @@ def randperm_rule(mesh, specs):
return OpStrategy([OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])])


# We do a few special things for factory ops
# - use the factory rule below
# - fake that they have input schemas so the solver doesn't freak out
# - convert their sizes to 'local tensor sizes' (divide by mesh dim) during ApplySharding
TENSOR_FACTORY_OPS = [
# Factory ops that take a shape as the first argument
_SHAPE_FACTORY_OPS = [
torch.ops.aten.zeros.default,
torch.ops.aten.ones.default,
torch.ops.aten.full.default,
Expand All @@ -376,8 +373,49 @@ def randperm_rule(mesh, specs):
torch.ops.aten.randn.default,
]

# We do a few special things for factory ops
# - use the factory rule below
# - fake that they have input schemas so the solver doesn't freak out
# - convert their sizes to 'local tensor sizes' (divide by mesh dim) during ApplySharding
TENSOR_FACTORY_OPS = _SHAPE_FACTORY_OPS + [
torch.ops.aten.scalar_tensor.default, # Special case: creates 0-dim tensor
]


@register_opschema_rule(torch.ops.aten.scalar_tensor.default)
def scalar_tensor_rule(mesh, op_schema: OpSchema) -> OpStrategy:
"""
Rule for aten.scalar_tensor which creates a scalar (0-dimensional) tensor.
Unlike other factory ops, this doesn't take a shape parameter.

Schema: scalar_tensor(Scalar s, *, ScalarType? dtype=None, ...) -> Tensor
"""
# scalar_tensor creates a 0-dimensional tensor
shape = ()
stride = ()
dtype = torch.get_default_dtype()

# Check if dtype is specified in kwargs or args
if len(op_schema.args_schema) >= 2 and op_schema.args_schema[1] is not None:
dtype = op_schema.args_schema[1] # type: ignore[assignment]

tensor_meta = TensorMeta(shape, stride, dtype) # type: ignore[arg-type]

# For a scalar (0-dim) tensor, we can only replicate across all mesh dimensions
placement = (Replicate(),) * mesh.ndim
output_specs = DTensorSpec(mesh, placement, tensor_meta=tensor_meta)

# Similar to factory_rule, we add a dummy input_specs for solver compatibility
strategy = OpSpec(
output_specs=output_specs,
input_specs=[output_specs],
redistribute_cost=[[0.0]],
)

return OpStrategy([strategy])


@register_opschema_rule(TENSOR_FACTORY_OPS)
@register_opschema_rule(_SHAPE_FACTORY_OPS)
def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy:
"""
This is an auto-parallel specific util that won't be upstreamed becuase of a UX mismatch.
Expand Down
Loading
Loading