-
Notifications
You must be signed in to change notification settings - Fork 19
Upload to gcs #108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Upload to gcs #108
Changes from all commits
d0ca019
885439d
12e08f4
b664ffa
bb69665
fc699f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
from tqdm import tqdm | ||
|
||
from tabpfn_client.tabpfn_common_utils import utils as common_utils | ||
from tabpfn_client.constants import CACHE_DIR | ||
from tabpfn_client.constants import CACHE_DIR, CELL_THRESHOLD_LARGE_DATASET | ||
from tabpfn_client.browser_auth import BrowserAuthHandler | ||
from tabpfn_client.tabpfn_common_utils.utils import Singleton | ||
|
||
|
@@ -182,6 +182,110 @@ def reset_authorization(cls): | |
cls._access_token = None | ||
cls.httpx_client.headers.pop("Authorization", None) | ||
|
||
@classmethod | ||
def generate_upload_url( | ||
cls, file_type: Literal["x_train", "y_train", "x_test"] | ||
) -> dict: | ||
""" | ||
Generate a signed URL for direct dataset file upload to cloud storage. | ||
|
||
Parameters | ||
---------- | ||
file_type : Literal["x_train", "y_train", "x_test"] | ||
The type of the file to upload. | ||
|
||
Returns | ||
------- | ||
dict | ||
A dictionary containing the signed URL and GCS path. | ||
""" | ||
response_x = cls.httpx_client.post( | ||
url=cls.server_endpoints.generate_upload_url.path, | ||
params={"file_type": file_type}, | ||
) | ||
cls._validate_response(response_x, "generate_upload_url") | ||
return response_x.json() | ||
|
||
@classmethod | ||
def upload_to_gcs(cls, serialized_data: bytes, signed_url: str): | ||
""" | ||
Upload a serialized dataset file to cloud storage using a signed URL. | ||
|
||
Parameters | ||
---------- | ||
serialized_data : bytes | ||
The serialized dataset file. | ||
signed_url : str | ||
The signed URL for the file upload. | ||
""" | ||
response = cls.httpx_client.put( | ||
signed_url, | ||
data=serialized_data, | ||
headers={"content-type": "text/csv"}, | ||
) | ||
cls._validate_response(response, "upload_to_gcs") | ||
|
||
@classmethod | ||
def upload_train_set( | ||
cls, | ||
X_serialized: bytes, | ||
y_serialized: bytes, | ||
num_cells: int, | ||
tabpfn_systems: list[str], | ||
): | ||
""" | ||
Upload a train set to server and return the train set UID if successful. | ||
|
||
Parameters | ||
---------- | ||
X_serialized : bytes | ||
The serialized training input samples. | ||
y_serialized : bytes | ||
The serialized target values. | ||
num_cells : int | ||
The number of cells in the train set. | ||
tabpfn_systems : list[str] | ||
The tabpfn systems to use for the fit. | ||
|
||
Returns | ||
------- | ||
train_set_uid : str | ||
The unique ID of the train set in the server. | ||
""" | ||
if num_cells > CELL_THRESHOLD_LARGE_DATASET: | ||
# Upload train data to GCS | ||
response_x = cls.generate_upload_url("x_train") | ||
response_y = cls.generate_upload_url("y_train") | ||
cls.upload_to_gcs(X_serialized, response_x["signed_url"]) | ||
cls.upload_to_gcs(y_serialized, response_y["signed_url"]) | ||
|
||
# Call fit endpoint | ||
response = cls.httpx_client.post( | ||
url=cls.server_endpoints.fit.path, | ||
params={ | ||
"x_gcs_path": response_x["gcs_path"], | ||
"y_gcs_path": response_y["gcs_path"], | ||
"tabpfn_systems": json.dumps(tabpfn_systems), | ||
}, | ||
) | ||
else: | ||
# Small dataset, so upload directly. | ||
response = cls.httpx_client.post( | ||
url=cls.server_endpoints.fit.path, | ||
files=common_utils.to_httpx_post_file_format( | ||
[ | ||
("x_file", "x_train_filename", X_serialized), | ||
("y_file", "y_train_filename", y_serialized), | ||
] | ||
), | ||
params={"tabpfn_systems": json.dumps(tabpfn_systems)}, | ||
) | ||
|
||
cls._validate_response(response, "fit") | ||
|
||
train_set_uid = response.json()["train_set_uid"] | ||
return train_set_uid | ||
|
||
@classmethod | ||
def fit(cls, X, y, config=None) -> str: | ||
""" | ||
|
@@ -223,20 +327,10 @@ def fit(cls, X, y, config=None) -> str: | |
if cached_dataset_uid: | ||
return cached_dataset_uid | ||
|
||
response = cls.httpx_client.post( | ||
url=cls.server_endpoints.fit.path, | ||
files=common_utils.to_httpx_post_file_format( | ||
[ | ||
("x_file", "x_train_filename", X_serialized), | ||
("y_file", "y_train_filename", y_serialized), | ||
] | ||
), | ||
params={"tabpfn_systems": json.dumps(tabpfn_systems)}, | ||
num_cells = X.shape[0] * (X.shape[1] + 1) | ||
train_set_uid = cls.upload_train_set( | ||
X_serialized, y_serialized, num_cells, tabpfn_systems | ||
) | ||
|
||
cls._validate_response(response, "fit") | ||
|
||
train_set_uid = response.json()["train_set_uid"] | ||
cls.dataset_uid_cache_manager.add_dataset_uid(dataset_hash, train_set_uid) | ||
return train_set_uid | ||
|
||
|
@@ -269,6 +363,8 @@ def predict( | |
|
||
x_test_serialized = common_utils.serialize_to_csv_formatted_bytes(x_test) | ||
|
||
num_cells = x_test.shape[0] * (x_test.shape[1] + 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why the +1? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For y_test, which has shape (X_test.shape[0], 1) |
||
|
||
params = { | ||
"train_set_uid": train_set_uid, | ||
"task": task, | ||
|
@@ -303,7 +399,7 @@ def predict( | |
for attempt in range(max_attempts): | ||
try: | ||
with cls._make_prediction_request( | ||
cached_test_set_uid, x_test_serialized, params | ||
cached_test_set_uid, x_test_serialized, params, num_cells | ||
) as response: | ||
cls._validate_response(response, "predict") | ||
# Handle updates from server | ||
|
@@ -395,9 +491,24 @@ def run_progress(): | |
return result | ||
|
||
@classmethod | ||
def _make_prediction_request(cls, test_set_uid, x_test_serialized, params): | ||
def _make_prediction_request( | ||
cls, | ||
test_set_uid: Optional[str], | ||
x_test_serialized: bytes, | ||
params: dict, | ||
num_cells: int, | ||
): | ||
""" | ||
Helper function to make the prediction request to the server. | ||
Helper function to upload test set if required and make the prediction request to the server. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing parameter documentation. |
||
|
||
Args: | ||
test_set_uid: The unique ID of the train set in the server. | ||
x_test_serialized: The serialized test set. | ||
params: The parameters for the prediction request. | ||
num_cells: The number of cells in the test set. | ||
|
||
Returns: | ||
response: Streaming response from the server. | ||
""" | ||
if test_set_uid: | ||
params = params.copy() | ||
|
@@ -406,14 +517,31 @@ def _make_prediction_request(cls, test_set_uid, x_test_serialized, params): | |
method="post", url=cls.server_endpoints.predict.path, params=params | ||
) | ||
else: | ||
response = cls.httpx_client.stream( | ||
method="post", | ||
url=cls.server_endpoints.predict.path, | ||
params=params, | ||
files=common_utils.to_httpx_post_file_format( | ||
[("x_file", "x_test_filename", x_test_serialized)] | ||
), | ||
) | ||
if num_cells > CELL_THRESHOLD_LARGE_DATASET: | ||
# Upload to GCS. | ||
url_response = cls.generate_upload_url("x_test") | ||
cls.upload_to_gcs(x_test_serialized, url_response["signed_url"]) | ||
# Make prediction request. | ||
response = cls.httpx_client.stream( | ||
method="post", | ||
url=cls.server_endpoints.predict.path, | ||
params={ | ||
**params, | ||
"x_gcs_path": url_response["gcs_path"], | ||
}, | ||
) | ||
else: | ||
# Small dataset, so upload directly. | ||
response = cls.httpx_client.stream( | ||
method="post", | ||
url=cls.server_endpoints.predict.path, | ||
files=common_utils.to_httpx_post_file_format( | ||
[ | ||
("x_file", "x_test_filename", x_test_serialized), | ||
] | ||
), | ||
params=params, | ||
) | ||
return response | ||
|
||
@staticmethod | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new functionality is completely untested.
Add integration test cases that can be run against a real backend to ensure everything works.
Please add instructions how to test this.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these integration tests just be for manual testing or should they be integrated into the automated pipeline? Because always testing against the real backend could be pretty slow.