Skip to content

Commit e04fd31

Browse files
committed
Add autoregressive neural network models
1 parent 8d8a099 commit e04fd31

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

torchts/nn/models/autoreg.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
from torch import nn
3+
4+
from torchts.nn.model import TimeSeriesModel
5+
6+
7+
class SimpleAR(TimeSeriesModel):
8+
def __init__(self, p, bias=True, **kwargs):
9+
super().__init__(**kwargs)
10+
self.linear = nn.Linear(p, 1, bias=bias)
11+
12+
def forward(self, x):
13+
return self.linear(x)
14+
15+
16+
class MultiAR(TimeSeriesModel):
17+
def __init__(self, p, k, bias=True, **kwargs):
18+
super().__init__(**kwargs)
19+
self.p = p
20+
self.k = k
21+
self.layers = nn.ModuleList(nn.Linear(k, k, bias=False) for _ in range(p))
22+
self.bias = nn.Parameter(torch.zeros(k)) if bias else None
23+
24+
def forward(self, x):
25+
y = torch.zeros(x.shape[0], self.k)
26+
27+
for i in range(self.p):
28+
y += self.layers[i](x[:, i, :])
29+
30+
if self.bias is not None:
31+
y += self.bias
32+
33+
return y

0 commit comments

Comments
 (0)