Skip to content

Commit 8db6a0a

Browse files
committed
Drop unnecessary arguments
The active_dims argument can now be dropped due to #671
1 parent 5190477 commit 8db6a0a

File tree

1 file changed

+14
-13
lines changed
  • baybe/surrogates/gaussian_process

1 file changed

+14
-13
lines changed

baybe/surrogates/gaussian_process/core.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,14 @@ def parameter_bounds(self) -> Tensor:
6969

7070
return torch.from_numpy(self.searchspace.scaling_bounds.values)
7171

72-
def get_numerical_indices(self, n_inputs: int) -> tuple[int, ...]:
73-
"""Get the indices of the regular numerical model inputs."""
74-
return tuple(i for i in range(n_inputs) if i != self.task_idx)
72+
@property
73+
def numerical_indices(self) -> tuple[int, ...]:
74+
"""The indices of the regular numerical model inputs."""
75+
return tuple(
76+
i
77+
for i in range(len(self.searchspace.comp_rep_columns))
78+
if i != self.task_idx
79+
)
7580

7681

7782
@define
@@ -148,37 +153,33 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
148153

149154
assert self._searchspace is not None # provided by base class
150155
context = _ModelContext(self._searchspace)
151-
batch_shape = train_x.shape[:-2]
152156

153157
# Input/output scaling
154158
# NOTE: For GPs, we let BoTorch handle scaling (see [Scaling Workaround] above)
155159
input_transform = botorch.models.transforms.Normalize(
156160
train_x.shape[-1],
157161
bounds=context.parameter_bounds,
158-
indices=list(context.get_numerical_indices(train_x.shape[-1])),
162+
indices=list(context.numerical_indices),
159163
)
160164
outcome_transform = botorch.models.transforms.Standardize(train_y.shape[-1])
161165

162166
# Mean function
163-
mean_module = gpytorch.means.ConstantMean(batch_shape=batch_shape)
167+
mean_module = gpytorch.means.ConstantMean()
164168

165169
# Covariance function
166170
kernel = self.kernel_factory(context.searchspace, train_x, train_y)
167171
kernel_num_dims = train_x.shape[-1] - context.n_task_dimensions
168-
covar_module = kernel.to_gpytorch(
169-
ard_num_dims=kernel_num_dims,
170-
batch_shape=batch_shape,
171-
active_dims=tuple(range(kernel_num_dims)),
172-
)
172+
covar_module = kernel.to_gpytorch(ard_num_dims=kernel_num_dims)
173173

174174
# Likelihood model
175175
noise_prior = _default_noise_factory(context.searchspace, train_x, train_y)
176176
likelihood = gpytorch.likelihoods.GaussianLikelihood(
177-
noise_prior=noise_prior[0].to_gpytorch(), batch_shape=batch_shape
177+
noise_prior=noise_prior[0].to_gpytorch()
178178
)
179179
likelihood.noise = torch.tensor([noise_prior[1]])
180180

181181
# Model selection
182+
model_cls: type[botorch.models.SingleTaskGP] | type[botorch.models.MultiTaskGP]
182183
if (task_param := context.searchspace._task_parameter) is None:
183184
model_cls = botorch.models.SingleTaskGP
184185
model_kwargs = {}
@@ -201,7 +202,7 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
201202
mean_module=mean_module,
202203
covar_module=covar_module,
203204
likelihood=likelihood,
204-
**model_kwargs,
205+
**model_kwargs, # type: ignore[arg-type]
205206
)
206207

207208
# TODO: This is still a temporary workaround to avoid overfitting seen in

0 commit comments

Comments
 (0)