Skip to content

Commit e1cc1ce

Browse files
[ENH] Kolmogorov Arnold Block for NBeats (#1751)
### Description Fixes: #1741 This PR adds Kolmogorov Arnold(KAN) Blocks in NBeats and also does refactoring of NBeats. Implementation of KAN blocks' layers is taken from [original paper code](https://github.com/KindXiaoming/pykan/tree/master). ### Changes in Structure * Introduced the `NBEATSKAN` module, which enables usage of **KAN blocks** within the `NBEATS` architecture. * Integrated `KANLayer` logic, implemented in `kan_layer.py`, which handles KAN-specific operations such as: * **Spline coefficient computation**, * **Grid initialization and updates**, etc. * Imported `KANLayer` to `submodules.py` for block operations, allowing `NBEATSKAN` to delegate block-level behavior through `use_kan=True`. * Added the `NBEATSAdapter` class to encapsulate **common methods** shared by both `NBEATS` and `NBEATSKAN`, including: * Standard training, forward logic, etc. * **Excludes** block initialization (`__init__`), which is separately defined in each class to maintain architectural flexibility. ### GridUpdateCallback When training **KAN-based models**, the grid can be **iteratively refined** during training for better performance. To support this, logic from the original [`[pykan](https://github.com/KindXiaoming/pykan)`](https://github.com/KindXiaoming/pykan) implementation has been adapted to define a **custom callback** named `GridUpdateCallback`. This callback automatically updates the grid at specified training steps, improving model accuracy and convergence. This callback has been **tested successfully** and demonstrates **improved results** in practice. An example usage is provided in: `examples/nbeats_with_kan.py`
1 parent 21a46a2 commit e1cc1ce

File tree

17 files changed

+1900
-575
lines changed

17 files changed

+1900
-575
lines changed

docs/source/models.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ and you should take into account. Here is an overview over the pros and cons of
2727
:py:class:`~pytorch_forecasting.models.rnn.RecurrentNetwork`, "x", "x", "x", "", "", "", "", "x", "", 2
2828
:py:class:`~pytorch_forecasting.models.mlp.DecoderMLP`, "x", "x", "x", "x", "", "x", "", "x", "x", 1
2929
:py:class:`~pytorch_forecasting.models.nbeats.NBeats`, "", "", "x", "", "", "", "", "", "", 1
30+
:py:class:`~pytorch_forecasting.models.nbeats.NBeatsKAN`, "", "", "x", "", "", "", "", "", "", 1
3031
:py:class:`~pytorch_forecasting.models.nhits.NHiTS`, "x", "x", "x", "", "", "", "", "", "", 1
3132
:py:class:`~pytorch_forecasting.models.deepar.DeepAR`, "x", "x", "x", "", "x", "x", "x [#deepvar]_ ", "x", "", 3
3233
:py:class:`~pytorch_forecasting.models.temporal_fusion_transformer.TemporalFusionTransformer`, "x", "x", "x", "x", "", "x", "", "x", "x", 4

examples/nbeats_with_kan.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import sys
2+
3+
import lightning.pytorch as pl
4+
from lightning.pytorch.callbacks import EarlyStopping
5+
import pandas as pd
6+
7+
from pytorch_forecasting import NBeatsKAN, TimeSeriesDataSet
8+
from pytorch_forecasting.data import NaNLabelEncoder
9+
from pytorch_forecasting.data.examples import generate_ar_data
10+
from pytorch_forecasting.models.nbeats import GridUpdateCallback
11+
12+
sys.path.append("..")
13+
14+
15+
print("load data")
16+
data = generate_ar_data(seasonality=10.0, timesteps=400, n_series=100)
17+
data["static"] = 2
18+
data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D")
19+
validation = data.series.sample(20)
20+
21+
22+
max_encoder_length = 150
23+
max_prediction_length = 20
24+
25+
training_cutoff = data["time_idx"].max() - max_prediction_length
26+
27+
context_length = max_encoder_length
28+
prediction_length = max_prediction_length
29+
30+
training = TimeSeriesDataSet(
31+
data[lambda x: x.time_idx < training_cutoff],
32+
time_idx="time_idx",
33+
target="value",
34+
categorical_encoders={"series": NaNLabelEncoder().fit(data.series)},
35+
group_ids=["series"],
36+
min_encoder_length=context_length,
37+
max_encoder_length=context_length,
38+
max_prediction_length=prediction_length,
39+
min_prediction_length=prediction_length,
40+
time_varying_unknown_reals=["value"],
41+
randomize_length=None,
42+
add_relative_time_idx=False,
43+
add_target_scales=False,
44+
)
45+
46+
validation = TimeSeriesDataSet.from_dataset(
47+
training, data, min_prediction_idx=training_cutoff
48+
)
49+
batch_size = 128
50+
train_dataloader = training.to_dataloader(
51+
train=True, batch_size=batch_size, num_workers=0
52+
)
53+
val_dataloader = validation.to_dataloader(
54+
train=False, batch_size=batch_size, num_workers=0
55+
)
56+
57+
58+
early_stop_callback = EarlyStopping(
59+
monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min"
60+
)
61+
# updates KAN layers' grid after every 3 steps during training
62+
grid_update_callback = GridUpdateCallback(update_interval=3)
63+
64+
trainer = pl.Trainer(
65+
max_epochs=1,
66+
accelerator="auto",
67+
gradient_clip_val=0.1,
68+
callbacks=[early_stop_callback, grid_update_callback],
69+
limit_train_batches=15,
70+
# limit_val_batches=1,
71+
# fast_dev_run=True,
72+
# logger=logger,
73+
# profiler=True,
74+
)
75+
76+
77+
net = NBeatsKAN.from_dataset(
78+
training,
79+
learning_rate=3e-2,
80+
log_interval=10,
81+
log_val_interval=1,
82+
log_gradient_flow=False,
83+
weight_decay=1e-2,
84+
)
85+
print(f"Number of parameters in network: {net.size() / 1e3:.1f}k")
86+
87+
# # find optimal learning rate
88+
# # remove logging and artificial epoch size
89+
# net.hparams.log_interval = -1
90+
# net.hparams.log_val_interval = -1
91+
# trainer.limit_train_batches = 1.0
92+
# # run learning rate finder
93+
# res = Tuner(trainer).lr_find(
94+
# net, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, min_lr=1e-5, max_lr=1e2 # noqa: E501
95+
# )
96+
# print(f"suggested learning rate: {res.suggestion()}")
97+
# fig = res.plot(show=True, suggest=True)
98+
# fig.show()
99+
# net.hparams.learning_rate = res.suggestion()
100+
101+
trainer.fit(
102+
net,
103+
train_dataloaders=train_dataloader,
104+
val_dataloaders=val_dataloader,
105+
)

pytorch_forecasting/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
DeepAR,
4444
MultiEmbedding,
4545
NBeats,
46+
NBeatsKAN,
4647
NHiTS,
4748
RecurrentNetwork,
4849
TemporalFusionTransformer,
@@ -73,6 +74,7 @@
7374
"TemporalFusionTransformer",
7475
"TiDEModel",
7576
"NBeats",
77+
"NBeatsKAN",
7678
"NHiTS",
7779
"Baseline",
7880
"DeepAR",
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""
2+
KAN (Kolmogorov Arnold Network) layer implementation.
3+
"""
4+
5+
from pytorch_forecasting.layers._kan._kan_layer import KANLayer
6+
7+
__all__ = ["KANLayer"]
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# The following implementation of KANLayer is inspired by the pykan library.
2+
# Reference: https://github.com/KindXiaoming/pykan/blob/master/kan/KANLayer.py
3+
4+
import numpy as np
5+
import torch
6+
import torch.nn as nn
7+
8+
from pytorch_forecasting.layers._kan._utils import (
9+
coef2curve,
10+
curve2coef,
11+
extend_grid,
12+
sparse_mask,
13+
)
14+
15+
16+
class KANLayer(nn.Module):
17+
"""
18+
Initialize a KANLayer
19+
20+
Parameters
21+
----------
22+
in_dim : int
23+
input dimension. Default: 2.
24+
out_dim : int
25+
output dimension. Default: 3.
26+
num : int
27+
the number of grid intervals = G. Default: 5.
28+
k : int
29+
the order of piecewise polynomial. Default: 3.
30+
noise_scale : float
31+
the scale of noise injected at initialization. Default: 0.1.
32+
scale_base_mu : float
33+
the scale of the residual function b(x) is intialized to be
34+
N(scale_base_mu, scale_base_sigma^2).
35+
scale_base_sigma : float
36+
the scale of the residual function b(x) is intialized to be
37+
N(scale_base_mu, scale_base_sigma^2).
38+
scale_sp : float
39+
the scale of the base function spline(x).
40+
base_fun : function
41+
residual function b(x). Default: None
42+
grid_eps : float
43+
When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is
44+
partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates
45+
between the two extremes.
46+
grid_range : list or np.array of shape (2,)
47+
setting the range of grids. Default: None.
48+
sp_trainable : bool
49+
If true, scale_sp is trainable.
50+
sb_trainable : bool
51+
If true, scale_base is trainable.
52+
sparse_init : bool
53+
if sparse_init = True, sparse initialization is applied.
54+
55+
Returns
56+
-------
57+
self : reference to self
58+
59+
Examples
60+
--------
61+
The following is an example from the original `pykan` library, adapted here
62+
for illustration within the PyTorch Forecasting integration.
63+
64+
Install the `pykan` package first:
65+
pip install pykan
66+
Then use:
67+
68+
>>> from kan.KANLayer import *
69+
>>> model = KANLayer(in_dim=3, out_dim=5)
70+
>>> (model.in_dim, model.out_dim)
71+
"""
72+
73+
def __init__(
74+
self,
75+
in_dim=3,
76+
out_dim=2,
77+
num=5,
78+
k=3,
79+
noise_scale=0.5,
80+
scale_base_mu=0.0,
81+
scale_base_sigma=1.0,
82+
scale_sp=1.0,
83+
base_fun=None,
84+
grid_eps=0.02,
85+
grid_range=None,
86+
sp_trainable=True,
87+
sb_trainable=True,
88+
sparse_init=False,
89+
):
90+
super().__init__()
91+
92+
# Handle mutable parameters
93+
if grid_range is None:
94+
grid_range = [-1, 1]
95+
if base_fun is None:
96+
base_fun = torch.nn.SiLU()
97+
# size
98+
self.out_dim = out_dim
99+
self.in_dim = in_dim
100+
self.num = num
101+
self.k = k
102+
103+
grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[
104+
None, :
105+
].expand(self.in_dim, num + 1)
106+
grid = extend_grid(grid, k_extend=k)
107+
self.grid = torch.nn.Parameter(grid).requires_grad_(False)
108+
noises = (
109+
(torch.rand(self.num + 1, self.in_dim, self.out_dim) - 1 / 2)
110+
* noise_scale
111+
/ num
112+
)
113+
114+
self.coef = torch.nn.Parameter(
115+
curve2coef(self.grid[:, k:-k].permute(1, 0), noises, self.grid, k)
116+
)
117+
118+
if sparse_init:
119+
self.mask = torch.nn.Parameter(sparse_mask(in_dim, out_dim)).requires_grad_(
120+
False
121+
)
122+
else:
123+
self.mask = torch.nn.Parameter(torch.ones(in_dim, out_dim)).requires_grad_(
124+
False
125+
)
126+
127+
self.scale_base = torch.nn.Parameter(
128+
scale_base_mu * 1 / np.sqrt(in_dim)
129+
+ scale_base_sigma
130+
* (torch.rand(in_dim, out_dim) * 2 - 1)
131+
* 1
132+
/ np.sqrt(in_dim)
133+
).requires_grad_(sb_trainable)
134+
self.scale_sp = torch.nn.Parameter(
135+
torch.ones(in_dim, out_dim) * scale_sp * 1 / np.sqrt(in_dim) * self.mask
136+
).requires_grad_(sp_trainable) # make scale trainable
137+
self.base_fun = base_fun
138+
139+
self.grid_eps = grid_eps
140+
141+
def forward(self, x):
142+
"""
143+
KANLayer forward given input x
144+
145+
Parameters
146+
-----
147+
x : torch.Tensor
148+
Input tensor of shape (batch_size, in_dim), where:
149+
- batch_size is the number of input samples.
150+
- in_dim is the input feature dimension.
151+
152+
Returns
153+
--------
154+
y : torch.Tensor
155+
Output tensor, the result of applying spline and residual
156+
transformations followed by weighted summation.
157+
158+
Examples
159+
--------
160+
The following is an example from the original `pykan` library, adapted here
161+
for illustration within the PyTorch Forecasting integration.
162+
163+
Install the `pykan` package first:
164+
pip install pykan
165+
Then use:
166+
167+
>>> from kan.KANLayer import *
168+
>>> model = KANLayer(in_dim=3, out_dim=5)
169+
>>> x = torch.normal(0,1,size=(100,3))
170+
>>> y, _, _, _ = model(x)
171+
>>> y.shape
172+
"""
173+
174+
base = self.base_fun(x) # (batch, in_dim)
175+
y = coef2curve(x_eval=x, grid=self.grid, coef=self.coef, k=self.k)
176+
y = (
177+
self.scale_base[None, :, :] * base[:, :, None]
178+
+ self.scale_sp[None, :, :] * y
179+
)
180+
y = self.mask[None, :, :] * y
181+
y = torch.sum(y, dim=1)
182+
return y
183+
184+
def update_grid_from_samples(self, x):
185+
"""
186+
Update grid from samples
187+
188+
Parameters
189+
-----
190+
x : 2D torch.float
191+
inputs, shape (number of samples, input dimension)
192+
193+
Returns:
194+
--------
195+
None
196+
197+
Examples
198+
-------
199+
>>> model = KANLayer(in_dim=1, out_dim=1, num=5, k=3)
200+
>>> print(model.grid.data)
201+
>>> x = torch.linspace(-3,3,steps=100)[:,None]
202+
>>> model.update_grid_from_samples(x)
203+
>>> print(model.grid.data)
204+
"""
205+
206+
batch = x.shape[0]
207+
x_pos = torch.sort(x, dim=0)[0]
208+
y_eval = coef2curve(x_pos, self.grid, self.coef, self.k)
209+
num_interval = self.grid.shape[1] - 1 - 2 * self.k
210+
211+
def get_grid(num_interval):
212+
"""
213+
Generate adaptive or uniform grid points from sorted input samples.
214+
215+
Parameters
216+
-----
217+
num_interval : int
218+
Number of intervals between grid points.
219+
220+
Returns:
221+
--------
222+
grid : torch.Tensor
223+
New grid of shape (in_dim, num_interval + 1).
224+
"""
225+
ids = [int(batch / num_interval * i) for i in range(num_interval)] + [-1]
226+
grid_adaptive = x_pos[ids, :].permute(1, 0)
227+
h = (grid_adaptive[:, [-1]] - grid_adaptive[:, [0]]) / num_interval
228+
grid_uniform = (
229+
grid_adaptive[:, [0]]
230+
+ h * torch.arange(num_interval + 1, device=h.device)[None, :]
231+
)
232+
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
233+
return grid
234+
235+
grid = get_grid(num_interval)
236+
self.grid.data = extend_grid(grid, k_extend=self.k)
237+
self.coef.data = curve2coef(x_pos, y_eval, self.grid, self.k)

0 commit comments

Comments
 (0)