Skip to content

Commit 00ccbf8

Browse files
Enable Loss Fn in Graph PP
ghstack-source-id: 56b4b85 Pull Request resolved: #247
1 parent b1c4909 commit 00ccbf8

File tree

7 files changed

+308
-107
lines changed

7 files changed

+308
-107
lines changed

autoparallel/_testing/models/dsv3.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1565,6 +1565,12 @@ def forward(
15651565
return output
15661566

15671567

1568+
def dsv3_loss_fn(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
1569+
return torch.nn.functional.cross_entropy(
1570+
pred.flatten(0, 1).float(), labels.flatten(0, 1)
1571+
)
1572+
1573+
15681574
########################
15691575
# Pipeline stuff start #
15701576
########################

autoparallel/api.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import warnings
99
from contextlib import ExitStack, contextmanager
1010
from types import MethodType
11-
from typing import Any, Optional, Union
11+
from typing import Any, Callable, Optional, Union
1212

1313
import torch
1414
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
@@ -159,7 +159,9 @@ def enable_local_map_wrapping():
159159
yield
160160

161161

162-
def _export(model: torch.nn.Module, inputs: tuple[Any]) -> torch.nn.Module:
162+
def _export(
163+
model: torch.nn.Module, model_wrapper: Callable, inputs: tuple[Any]
164+
) -> torch.fx.GraphModule:
163165
"""
164166
Thin wrapper around graph capture output that restores the
165167
original calling convention and attribute fqn. TODO:
@@ -169,7 +171,7 @@ def _export(model: torch.nn.Module, inputs: tuple[Any]) -> torch.nn.Module:
169171
3) Be more careful about tensor constants names.
170172
"""
171173
with torch._dynamo.config.patch(install_free_tensors=True):
172-
gm = _dynamo_graph_capture_for_export(model)(*inputs)
174+
gm = _dynamo_graph_capture_for_export(model_wrapper)(*inputs)
173175
_restore_state_dict(model, gm)
174176
return gm
175177

@@ -193,6 +195,7 @@ def __init__(
193195
ac_stage_size_in_GiB: Optional[Union[float, str]] = "auto",
194196
reshard_after_forward: bool = True,
195197
dynamic: bool = False,
198+
loss_fn: Optional[Callable] = None,
196199
**kwargs,
197200
):
198201
self.stack = ExitStack()
@@ -224,6 +227,7 @@ def __init__(
224227
self.enable_ac = enable_ac
225228
self.ac_stage_size_in_GiB = ac_stage_size_in_GiB
226229
self.reshard_after_forward = reshard_after_forward
230+
self.loss_fn = loss_fn
227231

228232
if dynamic:
229233
self.fake_mode.shape_env = ShapeEnv()
@@ -294,11 +298,27 @@ def build_model_graph(self):
294298
inputs = self.input_fn()
295299
if not isinstance(inputs, tuple):
296300
inputs = (inputs,)
301+
model_wrapper: Callable
302+
if self.loss_fn is not None:
303+
304+
def model_with_loss(input, target) -> Any:
305+
output = self.model(input)
306+
loss = self.loss_fn(output, target) # type: ignore[misc]
307+
return loss
308+
309+
model_wrapper = model_with_loss
310+
else:
311+
312+
def model_wo_loss(input) -> Any:
313+
output = self.model(input)
314+
return output
315+
316+
model_wrapper = model_wo_loss
297317

298318
with set_dtype_cast(
299319
True
300320
), enable_local_map_wrapping(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
301-
torch_ir_with_fqn = _export(self.model, inputs)
321+
torch_ir_with_fqn = _export(self.model, model_wrapper, inputs)
302322
# TODO Cna't use fake mode here because it clashes with the user level
303323
# fake mode. Ideally dynamo should reuse the user level fake mode.
304324
self.joint_with_descriptors = aot_export_joint_with_descriptors(
@@ -326,6 +346,7 @@ def build_model_graph(self):
326346
print_output=False, include_stride=True, include_device=True
327347
),
328348
)
349+
print(gm.graph)
329350

330351
self.gm = gm
331352

autoparallel/apply_sharding.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,18 @@ def call_function(self, target, args, kwargs):
208208

209209
# apply sharding to constructor functions as well
210210
if target in TENSOR_FACTORY_OPS:
211-
val = list(new_args[0])
212-
spec = self.sharding_placement[node].output_specs
213-
for mesh_size, placement in zip(spec.mesh.shape, spec.placements):
214-
if placement.is_shard():
215-
# TODO: fix uneven cases ?
216-
val[placement.dim] //= mesh_size
217-
new_args[0] = tuple(val)
211+
# scalar_tensor has a scalar as first arg, not a shape
212+
if target == torch.ops.aten.scalar_tensor.default:
213+
# scalar tensors can't be sharded, so no transformation needed
214+
pass
215+
else:
216+
val = list(new_args[0])
217+
spec = self.sharding_placement[node].output_specs
218+
for mesh_size, placement in zip(spec.mesh.shape, spec.placements):
219+
if placement.is_shard():
220+
# TODO: fix uneven cases ?
221+
val[placement.dim] //= mesh_size
222+
new_args[0] = tuple(val)
218223

219224
# use DTensor machinery to ensure the view ops are valid
220225
# otherwise we would end-up forcing global shapes on local tensors

autoparallel/graph_pp_runner.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,11 @@ def stage_forward(
275275
# Receive activations for this chunk
276276
# Activations only come in args form
277277
composite_args = stage._retrieve_recv_activations(mb_index)
278+
if stage.is_last and ctx.target_mbs is not None:
279+
assert isinstance(
280+
composite_args, tuple
281+
), f"Expected composite args to be a tuple but got {type(composite_args)}"
282+
composite_args = composite_args + (ctx.target_mbs[mb_index],) # type: ignore[index]
278283

279284
# stage._validate_fwd_input(args, kwargs) Maybe need to validate composite args?
280285
logger.debug(
@@ -292,6 +297,8 @@ def stage_forward(
292297
# Output chunks is only used for the last stage since we only merge the output of the last stage
293298
if stage.is_last:
294299
stage.output_chunks.append(output)
300+
if ctx.target_mbs is not None:
301+
ctx.schedule_ref._internal_losses.append(output)
295302

296303
stage.fwd_cache[mb_index] = (
297304
output_tuple, # stage_output
@@ -360,7 +367,7 @@ def stage_full_backward(
360367
# HACK till we have loss function, we populate the tangents here manually
361368
bwd_kwargs = {
362369
"stage_output": loss,
363-
"tangents": [torch.randn_like(stage_output[0])],
370+
"tangents": [torch.ones_like(stage_output[0])],
364371
"saved_intermediates": saved_intermediates,
365372
}
366373
else:
@@ -525,7 +532,9 @@ def _accumulate_stage_grads_and_clear_states(
525532
stage.state.clear()
526533

527534
def step(self, *args, **kwargs) -> None:
528-
535+
has_targets_and_loss = (
536+
"losses" in kwargs and "targets" in kwargs if kwargs else False
537+
)
529538
for stage in self.schedule._stages:
530539
assert isinstance(stage, GraphPipelineStage)
531540
self._populate_stage_states(stage)
@@ -535,3 +544,11 @@ def step(self, *args, **kwargs) -> None:
535544
for stage in self.schedule._stages:
536545
assert isinstance(stage, GraphPipelineStage)
537546
self._accumulate_stage_grads_and_clear_states(stage)
547+
if stage.is_last and has_targets_and_loss:
548+
losses = kwargs["losses"]
549+
losses.clear()
550+
assert (
551+
len(self.schedule._internal_losses) == self.schedule._n_microbatches
552+
)
553+
losses.extend(self.schedule._internal_losses)
554+
self.schedule._internal_losses.clear()

autoparallel/propagation_rules.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -363,11 +363,8 @@ def randperm_rule(mesh, specs):
363363
return OpStrategy([OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])])
364364

365365

366-
# We do a few special things for factory ops
367-
# - use the factory rule below
368-
# - fake that they have input schemas so the solver doesn't freak out
369-
# - convert their sizes to 'local tensor sizes' (divide by mesh dim) during ApplySharding
370-
TENSOR_FACTORY_OPS = [
366+
# Factory ops that take a shape as the first argument
367+
_SHAPE_FACTORY_OPS = [
371368
torch.ops.aten.zeros.default,
372369
torch.ops.aten.ones.default,
373370
torch.ops.aten.full.default,
@@ -376,8 +373,49 @@ def randperm_rule(mesh, specs):
376373
torch.ops.aten.randn.default,
377374
]
378375

376+
# We do a few special things for factory ops
377+
# - use the factory rule below
378+
# - fake that they have input schemas so the solver doesn't freak out
379+
# - convert their sizes to 'local tensor sizes' (divide by mesh dim) during ApplySharding
380+
TENSOR_FACTORY_OPS = _SHAPE_FACTORY_OPS + [
381+
torch.ops.aten.scalar_tensor.default, # Special case: creates 0-dim tensor
382+
]
383+
384+
385+
@register_opschema_rule(torch.ops.aten.scalar_tensor.default)
386+
def scalar_tensor_rule(mesh, op_schema: OpSchema) -> OpStrategy:
387+
"""
388+
Rule for aten.scalar_tensor which creates a scalar (0-dimensional) tensor.
389+
Unlike other factory ops, this doesn't take a shape parameter.
390+
391+
Schema: scalar_tensor(Scalar s, *, ScalarType? dtype=None, ...) -> Tensor
392+
"""
393+
# scalar_tensor creates a 0-dimensional tensor
394+
shape = ()
395+
stride = ()
396+
dtype = torch.get_default_dtype()
397+
398+
# Check if dtype is specified in kwargs or args
399+
if len(op_schema.args_schema) >= 2 and op_schema.args_schema[1] is not None:
400+
dtype = op_schema.args_schema[1] # type: ignore[assignment]
401+
402+
tensor_meta = TensorMeta(shape, stride, dtype) # type: ignore[arg-type]
403+
404+
# For a scalar (0-dim) tensor, we can only replicate across all mesh dimensions
405+
placement = (Replicate(),) * mesh.ndim
406+
output_specs = DTensorSpec(mesh, placement, tensor_meta=tensor_meta)
407+
408+
# Similar to factory_rule, we add a dummy input_specs for solver compatibility
409+
strategy = OpSpec(
410+
output_specs=output_specs,
411+
input_specs=[output_specs],
412+
redistribute_cost=[[0.0]],
413+
)
414+
415+
return OpStrategy([strategy])
416+
379417

380-
@register_opschema_rule(TENSOR_FACTORY_OPS)
418+
@register_opschema_rule(_SHAPE_FACTORY_OPS)
381419
def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy:
382420
"""
383421
This is an auto-parallel specific util that won't be upstreamed becuase of a UX mismatch.

0 commit comments

Comments
 (0)