-
Notifications
You must be signed in to change notification settings - Fork 17
Upload to gcs #108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Upload to gcs #108
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
from tqdm import tqdm | ||
|
||
from tabpfn_client.tabpfn_common_utils import utils as common_utils | ||
from tabpfn_client.constants import CACHE_DIR | ||
from tabpfn_client.constants import CACHE_DIR, LARGE_DATASET_THRESHOLD | ||
from tabpfn_client.browser_auth import BrowserAuthHandler | ||
from tabpfn_client.tabpfn_common_utils.utils import Singleton | ||
|
||
|
@@ -223,16 +223,57 @@ def fit(cls, X, y, config=None) -> str: | |
if cached_dataset_uid: | ||
return cached_dataset_uid | ||
|
||
response = cls.httpx_client.post( | ||
url=cls.server_endpoints.fit.path, | ||
files=common_utils.to_httpx_post_file_format( | ||
[ | ||
("x_file", "x_train_filename", X_serialized), | ||
("y_file", "y_train_filename", y_serialized), | ||
] | ||
), | ||
params={"tabpfn_systems": json.dumps(tabpfn_systems)}, | ||
) | ||
num_cells = X.shape[0] * (X.shape[1] + 1) | ||
if num_cells > LARGE_DATASET_THRESHOLD: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consider moving this whole block to a helper method to keep the fit method clean. |
||
# Generate Upload URLs | ||
response_x = cls.httpx_client.post( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. probably a good idea to hide the low level HTTP calls and the to JSON behind a method. Ideally of an API class to abstact way these details. |
||
url=cls.server_endpoints.generate_upload_url.path, | ||
json={"file_type": "x_train", "file_name": "x_train_filename"}, | ||
) | ||
cls._validate_response(response_x, "fit") | ||
response_x = response_x.json() | ||
|
||
response_y = cls.httpx_client.post( | ||
url=cls.server_endpoints.generate_upload_url.path, | ||
json={"file_type": "y_train", "file_name": "y_train_filename"}, | ||
) | ||
cls._validate_response(response_y, "fit") | ||
response_y = response_y.json() | ||
|
||
# Upload train and test set to GCS | ||
response = cls.httpx_client.put( | ||
response_x["signed_url"], | ||
data=X_serialized, | ||
headers={"content-type": "text/csv"}, | ||
) | ||
cls._validate_response(response, "fit") | ||
response = cls.httpx_client.put( | ||
response_y["signed_url"], | ||
data=y_serialized, | ||
headers={"content-type": "text/csv"}, | ||
) | ||
cls._validate_response(response, "fit") | ||
# Call fit endpoint | ||
response = cls.httpx_client.post( | ||
url=cls.server_endpoints.fit.path, | ||
params={ | ||
"x_gcs_path": response_x["gcs_path"], | ||
"y_gcs_path": response_y["gcs_path"], | ||
"tabpfn_systems": json.dumps(tabpfn_systems), | ||
}, | ||
) | ||
else: | ||
# Small dataset, so upload directly. | ||
response = cls.httpx_client.post( | ||
url=cls.server_endpoints.fit.path, | ||
files=common_utils.to_httpx_post_file_format( | ||
[ | ||
("x_file", "x_train_filename", X_serialized), | ||
("y_file", "y_train_filename", y_serialized), | ||
] | ||
), | ||
params={"tabpfn_systems": json.dumps(tabpfn_systems)}, | ||
) | ||
|
||
cls._validate_response(response, "fit") | ||
|
||
|
@@ -269,6 +310,8 @@ def predict( | |
|
||
x_test_serialized = common_utils.serialize_to_csv_formatted_bytes(x_test) | ||
|
||
num_cells = x_test.shape[0] * (x_test.shape[1] + 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why the +1? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For y_test, which has shape (X_test.shape[0], 1) |
||
|
||
params = { | ||
"train_set_uid": train_set_uid, | ||
"task": task, | ||
|
@@ -303,7 +346,7 @@ def predict( | |
for attempt in range(max_attempts): | ||
try: | ||
with cls._make_prediction_request( | ||
cached_test_set_uid, x_test_serialized, params | ||
cached_test_set_uid, x_test_serialized, params, num_cells | ||
) as response: | ||
cls._validate_response(response, "predict") | ||
# Handle updates from server | ||
|
@@ -395,9 +438,11 @@ def run_progress(): | |
return result | ||
|
||
@classmethod | ||
def _make_prediction_request(cls, test_set_uid, x_test_serialized, params): | ||
def _make_prediction_request( | ||
cls, test_set_uid, x_test_serialized, params, num_cells | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing annotations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this setup is not ideal. It would be better if uploading is independent of predict. So the user can:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is again not related to my PR, right? But yes we can think more about this in the future. |
||
): | ||
""" | ||
Helper function to make the prediction request to the server. | ||
Helper function to upload test set if required and make the prediction request to the server. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing parameter documentation. |
||
""" | ||
if test_set_uid: | ||
params = params.copy() | ||
|
@@ -406,14 +451,45 @@ def _make_prediction_request(cls, test_set_uid, x_test_serialized, params): | |
method="post", url=cls.server_endpoints.predict.path, params=params | ||
) | ||
else: | ||
response = cls.httpx_client.stream( | ||
method="post", | ||
url=cls.server_endpoints.predict.path, | ||
params=params, | ||
files=common_utils.to_httpx_post_file_format( | ||
[("x_file", "x_test_filename", x_test_serialized)] | ||
), | ||
) | ||
if num_cells > LARGE_DATASET_THRESHOLD: | ||
# Generate upload URL | ||
url_response = cls.httpx_client.post( | ||
url=cls.server_endpoints.generate_upload_url.path, | ||
json={ | ||
"file_type": "x_test", | ||
"file_name": "x_test_filename", | ||
}, | ||
) | ||
cls._validate_response(url_response, "predict") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have concerns about the validate response method.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, we should definitely tackle this in the future. |
||
url_response = url_response.json() | ||
# Upload test set to GCS | ||
response = cls.httpx_client.put( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. uhm is there not a GCS SDK method to do this? Maybe this can handle the upload more efficiently? |
||
url_response["signed_url"], | ||
data=x_test_serialized, | ||
headers={"content-type": "text/csv"}, | ||
) | ||
cls._validate_response(response, "predict") | ||
# Make prediction request | ||
response = cls.httpx_client.stream( | ||
method="post", | ||
url=cls.server_endpoints.predict.path, | ||
params={ | ||
**params, | ||
"x_gcs_path": url_response["gcs_path"], | ||
}, | ||
) | ||
else: | ||
# Small dataset, so upload directly. | ||
response = cls.httpx_client.stream( | ||
method="post", | ||
url=cls.server_endpoints.predict.path, | ||
files=common_utils.to_httpx_post_file_format( | ||
[ | ||
("x_file", "x_test_filename", x_test_serialized), | ||
] | ||
), | ||
params=params, | ||
) | ||
return response | ||
|
||
@staticmethod | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,5 @@ | |
from pathlib import Path | ||
|
||
CACHE_DIR = Path(__file__).parent.resolve() / ".tabpfn" | ||
|
||
LARGE_DATASET_THRESHOLD = 500000 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing unit. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new functionality is completely untested.
Add integration test cases that can be run against a real backend to ensure everything works.
Please add instructions how to test this.