From 877c105cde8ad86df4ef7684e126a8043332b418 Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Mon, 27 Oct 2025 20:32:42 -0700 Subject: [PATCH] init --- torchtitan/components/loss.py | 2 +- torchtitan/experiments/flux/loss.py | 2 +- torchtitan/experiments/simple_fsdp/llama3/parallelize.py | 1 + torchtitan/experiments/vlm/infra/loss.py | 2 +- torchtitan/models/deepseek_v3/train_configs/debug_model.toml | 2 +- torchtitan/models/llama3/infra/parallelize.py | 2 +- torchtitan/models/llama3/train_configs/llama3_8b.toml | 2 +- torchtitan/models/llama4/infra/parallelize.py | 1 + 8 files changed, 8 insertions(+), 6 deletions(-) diff --git a/torchtitan/components/loss.py b/torchtitan/components/loss.py index 3036611317..8d51544a75 100644 --- a/torchtitan/components/loss.py +++ b/torchtitan/components/loss.py @@ -28,7 +28,7 @@ def build_cross_entropy_loss(job_config: JobConfig, **kwargs): loss_fn = cross_entropy_loss if job_config.compile.enable and "loss" in job_config.compile.components: logger.info("Compiling the loss function with torch.compile") - loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend) + loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend, mode="light") return loss_fn diff --git a/torchtitan/experiments/flux/loss.py b/torchtitan/experiments/flux/loss.py index 5a175e741d..648f54b28b 100644 --- a/torchtitan/experiments/flux/loss.py +++ b/torchtitan/experiments/flux/loss.py @@ -23,5 +23,5 @@ def build_mse_loss(job_config: JobConfig, **kwargs): loss_fn = mse_loss if job_config.compile.enable and "loss" in job_config.compile.components: logger.info("Compiling the loss function with torch.compile") - loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend) + loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend, mode="light") return loss_fn diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index 446dcaff89..45bf4eed36 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -133,6 +133,7 @@ def parallelize_llama( model, backend=get_compile_backend(backend), fullgraph=True, + mode="light", ) return model diff --git a/torchtitan/experiments/vlm/infra/loss.py b/torchtitan/experiments/vlm/infra/loss.py index bba51f2819..093165229d 100644 --- a/torchtitan/experiments/vlm/infra/loss.py +++ b/torchtitan/experiments/vlm/infra/loss.py @@ -109,5 +109,5 @@ def build_token_imbalance_ce_loss( loss_fn = partial(token_imbalance_ce_loss, token_mesh=token_mesh, ft_pg=ft_pg) if job_config.compile.enable and "loss" in job_config.compile.components: logger.info("Compiling the loss function with torch.compile") - loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend) + loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend, mode="light") return loss_fn diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 1951cc4350..3dfc7ed2df 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -4,7 +4,7 @@ description = "DeepSeek-V3 debug training" print_config = false [profiling] -enable_profiling = false +enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 10 enable_memory_snapshot = false diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 4944af569e..b8d7d0430b 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -243,7 +243,7 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig): """ for layer_id, transformer_block in model.layers.named_children(): transformer_block = torch.compile( - transformer_block, backend=compile_config.backend, fullgraph=True + transformer_block, backend=compile_config.backend, fullgraph=True, mode="light" ) model.layers.register_module(layer_id, transformer_block) diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index ef86d783bf..7dd15ff97a 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -37,7 +37,7 @@ dataset = "c4" [parallelism] data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 +data_parallel_shard_degree = 8 tensor_parallel_degree = 1 pipeline_parallel_degree = 1 context_parallel_degree = 1 diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 1f579ccd04..171f694a84 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -520,6 +520,7 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig): transformer_block, backend=compile_config.backend, fullgraph=fullgraph, + mode="light", ) model.layers.register_module(layer_id, transformer_block)