Skip to content

Commit 9274e6f

Browse files
Cached padding by prepradding before random crop
1 parent 1e6f1bf commit 9274e6f

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

utils/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77
from timed_decorator.simple_timed import timed
8-
from torch import GradScaler, Tensor, nn
8+
from torch import GradScaler, Tensor
99
from torch.backends import cudnn
1010
from torch.nn.utils import clip_grad_norm_
1111
from torch.utils.tensorboard import SummaryWriter

utils/transforms.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def forward(self, x: Tensor) -> Tensor:
3939
class StepCompose(v2.Compose):
4040
def __init__(self, transforms: Sequence[Callable]):
4141
super().__init__(transforms)
42-
self.need_step = [x for x in self.transforms if hasattr(x, "step")]
4342

4443
def init(self, x: Tensor) -> Tensor:
4544
for transform in self.transforms:
@@ -126,14 +125,13 @@ def train_cached(self):
126125
[
127126
v2.ToImage(),
128127
v2.ToDtype(torch.float32, scale=True),
128+
v2.Pad(padding=4, fill=0 if self.args.fill is None else self.args.fill)
129129
]
130130
)
131131

132132
def train_runtime(self):
133133
transforms = [
134-
v2.RandomCrop(
135-
32, padding=4, fill=0 if self.args.fill is None else self.args.fill
136-
),
134+
v2.RandomCrop(32),
137135
]
138136

139137
if self.args.autoaug:

0 commit comments

Comments
 (0)