Skip to content

Various Corrdiff optimizations for drastic increase of training efficiency #809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- ERA5 download example updated to use current file format convention and
restricts global statistics computation to the training set
- Support for training custom StormCast models and various other improvements for StormCast
- Updated CorrDiff training code to support multiple patch iterations to amortize
regression cost and usage of `torch.compile`
- Refactored `physicsnemo/models/diffusion/layers.py` to optimize data type
casting workflow, avoiding unnecessary casting under autocast mode
- Refactored Conv2d to enable fusion of conv2d with bias addition
- Refactored GroupNorm, UNetBlock, SongUNet, SongUNetPosEmbd to support usage of
Apex GroupNorm, fusion of activation with GroupNorm, and AMP workflow.
- Updated SongUNetPosEmbd to avoid unnecessary HtoD Memcpy of `pos_embd`
- Updated `from_checkpoint` to accommodate conversion between Apex optimized ckp
and non-optimized ckp
- Refactored CorrDiff NVTX annotation workflow to be configurable
- Refactored `ResidualLoss` to support patch-accumlating training for
amortizing regression costs

### Deprecated

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

name: diffusion
name: patched_diffusion
# Model type.
hr_mean_conditioning: True
# Recommended to use high-res conditioning for diffusion.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ model_args:
# Per-resolution multipliers for the number of channels.
channel_mult: [1, 2, 2, 2, 2]
# Resolutions at which self-attention layers are applied.
attention_levels: [28]
attn_resolutions: [28]
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

hydra:
job:
chdir: true
name: patched_diffusion_opt
run:
dir: ./output/${hydra:job.name}
searchpath:
- pkg://conf/base # Do not modify

# Base parameters for dataset, model, training, and validation
defaults:

- dataset: hrrr_corrdiff_synthetic
# The dataset type for training.
# Accepted values:
# `gefs_hrrr`: full GEFS-HRRR dataset for continental US.
# `hrrr_mini`: smaller HRRR dataset (continental US), for fast experiments.
# `cwb`: full CWB dataset for Taiwan.
# `custom`: user-defined dataset. Parameters need to be specified below.

- model: patched_diffusion
# The model type.
# Accepted values:
# `regression`: a regression UNet for deterministic predictions
# `lt_aware_ce_regression`: similar to `regression` but with lead time
# conditioning
# `diffusion`: a diffusion UNet for residual predictions
# `patched_diffusion`: a more memory-efficient diffusion model
# `lt_aware_patched_diffusion`: similar to `patched_diffusion` but
# with lead time conditioning

- model_size: normal
# The model size configuration.
# Accepted values:
# `normal`: normal model size
# `mini`: smaller model size for fast experiments

- training: ${model}
# The base training parameters. Determined by the model type.


# Dataset parameters. Used for `custom` dataset type.
# Modify or add below parameters that should be passed as argument to the
# user-defined dataset class.
dataset:
data_path: ./data
# Path to .nc data file
stats_path: ./data/stats.json
# Path to json stats file

# Training parameters
training:
hp:
training_duration: 200000000
# Training duration based on the number of processed samples
total_batch_size: 512
# Total batch size
batch_size_per_gpu: 4

patch_shape_x: 448
patch_shape_y: 448
# Patch size. Patch training is used if these dimensions differ from
# img_shape_x and img_shape_y.
patch_num: 16
# Number of patches from a single sample. Total number of patches is
# patch_num * batch_size_global.
max_patch_per_gpu: 9
# Maximum number of pataches a gpu can hold

lr: 0.0002
# Learning rate
grad_clip_threshold: 1e6
lr_decay: 0.7
lr_rampup: 1000000

# Performance
perf:
fp_optimizations: amp-bf16
# Floating point mode, one of ["fp32", "fp16", "amp-fp16", "amp-bf16"]
# "amp-{fp16,bf16}" activates Automatic Mixed Precision (AMP) with {float16,bfloat16}
dataloader_workers: 4
# DataLoader worker processes
songunet_checkpoint_level: 0 # 0 means no checkpointing
# Gradient checkpointing level, value is number of layers to checkpoint
# optimization_mode: True
use_apex_gn: True
torch_compile: True
profile_mode: False

io:
regression_checkpoint_path: /lustre/fsw/portfolios/coreai/users/asui/video-corrdiff-checkpoints/training-state-regression-000513.mdlus
# Path to load the regression checkpoint

# Where to load the regression checkpoint
print_progress_freq: 1000
# How often to print progress
save_checkpoint_freq: 500000
# How often to save the checkpoints, measured in number of processed samples
validation_freq: 5000
# how often to record the validation loss, measured in number of processed samples
validation_steps: 10
# how many loss evaluations are used to compute the validation loss per checkpoint

# Parameters for wandb logging
wandb:
mode: offline
# Configure whether to use wandb: "offline", "online", "disabled"
results_dir: "./wandb"
# Directory to store wandb results
watch_model: false
# If true, wandb will track model parameters and gradients
Loading