diff --git a/criteria/moco_loss.py b/criteria/moco_loss.py index 8fb13fb..515b056 100644 --- a/criteria/moco_loss.py +++ b/criteria/moco_loss.py @@ -43,7 +43,9 @@ def extract_feats(self, x): x = F.interpolate(x, size=224) x_feats = self.model(x) x_feats = nn.functional.normalize(x_feats, dim=1) - x_feats = x_feats.squeeze() + # x_feats = x_feats.squeeze() + bs, feat_dim, _, _ = x_feats + x_feats = x_feats.reshape((bs, feat_dum)) return x_feats def forward(self, y_hat, y, x):