Skip to content

Extend quantile regression to multiple quantiles #187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

460 changes: 327 additions & 133 deletions examples/quantile-regression/lstm-quantile-regression.ipynb

Large diffs are not rendered by default.

12 changes: 9 additions & 3 deletions tests/nn/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
@pytest.fixture
def y_true():
data = [1, 2, 3]
return torch.tensor(data)
return torch.tensor(data).reshape(-1, 1)


@pytest.fixture
def y_pred():
data = [1.1, 1.9, 3.1]
return torch.tensor(data)
return torch.tensor(data).reshape(-1, 1)


def test_masked_mae_loss(y_true, y_pred):
Expand All @@ -23,7 +23,13 @@ def test_masked_mae_loss(y_true, y_pred):


@pytest.mark.parametrize(
"quantile, expected_loss", [(0.05, 0.065), (0.5, 0.05), (0.95, 0.035)]
"quantile, expected_loss",
[
(0.05, 0.065),
(0.5, 0.05),
(0.95, 0.035),
([0.05, 0.5, 0.95], 0.065 + 0.05 + 0.035),
],
)
def test_quantile_loss(y_true, y_pred, quantile, expected_loss):
"""Test quantile_loss()"""
Expand Down
39 changes: 34 additions & 5 deletions torchts/nn/loss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import List, Union

import numpy as np
import torch


Expand All @@ -21,18 +24,44 @@ def masked_mae_loss(y_pred, y_true):
return loss.mean()


def quantile_loss(y_pred: torch.tensor, y_true: torch.tensor, quantile: float) -> float:
def quantile_loss(
y_pred: torch.tensor, y_true: torch.tensor, quantile: Union[float, List[float]]
) -> torch.tensor:
"""Calculate quantile loss

Args:
y_pred (torch.tensor): Predicted values
y_true (torch.tensor): True values
quantile (float): quantile (e.g. 0.5 for median)
quantile (float or list): quantile(s) (e.g. 0.5 for median)

Returns:
float: output losses
torch.tensor: output losses
"""
errors = y_true - y_pred
if isinstance(quantile, list):
errors = torch.repeat_interleave(y_true, len(quantile), dim=1) - y_pred
quantile = torch.FloatTensor(quantile)
quantile = quantile.repeat(1, y_true.shape[-1])
else:
errors = y_true - y_pred

loss = torch.max((quantile - 1) * errors, quantile * errors)
loss = torch.mean(loss)
loss = torch.mean(loss, dim=0)
loss = torch.sum(loss)

return loss


def quantile_err(prediction, y):
"""
prediction: arr where first 3 columns are: lower quantile, middle quantile (50%), upper quantile in that order
"""
y_lower = prediction[:, 0]
y_upper = prediction[:, 2]
# Calculate error on our predicted upper and lower quantiles
# this will get us an array of negative values with the distance between the upper/lower quantile and the
# 50% quantile
error_low = y_lower - y.view(-1)
error_high = y.view(-1) - y_upper
# Make an array where each entry is the highest error when comparing the upper and lower bounds for that entry prediction
err = np.maximum(error_high, error_low)
return err
139 changes: 131 additions & 8 deletions torchts/nn/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
from abc import abstractmethod
from functools import partial

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from matplotlib import fontconfig_pattern
from pytorch_lightning import LightningModule, Trainer
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset, random_split

from torchts.nn.loss import quantile_err


class TimeSeriesModel(LightningModule):
Expand All @@ -14,6 +21,7 @@ class TimeSeriesModel(LightningModule):
opimizer_args (dict): Arguments for the optimizer
criterion: Loss function
criterion_args (dict): Arguments for the loss function
method: conformal prediction
scheduler (torch.optim.lr_scheduler): Learning rate scheduler
scheduler_args (dict): Arguments for the scheduler
scaler (torchts.utils.scaler.Scaler): Scaler
Expand All @@ -25,14 +33,20 @@ def __init__(
optimizer_args=None,
criterion=F.mse_loss,
criterion_args=None,
significance=None,
method=None,
mode=None,
scheduler=None,
scheduler_args=None,
scaler=None,
):
super().__init__()
self.criterion = criterion
self.criterion_args = criterion_args
self.significance = significance
self.method = method
self.scaler = scaler
self.mode = mode

if optimizer_args is not None:
self.optimizer = partial(optimizer, **optimizer_args)
Expand All @@ -54,9 +68,36 @@ def fit(self, x, y, max_epochs=10, batch_size=128):
batch_size (int): Batch size for torch.utils.data.DataLoader
"""
dataset = TensorDataset(x, y)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# trainer.fit(self, loader)

trainer = Trainer(max_epochs=max_epochs)
trainer.fit(self, loader)

# conformal prediction, data initiation
if self.method == "conformal":
lengths = [int(len(dataset) * 0.6), len(dataset) - int(len(dataset) * 0.6)]
if self.mode == "regression":
train_dataset, cal_dataset = random_split(dataset, lengths)
if self.mode == "time_series":
train_dataset, cal_dataset = random_split(dataset, lengths)
# self.train_dataset, self.cal_dataset = train_test_split(dataset,test_size =0.4,shuffle=False)
train_dataloader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
)
cal_dataloader = DataLoader(
cal_dataset, batch_size=batch_size, shuffle=True
)
# self.trainer.fit(self, train_dataloader)
trainer.fit(self, train_dataloader, cal_dataloader)

else:
# train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# split to only train on training set
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
trainer.fit(self, loader)

def prepare_batch(self, batch):
return batch
Expand All @@ -72,7 +113,6 @@ def _step(self, batch, batch_idx, num_batches):
Returns: loss for the batch
"""
x, y = self.prepare_batch(batch)

if self.training:
batches_seen = batch_idx + self.current_epoch * num_batches
else:
Expand All @@ -82,10 +122,23 @@ def _step(self, batch, batch_idx, num_batches):

if self.scaler is not None:
y = self.scaler.inverse_transform(y)
pred = self.scaler.inverse_transform(pred)
y = torch.tensor(y).float()
pred = self.scaler.inverse_transform(pred.detach())
pred = torch.tensor(pred, requires_grad=True).float()

if self.criterion_args is not None:
loss = self.criterion(pred, y, **self.criterion_args)
# if in validation step, do the conformal on calibrition set
if (not self.training) and self.method == "conformal":
intervals = np.zeros((x.shape[0], 3))
# ensure that we want to multiply our error distances by the size of our training set
err_dist = np.hstack([self.err_dist] * x.shape[0])

intervals[:, 0] = pred[:, 0] - err_dist[0, :]
intervals[:, 1] = pred[:, 1]
intervals[:, -1] = pred[:, -1] + err_dist[1, :]
loss = self.criterion(intervals, y, **self.criterion_args)
else:
loss = self.criterion(pred, y, **self.criterion_args)
else:
loss = self.criterion(pred, y)

Expand Down Expand Up @@ -116,7 +169,13 @@ def validation_step(self, batch, batch_idx):
batch (torch.Tensor): Output of the torch.utils.data.DataLoader
batch_idx (int): Integer displaying index of this batch
"""
val_loss = self._step(batch, batch_idx, len(self.trainer.val_dataloader))

# do calibration on validation set to prevent overfitting
if self.method == "conformal":
self.err_dist = self.calibration(
batch, batch_idx, len(self.trainer.val_dataloaders)
)
val_loss = self._step(batch, batch_idx, len(self.trainer.val_dataloaders))
self.log("val_loss", val_loss)
return val_loss

Expand All @@ -127,10 +186,61 @@ def test_step(self, batch, batch_idx):
batch (torch.Tensor): Output of the torch.utils.data.DataLoader
batch_idx (int): Integer displaying index of this batch
"""
test_loss = self._step(batch, batch_idx, len(self.trainer.test_dataloader))
test_loss = self._step(batch, batch_idx, len(self.trainer.test_dataloaders))
self.log("test_loss", test_loss)
return test_loss

def calibration(self, batch, batch_idx, num_batches):
"""

Args:
batch: Output of the torch.utils.data.DataLoader
batch_idx: Integer displaying index of this batch

Returns: err_dist for the calibration set
"""
x, y = self.prepare_batch(batch)
batches_seen = batch_idx
pred = self(x, y, batches_seen)

if self.scaler is not None:
y = self.scaler.inverse_transform(y)
y = torch.tensor(y).float()
pred = self.scaler.inverse_transform(pred.detach())
pred = torch.tensor(pred).float()

cal_scores = quantile_err(pred, y)

# Sort calibration scores in ascending order
nc = np.sort(cal_scores, 0) # [::-1]

index = int(np.ceil((1 - self.significance) * (nc.shape[0] + 1))) - 1
# find largest error that gets us guaranteed coverage
index = min(max(index, 0), nc.shape[0] - 1)
err_dist = np.vstack([nc[index], nc[index]])

return err_dist

def calibration_pred(self, x):
"""
Incorprating the err_dist, predict result
Args:
x (torch.Tensor): Input data

Output: Predicted interval
"""

pred = self.predict(x)
intervals = np.zeros((x.shape[0], 3))
# ensure that we want to multiply our error distances by the size of our training set
err_dist = np.hstack([self.err_dist] * x.shape[0])

intervals[:, 0] = pred[:, 0] - err_dist[0, :]
intervals[:, 1] = pred[:, 1]
intervals[:, -1] = pred[:, -1] + err_dist[1, :]
conformal_intervals = intervals
return conformal_intervals

@abstractmethod
def forward(self, x, y=None, batches_seen=None):
"""Forward pass.
Expand All @@ -153,6 +263,19 @@ def predict(self, x):
"""
return self(x).detach()

def conformal_predict(self, x):
"""Runs model inference.

Args:
x (torch.Tensor): Input data

Returns:
torch.Tensor: Predicted confromal result
"""
if self.method == "conformal":
return self.calibration_pred(x)
return self(x).detach()

def configure_optimizers(self):
"""Configure optimizer.

Expand Down
Loading