Skip to content

Figconvnet performance improvements #822

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion examples/cfd/external_aerodynamics/figconvnet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,8 @@ def train(config: DictConfig, signal_handler: SignalHandler):
if config.train.lr_scheduler_mode == "iteration":
scheduler.step()
tot_iter += 1
torch.cuda.empty_cache()
# This is a bottleneck for performance.
# torch.cuda.empty_cache()

if config.train.lr_scheduler_mode == "epoch":
scheduler.step()
Expand Down
17 changes: 11 additions & 6 deletions physicsnemo/models/figconvnet/components/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,19 @@ def __init__(self, num_channels: int, data_range: float = 2.0):
self.num_channels = num_channels
self.data_range = data_range

def forward(self, x):
freqs = 2 ** torch.arange(
start=0, end=self.num_channels // 2, device=x.device
).to(x.dtype)
freqs = 2 ** torch.arange(start=0, end=self.num_channels // 2)
freqs = (2 * np.pi / self.data_range) * freqs

self.register_buffer("freqs", freqs)

def forward(self, x):

x = x.unsqueeze(-1)
# Make freq to have the same dimensions as x. X can be of any shape
freqs = freqs.reshape((1,) * (len(x.shape) - 1) + freqs.shape)
# Reshape frequencies to match input dimensions for broadcasting
# Create a shape with 1s for all dimensions except the last one
# For example, if x is (batch, points, 3), create (1, 1, freqs.shape)
broadcast_shape = (1,) * (len(x.shape) - 1) + self.freqs.shape
freqs = self.freqs.reshape(broadcast_shape)
x = x * freqs
x = torch.cat([x.cos(), x.sin()], dim=-1).flatten(start_dim=-2)
return x
99 changes: 91 additions & 8 deletions physicsnemo/models/figconvnet/components/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
from jaxtyping import Float
from torch import Tensor

try:
import transformer_engine.pytorch as te

HAS_TE = True
except ImportError:
HAS_TE = False


class LinearBlock(nn.Module):
"""Simple linear block with ReLU and dropout
Expand All @@ -33,18 +40,26 @@ class LinearBlock(nn.Module):
Number of output channels
activation : type[nn.Module]
Activation function, default nn.GELU
use_te_norm : bool, optional
If True, use transformer_engine LayerNorm, else use torch.nn LayerNorm
"""

def __init__(
self,
in_channels: int,
out_channels: int,
activation: type[nn.Module] = nn.GELU,
use_te_norm: bool = True,
):
super().__init__()
if use_te_norm and not HAS_TE:
raise ImportError(
"transformer_engine is not available but use_te_norm=True. "
"Either install transformer_engine or set use_te_norm=False"
)
self.block = nn.Sequential(
nn.Linear(in_channels, out_channels, bias=False),
nn.LayerNorm(out_channels),
te.LayerNorm(out_channels) if use_te_norm else nn.LayerNorm(out_channels),
activation(),
)

Expand All @@ -53,24 +68,52 @@ def forward(self, x: Float[Tensor, "... C1"]) -> Float[Tensor, "... C2"]:


class ResidualLinearBlock(nn.Module):
"""MLPBlock."""
"""Residual Linear Block. Performs the following operations:
- Linear layer
- LayerNorm
- Activation
- Linear layer
- LayerNorm
- Add skip connection

Parameters
----------
in_channels : int
Number of input channels
out_channels : int
Number of output channels
hidden_channels : int, optional
Number of hidden channels
activation : type[nn.Module]
Activation function, default nn.GELU
use_te_norm : bool, optional
If True, use transformer_engine LayerNorm, else use torch.nn LayerNorm
"""

def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int = None,
activation: type[nn.Module] = nn.GELU,
use_te_norm: bool = True,
):
super().__init__()
if use_te_norm and not HAS_TE:
raise ImportError(
"transformer_engine is not available but use_te_norm=True. "
"Either install transformer_engine or set use_te_norm=False"
)
if hidden_channels is None:
hidden_channels = in_channels
self.blocks = nn.Sequential(
nn.Linear(in_channels, hidden_channels),
nn.LayerNorm(hidden_channels),
te.LayerNorm(hidden_channels)
if use_te_norm
else nn.LayerNorm(hidden_channels),
activation(),
nn.Linear(hidden_channels, out_channels),
nn.LayerNorm(out_channels),
te.LayerNorm(out_channels) if use_te_norm else nn.LayerNorm(out_channels),
)
self.shortcut = (
nn.Identity()
Expand Down Expand Up @@ -101,6 +144,8 @@ class MLP(nn.Module):
Whether to use residual connections, default False.
activation : type[nn.Module]
Activation function, default nn.GELU
use_te_norm : bool, optional
If True, use transformer_engine LayerNorm, else use torch.nn LayerNorm
"""

def __init__(
Expand All @@ -110,6 +155,7 @@ def __init__(
hidden_channels: List[int],
use_residual: bool = False,
activation: type[nn.Module] = nn.GELU,
use_te_norm: bool = True,
):
"""
:param channels: list of channels
Expand All @@ -126,11 +172,17 @@ def __init__(
channels[i],
channels[i + 1],
activation=activation,
use_te_norm=use_te_norm,
)
)
else:
self.layers.append(
LinearBlock(channels[i], channels[i + 1], activation=activation)
LinearBlock(
channels[i],
channels[i + 1],
activation=activation,
use_te_norm=use_te_norm,
)
)

def forward(self, x: Float[Tensor, "... C1"]) -> Float[Tensor, "... C2"]:
Expand All @@ -143,25 +195,56 @@ def forward(self, x: Float[Tensor, "... C1"]) -> Float[Tensor, "... C2"]:


class MLPBlock(nn.Module):
"""MLPBlock."""
"""MLPBlock. Performs the following operations:
- Linear layer
- LayerNorm
- Activation
- Linear layer
- LayerNorm

Parameters
----------
in_channels : int
Number of input channels
hidden_channels : int, optional
Number of hidden channels. Defaults to in_channels if None
out_channels : int, optional
Number of output channels. Defaults to in_channels if None
activation : type[nn.Module]
Activation function, default nn.GELU
use_te_norm : bool, optional
If True, use transformer_engine LayerNorm, else use torch.nn LayerNorm
"""

def __init__(
self,
in_channels: int,
hidden_channels: int = None,
out_channels: int = None,
activation: type[nn.Module] = nn.GELU,
use_te_norm: bool = True,
):
super().__init__()
if use_te_norm and not HAS_TE:
raise ImportError(
"transformer_engine is not available but use_te_norm=True. "
"Either install transformer_engine or set use_te_norm=False"
)
if hidden_channels is None:
hidden_channels = in_channels
if out_channels is None:
out_channels = in_channels
self.in_channels = in_channels
self.fc1 = nn.Linear(in_channels, hidden_channels)
self.norm1 = nn.LayerNorm(hidden_channels)
self.norm1 = (
te.LayerNorm(hidden_channels)
if use_te_norm
else nn.LayerNorm(hidden_channels)
)
self.fc2 = nn.Linear(hidden_channels, out_channels)
self.norm2 = nn.LayerNorm(out_channels)
self.norm2 = (
te.LayerNorm(out_channels) if use_te_norm else nn.LayerNorm(out_channels)
)
self.shortcut = nn.Linear(in_channels, out_channels)
self.activation = activation()

Expand Down
Loading