Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions monai/networks/blocks/patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,22 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm

from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding
from monai.networks.blocks.pos_embed_utils import build_fourier_position_embedding, build_sincos_position_embedding
from monai.networks.layers import Conv, trunc_normal_
from monai.utils import ensure_tuple_rep, optional_import
from monai.utils.module import look_up_option

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
SUPPORTED_PATCH_EMBEDDING_TYPES = {"conv", "perceptron"}
SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"}
SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos", "fourier"}


class PatchEmbeddingBlock(nn.Module):
Expand All @@ -53,6 +54,7 @@ def __init__(
pos_embed_type: str = "learnable",
dropout_rate: float = 0.0,
spatial_dims: int = 3,
pos_embed_kwargs: Optional[dict] = None,
) -> None:
"""
Args:
Expand All @@ -65,6 +67,8 @@ def __init__(
pos_embed_type: position embedding layer type.
dropout_rate: fraction of the input units to drop.
spatial_dims: number of spatial dimensions.
pos_embed_kwargs: additional arguments for position embedding. For `sincos`, it can contain
`temperature` and for fourier it can contain `scales`.
"""

super().__init__()
Expand Down Expand Up @@ -105,6 +109,8 @@ def __init__(
self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
self.dropout = nn.Dropout(dropout_rate)

pos_embed_kwargs = {} if pos_embed_kwargs is None else pos_embed_kwargs

if self.pos_embed_type == "none":
pass
elif self.pos_embed_type == "learnable":
Expand All @@ -114,7 +120,17 @@ def __init__(
for in_size, pa_size in zip(img_size, patch_size):
grid_size.append(in_size // pa_size)

self.position_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims)
self.position_embeddings = build_sincos_position_embedding(
grid_size, hidden_size, spatial_dims, **pos_embed_kwargs
)
elif self.pos_embed_type == "fourier":
grid_size = []
for in_size, pa_size in zip(img_size, patch_size):
grid_size.append(in_size // pa_size)

self.position_embeddings = build_fourier_position_embedding(
grid_size, hidden_size, spatial_dims, **pos_embed_kwargs
)
else:
raise ValueError(f"pos_embed_type {self.pos_embed_type} not supported.")

Expand Down
46 changes: 45 additions & 1 deletion monai/networks/blocks/pos_embed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
import torch.nn as nn

__all__ = ["build_sincos_position_embedding"]
__all__ = ["build_sincos_position_embedding", "build_fourier_position_embedding"]


# From PyTorch internals
Expand All @@ -32,6 +32,50 @@ def parse(x):
return parse


def build_fourier_position_embedding(
grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, scales: Union[float, List[float]] = 1.0
):
"""
Builds a (Anistropic) Fourier Feature based positional encoding based on the given grid size, embed dimension,
spatial dimensions, and scales. The scales control the variance of the Fourier features, higher values make distant
points more distinguishable.
Reference: https://arxiv.org/abs/2509.02488

Args:
grid_size (List[int]): The size of the grid in each spatial dimension.
embed_dim (int): The dimension of the embedding.
spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D).
scales (List[float]): The scale for every spatial dimension. If a single float is provided,
the same scale is used for all dimensions.

Returns:
pos_embed (nn.Parameter): The Fourier feature position embedding as a fixed parameter.
"""

to_tuple = _ntuple(spatial_dims)
grid_size = to_tuple(grid_size)

scales = torch.tensor(scales)
if scales.ndim > 1 and scales.ndim != spatial_dims:
raise ValueError("Scales must be either a float or a list of floats with length equal to spatial_dims")
if scales.ndim == 0:
scales = scales.repeat(spatial_dims)

gaussians = torch.normal(0.0, 1.0, (embed_dim // 2, spatial_dims))
gaussians = gaussians * scales

positions = [torch.linspace(0, 1, x) for x in grid_size]
positions = torch.stack(torch.meshgrid(*positions, indexing="ij"), axis=-1)
positions = positions.flatten(end_dim=-2)

x_proj = (2.0 * torch.pi * positions) @ gaussians.T

pos_emb = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], axis=-1)
pos_emb = pos_emb[None, :, :]

return nn.Parameter(pos_emb, requires_grad=False)


def build_sincos_position_embedding(
grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, temperature: float = 10000.0
) -> torch.nn.Parameter:
Expand Down
38 changes: 38 additions & 0 deletions tests/networks/blocks/test_patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,19 @@ def test_sincos_pos_embed(self):

self.assertEqual(net.position_embeddings.requires_grad, False)

def test_fourier_pos_embed(self):
net = PatchEmbeddingBlock(
in_channels=1,
img_size=(32, 32, 32),
patch_size=(8, 8, 8),
hidden_size=96,
num_heads=8,
pos_embed_type="fourier",
dropout_rate=0.5,
)

self.assertEqual(net.position_embeddings.requires_grad, False)

def test_learnable_pos_embed(self):
net = PatchEmbeddingBlock(
in_channels=1,
Expand All @@ -101,6 +114,31 @@ def test_learnable_pos_embed(self):
self.assertEqual(net.position_embeddings.requires_grad, True)

def test_ill_arg(self):
with self.assertRaises(ValueError):
PatchEmbeddingBlock(
in_channels=1,
img_size=(128, 128, 128),
patch_size=(16, 16, 16),
hidden_size=128,
num_heads=12,
proj_type="conv",
dropout_rate=5.0,
pos_embed_type="fourier",
pos_embed_kwargs=dict(scales=[1.0, 1.0]),
)

PatchEmbeddingBlock(
in_channels=1,
img_size=(128, 128),
patch_size=(16, 16),
hidden_size=128,
num_heads=12,
proj_type="conv",
dropout_rate=5.0,
pos_embed_type="fourier",
pos_embed_kwargs=dict(scales=[1.0, 1.0, 1.0]),
)

with self.assertRaises(ValueError):
PatchEmbeddingBlock(
in_channels=1,
Expand Down
Loading