-
Notifications
You must be signed in to change notification settings - Fork 59
Add initial version of the new SourcePriorGaussianProcessSurrogate #678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This surrogate implements transfer learning by:
1. Training a source GP on source task data (without task dimension)
2. Using the source GP as a mean prior for the target GP
3. Training the target GP on all data (source + target) with source-informed priors.
- Implement GPyTKernel class to wrap pretrained kernel - comment out IndexKernel in GPBuilder class (reduced search space should not hava a task parameter - Introduce flag in GPBuilder.create_gp to use prior mean and/or covariance, but raise NotImplementedError
be7aaf8 to
dc0e801
Compare
…L model - New class method for creating a GP Surrogate from a given prior GP - transfer_mode "mean" or "kernel" defines wether mean or kernel from prior should be used - new GaussianProcessSurrogate._initialize_model method that will create the _model depending on whether a prior is given or not - New PriorMean and PriorKernel classes in baybe/surrogates/gaussian_process/prior_modules.py (GPyTorch mean and covariance modules created from a pior GP) - use new from_prior_gp method in SourcePriorGaussianProcessSurrogate
|
First of all: Appreciate the detailed PR description :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a transfer learning implementation for Gaussian Process surrogates through the SourcePriorGaussianProcessSurrogate class. The implementation enables knowledge transfer from source tasks to target tasks using hierarchical Gaussian Process modeling.
Key changes include:
- New transfer learning surrogate that trains separate source and target GPs with shared priors
- Extension of
GaussianProcessSurrogatewithfrom_prior_gpclass method for creating GP surrogates from pre-trained priors - New prior modules (
PriorMeanandPriorKernel) that extract posterior mean or covariance from existing GPs
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| baybe/surrogates/transfer_learning/source_prior.py | Main implementation of the transfer learning surrogate with hierarchical GP training and prediction logic |
| baybe/surrogates/gaussian_process/prior_modules.py | Helper classes for extracting mean and kernel priors from trained GPs |
| baybe/surrogates/gaussian_process/core.py | Extended GaussianProcessSurrogate with transfer learning capabilities and prior-based initialization |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| print( | ||
| "No target data provided. Using copy of source GP as target GP." | ||
| ) |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using print() statements for logging in production code is not recommended. Consider using the logging module for better control over output levels and destinations.
| print( | ||
| "No target data provided. Using copy of source GP as target GP." | ||
| ) | ||
| from copy import deepcopy |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import statement should be moved to the top of the file with other imports rather than being inside a function.
| from copy import deepcopy | ||
|
|
||
| from botorch.models import SingleTaskGP |
Copilot
AI
Oct 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import statements should be moved to the top of the file with other imports rather than being inside a method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First round of comments. Didn't manage to get through everything. I'd appreciate an example (just as a test) that can be used to see the code in action :)
| @@ -0,0 +1,518 @@ | |||
| """Abstract base class for source prior surrogates.""" | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not see any new tests, is this new class being automatically being picked up by our machinery or has the addition of a test been overlooked?
|
|
||
| # Transfer learning fields | ||
| _prior_gp = field(init=False, default=None, eq=False) | ||
| """Prior GP to extract mean/covariance from for transfer learning.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| """Prior GP to extract mean/covariance from for transfer learning.""" | |
| """Prior GP to extract mean/covariance from transfer learning.""" |
| _prior_gp = field(init=False, default=None, eq=False) | ||
| """Prior GP to extract mean/covariance from for transfer learning.""" | ||
|
|
||
| _transfer_mode: Literal["mean", "kernel"] | None = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These new fields are only relevant for special contexts, but we now have them in the base class. Looking at the code, I have the feeling that having a dedicated subclass of the GaussianProcessSurrogate. But I am not sure if that would be the preferred option and hence summon the master of all class design @AdrianSosic
| prior_gp: Fitted SingleTaskGP to use as prior | ||
| transfer_mode: "mean" extracts posterior mean as prior mean, | ||
| "kernel" uses prior's covariance | ||
| kernel_factory: Kernel factory for new covariance (required for mean mode, | ||
| ignored for kernel mode) | ||
| **kwargs: Additional arguments passed to GaussianProcessSurrogate constructor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indentation here looks a bit weird, does this render correctly?
| def from_prior_gp( | ||
| cls, | ||
| prior_gp, | ||
| transfer_mode: Literal["mean", "kernel"] = "mean", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any particular reason to choose one over the other?
| _target_gp: SingleTaskGP | None = field(init=False, default=None, eq=False) | ||
| """Fitted target Gaussian Process model with source prior.""" | ||
|
|
||
| def _identify_target_task(self) -> tuple[int, float]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can get rid of some of the comments in the function here, e.g. the "Find the TaskParamter in the search space" one
| return task_idx, target_value | ||
|
|
||
| def _validate_transfer_learning_context(self) -> None: | ||
| """Validate that we have a proper transfer learning setup. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What exactly is meant by "proper transfer learning setup"? Few more details in the docstring would be nice
| if self._searchspace.task_idx is None: | ||
| raise ValueError( | ||
| "No task parameter found in search space. " | ||
| "SourcePriorGaussianProcessSurrogate requires a TaskParameter " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't you use self.__class__.__name__ or type(sef).__name__ or similar here?
| target_task: Task ID for the target task. | ||
| Returns: | ||
| Tuple of (source_data_list, target_data_tuple). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type looks quite complicated - would it be easier to either simplify it (maybe by storing things together and separating them later?) or add a bit more detailed description on the exact format as there are e.g. also the Union types which are not immediatly clear?
| def _fit(self, train_x: Tensor, train_y: Tensor) -> None: | ||
| """Fit the transfer learning model. | ||
| This method handles the common training workflow for all transfergpbo models: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does transfergpbo refer to? I think that is the general name of the method/the paper, can you maybe link this at some point/make it very explicit where the method is from?
SourcePriorGaussianProcessSurrogate Implementation
The
SourcePriorGaussianProcessSurrogateinbaybe/surrogates/transfer_learning/source_prior.pyimplements transfer learning through a hierarchical approach:Training Flow
_validate_transfer_learning_context(), ensuring aTaskParameteris present._identify_target_task()identifies the task parameter from the search space and determines the target task by examining theactive_valuesof theTaskParameter._extract_task_data():- Source data:
(X_s, y_s)- Target data:
(X_t, y_t)These pairs contain only features (no task parameter column) after filtering.
_fit())a) A
SingleTaskGPinstance is created using BayBE's default configuration (mean and covariance modules). Thisself._source_gpreceives a reduced search space (without task parameter) and is fitted to the source data(X_s, y_s).b) A new
SingleTaskGPinstance is created that uses the posterior mean of the source GP as a mean prior (viaPriorMeanclass). Thisself._target_gpis fitted to the target data(X_t, y_t).Making Predictions (
_posterior())- Source predictions: Made using the
self._source_gp- Target predictions: Made using the
self._target_gp(which has source prior))Individual predictions are combined into unified mean and covariance tensors for the entire batch.
New
GaussianProcessSurrogate.from_prior_gpclass method to create a GP surrogate by transferring the mean or covariance from a pretrained prior:"mean"or"kernel"defines wether mean or kernel from prior should be usedGaussianProcessSurrogate._initialize_modelmethod that will create the_modeldepending onwhether a prior is given or not
PriorMeanandPriorKernelclasses inbaybe/surrogates/gaussian_process/prior_modules.pyextract posterior mean or covariance from a prior GPHelper classes
SourcePriorWrapperModel: Provides BoTorch compatibility by wrapping theSourcePriorGaussianProcessSurrogatefor the_to_botorch()method.Current Limitations