-
Notifications
You must be signed in to change notification settings - Fork 59
Add initial version of the new SourcePriorGaussianProcessSurrogate #678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
90f3835
9e4f795
1646129
dc0e801
3a6f925
4ef5ac4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,7 @@ | |
| from __future__ import annotations | ||
|
|
||
| import gc | ||
| from typing import TYPE_CHECKING, ClassVar | ||
| from typing import TYPE_CHECKING, ClassVar, Literal | ||
|
|
||
| from attrs import define, field | ||
| from attrs.validators import instance_of | ||
|
|
@@ -24,6 +24,7 @@ | |
| DefaultKernelFactory, | ||
| _default_noise_factory, | ||
| ) | ||
| from baybe.surrogates.gaussian_process.prior_modules import PriorKernel, PriorMean | ||
| from baybe.utils.conversion import to_string | ||
|
|
||
| if TYPE_CHECKING: | ||
|
|
@@ -113,11 +114,74 @@ class GaussianProcessSurrogate(Surrogate): | |
| _model = field(init=False, default=None, eq=False) | ||
| """The actual model.""" | ||
|
|
||
| # Transfer learning fields | ||
| _prior_gp = field(init=False, default=None, eq=False) | ||
| """Prior GP to extract mean/covariance from for transfer learning.""" | ||
|
|
||
| _transfer_mode: Literal["mean", "kernel"] | None = field( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These new fields are only relevant for special contexts, but we now have them in the base class. Looking at the code, I have the feeling that having a dedicated subclass of the |
||
| init=False, default=None, eq=False | ||
| ) | ||
| """Transfer learning mode: 'mean' uses prior as mean function, 'kernel' uses prior covariance.""" | ||
|
|
||
| @staticmethod | ||
| def from_preset(preset: GaussianProcessPreset) -> GaussianProcessSurrogate: | ||
| """Create a Gaussian process surrogate from one of the defined presets.""" | ||
| return make_gp_from_preset(preset) | ||
|
|
||
| @classmethod | ||
| def from_prior_gp( | ||
| cls, | ||
| prior_gp, | ||
| transfer_mode: Literal["mean", "kernel"] = "mean", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any particular reason to choose one over the other? |
||
| kernel_factory: KernelFactory | None = None, | ||
| **kwargs, | ||
| ) -> GaussianProcessSurrogate: | ||
| """Create a GP surrogate from a prior GP. | ||
|
|
||
| Args: | ||
| prior_gp: Fitted SingleTaskGP to use as prior | ||
| transfer_mode: "mean" extracts posterior mean as prior mean, | ||
| "kernel" uses prior's covariance | ||
| kernel_factory: Kernel factory for new covariance (required for mean mode, | ||
| ignored for kernel mode) | ||
| **kwargs: Additional arguments passed to GaussianProcessSurrogate constructor | ||
|
Comment on lines
+142
to
+147
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The indentation here looks a bit weird, does this render correctly? |
||
|
|
||
| Returns: | ||
| New GaussianProcessSurrogate instance with prior mean or covariance | ||
|
|
||
| Raises: | ||
| ValueError: If prior_gp is not fitted or configuration is invalid | ||
| """ | ||
| from copy import deepcopy | ||
|
|
||
| from botorch.models import SingleTaskGP | ||
|
Comment on lines
+155
to
+157
|
||
|
|
||
| # Validate prior GP is fitted | ||
| if not isinstance(prior_gp, SingleTaskGP): | ||
| raise ValueError("prior_gp must be a fitted SingleTaskGP instance") | ||
| if not hasattr(prior_gp, "train_inputs") or prior_gp.train_inputs is None: | ||
| raise ValueError("Prior GP must be fitted (have train_inputs) before use") | ||
|
|
||
| # Validate transfer mode configuration | ||
| if transfer_mode not in ["mean", "kernel"]: | ||
| raise ValueError("transfer_mode must be 'mean' or 'kernel'") | ||
|
|
||
| if transfer_mode == "mean" and kernel_factory is None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we also raise a warning if |
||
| raise ValueError("kernel_factory is required for mean transfer mode") | ||
|
|
||
| # For kernel transfer, kernel_factory is ignored (we use prior's kernel) | ||
| if transfer_mode == "kernel": | ||
| kernel_factory = kernel_factory or DefaultKernelFactory() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't this break if a |
||
|
|
||
| # Create new surrogate instance | ||
| instance = cls(kernel_or_factory=kernel_factory, **kwargs) | ||
|
|
||
| # Configure for transfer learning | ||
| instance._prior_gp = deepcopy(prior_gp) | ||
| instance._transfer_mode = transfer_mode | ||
|
|
||
| return instance | ||
|
|
||
| @override | ||
| def to_botorch(self) -> GPyTorchModel: | ||
| return self._model | ||
|
|
@@ -140,62 +204,81 @@ def _make_target_scaler_factory() -> type[OutcomeTransform] | None: | |
| def _posterior(self, candidates_comp_scaled: Tensor, /) -> Posterior: | ||
| return self._model.posterior(candidates_comp_scaled) | ||
|
|
||
| @override | ||
| def _fit(self, train_x: Tensor, train_y: Tensor) -> None: | ||
| def _initialize_model( | ||
| self, | ||
| train_x: Tensor, | ||
| train_y: Tensor, | ||
| context: _ModelContext, | ||
| batch_shape, | ||
| ) -> None: | ||
| """Initialize the GP model with appropriate mean and covariance modules. | ||
|
|
||
| Handles both standard GP creation and creation of GP from given prior. | ||
|
|
||
| Args: | ||
| train_x: Training input data | ||
| train_y: Training target data | ||
| context: Model context containing searchspace information | ||
| batch_shape: Batch shape for the training data | ||
| """ | ||
| import botorch | ||
| import gpytorch | ||
| import torch | ||
|
|
||
| # FIXME[typing]: It seems there is currently no better way to inform the type | ||
| # checker that the attribute is available at the time of the function call | ||
| assert self._searchspace is not None | ||
|
|
||
| context = _ModelContext(self._searchspace) | ||
|
|
||
| numerical_idxs = context.get_numerical_indices(train_x.shape[-1]) | ||
|
|
||
| # For GPs, we let botorch handle the scaling. See [Scaling Workaround] above. | ||
| input_transform = botorch.models.transforms.Normalize( | ||
| train_x.shape[-1], | ||
| bounds=context.parameter_bounds, | ||
| indices=list(numerical_idxs), | ||
| ) | ||
| outcome_transform = botorch.models.transforms.Standardize(train_y.shape[-1]) | ||
|
|
||
| # extract the batch shape of the training data | ||
| batch_shape = train_x.shape[:-2] | ||
| # Configure input/output transforms | ||
| if self._prior_gp is not None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are a lot of of if self._prior_gp is not None:
.... # do the TL stuff
return
.... # do the non-TL stuffand separate the two cases like that? Might not be necessary if we use inheritance though, and not sure if this would actually improve readability (but my feeling would be that it does) |
||
| # Use prior's transforms for consistency in transfer learning | ||
| input_transform = self._prior_gp.input_transform | ||
| outcome_transform = self._prior_gp.outcome_transform | ||
| else: | ||
| # Standard transform setup | ||
| input_transform = botorch.models.transforms.Normalize( | ||
| train_x.shape[-1], | ||
| bounds=context.parameter_bounds, | ||
| indices=numerical_idxs, | ||
| ) | ||
| outcome_transform = botorch.models.transforms.Standardize(train_y.shape[-1]) | ||
|
|
||
| # create GP mean | ||
| mean_module = gpytorch.means.ConstantMean(batch_shape=batch_shape) | ||
| # Configure mean module | ||
| if self._prior_gp is not None and self._transfer_mode == "mean": | ||
| mean_module = PriorMean(self._prior_gp, batch_shape=batch_shape) | ||
| else: | ||
| mean_module = gpytorch.means.ConstantMean(batch_shape=batch_shape) | ||
|
|
||
| # define the covariance module for the numeric dimensions | ||
| base_covar_module = self.kernel_factory( | ||
| context.searchspace, train_x, train_y | ||
| ).to_gpytorch( | ||
| ard_num_dims=train_x.shape[-1] - context.n_task_dimensions, | ||
| active_dims=numerical_idxs, | ||
| batch_shape=batch_shape, | ||
| ) | ||
| # Configure base covariance module | ||
| if self._prior_gp is not None and self._transfer_mode == "kernel": | ||
| base_covar_module = PriorKernel(self._prior_gp.covar_module) | ||
| else: | ||
| # Use kernel factory | ||
| base_covar_module = self.kernel_factory( | ||
| context.searchspace, train_x, train_y | ||
| ).to_gpytorch( | ||
| ard_num_dims=train_x.shape[-1] - context.n_task_dimensions, | ||
| active_dims=numerical_idxs, | ||
| batch_shape=batch_shape, | ||
| ) | ||
|
|
||
| # create GP covariance | ||
| # Handle multi-task covariance (keep existing logic) | ||
| if not context.is_multitask: | ||
| covar_module = base_covar_module | ||
| else: | ||
| task_covar_module = gpytorch.kernels.IndexKernel( | ||
| num_tasks=context.n_tasks, | ||
| active_dims=context.task_idx, | ||
| rank=context.n_tasks, # TODO: make controllable | ||
| rank=context.n_tasks, | ||
| ) | ||
| covar_module = base_covar_module * task_covar_module | ||
|
|
||
| # create GP likelihood | ||
| # Configure likelihood (keep existing logic) | ||
| noise_prior = _default_noise_factory(context.searchspace, train_x, train_y) | ||
| likelihood = gpytorch.likelihoods.GaussianLikelihood( | ||
| noise_prior=noise_prior[0].to_gpytorch(), batch_shape=batch_shape | ||
| ) | ||
| likelihood.noise = torch.tensor([noise_prior[1]]) | ||
|
|
||
| # construct and fit the Gaussian process | ||
| # Create the model | ||
| self._model = botorch.models.SingleTaskGP( | ||
| train_x, | ||
| train_y, | ||
|
|
@@ -206,6 +289,21 @@ def _fit(self, train_x: Tensor, train_y: Tensor) -> None: | |
| likelihood=likelihood, | ||
| ) | ||
|
|
||
| @override | ||
| def _fit(self, train_x: Tensor, train_y: Tensor) -> None: | ||
| import botorch | ||
| import gpytorch | ||
|
|
||
| # FIXME[typing]: It seems there is currently no better way to inform the type | ||
| # checker that the attribute is available at the time of the function call | ||
| assert self._searchspace is not None | ||
|
|
||
| context = _ModelContext(self._searchspace) | ||
| batch_shape = train_x.shape[:-2] | ||
|
|
||
| # Initialize model | ||
| self._initialize_model(train_x, train_y, context, batch_shape) | ||
|
|
||
| # TODO: This is still a temporary workaround to avoid overfitting seen in | ||
| # low-dimensional TL cases. More robust settings are being researched. | ||
| if context.n_task_dimensions > 0: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from typing import Any | ||
| from torch import Tensor | ||
|
|
||
| from copy import deepcopy | ||
|
|
||
| import gpytorch | ||
| import torch | ||
| from botorch.models import SingleTaskGP | ||
|
|
||
|
|
||
| class PriorMean(gpytorch.means.Mean): | ||
| """GPyTorch mean module using a trained GP as prior mean. | ||
| This mean module wraps a trained Gaussian Process and uses its predictions | ||
| as the mean function for another GP. | ||
|
Comment on lines
+14
to
+17
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would change it such that the first line says "Gaussian Process" and the subsequent ones use the abbreviation. |
||
| """ | ||
|
|
||
| def __init__( | ||
| self, gp: SingleTaskGP, batch_shape: torch.Size = torch.Size(), **kwargs: Any | ||
| ) -> None: | ||
| """Initialize the GP-based mean module. | ||
| Args: | ||
| gp: Trained Gaussian Process to use as mean function. | ||
| batch_shape: Batch shape for the mean module. | ||
| **kwargs: Additional keyword arguments. | ||
| """ | ||
| super().__init__() | ||
| # See https://github.com/cornellius-gp/gpytorch/issues/743 | ||
| self.gp: SingleTaskGP = deepcopy(gp) | ||
| self.batch_shape: torch.Size = batch_shape | ||
| for param in self.gp.parameters(): | ||
| param.requires_grad = False | ||
|
|
||
| def reset_gp(self) -> None: | ||
| """Reset the GP to evaluation mode for prediction.""" | ||
| self.gp.eval() | ||
| self.gp.likelihood.eval() | ||
|
|
||
| def forward(self, input: Tensor) -> Tensor: | ||
| """Compute the mean function using the wrapped GP. | ||
| Args: | ||
| input: Input tensor for which to compute the mean. | ||
| Returns: | ||
| Mean predictions from the wrapped GP. | ||
| """ | ||
| self.reset_gp() | ||
| with torch.no_grad(), gpytorch.settings.fast_pred_var(): | ||
| with gpytorch.settings.detach_test_caches(False): | ||
| mean = self.gp(input).mean.detach() | ||
| mean = mean.reshape(torch.broadcast_shapes(self.batch_shape, input.shape[:-1])) | ||
| return mean | ||
|
|
||
|
|
||
| class PriorKernel(gpytorch.kernels.Kernel): | ||
| """GPyTorch kernel module wrapping a pre-trained kernel. | ||
| This kernel module wraps a trained kernel and uses it as a fixed kernel | ||
| component in another GP. The wrapped kernel's parameters are frozen. | ||
| """ | ||
|
|
||
| def __init__(self, kernel, **kwargs): | ||
| """Initialize the kernel wrapper. | ||
| Args: | ||
| kernel: Pre-trained kernel to wrap. | ||
| **kwargs: Additional keyword arguments. | ||
| """ | ||
| super().__init__() | ||
| # See https://github.com/cornellius-gp/gpytorch/issues/743 | ||
| self.base_kernel = deepcopy(kernel) | ||
| for param in self.base_kernel.parameters(): | ||
| param.requires_grad = False | ||
|
|
||
| def reset(self): | ||
| """Reset the wrapped kernel to evaluation mode.""" | ||
| self.base_kernel.eval() | ||
|
|
||
| def forward(self, x1, x2, **params): | ||
| """Compute kernel matrix using the wrapped kernel. | ||
| Args: | ||
| x1: First set of input points. | ||
| x2: Second set of input points. | ||
| **params: Additional kernel parameters. | ||
| Returns: | ||
| Kernel matrix computed by the wrapped kernel. | ||
| """ | ||
| self.reset() | ||
| with gpytorch.settings.fast_pred_var(): | ||
| with gpytorch.settings.detach_test_caches(False): | ||
| k = self.base_kernel.forward(x1, x2, **params) | ||
| return k | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.