Skip to content

Upload to gcs #108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Upload to gcs #108

wants to merge 4 commits into from

Conversation

davidotte
Copy link
Collaborator

Change Description

Try to be precise. You can additionally add comments to your PR, this might help the reviewer a lot.

PR corresponds to this server PR.

If you used new dependencies: Did you add them to requirements.txt?

Who did you ping on Mattermost to review your PR? Please ping that person again whenever you are ready for another review.

Breaking changes

If you made any breaking changes, please update the version number.
Breaking changes are totally fine, we just need to make sure to keep the users informed and the server in sync.

Does this PR break the API? If so, what is the corresponding server commit?

Does this PR break the user interface? If so, why?


Please do not mark comments/conversations as resolved unless you are the assigned reviewer. This helps maintain clarity during the review process.

@CLAassistant
Copy link

CLAassistant commented May 1, 2025

CLA assistant check
All committers have signed the CLA.

@davidotte davidotte requested review from Jabb0 and noahho May 1, 2025 14:10
@@ -4,3 +4,5 @@
from pathlib import Path

CACHE_DIR = Path(__file__).parent.resolve() / ".tabpfn"

LARGE_DATASET_THRESHOLD = 500000
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing unit.

num_cells = X.shape[0] * (X.shape[1] + 1)
if num_cells > LARGE_DATASET_THRESHOLD:
# Generate Upload URLs
response_x = cls.httpx_client.post(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably a good idea to hide the low level HTTP calls and the to JSON behind a method. Ideally of an API class to abstact way these details.

params={"tabpfn_systems": json.dumps(tabpfn_systems)},
)
num_cells = X.shape[0] * (X.shape[1] + 1)
if num_cells > LARGE_DATASET_THRESHOLD:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider moving this whole block to a helper method to keep the fit method clean.

"file_name": "x_test_filename",
},
)
cls._validate_response(url_response, "predict")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have concerns about the validate response method.

  1. It does not validate anything if status code is 200.
  2. It just drops JSON decode errors.
  3. Seems like the version checking should be a middle-ware and not called for every api request.
  4. Silently drops a lot of exceptions.

cls._validate_response(url_response, "predict")
url_response = url_response.json()
# Upload test set to GCS
response = cls.httpx_client.put(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uhm is there not a GCS SDK method to do this? Maybe this can handle the upload more efficiently?

@@ -395,9 +438,11 @@ def run_progress():
return result

@classmethod
def _make_prediction_request(cls, test_set_uid, x_test_serialized, params):
def _make_prediction_request(
cls, test_set_uid, x_test_serialized, params, num_cells
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing annotations.

"""
Helper function to make the prediction request to the server.
Helper function to upload test set if required and make the prediction request to the server.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing parameter documentation.

@@ -269,6 +310,8 @@ def predict(

x_test_serialized = common_utils.serialize_to_csv_formatted_bytes(x_test)

num_cells = x_test.shape[0] * (x_test.shape[1] + 1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the +1?

@@ -395,9 +438,11 @@ def run_progress():
return result

@classmethod
def _make_prediction_request(cls, test_set_uid, x_test_serialized, params):
def _make_prediction_request(
cls, test_set_uid, x_test_serialized, params, num_cells
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this setup is not ideal.
There is now code duplication for uploading datasets.

It would be better if uploading is independent of predict. So the user can:

  1. upload data set (untyped if train or test)
  2. call fit or predict on an arbitrary dataset.

@@ -22,7 +22,7 @@
from tqdm import tqdm

from tabpfn_client.tabpfn_common_utils import utils as common_utils
from tabpfn_client.constants import CACHE_DIR
from tabpfn_client.constants import CACHE_DIR, LARGE_DATASET_THRESHOLD
from tabpfn_client.browser_auth import BrowserAuthHandler
from tabpfn_client.tabpfn_common_utils.utils import Singleton
Copy link

@Jabb0 Jabb0 May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new functionality is completely untested.
Add integration test cases that can be run against a real backend to ensure everything works.

Please add instructions how to test this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants