Skip to content

Upload to gcs #108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 153 additions & 25 deletions tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tqdm import tqdm

from tabpfn_client.tabpfn_common_utils import utils as common_utils
from tabpfn_client.constants import CACHE_DIR
from tabpfn_client.constants import CACHE_DIR, CELL_THRESHOLD_LARGE_DATASET
from tabpfn_client.browser_auth import BrowserAuthHandler
from tabpfn_client.tabpfn_common_utils.utils import Singleton
Copy link

@Jabb0 Jabb0 May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new functionality is completely untested.
Add integration test cases that can be run against a real backend to ensure everything works.

Please add instructions how to test this.

Copy link
Collaborator Author

@davidotte davidotte Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these integration tests just be for manual testing or should they be integrated into the automated pipeline? Because always testing against the real backend could be pretty slow.


Expand Down Expand Up @@ -182,6 +182,110 @@ def reset_authorization(cls):
cls._access_token = None
cls.httpx_client.headers.pop("Authorization", None)

@classmethod
def generate_upload_url(
cls, file_type: Literal["x_train", "y_train", "x_test"]
) -> dict:
"""
Generate a signed URL for direct dataset file upload to cloud storage.

Parameters
----------
file_type : Literal["x_train", "y_train", "x_test"]
The type of the file to upload.

Returns
-------
dict
A dictionary containing the signed URL and GCS path.
"""
response_x = cls.httpx_client.post(
url=cls.server_endpoints.generate_upload_url.path,
params={"file_type": file_type},
)
cls._validate_response(response_x, "generate_upload_url")
return response_x.json()

@classmethod
def upload_to_gcs(cls, serialized_data: bytes, signed_url: str):
"""
Upload a serialized dataset file to cloud storage using a signed URL.

Parameters
----------
serialized_data : bytes
The serialized dataset file.
signed_url : str
The signed URL for the file upload.
"""
response = cls.httpx_client.put(
signed_url,
data=serialized_data,
headers={"content-type": "text/csv"},
)
cls._validate_response(response, "upload_to_gcs")

@classmethod
def upload_train_set(
cls,
X_serialized: bytes,
y_serialized: bytes,
num_cells: int,
tabpfn_systems: list[str],
):
"""
Upload a train set to server and return the train set UID if successful.

Parameters
----------
X_serialized : bytes
The serialized training input samples.
y_serialized : bytes
The serialized target values.
num_cells : int
The number of cells in the train set.
tabpfn_systems : list[str]
The tabpfn systems to use for the fit.

Returns
-------
train_set_uid : str
The unique ID of the train set in the server.
"""
if num_cells > CELL_THRESHOLD_LARGE_DATASET:
# Upload train data to GCS
response_x = cls.generate_upload_url("x_train")
response_y = cls.generate_upload_url("y_train")
cls.upload_to_gcs(X_serialized, response_x["signed_url"])
cls.upload_to_gcs(y_serialized, response_y["signed_url"])

# Call fit endpoint
response = cls.httpx_client.post(
url=cls.server_endpoints.fit.path,
params={
"x_gcs_path": response_x["gcs_path"],
"y_gcs_path": response_y["gcs_path"],
"tabpfn_systems": json.dumps(tabpfn_systems),
},
)
else:
# Small dataset, so upload directly.
response = cls.httpx_client.post(
url=cls.server_endpoints.fit.path,
files=common_utils.to_httpx_post_file_format(
[
("x_file", "x_train_filename", X_serialized),
("y_file", "y_train_filename", y_serialized),
]
),
params={"tabpfn_systems": json.dumps(tabpfn_systems)},
)

cls._validate_response(response, "fit")

train_set_uid = response.json()["train_set_uid"]
return train_set_uid

@classmethod
def fit(cls, X, y, config=None) -> str:
"""
Expand Down Expand Up @@ -223,20 +327,10 @@ def fit(cls, X, y, config=None) -> str:
if cached_dataset_uid:
return cached_dataset_uid

response = cls.httpx_client.post(
url=cls.server_endpoints.fit.path,
files=common_utils.to_httpx_post_file_format(
[
("x_file", "x_train_filename", X_serialized),
("y_file", "y_train_filename", y_serialized),
]
),
params={"tabpfn_systems": json.dumps(tabpfn_systems)},
num_cells = X.shape[0] * (X.shape[1] + 1)
train_set_uid = cls.upload_train_set(
X_serialized, y_serialized, num_cells, tabpfn_systems
)

cls._validate_response(response, "fit")

train_set_uid = response.json()["train_set_uid"]
cls.dataset_uid_cache_manager.add_dataset_uid(dataset_hash, train_set_uid)
return train_set_uid

Expand Down Expand Up @@ -269,6 +363,8 @@ def predict(

x_test_serialized = common_utils.serialize_to_csv_formatted_bytes(x_test)

num_cells = x_test.shape[0] * (x_test.shape[1] + 1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the +1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For y_test, which has shape (X_test.shape[0], 1)


params = {
"train_set_uid": train_set_uid,
"task": task,
Expand Down Expand Up @@ -303,7 +399,7 @@ def predict(
for attempt in range(max_attempts):
try:
with cls._make_prediction_request(
cached_test_set_uid, x_test_serialized, params
cached_test_set_uid, x_test_serialized, params, num_cells
) as response:
cls._validate_response(response, "predict")
# Handle updates from server
Expand Down Expand Up @@ -395,9 +491,24 @@ def run_progress():
return result

@classmethod
def _make_prediction_request(cls, test_set_uid, x_test_serialized, params):
def _make_prediction_request(
cls,
test_set_uid: Optional[str],
x_test_serialized: bytes,
params: dict,
num_cells: int,
):
"""
Helper function to make the prediction request to the server.
Helper function to upload test set if required and make the prediction request to the server.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing parameter documentation.


Args:
test_set_uid: The unique ID of the train set in the server.
x_test_serialized: The serialized test set.
params: The parameters for the prediction request.
num_cells: The number of cells in the test set.

Returns:
response: Streaming response from the server.
"""
if test_set_uid:
params = params.copy()
Expand All @@ -406,14 +517,31 @@ def _make_prediction_request(cls, test_set_uid, x_test_serialized, params):
method="post", url=cls.server_endpoints.predict.path, params=params
)
else:
response = cls.httpx_client.stream(
method="post",
url=cls.server_endpoints.predict.path,
params=params,
files=common_utils.to_httpx_post_file_format(
[("x_file", "x_test_filename", x_test_serialized)]
),
)
if num_cells > CELL_THRESHOLD_LARGE_DATASET:
# Upload to GCS.
url_response = cls.generate_upload_url("x_test")
cls.upload_to_gcs(x_test_serialized, url_response["signed_url"])
# Make prediction request.
response = cls.httpx_client.stream(
method="post",
url=cls.server_endpoints.predict.path,
params={
**params,
"x_gcs_path": url_response["gcs_path"],
},
)
else:
# Small dataset, so upload directly.
response = cls.httpx_client.stream(
method="post",
url=cls.server_endpoints.predict.path,
files=common_utils.to_httpx_post_file_format(
[
("x_file", "x_test_filename", x_test_serialized),
]
),
params=params,
)
return response

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions tabpfn_client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
from pathlib import Path

CACHE_DIR = Path(__file__).parent.resolve() / ".tabpfn"

CELL_THRESHOLD_LARGE_DATASET = 500000
5 changes: 5 additions & 0 deletions tabpfn_client/server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,9 @@ endpoints:
path: "/get_api_usage/"
methods: [ "POST" ]
description: "Get prediction hits data for a given user"

generate_upload_url:
path: "/generate-upload-url/"
methods: ["POST"]
description: "Generate a signed URL for direct dataset file upload to cloud storage"