Skip to content

adding t-edm #834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open

adding t-edm #834

wants to merge 5 commits into from

Conversation

mnabian
Copy link
Collaborator

@mnabian mnabian commented Mar 31, 2025

PhysicsNeMo Pull Request

Description

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

@mnabian mnabian requested a review from CharlelieLrt April 1, 2025 00:05
@mnabian mnabian self-assigned this Apr 1, 2025
@mnabian
Copy link
Collaborator Author

mnabian commented Apr 1, 2025

/blossom-ci

3 similar comments
@mnabian
Copy link
Collaborator Author

mnabian commented Apr 1, 2025

/blossom-ci

@mnabian
Copy link
Collaborator Author

mnabian commented Apr 1, 2025

/blossom-ci

@Alexey-Kamenev
Copy link
Collaborator

/blossom-ci

@mnabian
Copy link
Collaborator Author

mnabian commented Apr 1, 2025

/blossom-ci

Copy link
Collaborator

@CharlelieLrt CharlelieLrt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a few comments and questions mostly about:

  1. More thorough docstrings, including tensor shapes and expected signatures of callable arguments. There are now many samplers and preconditioners, with different signatures and so on, and it starts to get really challenging to use them properly.
  2. Test coverage for new features

I have two major comments:

  1. This PR introduces tEDMPrecond and tEDMLoss, but never actually uses them in the train.py. I want to make sure that this is not an omission and that is really the purpose?
  2. I understand that for generation, the tEDMPrecond is only expected to work in combination with the deterministic_sampler. However, I believe the deterministic_sampler suffers from multiple bugs (mostly arguments mismatches), so I doubt that the combination tEDMPrecond + deterministic_sampler is currently usable as it is.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO these parameters should go in the conf/sampler/ configs file rather than the generation

@@ -111,6 +111,15 @@ def main(cfg: DictConfig) -> None:
else:
logger0.info("Patch-based training disabled")

# Parse the t-distribution parameters
Copy link
Collaborator

@CharlelieLrt CharlelieLrt Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checks unnecessary as they are already handled by the default kwargs of diffusion_step. Can we make sure that either:

  1. APIs are robust and can handle default values
  2. Be more explicit and enforce default values to always be specified in config files

@@ -164,6 +173,10 @@ def main(cfg: DictConfig) -> None:
solver=cfg.sampler.solver,
)
elif cfg.sampler.type == "stochastic":
if use_t_latents:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly we don't support t-EDM + patch-based generation for now. Might be helpful to raise an error to avoid this combination

@@ -79,6 +79,8 @@ def diffusion_step( # TODO generalize the module and add defaults
device: torch.device,
hr_mean: torch.Tensor = None,
lead_time_label: torch.Tensor = None,
use_t_latents: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test to test the diffusion_step with use_t_latents=True?

@@ -79,6 +79,8 @@ def diffusion_step( # TODO generalize the module and add defaults
device: torch.device,
hr_mean: torch.Tensor = None,
lead_time_label: torch.Tensor = None,
use_t_latents: bool = False,
nu: int = 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two options are not documented in the docstring


Parameters
----------
P_mean: float, optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be missing something, but this t-EDM loss is NOT supposed to be used with CorrDiff, right? The fact that this loss is introduced but never actually used probably does not help to understand, but if it is indeed supposed to be used with CorrDiff training, there should be a regression model somewhere in tEDMLoss.__init__, as it is the case for the classical CorrDiff ResLoss

net: torch.nn.Module
The neural network model that will make predictions.

images: torch.Tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add the shape of the input/output tensors? I found it a little counter-intuitive that we don't actually output a loss, but a pixelwise squared difference (so it needs reduction by mean or sum afterwards)

Ideally, the docstring should also include the expected signature of the net argument. I realized that the signature of tEDMPrecond is not the same as other CorrDiff preconditioner like EDMPrecond, so not specifying the signature will inevitably lead to confusion

auto_grad: bool = False


class tEDMPrecond(Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test for tEDMPrecond

sigma_min=0,
sigma_max=float("inf"),
sigma_data=0.5,
model_type="DhariwalUNet",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure that we want DhariwalUNet to be the default architecture? I thought SongUNet would make more sense?
In any case, would it be possible to document the possible values for the model_type argument?
If for now we want to restrict it to DhariwalUNet we should remove this option.

**model_kwargs,
)

def forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this foward will not work with the current deterministic_sampler, as the condition kwarg will be ignored. See this snippet from deterministic_sampler.py:

        if isinstance(net, EDMPrecond):
            # Conditioning info is passed as keyword arg
            denoised = net(
                x_hat / s(t_hat),
                sigma(t_hat),
                condition=x_lr,
                class_labels=class_labels,
            ).to(torch.float64)
        else:
            denoised = net(x_hat / s(t_hat), x_lr, sigma(t_hat), class_labels).to(
                torch.float64
            )

Note: When working on other CorrDiff things I realized the deterministic_sampler is mostly broken. I am not even sure there is a single use case where it can work properly. I did not address these bugs because my understanding is that deterministic_sampler is almost not used at all.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants