Skip to content

Commit 6bb93a7

Browse files
Merge branch 'main' into kan-nbeats
2 parents 92213aa + 3821c0b commit 6bb93a7

File tree

4 files changed

+165
-7
lines changed

4 files changed

+165
-7
lines changed

pytorch_forecasting/data/timeseries/_timeseries.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,25 +1430,34 @@ def _data_to_tensors(self, data: pd.DataFrame) -> dict[str, torch.Tensor]:
14301430
time index
14311431
"""
14321432

1433-
def _to_tensor(cols, long=True) -> torch.Tensor:
1433+
def _to_tensor(cols, long=True, real=False) -> torch.Tensor:
14341434
"""Convert data[cols] to torch tensor.
14351435
14361436
Converts sub-frames to numpy and then to torch tensor.
14371437
Makes the following choices for types:
14381438
1439-
* float columns are converted to torch.float
1440-
* integer columns are converted to torch.int64 or torch.long,
1441-
depending on the long argument
1439+
- real is True:
1440+
* the sub-frame is converted to a torch.float32 tensor
1441+
- long is True (and real is False):
1442+
* the sub-frame is converted to a torch.long tensor
1443+
- real is False and long is False:
1444+
* if all columns are integer or boolean, the sub-frame is
1445+
converted to a torch.int64 tensor
1446+
* if one column is a float, the sub-frame is converted to
1447+
a torch.float32 tensor
14421448
"""
14431449
if not isinstance(cols, list) and cols not in data.columns:
14441450
return None
14451451
if isinstance(cols, list) and len(cols) == 0:
14461452
dtypekind = "f"
14471453
elif isinstance(cols, list): # and len(cols) > 0
1448-
dtypekind = data.dtypes[cols[0]].kind
1454+
# dtypekind = data.dtypes[cols[0]].kind
1455+
dtypekind = np.result_type(*data[cols].dtypes.to_list()).kind
14491456
else:
14501457
dtypekind = data.dtypes[cols].kind
1451-
if not long:
1458+
if real:
1459+
return torch.tensor(data[cols].to_numpy(np.float64), dtype=torch.float)
1460+
elif not long:
14521461
return torch.tensor(data[cols].to_numpy(np.int64), dtype=torch.int64)
14531462
elif dtypekind in "bi":
14541463
return torch.tensor(data[cols].to_numpy(np.int64), dtype=torch.long)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# copyright: pytorch-forecasting developers, BSD-3-Clause License (see LICENSE file)
2+
# copy of the sktime utility of the same name (BSD-3)
3+
"""Doctest checks directed through pytest with conditional skipping."""
4+
5+
from functools import lru_cache
6+
import importlib
7+
import inspect
8+
import pkgutil
9+
10+
EXCLUDE_MODULES_STARTING_WITH = ("all", "test")
11+
12+
13+
def _all_functions(module_name):
14+
"""Get all functions from a module, including submodules.
15+
16+
Excludes:
17+
18+
* modules starting with 'all' or 'test'.
19+
* if the flag ``ONLY_CHANGED_MODULES`` is set, modules that have not changed,
20+
compared to the ``main`` branch.
21+
22+
Parameters
23+
----------
24+
module_name : str
25+
Name of the module.
26+
27+
Returns
28+
-------
29+
functions_list : list
30+
List of tuples (function_name, function_object).
31+
"""
32+
res = _all_functions_cached(module_name)
33+
# copy the result to avoid modifying the cached result
34+
return res.copy()
35+
36+
37+
@lru_cache
38+
def _all_functions_cached(module_name, only_changed_modules=False):
39+
"""Get all functions from a module, including submodules.
40+
41+
Excludes:
42+
43+
* modules starting with 'all' or 'test'.
44+
* if ``only_changed_modules`` is ``True``, modules that have not changed,
45+
compared to the ``main`` branch.
46+
47+
Parameters
48+
----------
49+
module_name : str
50+
Name of the module.
51+
only_changed_modules : bool, optional (default=False)
52+
If True, only functions from modules that have changed are returned.
53+
54+
Returns
55+
-------
56+
functions_list : list
57+
List of tuples (function_name, function_object).
58+
"""
59+
# Import the package
60+
package = importlib.import_module(module_name)
61+
62+
# Initialize an empty list to hold all functions
63+
functions_list = []
64+
65+
# Walk through the package's modules
66+
package_path = package.__path__[0]
67+
for _, modname, _ in pkgutil.walk_packages(
68+
path=[package_path], prefix=package.__name__ + "."
69+
):
70+
# Skip modules starting with 'all' or 'test'
71+
if modname.split(".")[-1].startswith(EXCLUDE_MODULES_STARTING_WITH):
72+
continue
73+
74+
# Import the module
75+
module = importlib.import_module(modname)
76+
77+
# Get all functions from the module
78+
for name, obj in inspect.getmembers(module, inspect.isfunction):
79+
# if function is imported from another module, skip it
80+
if obj.__module__ != module.__name__:
81+
continue
82+
# add the function to the list
83+
functions_list.append((name, obj))
84+
85+
return functions_list
86+
87+
88+
def pytest_generate_tests(metafunc):
89+
"""Test parameterization routine for pytest.
90+
91+
Fixtures parameterized
92+
----------------------
93+
func : all functions from sktime, as returned by _all_functions
94+
if ONLY_CHANGED_MODULES is set, only functions from modules that have changed
95+
"""
96+
# we assume all four arguments are present in the test below
97+
funcs_and_names = _all_functions("pytorch_forecasting")
98+
99+
if len(funcs_and_names) > 0:
100+
names, funcs = zip(*funcs_and_names)
101+
102+
metafunc.parametrize("func", funcs, ids=names)
103+
else:
104+
metafunc.parametrize("func", [])
105+
106+
107+
def test_all_functions_doctest(func):
108+
"""Run doctest for all functions in pytorch-forecasting."""
109+
from skbase.utils.doctest_run import run_doctest
110+
111+
run_doctest(func, name=f"function {func.__name__}")

pytorch_forecasting/utils/_dependencies/_safe_import.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _safe_import(import_path, pkg_name=None):
5959
6060
Examples
6161
--------
62-
>>> from pytorch_forecasting.utils.dependencies._safe_import import _safe_import
62+
>>> from pytorch_forecasting.utils._dependencies._safe_import import _safe_import
6363
6464
>>> # Import a top-level module
6565
>>> torch = _safe_import("torch")

tests/test_data/test_timeseries.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,3 +678,41 @@ def distance_to_weights(dist):
678678
if idx > 100:
679679
break
680680
print(a)
681+
682+
683+
def test_correct_dtype_inference():
684+
# Create a small dataset
685+
data = pd.DataFrame(
686+
{
687+
"time_idx": np.arange(30),
688+
"value": np.sin(np.arange(30) / 5) + np.random.normal(scale=1, size=30),
689+
"group": ["A"] * 30,
690+
}
691+
)
692+
693+
# Define the dataset
694+
dataset = TimeSeriesDataSet(
695+
data.copy(),
696+
time_idx="time_idx",
697+
target="value",
698+
group_ids=["group"],
699+
static_categoricals=["group"],
700+
max_encoder_length=4,
701+
max_prediction_length=2,
702+
time_varying_unknown_reals=["value"],
703+
target_normalizer=None,
704+
# WATCH THIS
705+
time_varying_known_reals=["time_idx"],
706+
scalers=dict(time_idx=None),
707+
)
708+
709+
# and the dataloader
710+
dataloader = dataset.to_dataloader(batch_size=8)
711+
712+
x, y = next(iter(dataset))
713+
# real features must be real
714+
assert x["x_cont"].dtype is torch.float
715+
716+
x, y = next(iter(dataloader))
717+
# real features must be real
718+
assert x["encoder_cont"].dtype is torch.float

0 commit comments

Comments
 (0)