Skip to content

Conversation

@kalama-ai
Copy link
Collaborator

@kalama-ai kalama-ai commented Oct 24, 2025

SourcePriorGaussianProcessSurrogate Implementation

The SourcePriorGaussianProcessSurrogate in baybe/surrogates/transfer_learning/source_prior.py implements transfer learning through a hierarchical approach:

  1. Training a source GP on source task data (without task dimension)
  2. Using the mean or covariance of the source GP as a prior for the target GP
  3. Training the target GP on target data

Training Flow

  1. The surrogate receives a search space and validates the transfer learning context via _validate_transfer_learning_context(), ensuring a TaskParameter is present.
  2. _identify_target_task() identifies the task parameter from the search space and determines the target task by examining the active_values of the TaskParameter.
  3. Data Extraction: The surrogate splits the training data into source and target pairs using _extract_task_data():
    - Source data: (X_s, y_s)
    - Target data: (X_t, y_t)
    These pairs contain only features (no task parameter column) after filtering.
  4. Model Training (_fit())
    a) A SingleTaskGP instance is created using BayBE's default configuration (mean and covariance modules). This self._source_gp receives a reduced search space (without task parameter) and is fitted to the source data (X_s, y_s).
    b) A new SingleTaskGP instance is created that uses the posterior mean of the source GP as a mean prior (via PriorMean class). This self._target_gp is fitted to the target data (X_t, y_t).

Making Predictions (_posterior())

  1. The method receives candidates in computational representation, extracts batch dimensions, and initializes empty tensors for mean and covariance predictions.
  2. Candidates are filtered into source and target groups using the task parameter column.
  3. Separate Predictions:
    - Source predictions: Made using the self._source_gp
    - Target predictions: Made using the self._target_gp (which has source prior))
    Individual predictions are combined into unified mean and covariance tensors for the entire batch.

New GaussianProcessSurrogate.from_prior_gp class method to create a GP surrogate by transferring the mean or covariance from a pretrained prior:

  • transfer_mode "mean" or "kernel" defines wether mean or kernel from prior should be used
  • new GaussianProcessSurrogate._initialize_model method that will create the _model depending on
    whether a prior is given or not
  • PriorMean and PriorKernel classes in baybe/surrogates/gaussian_process/prior_modules.py extract posterior mean or covariance from a prior GP

Helper classes

  1. SourcePriorWrapperModel: Provides BoTorch compatibility by wrapping the SourcePriorGaussianProcessSurrogate for the _to_botorch() method.

Current Limitations

  1. Single Active Value: The model expects exactly one active value for the task parameter in the search space.
  2. Single Source Limitation: The surrogate is limited to one source task. Extending to multiple sources would require developing methods to combine multiple source GPs into unified prior mean (and potentially covariance) functions.
  3. Single Output: Multi-output extensions haven't been considered yet.

This surrogate implements transfer learning by:
    1. Training a source GP on source task data (without task dimension)
    2. Using the source GP as a mean prior for the target GP
    3. Training the target GP on all data (source + target) with source-informed priors.
- Implement GPyTKernel class to wrap pretrained kernel
- comment out IndexKernel in GPBuilder class (reduced search space should not hava a task parameter
- Introduce flag in GPBuilder.create_gp to use prior mean and/or covariance, but raise NotImplementedError
@AdrianSosic AdrianSosic changed the title Add initial version of the new SOurcePriorGaussianProcessSurrogate. Add initial version of the new SourcePriorGaussianProcessSurrogate Oct 27, 2025
@kalama-ai kalama-ai force-pushed the feature/source_prior_gp branch from be7aaf8 to dc0e801 Compare October 27, 2025 15:08
…L model

- New class method for creating a GP Surrogate from a given prior GP
- transfer_mode "mean" or "kernel" defines wether mean or kernel from prior should be used
- new GaussianProcessSurrogate._initialize_model method that will create the _model depending on
whether a prior is given or not
- New PriorMean and PriorKernel classes in baybe/surrogates/gaussian_process/prior_modules.py
(GPyTorch mean and covariance modules created from a pior GP)
- use new from_prior_gp method in SourcePriorGaussianProcessSurrogate
@kalama-ai kalama-ai marked this pull request as ready for review October 28, 2025 16:59
@AVHopp AVHopp requested a review from Copilot October 29, 2025 09:31
@AVHopp
Copy link
Collaborator

AVHopp commented Oct 29, 2025

First of all: Appreciate the detailed PR description :)

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a transfer learning implementation for Gaussian Process surrogates through the SourcePriorGaussianProcessSurrogate class. The implementation enables knowledge transfer from source tasks to target tasks using hierarchical Gaussian Process modeling.

Key changes include:

  • New transfer learning surrogate that trains separate source and target GPs with shared priors
  • Extension of GaussianProcessSurrogate with from_prior_gp class method for creating GP surrogates from pre-trained priors
  • New prior modules (PriorMean and PriorKernel) that extract posterior mean or covariance from existing GPs

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
baybe/surrogates/transfer_learning/source_prior.py Main implementation of the transfer learning surrogate with hierarchical GP training and prediction logic
baybe/surrogates/gaussian_process/prior_modules.py Helper classes for extracting mean and kernel priors from trained GPs
baybe/surrogates/gaussian_process/core.py Extended GaussianProcessSurrogate with transfer learning capabilities and prior-based initialization

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +337 to +339
print(
"No target data provided. Using copy of source GP as target GP."
)
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using print() statements for logging in production code is not recommended. Consider using the logging module for better control over output levels and destinations.

Copilot uses AI. Check for mistakes.
print(
"No target data provided. Using copy of source GP as target GP."
)
from copy import deepcopy
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import statement should be moved to the top of the file with other imports rather than being inside a function.

Copilot uses AI. Check for mistakes.
Comment on lines +155 to +157
from copy import deepcopy

from botorch.models import SingleTaskGP
Copy link

Copilot AI Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import statements should be moved to the top of the file with other imports rather than being inside a method.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

@AVHopp AVHopp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First round of comments. Didn't manage to get through everything. I'd appreciate an example (just as a test) that can be used to see the code in action :)

@@ -0,0 +1,518 @@
"""Abstract base class for source prior surrogates."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see any new tests, is this new class being automatically being picked up by our machinery or has the addition of a test been overlooked?


# Transfer learning fields
_prior_gp = field(init=False, default=None, eq=False)
"""Prior GP to extract mean/covariance from for transfer learning."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Prior GP to extract mean/covariance from for transfer learning."""
"""Prior GP to extract mean/covariance from transfer learning."""

_prior_gp = field(init=False, default=None, eq=False)
"""Prior GP to extract mean/covariance from for transfer learning."""

_transfer_mode: Literal["mean", "kernel"] | None = field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These new fields are only relevant for special contexts, but we now have them in the base class. Looking at the code, I have the feeling that having a dedicated subclass of the GaussianProcessSurrogate. But I am not sure if that would be the preferred option and hence summon the master of all class design @AdrianSosic

Comment on lines +142 to +147
prior_gp: Fitted SingleTaskGP to use as prior
transfer_mode: "mean" extracts posterior mean as prior mean,
"kernel" uses prior's covariance
kernel_factory: Kernel factory for new covariance (required for mean mode,
ignored for kernel mode)
**kwargs: Additional arguments passed to GaussianProcessSurrogate constructor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indentation here looks a bit weird, does this render correctly?

def from_prior_gp(
cls,
prior_gp,
transfer_mode: Literal["mean", "kernel"] = "mean",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason to choose one over the other?

_target_gp: SingleTaskGP | None = field(init=False, default=None, eq=False)
"""Fitted target Gaussian Process model with source prior."""

def _identify_target_task(self) -> tuple[int, float]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can get rid of some of the comments in the function here, e.g. the "Find the TaskParamter in the search space" one

return task_idx, target_value

def _validate_transfer_learning_context(self) -> None:
"""Validate that we have a proper transfer learning setup.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What exactly is meant by "proper transfer learning setup"? Few more details in the docstring would be nice

if self._searchspace.task_idx is None:
raise ValueError(
"No task parameter found in search space. "
"SourcePriorGaussianProcessSurrogate requires a TaskParameter "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you use self.__class__.__name__ or type(sef).__name__ or similar here?

target_task: Task ID for the target task.
Returns:
Tuple of (source_data_list, target_data_tuple).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type looks quite complicated - would it be easier to either simplify it (maybe by storing things together and separating them later?) or add a bit more detailed description on the exact format as there are e.g. also the Union types which are not immediatly clear?

def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
"""Fit the transfer learning model.
This method handles the common training workflow for all transfergpbo models:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does transfergpbo refer to? I think that is the general name of the method/the paper, can you maybe link this at some point/make it very explicit where the method is from?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants