Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 131 additions & 33 deletions baybe/surrogates/gaussian_process/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import gc
from typing import TYPE_CHECKING, ClassVar
from typing import TYPE_CHECKING, ClassVar, Literal

from attrs import define, field
from attrs.validators import instance_of
Expand All @@ -24,6 +24,7 @@
DefaultKernelFactory,
_default_noise_factory,
)
from baybe.surrogates.gaussian_process.prior_modules import PriorKernel, PriorMean
from baybe.utils.conversion import to_string

if TYPE_CHECKING:
Expand Down Expand Up @@ -113,11 +114,74 @@ class GaussianProcessSurrogate(Surrogate):
_model = field(init=False, default=None, eq=False)
"""The actual model."""

# Transfer learning fields
_prior_gp = field(init=False, default=None, eq=False)
"""Prior GP to extract mean/covariance from for transfer learning."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Prior GP to extract mean/covariance from for transfer learning."""
"""Prior GP to extract mean/covariance from transfer learning."""


_transfer_mode: Literal["mean", "kernel"] | None = field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These new fields are only relevant for special contexts, but we now have them in the base class. Looking at the code, I have the feeling that having a dedicated subclass of the GaussianProcessSurrogate. But I am not sure if that would be the preferred option and hence summon the master of all class design @AdrianSosic

init=False, default=None, eq=False
)
"""Transfer learning mode: 'mean' uses prior as mean function, 'kernel' uses prior covariance."""

@staticmethod
def from_preset(preset: GaussianProcessPreset) -> GaussianProcessSurrogate:
"""Create a Gaussian process surrogate from one of the defined presets."""
return make_gp_from_preset(preset)

@classmethod
def from_prior_gp(
cls,
prior_gp,
transfer_mode: Literal["mean", "kernel"] = "mean",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason to choose one over the other?

kernel_factory: KernelFactory | None = None,
**kwargs,
) -> GaussianProcessSurrogate:
"""Create a GP surrogate from a prior GP.

Args:
prior_gp: Fitted SingleTaskGP to use as prior
transfer_mode: "mean" extracts posterior mean as prior mean,
"kernel" uses prior's covariance
kernel_factory: Kernel factory for new covariance (required for mean mode,
ignored for kernel mode)
**kwargs: Additional arguments passed to GaussianProcessSurrogate constructor
Comment on lines +142 to +147
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indentation here looks a bit weird, does this render correctly?


Returns:
New GaussianProcessSurrogate instance with prior mean or covariance

Raises:
ValueError: If prior_gp is not fitted or configuration is invalid
"""
from copy import deepcopy

from botorch.models import SingleTaskGP
Comment on lines +155 to +157
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import statements should be moved to the top of the file with other imports rather than being inside a method.

Copilot uses AI. Check for mistakes.

# Validate prior GP is fitted
if not isinstance(prior_gp, SingleTaskGP):
raise ValueError("prior_gp must be a fitted SingleTaskGP instance")
if not hasattr(prior_gp, "train_inputs") or prior_gp.train_inputs is None:
raise ValueError("Prior GP must be fitted (have train_inputs) before use")

# Validate transfer mode configuration
if transfer_mode not in ["mean", "kernel"]:
raise ValueError("transfer_mode must be 'mean' or 'kernel'")

if transfer_mode == "mean" and kernel_factory is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also raise a warning if transfer_mode == "kernel" and not (kernel_factory is None)? Just to make sure that people know that this is a somewhat incorrect setup.

raise ValueError("kernel_factory is required for mean transfer mode")

# For kernel transfer, kernel_factory is ignored (we use prior's kernel)
if transfer_mode == "kernel":
kernel_factory = kernel_factory or DefaultKernelFactory()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this break if a kernel_factory was provided or am I misunderstanding what the or would do in this case?


# Create new surrogate instance
instance = cls(kernel_or_factory=kernel_factory, **kwargs)

# Configure for transfer learning
instance._prior_gp = deepcopy(prior_gp)
instance._transfer_mode = transfer_mode

return instance

@override
def to_botorch(self) -> GPyTorchModel:
return self._model
Expand All @@ -140,62 +204,81 @@ def _make_target_scaler_factory() -> type[OutcomeTransform] | None:
def _posterior(self, candidates_comp_scaled: Tensor, /) -> Posterior:
return self._model.posterior(candidates_comp_scaled)

@override
def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
def _initialize_model(
self,
train_x: Tensor,
train_y: Tensor,
context: _ModelContext,
batch_shape,
) -> None:
"""Initialize the GP model with appropriate mean and covariance modules.

Handles both standard GP creation and creation of GP from given prior.

Args:
train_x: Training input data
train_y: Training target data
context: Model context containing searchspace information
batch_shape: Batch shape for the training data
"""
import botorch
import gpytorch
import torch

# FIXME[typing]: It seems there is currently no better way to inform the type
# checker that the attribute is available at the time of the function call
assert self._searchspace is not None

context = _ModelContext(self._searchspace)

numerical_idxs = context.get_numerical_indices(train_x.shape[-1])

# For GPs, we let botorch handle the scaling. See [Scaling Workaround] above.
input_transform = botorch.models.transforms.Normalize(
train_x.shape[-1],
bounds=context.parameter_bounds,
indices=list(numerical_idxs),
)
outcome_transform = botorch.models.transforms.Standardize(train_y.shape[-1])

# extract the batch shape of the training data
batch_shape = train_x.shape[:-2]
# Configure input/output transforms
if self._prior_gp is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a lot of of if...else construction here where one of the checks typically is if self._prior_gp is not None. Would it maybe be reasonable to have this as a top level check? So having the code structured like

if self._prior_gp is not None:
    .... # do the TL stuff
    return
.... # do the non-TL stuff

and separate the two cases like that? Might not be necessary if we use inheritance though, and not sure if this would actually improve readability (but my feeling would be that it does)

# Use prior's transforms for consistency in transfer learning
input_transform = self._prior_gp.input_transform
outcome_transform = self._prior_gp.outcome_transform
else:
# Standard transform setup
input_transform = botorch.models.transforms.Normalize(
train_x.shape[-1],
bounds=context.parameter_bounds,
indices=numerical_idxs,
)
outcome_transform = botorch.models.transforms.Standardize(train_y.shape[-1])

# create GP mean
mean_module = gpytorch.means.ConstantMean(batch_shape=batch_shape)
# Configure mean module
if self._prior_gp is not None and self._transfer_mode == "mean":
mean_module = PriorMean(self._prior_gp, batch_shape=batch_shape)
else:
mean_module = gpytorch.means.ConstantMean(batch_shape=batch_shape)

# define the covariance module for the numeric dimensions
base_covar_module = self.kernel_factory(
context.searchspace, train_x, train_y
).to_gpytorch(
ard_num_dims=train_x.shape[-1] - context.n_task_dimensions,
active_dims=numerical_idxs,
batch_shape=batch_shape,
)
# Configure base covariance module
if self._prior_gp is not None and self._transfer_mode == "kernel":
base_covar_module = PriorKernel(self._prior_gp.covar_module)
else:
# Use kernel factory
base_covar_module = self.kernel_factory(
context.searchspace, train_x, train_y
).to_gpytorch(
ard_num_dims=train_x.shape[-1] - context.n_task_dimensions,
active_dims=numerical_idxs,
batch_shape=batch_shape,
)

# create GP covariance
# Handle multi-task covariance (keep existing logic)
if not context.is_multitask:
covar_module = base_covar_module
else:
task_covar_module = gpytorch.kernels.IndexKernel(
num_tasks=context.n_tasks,
active_dims=context.task_idx,
rank=context.n_tasks, # TODO: make controllable
rank=context.n_tasks,
)
covar_module = base_covar_module * task_covar_module

# create GP likelihood
# Configure likelihood (keep existing logic)
noise_prior = _default_noise_factory(context.searchspace, train_x, train_y)
likelihood = gpytorch.likelihoods.GaussianLikelihood(
noise_prior=noise_prior[0].to_gpytorch(), batch_shape=batch_shape
)
likelihood.noise = torch.tensor([noise_prior[1]])

# construct and fit the Gaussian process
# Create the model
self._model = botorch.models.SingleTaskGP(
train_x,
train_y,
Expand All @@ -206,6 +289,21 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
likelihood=likelihood,
)

@override
def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
import botorch
import gpytorch

# FIXME[typing]: It seems there is currently no better way to inform the type
# checker that the attribute is available at the time of the function call
assert self._searchspace is not None

context = _ModelContext(self._searchspace)
batch_shape = train_x.shape[:-2]

# Initialize model
self._initialize_model(train_x, train_y, context, batch_shape)

# TODO: This is still a temporary workaround to avoid overfitting seen in
# low-dimensional TL cases. More robust settings are being researched.
if context.n_task_dimensions > 0:
Expand Down
98 changes: 98 additions & 0 deletions baybe/surrogates/gaussian_process/prior_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from __future__ import annotations

from typing import Any
from torch import Tensor

from copy import deepcopy

import gpytorch
import torch
from botorch.models import SingleTaskGP


class PriorMean(gpytorch.means.Mean):
"""GPyTorch mean module using a trained GP as prior mean.
This mean module wraps a trained Gaussian Process and uses its predictions
as the mean function for another GP.
Comment on lines +14 to +17
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would change it such that the first line says "Gaussian Process" and the subsequent ones use the abbreviation.

"""

def __init__(
self, gp: SingleTaskGP, batch_shape: torch.Size = torch.Size(), **kwargs: Any
) -> None:
"""Initialize the GP-based mean module.
Args:
gp: Trained Gaussian Process to use as mean function.
batch_shape: Batch shape for the mean module.
**kwargs: Additional keyword arguments.
"""
super().__init__()
# See https://github.com/cornellius-gp/gpytorch/issues/743
self.gp: SingleTaskGP = deepcopy(gp)
self.batch_shape: torch.Size = batch_shape
for param in self.gp.parameters():
param.requires_grad = False

def reset_gp(self) -> None:
"""Reset the GP to evaluation mode for prediction."""
self.gp.eval()
self.gp.likelihood.eval()

def forward(self, input: Tensor) -> Tensor:
"""Compute the mean function using the wrapped GP.
Args:
input: Input tensor for which to compute the mean.
Returns:
Mean predictions from the wrapped GP.
"""
self.reset_gp()
with torch.no_grad(), gpytorch.settings.fast_pred_var():
with gpytorch.settings.detach_test_caches(False):
mean = self.gp(input).mean.detach()
mean = mean.reshape(torch.broadcast_shapes(self.batch_shape, input.shape[:-1]))
return mean


class PriorKernel(gpytorch.kernels.Kernel):
"""GPyTorch kernel module wrapping a pre-trained kernel.
This kernel module wraps a trained kernel and uses it as a fixed kernel
component in another GP. The wrapped kernel's parameters are frozen.
"""

def __init__(self, kernel, **kwargs):
"""Initialize the kernel wrapper.
Args:
kernel: Pre-trained kernel to wrap.
**kwargs: Additional keyword arguments.
"""
super().__init__()
# See https://github.com/cornellius-gp/gpytorch/issues/743
self.base_kernel = deepcopy(kernel)
for param in self.base_kernel.parameters():
param.requires_grad = False

def reset(self):
"""Reset the wrapped kernel to evaluation mode."""
self.base_kernel.eval()

def forward(self, x1, x2, **params):
"""Compute kernel matrix using the wrapped kernel.
Args:
x1: First set of input points.
x2: Second set of input points.
**params: Additional kernel parameters.
Returns:
Kernel matrix computed by the wrapped kernel.
"""
self.reset()
with gpytorch.settings.fast_pred_var():
with gpytorch.settings.detach_test_caches(False):
k = self.base_kernel.forward(x1, x2, **params)
return k
Loading
Loading