From 6afc50129262a22aca1c815a619b154a1746c2e7 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort <71429321+blefo@users.noreply.github.com> Date: Mon, 13 May 2024 12:07:22 +0200 Subject: [PATCH 1/3] Making the backend attribute case-insensitive --- src/timesfm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/timesfm.py b/src/timesfm.py index 88502390..161ee5ec 100644 --- a/src/timesfm.py +++ b/src/timesfm.py @@ -126,7 +126,7 @@ def __init__( num_layers: int, model_dims: int, per_core_batch_size: int = 32, - backend: Literal["cpu", "gpu", "tpu"] = "cpu", + backend: Literal = "cpu", quantiles: Sequence[float] | None = None, verbose: bool = True, ) -> None: @@ -144,12 +144,12 @@ def __init__( num_layers: Number of transformer layers. model_dims: Model dimension. per_core_batch_size: Batch size on each core for data parallelism. - backend: One of "cpu", "gpu" or "tpu". + backend: One of "cpu", "gpu" or "tpu" (case-insensitive). quantiles: list of output quantiles supported by the model. verbose: Whether to print logging messages. """ self.per_core_batch_size = per_core_batch_size - self.backend = backend + self.backend = backend.lower() self.num_devices = jax.local_device_count(self.backend) self.global_batch_size = self.per_core_batch_size * self.num_devices From e114ccde8dd8be536382de45f4c2feeffd8eae03 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort <71429321+blefo@users.noreply.github.com> Date: Mon, 13 May 2024 12:15:50 +0200 Subject: [PATCH 2/3] Making the backend attribute case-insensitive --- src/timesfm.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/timesfm.py b/src/timesfm.py index 161ee5ec..003e524c 100644 --- a/src/timesfm.py +++ b/src/timesfm.py @@ -95,7 +95,7 @@ class TimesFm: Attributes: per_core_batch_size: Batch size on each core for data parallelism. - backend: One of "cpu", "gpu" or "tpu". + backend: One of "cpu", "gpu" or "tpu" (case-insensitive). num_devices: Number of cores provided the backend. global_batch_size: per_core_batch_size * num_devices. Each batch of inference task will be padded with respect to global_batch_size to @@ -126,7 +126,9 @@ def __init__( num_layers: int, model_dims: int, per_core_batch_size: int = 32, - backend: Literal = "cpu", + backend: Literal["GPU", "gpu", "Gpu", "gPu", "gpU", "GPu", "gPU", "GpU", + "CPU", "Cpu", "cPU", "cpU", "CPu", "cPu", "CPU", "CpU", + "TPU", "Tpu", "tPU", "tpU", "TPu", "tPu", "TPU", "TpU"] = "cpu", quantiles: Sequence[float] | None = None, verbose: bool = True, ) -> None: From 67f10335267eb8e2e1f1dd011f99527fe63d01e1 Mon Sep 17 00:00:00 2001 From: Baptiste Lefort <71429321+blefo@users.noreply.github.com> Date: Mon, 13 May 2024 12:25:43 +0200 Subject: [PATCH 3/3] Making the backend attribute case-insensitive --- src/timesfm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/timesfm.py b/src/timesfm.py index 003e524c..15791f12 100644 --- a/src/timesfm.py +++ b/src/timesfm.py @@ -602,3 +602,5 @@ def forecast_on_df( fcst_df[model_name] = fcst_df[q_col] logging.info("Finished creating output dataframe.") return fcst_df + +