-
Notifications
You must be signed in to change notification settings - Fork 322
adding t-edm #834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
adding t-edm #834
Conversation
/blossom-ci |
3 similar comments
/blossom-ci |
/blossom-ci |
/blossom-ci |
/blossom-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left a few comments and questions mostly about:
- More thorough docstrings, including tensor shapes and expected signatures of callable arguments. There are now many samplers and preconditioners, with different signatures and so on, and it starts to get really challenging to use them properly.
- Test coverage for new features
I have two major comments:
- This PR introduces
tEDMPrecond
andtEDMLoss
, but never actually uses them in thetrain.py
. I want to make sure that this is not an omission and that is really the purpose? - I understand that for generation, the
tEDMPrecond
is only expected to work in combination with thedeterministic_sampler
. However, I believe thedeterministic_sampler
suffers from multiple bugs (mostly arguments mismatches), so I doubt that the combinationtEDMPrecond
+deterministic_sampler
is currently usable as it is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO these parameters should go in the conf/sampler/ configs file rather than the generation
@@ -111,6 +111,15 @@ def main(cfg: DictConfig) -> None: | |||
else: | |||
logger0.info("Patch-based training disabled") | |||
|
|||
# Parse the t-distribution parameters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checks unnecessary as they are already handled by the default kwargs of diffusion_step
. Can we make sure that either:
- APIs are robust and can handle default values
- Be more explicit and enforce default values to always be specified in config files
@@ -164,6 +173,10 @@ def main(cfg: DictConfig) -> None: | |||
solver=cfg.sampler.solver, | |||
) | |||
elif cfg.sampler.type == "stochastic": | |||
if use_t_latents: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly we don't support t-EDM + patch-based generation for now. Might be helpful to raise an error to avoid this combination
@@ -79,6 +79,8 @@ def diffusion_step( # TODO generalize the module and add defaults | |||
device: torch.device, | |||
hr_mean: torch.Tensor = None, | |||
lead_time_label: torch.Tensor = None, | |||
use_t_latents: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a test to test the diffusion_step with use_t_latents=True
?
@@ -79,6 +79,8 @@ def diffusion_step( # TODO generalize the module and add defaults | |||
device: torch.device, | |||
hr_mean: torch.Tensor = None, | |||
lead_time_label: torch.Tensor = None, | |||
use_t_latents: bool = False, | |||
nu: int = 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These two options are not documented in the docstring
|
||
Parameters | ||
---------- | ||
P_mean: float, optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might be missing something, but this t-EDM loss is NOT supposed to be used with CorrDiff, right? The fact that this loss is introduced but never actually used probably does not help to understand, but if it is indeed supposed to be used with CorrDiff training, there should be a regression model somewhere in tEDMLoss.__init__
, as it is the case for the classical CorrDiff ResLoss
net: torch.nn.Module | ||
The neural network model that will make predictions. | ||
|
||
images: torch.Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add the shape of the input/output tensors? I found it a little counter-intuitive that we don't actually output a loss, but a pixelwise squared difference (so it needs reduction by mean
or sum
afterwards)
Ideally, the docstring should also include the expected signature of the net
argument. I realized that the signature of tEDMPrecond
is not the same as other CorrDiff preconditioner like EDMPrecond
, so not specifying the signature will inevitably lead to confusion
auto_grad: bool = False | ||
|
||
|
||
class tEDMPrecond(Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a test for tEDMPrecond
sigma_min=0, | ||
sigma_max=float("inf"), | ||
sigma_data=0.5, | ||
model_type="DhariwalUNet", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we sure that we want DhariwalUNet
to be the default architecture? I thought SongUNet
would make more sense?
In any case, would it be possible to document the possible values for the model_type
argument?
If for now we want to restrict it to DhariwalUNet
we should remove this option.
**model_kwargs, | ||
) | ||
|
||
def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this foward
will not work with the current deterministic_sampler
, as the condition
kwarg will be ignored. See this snippet from deterministic_sampler.py
:
if isinstance(net, EDMPrecond):
# Conditioning info is passed as keyword arg
denoised = net(
x_hat / s(t_hat),
sigma(t_hat),
condition=x_lr,
class_labels=class_labels,
).to(torch.float64)
else:
denoised = net(x_hat / s(t_hat), x_lr, sigma(t_hat), class_labels).to(
torch.float64
)
Note: When working on other CorrDiff things I realized the deterministic_sampler
is mostly broken. I am not even sure there is a single use case where it can work properly. I did not address these bugs because my understanding is that deterministic_sampler
is almost not used at all.
PhysicsNeMo Pull Request
Description
Checklist
Dependencies