diff --git a/models/model_clam.py b/models/model_clam.py index 5c727e70..0c0b3e1e 100755 --- a/models/model_clam.py +++ b/models/model_clam.py @@ -82,9 +82,9 @@ def __init__(self, gate = True, size_arg = "small", dropout = 0., k_sample=8, n_ size = self.size_dict[size_arg] fc = [nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout)] if gate: - attention_net = Attn_Net_Gated(L = size[1], D = size[2], dropout = dropout, n_classes = 1) + attention_net = Attn_Net_Gated(L = size[1], D = size[2], dropout = dropout, n_classes = n_classes) else: - attention_net = Attn_Net(L = size[1], D = size[2], dropout = dropout, n_classes = 1) + attention_net = Attn_Net(L = size[1], D = size[2], dropout = dropout, n_classes = n_classes) fc.append(attention_net) self.attention_net = nn.Sequential(*fc) self.classifiers = nn.Linear(size[1], n_classes)