Skip to content

Commit 2bda02c

Browse files
committed
Add ODE solver
1 parent 6b04b6d commit 2bda02c

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

torchts/nn/models/ode.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import torch
2+
from torch import nn
3+
4+
from torchts.nn.model import TimeSeriesModel
5+
6+
7+
class ODESolver(TimeSeriesModel):
8+
def __init__(
9+
self, ode, init_vars, init_coeffs, dt, solver="euler", outvar=None, **kwargs
10+
):
11+
super().__init__(**kwargs)
12+
13+
if ode.keys() != init_vars.keys():
14+
raise ValueError("Inconsistent keys in ode and init_vars")
15+
16+
if solver == "euler":
17+
self.solver = self.euler
18+
else:
19+
raise ValueError(f"Unrecognized solver {solver}")
20+
21+
for name, value in init_coeffs.items():
22+
self.register_parameter(name, nn.Parameter(torch.tensor(value)))
23+
24+
self.ode = ode
25+
self.var_names = ode.keys()
26+
self.init_vars = {
27+
name: torch.tensor(value, device=self.device)
28+
for name, value in init_vars.items()
29+
}
30+
self.coeffs = {name: param for name, param in self.named_parameters()}
31+
self.outvar = self.var_names if outvar is None else outvar
32+
self.dt = dt
33+
34+
def euler(self, nt):
35+
pred = {name: value.unsqueeze(0) for name, value in self.init_vars.items()}
36+
37+
for n in range(nt - 1):
38+
# create dictionary containing values from previous time step
39+
prev_val = {var: pred[var][[n]] for var in self.var_names}
40+
41+
for var in self.var_names:
42+
new_val = prev_val[var] + self.ode[var](prev_val, self.coeffs) * self.dt
43+
pred[var] = torch.cat([pred[var], new_val])
44+
45+
# reformat output to contain desired (observed) variables
46+
return torch.stack([pred[var] for var in self.outvar], dim=1)
47+
48+
def forward(self, nt):
49+
return self.solver(nt)
50+
51+
def get_coeffs(self):
52+
return {name: param.item() for name, param in self.named_parameters()}
53+
54+
def _step(self, batch, batch_idx, num_batches):
55+
(x,) = batch
56+
nt = x.shape[0]
57+
pred = self(nt)
58+
return self.criterion(pred, x)

0 commit comments

Comments
 (0)