From bd46688a86d42b14fca6eb2fd9781210eaf12069 Mon Sep 17 00:00:00 2001 From: Sam Lishak Date: Thu, 1 Sep 2022 18:06:13 +0100 Subject: [PATCH 1/2] Add ability for ODE function to manually reject an adaptive step --- torchdiffeq/__init__.py | 1 + torchdiffeq/_impl/__init__.py | 1 + torchdiffeq/_impl/rk_common.py | 49 +++++++++++++++++++--------------- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/torchdiffeq/__init__.py b/torchdiffeq/__init__.py index 4eff75dbf..74712bacf 100644 --- a/torchdiffeq/__init__.py +++ b/torchdiffeq/__init__.py @@ -1,4 +1,5 @@ from ._impl import odeint from ._impl import odeint_adjoint from ._impl import odeint_event +from ._impl import RejectStepError __version__ = "0.2.3" diff --git a/torchdiffeq/_impl/__init__.py b/torchdiffeq/_impl/__init__.py index 05b671e9c..6cbcbace2 100644 --- a/torchdiffeq/_impl/__init__.py +++ b/torchdiffeq/_impl/__init__.py @@ -1,2 +1,3 @@ from .odeint import odeint, odeint_event from .adjoint import odeint_adjoint +from .rk_common import RejectStepError \ No newline at end of file diff --git a/torchdiffeq/_impl/rk_common.py b/torchdiffeq/_impl/rk_common.py index 4d85afb18..ca20664a5 100644 --- a/torchdiffeq/_impl/rk_common.py +++ b/torchdiffeq/_impl/rk_common.py @@ -38,6 +38,10 @@ def backward(ctx, grad_scratch): return grad_scratch, grad_scratch[ctx.index], None +class RejectStepError(Exception): + pass + + def _runge_kutta_step(func, y0, f0, t0, dt, t1, tableau): """Take an arbitrary Runge-Kutta step and estimate error. Args: @@ -262,28 +266,31 @@ def _adaptive_step(self, rk_state): # Must be arranged as doing all the step_t handling, then all the jump_t handling, in case we # trigger both. (i.e. interleaving them would be wrong.) - - y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, t1, tableau=self.tableau) - # dtypes: - # y1.dtype == self.y0.dtype - # f1.dtype == self.y0.dtype - # y1_error.dtype == self.dtype - # k.dtype == self.y0.dtype - - ######################################################## - # Error Ratio # - ######################################################## - error_ratio = _compute_error_ratio(y1_error, self.rtol, self.atol, y0, y1, self.norm) - accept_step = error_ratio <= 1 - - # Handle min max stepping - if dt > self.max_step: + try: + y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, t1, tableau=self.tableau) + # dtypes: + # y1.dtype == self.y0.dtype + # f1.dtype == self.y0.dtype + # y1_error.dtype == self.dtype + # k.dtype == self.y0.dtype + except RejectStepError: + # self.func requested the step be rejected accept_step = False - if dt <= self.min_step: - accept_step = True - - # dtypes: - # error_ratio.dtype == self.dtype + else: + ######################################################## + # Error Ratio # + ######################################################## + error_ratio = _compute_error_ratio(y1_error, self.rtol, self.atol, y0, y1, self.norm) + accept_step = error_ratio <= 1 + + # Handle min max stepping + if dt > self.max_step: + accept_step = False + if dt <= self.min_step: + accept_step = True + + # dtypes: + # error_ratio.dtype == self.dtype ######################################################## # Update RK State # From 9f6cf7107323281f6d3815147e5ef9119289da29 Mon Sep 17 00:00:00 2001 From: Sam Lishak Date: Thu, 1 Sep 2022 18:26:08 +0100 Subject: [PATCH 2/2] Min step can be used to prevent infinite loop of timestep getting smaller and smaller --- torchdiffeq/_impl/rk_common.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchdiffeq/_impl/rk_common.py b/torchdiffeq/_impl/rk_common.py index ca20664a5..98d064a6c 100644 --- a/torchdiffeq/_impl/rk_common.py +++ b/torchdiffeq/_impl/rk_common.py @@ -273,8 +273,12 @@ def _adaptive_step(self, rk_state): # f1.dtype == self.y0.dtype # y1_error.dtype == self.dtype # k.dtype == self.y0.dtype - except RejectStepError: + except RejectStepError as ex: # self.func requested the step be rejected + # If already at minimum step size, stop integration as can't proceed + if dt <= self.min_step: + raise(ex) + error_ratio = torch.tensor(10.0, dtype=self.dtype, device=self.y0.device) accept_step = False else: ########################################################