diff --git a/src/timesfm.py b/src/timesfm.py index 88502390..15791f12 100644 --- a/src/timesfm.py +++ b/src/timesfm.py @@ -95,7 +95,7 @@ class TimesFm: Attributes: per_core_batch_size: Batch size on each core for data parallelism. - backend: One of "cpu", "gpu" or "tpu". + backend: One of "cpu", "gpu" or "tpu" (case-insensitive). num_devices: Number of cores provided the backend. global_batch_size: per_core_batch_size * num_devices. Each batch of inference task will be padded with respect to global_batch_size to @@ -126,7 +126,9 @@ def __init__( num_layers: int, model_dims: int, per_core_batch_size: int = 32, - backend: Literal["cpu", "gpu", "tpu"] = "cpu", + backend: Literal["GPU", "gpu", "Gpu", "gPu", "gpU", "GPu", "gPU", "GpU", + "CPU", "Cpu", "cPU", "cpU", "CPu", "cPu", "CPU", "CpU", + "TPU", "Tpu", "tPU", "tpU", "TPu", "tPu", "TPU", "TpU"] = "cpu", quantiles: Sequence[float] | None = None, verbose: bool = True, ) -> None: @@ -144,12 +146,12 @@ def __init__( num_layers: Number of transformer layers. model_dims: Model dimension. per_core_batch_size: Batch size on each core for data parallelism. - backend: One of "cpu", "gpu" or "tpu". + backend: One of "cpu", "gpu" or "tpu" (case-insensitive). quantiles: list of output quantiles supported by the model. verbose: Whether to print logging messages. """ self.per_core_batch_size = per_core_batch_size - self.backend = backend + self.backend = backend.lower() self.num_devices = jax.local_device_count(self.backend) self.global_batch_size = self.per_core_batch_size * self.num_devices @@ -600,3 +602,5 @@ def forecast_on_df( fcst_df[model_name] = fcst_df[q_col] logging.info("Finished creating output dataframe.") return fcst_df + +