Skip to content

Commit 6b04b6d

Browse files
committed
Fix inplace operation in DCRNN fully-connected gates (#59)
* Remove redundant sigmoid call * Rename internal variable inputs_and_state * Remove inplace operation * Remove NotImplementedError * Simplify reset and update gate size
1 parent 3e6f282 commit 6b04b6d

File tree

1 file changed

+7
-17
lines changed

1 file changed

+7
-17
lines changed

torchts/nn/graph.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,9 @@ def __init__(
4343
self.register_buffer("_supports", supports)
4444

4545
num_matrices = len(supports) * self._max_diffusion_step + 1
46-
input_size_gconv = (self._num_units + input_dim) * num_matrices
47-
48-
if self._use_gc_for_ru:
49-
input_size_ru = input_size_gconv
50-
else:
51-
input_size_ru = self._num_units + input_dim
52-
raise NotImplementedError(
53-
"Fully-connected reset and update gates not yet implemented"
54-
)
46+
input_size_fc = self._num_units + input_dim
47+
input_size_gconv = input_size_fc * num_matrices
48+
input_size_ru = input_size_gconv if self._use_gc_for_ru else input_size_fc
5549

5650
output_size = 2 * self._num_units
5751
self._ru_weights = nn.Parameter(torch.empty(input_size_ru, output_size))
@@ -85,22 +79,18 @@ def _fc(self, inputs, state, output_size, bias_start=0.0, reset=True):
8579
shape = (batch_size * self._num_nodes, -1)
8680
inputs = torch.reshape(inputs, shape)
8781
state = torch.reshape(state, shape)
88-
inputs_and_state = torch.cat([inputs, state], dim=-1)
89-
90-
value = torch.sigmoid(torch.matmul(inputs_and_state, self._ru_weights))
91-
value += self._ru_biases
82+
x = torch.cat([inputs, state], dim=-1)
9283

93-
return value
84+
return torch.matmul(x, self._ru_weights) + self._ru_biases
9485

9586
def _gconv(self, inputs, state, output_size, bias_start=0.0, reset=False):
9687
batch_size = inputs.shape[0]
9788
shape = (batch_size, self._num_nodes, -1)
9889
inputs = torch.reshape(inputs, shape)
9990
state = torch.reshape(state, shape)
100-
inputs_and_state = torch.cat([inputs, state], dim=2)
101-
input_size = inputs_and_state.size(2)
91+
x = torch.cat([inputs, state], dim=2)
92+
input_size = x.size(2)
10293

103-
x = inputs_and_state
10494
x0 = x.permute(1, 2, 0)
10595
x0 = torch.reshape(x0, shape=[self._num_nodes, input_size * batch_size])
10696
x = torch.unsqueeze(x0, 0)

0 commit comments

Comments
 (0)