@@ -30,12 +30,8 @@ def __init__(self, stiffness, mass, alpha_f=0.4, alpha_m=0.2) -> None:
30
30
self .stiffness = stiffness
31
31
self .mass = mass
32
32
33
- def rhs_eval_points (self , dt ) -> List [float ]:
34
- return [(1 - self .alpha_f ) * dt ]
35
-
36
- def do_step (self , u , v , a , f , dt ) -> Tuple [float , float , float ]:
37
- if isinstance (f , list ): # if f is list, turn it into a number
38
- f = f [0 ]
33
+ def do_step (self , u , v , a , rhs , dt ) -> Tuple [float , float , float ]:
34
+ f = rhs ((1 - self .alpha_f ) * dt )
39
35
40
36
m = 3 * [None ]
41
37
m [0 ] = (1 - self .alpha_m ) / (self .beta * dt ** 2 )
@@ -71,31 +67,25 @@ def __init__(self, ode_system) -> None:
71
67
self .ode_system = ode_system
72
68
pass
73
69
74
- def rhs_eval_points (self , dt ) -> List [float ]:
75
- return [self .c [0 ] * dt , self .c [1 ] * dt , self .c [2 ] * dt , self .c [3 ] * dt ]
76
-
77
- def do_step (self , u , v , a , f , dt ) -> Tuple [float , float , float ]:
70
+ def do_step (self , u , v , a , rhs , dt ) -> Tuple [float , float , float ]:
78
71
assert (isinstance (u , type (v )))
79
72
80
73
n_stages = 4
81
74
82
- if isinstance (f , numbers .Number ): # if f is number, assume constant f
83
- f = n_stages * [f ]
84
-
85
75
if isinstance (u , np .ndarray ):
86
76
x = np .concatenate ([u , v ])
87
- rhs = [ np .concatenate ([np .array ([0 , 0 ]), f [ i ]]) for i in range ( n_stages )]
77
+ def f ( t ): return np .concatenate ([np .array ([0 , 0 ]), rhs ( t )])
88
78
elif isinstance (u , numbers .Number ):
89
79
x = np .array ([u , v ])
90
- rhs = [ np .array ([0 , f [ i ]]) for i in range ( n_stages )]
80
+ def f ( t ): return np .array ([0 , rhs ( t )])
91
81
else :
92
82
raise Exception (f"Cannot handle input type { type (u )} of u and v" )
93
83
94
84
s = n_stages * [None ]
95
- s [0 ] = self .ode_system .dot (x ) + rhs [0 ]
96
- s [1 ] = self .ode_system .dot (x + self .a [1 , 0 ] * s [0 ] * dt ) + rhs [1 ]
97
- s [2 ] = self .ode_system .dot (x + self .a [2 , 1 ] * s [1 ] * dt ) + rhs [2 ]
98
- s [3 ] = self .ode_system .dot (x + self .a [3 , 2 ] * s [2 ] * dt ) + rhs [3 ]
85
+ s [0 ] = self .ode_system .dot (x ) + f ( self . c [0 ] * dt )
86
+ s [1 ] = self .ode_system .dot (x + self .a [1 , 0 ] * s [0 ] * dt ) + f ( self . c [1 ] * dt )
87
+ s [2 ] = self .ode_system .dot (x + self .a [2 , 1 ] * s [1 ] * dt ) + f ( self . c [2 ] * dt )
88
+ s [3 ] = self .ode_system .dot (x + self .a [3 , 2 ] * s [2 ] * dt ) + f ( self . c [3 ] * dt )
99
89
100
90
x_new = x
101
91
@@ -119,14 +109,7 @@ def __init__(self, ode_system) -> None:
119
109
self .ode_system = ode_system
120
110
pass
121
111
122
- def rhs_eval_points (self , dt ) -> List [float ]:
123
- return np .linspace (0 , dt , 5 ) # will create an interpolant from this later
124
-
125
- def do_step (self , u , v , a , f , dt ) -> Tuple [float , float , float ]:
126
- from brot .interpolation import do_lagrange_interpolation
127
-
128
- ts = self .rhs_eval_points (dt )
129
-
112
+ def do_step (self , u , v , a , rhs , dt ) -> Tuple [float , float , float ]:
130
113
t0 = 0
131
114
132
115
assert (isinstance (u , type (v )))
@@ -135,25 +118,24 @@ def do_step(self, u, v, a, f, dt) -> Tuple[float, float, float]:
135
118
x0 = np .concatenate ([u , v ])
136
119
f = np .array (f )
137
120
assert (u .shape [0 ] == f .shape [1 ])
138
- def rhs_fun (t , x ): return np .concatenate ([np .array ([np .zeros_like (t ), np .zeros_like (t )]), [
139
- do_lagrange_interpolation (t , ts , f [:, i ]) for i in range (u .shape [0 ])]])
121
+ def rhs_fun (t ): return np .concatenate ([np .array ([np .zeros_like (t ), np .zeros_like (t )]), rhs (t )])
140
122
elif isinstance (u , numbers .Number ):
141
123
x0 = np .array ([u , v ])
142
- def rhs_fun (t , x ): return np .array ([np .zeros_like (t ), do_lagrange_interpolation ( t , ts , f )])
124
+ def rhs_fun (t ): return np .array ([np .zeros_like (t ), rhs ( t )])
143
125
else :
144
126
raise Exception (f"Cannot handle input type { type (u )} of u and v" )
145
127
146
128
def fun (t , x ):
147
- return self .ode_system .dot (x ) + rhs_fun (t , x )
129
+ return self .ode_system .dot (x ) + rhs_fun (t )
148
130
149
- # use large rtol and atol to circumvent error control.
131
+ # use adaptive time stepping; dense_output=True allows us to sample from continuous function later
150
132
ret = sp .integrate .solve_ivp (fun , [t0 , t0 + dt ], x0 , method = "Radau" ,
151
- first_step = dt , max_step = dt , rtol = 10e10 , atol = 10e10 )
133
+ dense_output = True , rtol = 10e-5 , atol = 10e-9 )
152
134
153
135
a_new = None
154
136
if isinstance (u , np .ndarray ):
155
137
u_new , v_new = ret .y [0 :2 , - 1 ], ret .y [2 :4 , - 1 ]
156
138
elif isinstance (u , numbers .Number ):
157
139
u_new , v_new = ret .y [:, - 1 ]
158
140
159
- return u_new , v_new , a_new
141
+ return u_new , v_new , a_new , ret . sol
0 commit comments