-
Notifications
You must be signed in to change notification settings - Fork 12
Open
Description
There is only a Conv1d right now, no Conv2d.
I tried using a pixel shuffle and then a linear layer as a replacement, however it does not work either.
from torch import nn
import torch
from typing import Tuple
from unit_scaling.utils import analyse_module
import unit_scaling as uu
class PatchEmbed(nn.Module):
def __init__(self, in_channels, embed_dim, patch_size):
super().__init__()
self.patch_size = patch_size
self.linear = uu.Linear(in_channels * patch_size[0] * patch_size[1], embed_dim)
def forward(self, x):
# Reshape to (B, C*patch_h*patch_w, H/patch_h, W/patch_w)
B, C, H, W = x.shape
x = x.view(B, C, H//self.patch_size[0], self.patch_size[0], W//self.patch_size[1], self.patch_size[1])
x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
x = x.view(B, C*self.patch_size[0]*self.patch_size[1], -1)
x = x.transpose(1, 2) # (B, H'*W', C*patch_h*patch_w)
x = self.linear(x) # (B, H'*W', E)
return x
def show_layer_stats(layer: nn.Module, input_shape: Tuple[int, ...]) -> None:
input = torch.randn(*input_shape, requires_grad=True)
output = layer(input)
output.backward(torch.randn_like(output))
print(f"# {type(layer).__name__}:")
for k, v in {
"output": output.std(),
"input.grad": input.grad.std(),
**{f"{name}": param.std() for name, param in layer.named_parameters()},
**{f"{name}.grad": param.grad.std() for name, param in layer.named_parameters()},
}.items():
print(f"{k:>20}.std = {v.item():.2f}")
x = torch.randn(256, 3, 32, 32).requires_grad_()
bwd = torch.randn(256, 1024, 1024)
annotated_code = analyse_module(PatchEmbed(3, 1024, (1, 1)), x, bwd)
print(annotated_code)
show_layer_stats(PatchEmbed(3, 1024, (1, 1)), (256, 3, 32, 32))
Output:
def forward(self, x): (-> 1.0, <- 18.3)
getattr_1 = x.shape
getitem = getattr_1[0]
getitem_1 = getattr_1[1]
getitem_2 = getattr_1[2]
getitem_3 = getattr_1[3]
floordiv = getitem_2 // 1
floordiv_1 = getitem_3 // 1
view = x.view(getitem, getitem_1, floordiv, 1, floordiv_1, 1); (-> 1.0, <- 18.3)
permute = view.permute(0, 1, 3, 5, 2, 4); (-> 1.0, <- 18.3)
contiguous = permute.contiguous(); (-> 1.0, <- 18.3)
mul = getitem_1 * 1
mul_1 = mul * 1
view_1 = contiguous.view(getitem, mul_1, -1); (-> 1.0, <- 18.3)
transpose = view_1.transpose(1, 2); (-> 1.0, <- 18.3)
linear_weight = self.linear.weight; (-> 0.988, <- 1.0)
linear = U.linear(transpose, linear_weight, None, 'to_output_scale'); (-> 0.988, <- 1.0)
return linear
# PatchEmbed:
output.std = 1.01
input.grad.std = 18.59
linear.weight.std = 1.01
linear.weight.grad.std = 1.01
Metadata
Metadata
Assignees
Labels
No labels