diff --git a/difflogic/functional.py b/difflogic/functional.py index 1fd181c..ecce76f 100644 --- a/difflogic/functional.py +++ b/difflogic/functional.py @@ -78,6 +78,9 @@ def get_unique_connections(in_dim, out_dim, device='cuda'): '({}) because otherwise not all inputs could be used or considered.'.format( out_dim, in_dim ) + n_max = int(in_dim * (in_dim - 1) / 2) + assert out_dim <= n_max, 'The number of neurons ({}) must not be greater than the number of pair-wise combinations ' \ + 'of the inputs ({})'.format(out_dim, n_max) x = torch.arange(in_dim).long().unsqueeze(0) @@ -100,7 +103,7 @@ def get_unique_connections(in_dim, out_dim, device='cuda'): # If this was not enough, take pairs with offsets >= 2: offset = 2 - while out_dim > a.shape[-1] > offset: + while out_dim > a.shape[-1]: a_, b_ = x[..., :-offset], x[..., offset:] a = torch.cat([a, a_], dim=-1) b = torch.cat([b, b_], dim=-1)