We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
LinearInterpolation
1 parent 737bf39 commit e050614Copy full SHA for e050614
diffrax/global_interpolation.py
@@ -125,10 +125,11 @@ def _index(_ys):
125
prev_t = self.ts[index]
126
next_t = self.ts[index + 1]
127
diff_t = next_t - prev_t
128
-
129
- return (
130
- prev_ys**ω + (next_ys**ω - prev_ys**ω) * (fractional_part / diff_t)
131
- ).ω
+ return jnp.where(
+ diff_t >= jnp.finfo(diff_t.dtype).eps,
+ (prev_ys**ω + (next_ys**ω - prev_ys**ω) * (fractional_part / diff_t)).ω,
+ prev_ys
132
+ )
133
134
@eqx.filter_jit
135
def derivative(self, t: Scalar, left: bool = True) -> PyTree:
0 commit comments