diff --git a/torchtitan/components/loss.py b/torchtitan/components/loss.py index 6fb80f39cb..8f858ce7ff 100644 --- a/torchtitan/components/loss.py +++ b/torchtitan/components/loss.py @@ -28,7 +28,7 @@ def build_cross_entropy_loss(job_config: JobConfig, **kwargs): loss_fn = cross_entropy_loss if job_config.compile.enable and "loss" in job_config.compile.components: logger.info("Compiling the loss function with torch.compile") - loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend) + loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend, mode="light") return loss_fn diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index d61e74a5dd..eb3f1db0b1 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -148,6 +148,7 @@ def parallelize_llama( model, backend=get_compile_backend(backend), fullgraph=True, + mode="light", ) return model diff --git a/torchtitan/experiments/vlm/infra/loss.py b/torchtitan/experiments/vlm/infra/loss.py index bba51f2819..093165229d 100644 --- a/torchtitan/experiments/vlm/infra/loss.py +++ b/torchtitan/experiments/vlm/infra/loss.py @@ -109,5 +109,5 @@ def build_token_imbalance_ce_loss( loss_fn = partial(token_imbalance_ce_loss, token_mesh=token_mesh, ft_pg=ft_pg) if job_config.compile.enable and "loss" in job_config.compile.components: logger.info("Compiling the loss function with torch.compile") - loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend) + loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend, mode="light") return loss_fn diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 1951cc4350..3dfc7ed2df 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -4,7 +4,7 @@ description = "DeepSeek-V3 debug training" print_config = false [profiling] -enable_profiling = false +enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 10 enable_memory_snapshot = false diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 86ac3a6dfe..b0c3de420e 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -242,7 +242,7 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig): """ for layer_id, transformer_block in model.layers.named_children(): transformer_block = torch.compile( - transformer_block, backend=compile_config.backend, fullgraph=True + transformer_block, backend=compile_config.backend, fullgraph=True, mode="light" ) model.layers.register_module(layer_id, transformer_block) diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index ef86d783bf..7dd15ff97a 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -37,7 +37,7 @@ dataset = "c4" [parallelism] data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 +data_parallel_shard_degree = 8 tensor_parallel_degree = 1 pipeline_parallel_degree = 1 context_parallel_degree = 1 diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index aff029f736..4b906bd298 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -546,7 +546,7 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig): moe, attr_name, torch.compile( - submod, backend=compile_config.backend, fullgraph=True + submod, backend=compile_config.backend, fullgraph=True, mode="light", ), ) else: @@ -554,7 +554,7 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig): block, attr_name, torch.compile( - submod, backend=compile_config.backend, fullgraph=True + submod, backend=compile_config.backend, fullgraph=True, mode="light", ), ) @@ -565,6 +565,7 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig): transformer_block, backend=compile_config.backend, fullgraph=True, + mode="light", ) model.layers.register_module(layer_id, transformer_block)