Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions torchstat/compute_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@
def compute_flops(module, inp, out):
if isinstance(module, nn.Conv2d):
return compute_Conv2d_flops(module, inp, out)
elif isinstance(module, nn.Conv3d):
return compute_Conv3d_flops(module, inp, out)
elif isinstance(module, nn.BatchNorm2d):
return compute_BatchNorm2d_flops(module, inp, out)
elif isinstance(module, nn.BatchNorm3d):
return compute_BatchNorm3d_flops(module, inp, out)
elif isinstance(module, (nn.AvgPool2d, nn.MaxPool2d)):
return compute_Pool2d_flops(module, inp, out)
elif isinstance(module, (nn.AvgPool3d, nn.MaxPool3d)):
return compute_Pool3d_flops(module, inp, out)
elif isinstance(module, (nn.ReLU, nn.ReLU6, nn.PReLU, nn.ELU, nn.LeakyReLU)):
return compute_ReLU_flops(module, inp, out)
elif isinstance(module, nn.Upsample):
Expand Down Expand Up @@ -47,6 +53,31 @@ def compute_Conv2d_flops(module, inp, out):
return total_flops


def compute_Conv3d_flops(module, inp, out):
# Can have multiple inputs, getting the first one
assert isinstance(module, nn.Conv3d)
assert len(inp.size()) == 5 and len(inp.size()) == len(out.size())

batch_size = inp.size()[0]
in_c = inp.size()[1]
k_t, k_h, k_w = module.kernel_size
out_c, out_t, out_h, out_w = out.size()[1:]
groups = module.groups

filters_per_channel = out_c // groups
conv_per_position_flops = k_t * k_h * k_w * in_c * filters_per_channel
active_elements_count = batch_size * out_t * out_h * out_w

total_conv_flops = conv_per_position_flops * active_elements_count

bias_flops = 0
if module.bias is not None:
bias_flops = out_c * active_elements_count

total_flops = total_conv_flops + bias_flops
return total_flops


def compute_BatchNorm2d_flops(module, inp, out):
assert isinstance(module, nn.BatchNorm2d)
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
Expand All @@ -57,6 +88,16 @@ def compute_BatchNorm2d_flops(module, inp, out):
return batch_flops


def compute_BatchNorm3d_flops(module, inp, out):
assert isinstance(module, nn.BatchNorm3d)
assert len(inp.size()) == 5 and len(inp.size()) == len(out.size())
in_c, in_t, in_h, in_w = inp.size()[1:]
batch_flops = np.prod(inp.shape)
if module.affine:
batch_flops *= 2
return batch_flops


def compute_ReLU_flops(module, inp, out):
assert isinstance(module, (nn.ReLU, nn.ReLU6, nn.PReLU, nn.ELU, nn.LeakyReLU))
batch_size = inp.size()[0]
Expand All @@ -74,12 +115,19 @@ def compute_Pool2d_flops(module, inp, out):
return np.prod(inp.shape)


def compute_Pool3d_flops(module, inp, out):
assert isinstance(module, nn.MaxPool3d) or isinstance(module, nn.AvgPool3d)
assert len(inp.size()) == 5 and len(inp.size()) == len(out.size())
return np.prod(inp.shape)


def compute_Linear_flops(module, inp, out):
assert isinstance(module, nn.Linear)
assert len(inp.size()) == 2 and len(out.size()) == 2
batch_size = inp.size()[0]
return batch_size * inp.size()[1] * out.size()[1]


def compute_Upsample_flops(module, inp, out):
assert isinstance(module, nn.Upsample)
output_size = out[0]
Expand Down
95 changes: 95 additions & 0 deletions torchstat/compute_madd.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,28 @@ def compute_Conv2d_madd(module, inp, out):
return total_mul + total_add


def compute_Conv3d_madd(module, inp, out):
assert isinstance(module, nn.Conv3d)
assert len(inp.size()) == 5 and len(inp.size()) == len(out.size())

in_c = inp.size()[1]
k_t, k_h, k_w = module.kernel_size
out_c, out_t, out_h, out_w = out.size()[1:]
groups = module.groups

# ops per output element
kernel_mul = k_t * k_h * k_w * (in_c // groups)
kernel_add = kernel_mul - 1 + (0 if module.bias is None else 1)

kernel_mul_group = kernel_mul * out_t * out_h * out_w * (out_c // groups)
kernel_add_group = kernel_add * out_t * out_h * out_w * (out_c // groups)

total_mul = kernel_mul_group * groups
total_add = kernel_add_group * groups

return total_mul + total_add


def compute_ConvTranspose2d_madd(module, inp, out):
assert isinstance(module, nn.ConvTranspose2d)
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
Expand All @@ -48,6 +70,27 @@ def compute_ConvTranspose2d_madd(module, inp, out):
return total_mul + total_add


def compute_ConvTranspose3d_madd(module, inp, out):
assert isinstance(module, nn.ConvTranspose3d)
assert len(inp.size()) == 5 and len(inp.size()) == len(out.size())

in_c, in_t, in_h, in_w = inp.size()[1:]
k_t, k_h, k_w = module.kernel_size
out_c, out_t, out_h, out_w = out.size()[1:]
groups = module.groups

kernel_mul = k_t * k_h * k_w * (in_c // groups)
kernel_add = kernel_mul - 1 + (0 if module.bias is None else 1)

kernel_mul_group = kernel_mul * in_t * in_h * in_w * (out_c // groups)
kernel_add_group = kernel_add * in_t * in_h * in_w * (out_c // groups)

total_mul = kernel_mul_group * groups
total_add = kernel_add_group * groups

return total_mul + total_add


def compute_BatchNorm2d_madd(module, inp, out):
assert isinstance(module, nn.BatchNorm2d)
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
Expand All @@ -61,6 +104,19 @@ def compute_BatchNorm2d_madd(module, inp, out):
return 4 * in_c * in_h * in_w


def compute_BatchNorm3d_madd(module, inp, out):
assert isinstance(module, nn.BatchNorm3d)
assert len(inp.size()) == 5 and len(inp.size()) == len(out.size())

in_c, in_t, in_h, in_w = inp.size()[1:]

# 1. sub mean
# 2. div standard deviation
# 3. mul alpha
# 4. add beta
return 4 * in_c * in_t * in_h * in_w


def compute_MaxPool2d_madd(module, inp, out):
assert isinstance(module, nn.MaxPool2d)
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
Expand All @@ -74,6 +130,19 @@ def compute_MaxPool2d_madd(module, inp, out):
return (k_h * k_w - 1) * out_h * out_w * out_c


def compute_MaxPool3d_madd(module, inp, out):
assert isinstance(module, nn.MaxPool3d)
assert len(inp.size()) == 5 and len(inp.size()) == len(out.size())

if isinstance(module.kernel_size, (tuple, list)):
k_t, k_h, k_w = module.kernel_size
else:
k_t, k_h, k_w = module.kernel_size, module.kernel_size, module.kernel_size
out_c, out_t, out_h, out_w = out.size()[1:]

return (k_t * k_h * k_w - 1) * out_t * out_h * out_w * out_c


def compute_AvgPool2d_madd(module, inp, out):
assert isinstance(module, nn.AvgPool2d)
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
Expand All @@ -90,6 +159,22 @@ def compute_AvgPool2d_madd(module, inp, out):
return (kernel_add + kernel_avg) * (out_h * out_w) * out_c


def compute_AvgPool3d_madd(module, inp, out):
assert isinstance(module, nn.AvgPool3d)
assert len(inp.size()) == 5 and len(inp.size()) == len(out.size())

if isinstance(module.kernel_size, (tuple, list)):
k_t, k_h, k_w = module.kernel_size
else:
k_t, k_h, k_w = module.kernel_size, module.kernel_size, module.kernel_size
out_c, out_t, out_h, out_w = out.size()[1:]

kernel_add = k_t * k_h * k_w - 1
kernel_avg = 1

return (kernel_add + kernel_avg) * (out_t * out_h * out_w) * out_c


def compute_ReLU_madd(module, inp, out):
assert isinstance(module, (nn.ReLU, nn.ReLU6))

Expand Down Expand Up @@ -140,14 +225,24 @@ def compute_Bilinear_madd(module, inp1, inp2, out):
def compute_madd(module, inp, out):
if isinstance(module, nn.Conv2d):
return compute_Conv2d_madd(module, inp, out)
elif isinstance(module, nn.Conv3d):
return compute_Conv3d_madd(module, inp, out)
elif isinstance(module, nn.ConvTranspose2d):
return compute_ConvTranspose2d_madd(module, inp, out)
elif isinstance(module, nn.ConvTranspose3d):
return compute_ConvTranspose3d_madd(module, inp, out)
elif isinstance(module, nn.BatchNorm2d):
return compute_BatchNorm2d_madd(module, inp, out)
elif isinstance(module, nn.BatchNorm3d):
return compute_BatchNorm3d_madd(module, inp, out)
elif isinstance(module, nn.MaxPool2d):
return compute_MaxPool2d_madd(module, inp, out)
elif isinstance(module, nn.MaxPool3d):
return compute_MaxPool3d_madd(module, inp, out)
elif isinstance(module, nn.AvgPool2d):
return compute_AvgPool2d_madd(module, inp, out)
elif isinstance(module, nn.AvgPool3d):
return compute_AvgPool3d_madd(module, inp, out)
elif isinstance(module, (nn.ReLU, nn.ReLU6)):
return compute_ReLU_madd(module, inp, out)
elif isinstance(module, nn.Softmax):
Expand Down
40 changes: 40 additions & 0 deletions torchstat/compute_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,18 @@ def compute_memory(module, inp, out):
return compute_PReLU_memory(module, inp, out)
elif isinstance(module, nn.Conv2d):
return compute_Conv2d_memory(module, inp, out)
elif isinstance(module, nn.Conv3d):
return compute_Conv3d_memory(module, inp, out)
elif isinstance(module, nn.BatchNorm2d):
return compute_BatchNorm2d_memory(module, inp, out)
elif isinstance(module, nn.BatchNorm3d):
return compute_BatchNorm3d_memory(module, inp, out)
elif isinstance(module, nn.Linear):
return compute_Linear_memory(module, inp, out)
elif isinstance(module, (nn.AvgPool2d, nn.MaxPool2d)):
return compute_Pool2d_memory(module, inp, out)
elif isinstance(module, (nn.AvgPool3d, nn.MaxPool3d)):
return compute_Pool3d_memory(module, inp, out)
else:
print(f"[Memory]: {type(module).__name__} is not supported!")
return (0, 0)
Expand Down Expand Up @@ -59,6 +65,21 @@ def compute_Conv2d_memory(module, inp, out):
return (mread, mwrite)


def compute_Conv3d_memory(module, inp, out):
# Can have multiple inputs, getting the first one
assert isinstance(module, nn.Conv3d)
assert len(inp.size()) == 5 and len(inp.size()) == len(out.size())

batch_size = inp.size()[0]
in_c = inp.size()[1]
out_c, out_t, out_h, out_w = out.size()[1:]

# This includes weighs with bias if the module contains it.
mread = batch_size * (inp.size()[1:].numel() + num_params(module))
mwrite = batch_size * out_t * out_c * out_h * out_w
return (mread, mwrite)


def compute_BatchNorm2d_memory(module, inp, out):
assert isinstance(module, nn.BatchNorm2d)
assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
Expand All @@ -69,6 +90,16 @@ def compute_BatchNorm2d_memory(module, inp, out):
return (mread, mwrite)


def compute_BatchNorm3d_memory(module, inp, out):
assert isinstance(module, nn.BatchNorm3d)
assert len(inp.size()) == 5 and len(inp.size()) == len(out.size())
batch_size, in_c, in_t, in_h, in_w = inp.size()

mread = batch_size * (inp.size()[1:].numel() + 2 * in_c)
mwrite = inp.size().numel()
return (mread, mwrite)


def compute_Linear_memory(module, inp, out):
assert isinstance(module, nn.Linear)
assert len(inp.size()) == 2 and len(out.size()) == 2
Expand All @@ -86,3 +117,12 @@ def compute_Pool2d_memory(module, inp, out):
mread = batch_size * inp.size()[1:].numel()
mwrite = batch_size * out.size()[1:].numel()
return (mread, mwrite)


def compute_Pool3d_memory(module, inp, out):
assert isinstance(module, (nn.MaxPool3d, nn.AvgPool3d))
assert len(inp.size()) == 5 and len(inp.size()) == len(out.size())
batch_size = inp.size()[0]
mread = batch_size * inp.size()[1:].numel()
mwrite = batch_size * out.size()[1:].numel()
return (mread, mwrite)
4 changes: 3 additions & 1 deletion torchstat/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def convert_leaf_modules_to_stat_tree(leaf_modules):
class ModelStat(object):
def __init__(self, model, input_size, query_granularity=1):
assert isinstance(model, nn.Module)
assert isinstance(input_size, (tuple, list)) and len(input_size) == 3
assert (
isinstance(input_size, (tuple, list)) and (
len(input_size) == 3 or len(input_size) == 4))
self._model = model
self._input_size = input_size
self._query_granularity = query_granularity
Expand Down