diff --git a/torchdiffeq/_impl/rk_common.py b/torchdiffeq/_impl/rk_common.py index 49cb2fe78..3e8fd7c63 100644 --- a/torchdiffeq/_impl/rk_common.py +++ b/torchdiffeq/_impl/rk_common.py @@ -230,7 +230,7 @@ def _adaptive_step(self, rk_state): def _interp_fit(self, y0, y1, k, dt): """Fit an interpolating polynomial to the results of a Runge-Kutta step.""" dt = dt.type_as(y0) - y_mid = y0 + k.matmul(dt * self.mid).view_as(y0) + y_mid = y0 + k.matmul(dt * self.mid) f0 = k[..., 0] f1 = k[..., -1] return _interp_fit(y0, y1, y_mid, f0, f1, dt)