Skip to content

Upload to gcs #108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 98 additions & 22 deletions tabpfn_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tqdm import tqdm

from tabpfn_client.tabpfn_common_utils import utils as common_utils
from tabpfn_client.constants import CACHE_DIR
from tabpfn_client.constants import CACHE_DIR, LARGE_DATASET_THRESHOLD
from tabpfn_client.browser_auth import BrowserAuthHandler
from tabpfn_client.tabpfn_common_utils.utils import Singleton
Copy link

@Jabb0 Jabb0 May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new functionality is completely untested.
Add integration test cases that can be run against a real backend to ensure everything works.

Please add instructions how to test this.


Expand Down Expand Up @@ -223,16 +223,57 @@ def fit(cls, X, y, config=None) -> str:
if cached_dataset_uid:
return cached_dataset_uid

response = cls.httpx_client.post(
url=cls.server_endpoints.fit.path,
files=common_utils.to_httpx_post_file_format(
[
("x_file", "x_train_filename", X_serialized),
("y_file", "y_train_filename", y_serialized),
]
),
params={"tabpfn_systems": json.dumps(tabpfn_systems)},
)
num_cells = X.shape[0] * (X.shape[1] + 1)
if num_cells > LARGE_DATASET_THRESHOLD:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider moving this whole block to a helper method to keep the fit method clean.

# Generate Upload URLs
response_x = cls.httpx_client.post(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably a good idea to hide the low level HTTP calls and the to JSON behind a method. Ideally of an API class to abstact way these details.

url=cls.server_endpoints.generate_upload_url.path,
json={"file_type": "x_train", "file_name": "x_train_filename"},
)
cls._validate_response(response_x, "fit")
response_x = response_x.json()

response_y = cls.httpx_client.post(
url=cls.server_endpoints.generate_upload_url.path,
json={"file_type": "y_train", "file_name": "y_train_filename"},
)
cls._validate_response(response_y, "fit")
response_y = response_y.json()

# Upload train and test set to GCS
response = cls.httpx_client.put(
response_x["signed_url"],
data=X_serialized,
headers={"content-type": "text/csv"},
)
cls._validate_response(response, "fit")
response = cls.httpx_client.put(
response_y["signed_url"],
data=y_serialized,
headers={"content-type": "text/csv"},
)
cls._validate_response(response, "fit")
# Call fit endpoint
response = cls.httpx_client.post(
url=cls.server_endpoints.fit.path,
params={
"x_gcs_path": response_x["gcs_path"],
"y_gcs_path": response_y["gcs_path"],
"tabpfn_systems": json.dumps(tabpfn_systems),
},
)
else:
# Small dataset, so upload directly.
response = cls.httpx_client.post(
url=cls.server_endpoints.fit.path,
files=common_utils.to_httpx_post_file_format(
[
("x_file", "x_train_filename", X_serialized),
("y_file", "y_train_filename", y_serialized),
]
),
params={"tabpfn_systems": json.dumps(tabpfn_systems)},
)

cls._validate_response(response, "fit")

Expand Down Expand Up @@ -269,6 +310,8 @@ def predict(

x_test_serialized = common_utils.serialize_to_csv_formatted_bytes(x_test)

num_cells = x_test.shape[0] * (x_test.shape[1] + 1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the +1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For y_test, which has shape (X_test.shape[0], 1)


params = {
"train_set_uid": train_set_uid,
"task": task,
Expand Down Expand Up @@ -303,7 +346,7 @@ def predict(
for attempt in range(max_attempts):
try:
with cls._make_prediction_request(
cached_test_set_uid, x_test_serialized, params
cached_test_set_uid, x_test_serialized, params, num_cells
) as response:
cls._validate_response(response, "predict")
# Handle updates from server
Expand Down Expand Up @@ -395,9 +438,11 @@ def run_progress():
return result

@classmethod
def _make_prediction_request(cls, test_set_uid, x_test_serialized, params):
def _make_prediction_request(
cls, test_set_uid, x_test_serialized, params, num_cells
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing annotations.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this setup is not ideal.
There is now code duplication for uploading datasets.

It would be better if uploading is independent of predict. So the user can:

  1. upload data set (untyped if train or test)
  2. call fit or predict on an arbitrary dataset.

Copy link
Collaborator Author

@davidotte davidotte May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is again not related to my PR, right? But yes we can think more about this in the future.

):
"""
Helper function to make the prediction request to the server.
Helper function to upload test set if required and make the prediction request to the server.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing parameter documentation.

"""
if test_set_uid:
params = params.copy()
Expand All @@ -406,14 +451,45 @@ def _make_prediction_request(cls, test_set_uid, x_test_serialized, params):
method="post", url=cls.server_endpoints.predict.path, params=params
)
else:
response = cls.httpx_client.stream(
method="post",
url=cls.server_endpoints.predict.path,
params=params,
files=common_utils.to_httpx_post_file_format(
[("x_file", "x_test_filename", x_test_serialized)]
),
)
if num_cells > LARGE_DATASET_THRESHOLD:
# Generate upload URL
url_response = cls.httpx_client.post(
url=cls.server_endpoints.generate_upload_url.path,
json={
"file_type": "x_test",
"file_name": "x_test_filename",
},
)
cls._validate_response(url_response, "predict")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have concerns about the validate response method.

  1. It does not validate anything if status code is 200.
  2. It just drops JSON decode errors.
  3. Seems like the version checking should be a middle-ware and not called for every api request.
  4. Silently drops a lot of exceptions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, we should definitely tackle this in the future.

url_response = url_response.json()
# Upload test set to GCS
response = cls.httpx_client.put(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uhm is there not a GCS SDK method to do this? Maybe this can handle the upload more efficiently?

url_response["signed_url"],
data=x_test_serialized,
headers={"content-type": "text/csv"},
)
cls._validate_response(response, "predict")
# Make prediction request
response = cls.httpx_client.stream(
method="post",
url=cls.server_endpoints.predict.path,
params={
**params,
"x_gcs_path": url_response["gcs_path"],
},
)
else:
# Small dataset, so upload directly.
response = cls.httpx_client.stream(
method="post",
url=cls.server_endpoints.predict.path,
files=common_utils.to_httpx_post_file_format(
[
("x_file", "x_test_filename", x_test_serialized),
]
),
params=params,
)
return response

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions tabpfn_client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
from pathlib import Path

CACHE_DIR = Path(__file__).parent.resolve() / ".tabpfn"

LARGE_DATASET_THRESHOLD = 500000
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing unit.

5 changes: 5 additions & 0 deletions tabpfn_client/server_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,9 @@ endpoints:
path: "/get_api_usage/"
methods: [ "POST" ]
description: "Get prediction hits data for a given user"

generate_upload_url:
path: "/generate-upload-url/"
methods: ["POST"]
description: "Generate a signed URL for direct dataset file upload to cloud storage"