-
Notifications
You must be signed in to change notification settings - Fork 17
Upload to gcs #108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Upload to gcs #108
Conversation
@@ -4,3 +4,5 @@ | |||
from pathlib import Path | |||
|
|||
CACHE_DIR = Path(__file__).parent.resolve() / ".tabpfn" | |||
|
|||
LARGE_DATASET_THRESHOLD = 500000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing unit.
num_cells = X.shape[0] * (X.shape[1] + 1) | ||
if num_cells > LARGE_DATASET_THRESHOLD: | ||
# Generate Upload URLs | ||
response_x = cls.httpx_client.post( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably a good idea to hide the low level HTTP calls and the to JSON behind a method. Ideally of an API class to abstact way these details.
params={"tabpfn_systems": json.dumps(tabpfn_systems)}, | ||
) | ||
num_cells = X.shape[0] * (X.shape[1] + 1) | ||
if num_cells > LARGE_DATASET_THRESHOLD: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider moving this whole block to a helper method to keep the fit method clean.
"file_name": "x_test_filename", | ||
}, | ||
) | ||
cls._validate_response(url_response, "predict") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have concerns about the validate response method.
- It does not validate anything if status code is 200.
- It just drops JSON decode errors.
- Seems like the version checking should be a middle-ware and not called for every api request.
- Silently drops a lot of exceptions.
cls._validate_response(url_response, "predict") | ||
url_response = url_response.json() | ||
# Upload test set to GCS | ||
response = cls.httpx_client.put( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uhm is there not a GCS SDK method to do this? Maybe this can handle the upload more efficiently?
@@ -395,9 +438,11 @@ def run_progress(): | |||
return result | |||
|
|||
@classmethod | |||
def _make_prediction_request(cls, test_set_uid, x_test_serialized, params): | |||
def _make_prediction_request( | |||
cls, test_set_uid, x_test_serialized, params, num_cells |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing annotations.
""" | ||
Helper function to make the prediction request to the server. | ||
Helper function to upload test set if required and make the prediction request to the server. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing parameter documentation.
@@ -269,6 +310,8 @@ def predict( | |||
|
|||
x_test_serialized = common_utils.serialize_to_csv_formatted_bytes(x_test) | |||
|
|||
num_cells = x_test.shape[0] * (x_test.shape[1] + 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why the +1?
@@ -395,9 +438,11 @@ def run_progress(): | |||
return result | |||
|
|||
@classmethod | |||
def _make_prediction_request(cls, test_set_uid, x_test_serialized, params): | |||
def _make_prediction_request( | |||
cls, test_set_uid, x_test_serialized, params, num_cells |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this setup is not ideal.
There is now code duplication for uploading datasets.
It would be better if uploading is independent of predict. So the user can:
- upload data set (untyped if train or test)
- call fit or predict on an arbitrary dataset.
@@ -22,7 +22,7 @@ | |||
from tqdm import tqdm | |||
|
|||
from tabpfn_client.tabpfn_common_utils import utils as common_utils | |||
from tabpfn_client.constants import CACHE_DIR | |||
from tabpfn_client.constants import CACHE_DIR, LARGE_DATASET_THRESHOLD | |||
from tabpfn_client.browser_auth import BrowserAuthHandler | |||
from tabpfn_client.tabpfn_common_utils.utils import Singleton |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new functionality is completely untested.
Add integration test cases that can be run against a real backend to ensure everything works.
Please add instructions how to test this.
Change Description
Try to be precise. You can additionally add comments to your PR, this might help the reviewer a lot.
PR corresponds to this server PR.
If you used new dependencies: Did you add them to
requirements.txt
?Who did you ping on Mattermost to review your PR? Please ping that person again whenever you are ready for another review.
Breaking changes
If you made any breaking changes, please update the version number.
Breaking changes are totally fine, we just need to make sure to keep the users informed and the server in sync.
Does this PR break the API? If so, what is the corresponding server commit?
Does this PR break the user interface? If so, why?
Please do not mark comments/conversations as resolved unless you are the assigned reviewer. This helps maintain clarity during the review process.