Skip to content

Add uu.Conv2d #87

@SwayStar123

Description

@SwayStar123

There is only a Conv1d right now, no Conv2d.

I tried using a pixel shuffle and then a linear layer as a replacement, however it does not work either.

from torch import nn
import torch
from typing import Tuple
from unit_scaling.utils import analyse_module
import unit_scaling as uu

class PatchEmbed(nn.Module):
    def __init__(self, in_channels, embed_dim, patch_size):
        super().__init__()
        self.patch_size = patch_size
        self.linear = uu.Linear(in_channels * patch_size[0] * patch_size[1], embed_dim)
        
    def forward(self, x):
        # Reshape to (B, C*patch_h*patch_w, H/patch_h, W/patch_w)
        B, C, H, W = x.shape
        x = x.view(B, C, H//self.patch_size[0], self.patch_size[0], W//self.patch_size[1], self.patch_size[1])
        x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
        x = x.view(B, C*self.patch_size[0]*self.patch_size[1], -1)
        x = x.transpose(1, 2)  # (B, H'*W', C*patch_h*patch_w)
        x = self.linear(x)  # (B, H'*W', E)
        return x

def show_layer_stats(layer: nn.Module, input_shape: Tuple[int, ...]) -> None:
    input = torch.randn(*input_shape, requires_grad=True)
    output = layer(input)
    output.backward(torch.randn_like(output))
    print(f"# {type(layer).__name__}:")
    for k, v in {
        "output": output.std(),
        "input.grad": input.grad.std(),
        **{f"{name}": param.std() for name, param in layer.named_parameters()},
        **{f"{name}.grad": param.grad.std() for name, param in layer.named_parameters()},
    }.items():
        print(f"{k:>20}.std = {v.item():.2f}")

x = torch.randn(256, 3, 32, 32).requires_grad_()
bwd = torch.randn(256, 1024, 1024)

annotated_code = analyse_module(PatchEmbed(3, 1024, (1, 1)), x, bwd)
print(annotated_code)

show_layer_stats(PatchEmbed(3, 1024, (1, 1)), (256, 3, 32, 32))

Output:

def forward(self, x):  (-> 1.0, <- 18.3)
    getattr_1 = x.shape
    getitem = getattr_1[0]
    getitem_1 = getattr_1[1]
    getitem_2 = getattr_1[2]
    getitem_3 = getattr_1[3]
    floordiv = getitem_2 // 1
    floordiv_1 = getitem_3 // 1
    view = x.view(getitem, getitem_1, floordiv, 1, floordiv_1, 1);  (-> 1.0, <- 18.3)
    permute = view.permute(0, 1, 3, 5, 2, 4);  (-> 1.0, <- 18.3)
    contiguous = permute.contiguous();  (-> 1.0, <- 18.3)
    mul = getitem_1 * 1
    mul_1 = mul * 1
    view_1 = contiguous.view(getitem, mul_1, -1);  (-> 1.0, <- 18.3)
    transpose = view_1.transpose(1, 2);  (-> 1.0, <- 18.3)
    linear_weight = self.linear.weight;  (-> 0.988, <- 1.0)
    linear = U.linear(transpose, linear_weight, None, 'to_output_scale');  (-> 0.988, <- 1.0)
    return linear

# PatchEmbed:
              output.std = 1.01
          input.grad.std = 18.59
       linear.weight.std = 1.01
  linear.weight.grad.std = 1.01

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions