-
Notifications
You must be signed in to change notification settings - Fork 59
Use Botorch MultiTaskGP for transfer learning #549
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
AVHopp
wants to merge
26
commits into
main
Choose a base branch
from
tl_benchmarking_investigation
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+133
−55
Open
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
ff3579e
Replace SingleTaskGP+IndexKernel with MultiTaskGP for transfer learning
Hrovatin 175b342
Expand tests with different types of available data
Hrovatin 678f11a
Improve the approach used to get the single task parameter
Hrovatin 09d25c4
Correct anti-pattern in test
Hrovatin 68de45d
Implement review comments
Hrovatin 2bb9fc0
Clarify why single active value is required
Hrovatin 221250b
Improve test description
Hrovatin 80213a8
Change TaskParameter computational representation to int
Hrovatin 1208995
Remove int conversion for new int-based comp_df of task parameter
Hrovatin 3a80d0d
Add test for transfer learning with multiple active task parameter va…
Hrovatin 93f422d
Remove constraint to use single active task parameter value
Hrovatin 74833a9
Update tests and assert that multiple active values are recommended
Hrovatin f83f8e1
Remove mypy errors
Hrovatin 18b081f
Remove check that both tasks were recommended as this may not always …
Hrovatin f6c947b
Update baybe/surrogates/gaussian_process/core.py
Hrovatin cb57237
Update tests/test_transfer_learning.py
Hrovatin 8e8cbb9
Remove unnecessary comments
Hrovatin 4b059e3
Clarify tests
Hrovatin c1d8946
Reuse parent method for integer casting
AdrianSosic 57c690f
Add temporary _task_parameter property to SearchSpace class
AdrianSosic f8578e4
Refactor GP fitting method
AdrianSosic 4292ddd
Refactor transfer learning tests using parametrization/fixtures
AdrianSosic cbb6da9
Update CHANGELOG.md
AdrianSosic 149b4d0
Use parametrization instead of request
AdrianSosic 5190477
Directly specify active_dims in kernel
Hrovatin 8db6a0a
Drop unnecessary arguments
AdrianSosic File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| """Tests for transfer learning.""" | ||
|
|
||
| from typing import Literal | ||
|
|
||
| import pandas as pd | ||
| import pytest | ||
|
|
||
| from baybe import Campaign | ||
| from baybe.parameters import NumericalContinuousParameter, TaskParameter | ||
| from baybe.recommenders.pure.bayesian.botorch import BotorchRecommender | ||
| from baybe.searchspace import SearchSpace | ||
| from baybe.targets import NumericalTarget | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def campaign( | ||
| training_data: Literal["source", "target", "both"], | ||
| active_tasks: Literal["target_only", "both"], | ||
| ) -> Campaign: | ||
| """A transfer-learning campaign with various active tasks and training data.""" | ||
| assert training_data in ["source", "target", "both"] | ||
| assert active_tasks in ["target_only", "both"] | ||
|
|
||
| source = "B" | ||
| target = "A" | ||
| parameters = [ | ||
| NumericalContinuousParameter("x", (0, 5)), | ||
| TaskParameter( | ||
| "task", | ||
| values=(target, source), | ||
| active_values=( | ||
| (target,) if active_tasks == "target_only" else (target, source) | ||
| ), | ||
| ), | ||
| ] | ||
| searchspace = SearchSpace.from_product(parameters=parameters) | ||
| objective = NumericalTarget(name="y").to_objective() | ||
| recommender = BotorchRecommender() | ||
| lookup = pd.DataFrame( | ||
| { | ||
| "x": [1.0, 2.0, 3.0, 4.0], | ||
| "y": [1.0, 2.0, 3.0, 4.0], | ||
| "task": [target] * 2 + [source] * 2, | ||
| } | ||
| ) | ||
|
|
||
| if training_data == "source": | ||
| lookup = lookup[lookup["task"] == source] | ||
| elif training_data == "target": | ||
| lookup = lookup[lookup["task"] == target] | ||
|
|
||
| campaign = Campaign( | ||
| searchspace=searchspace, | ||
| objective=objective, | ||
| recommender=recommender, | ||
| ) | ||
| campaign.add_measurements(lookup) | ||
|
|
||
| return campaign | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("active_tasks", ["target_only", "both"]) | ||
| @pytest.mark.parametrize("training_data", ["source", "target", "both"]) | ||
| def test_recommendation(campaign: Campaign): | ||
| """Transfer learning recommendation works regardless of which task are | ||
| present in the training data and which tasks are active. | ||
| """ # noqa: D205 | ||
| campaign.recommend(1) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.