diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index f29443e..4569358 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -22,7 +22,7 @@ from tqdm import tqdm from tabpfn_client.tabpfn_common_utils import utils as common_utils -from tabpfn_client.constants import CACHE_DIR +from tabpfn_client.constants import CACHE_DIR, CELL_THRESHOLD_LARGE_DATASET from tabpfn_client.browser_auth import BrowserAuthHandler from tabpfn_client.tabpfn_common_utils.utils import Singleton @@ -182,6 +182,110 @@ def reset_authorization(cls): cls._access_token = None cls.httpx_client.headers.pop("Authorization", None) + @classmethod + def generate_upload_url( + cls, file_type: Literal["x_train", "y_train", "x_test"] + ) -> dict: + """ + Generate a signed URL for direct dataset file upload to cloud storage. + + Parameters + ---------- + file_type : Literal["x_train", "y_train", "x_test"] + The type of the file to upload. + + Returns + ------- + dict + A dictionary containing the signed URL and GCS path. + """ + response_x = cls.httpx_client.post( + url=cls.server_endpoints.generate_upload_url.path, + params={"file_type": file_type}, + ) + cls._validate_response(response_x, "generate_upload_url") + return response_x.json() + + @classmethod + def upload_to_gcs(cls, serialized_data: bytes, signed_url: str): + """ + Upload a serialized dataset file to cloud storage using a signed URL. + + Parameters + ---------- + serialized_data : bytes + The serialized dataset file. + signed_url : str + The signed URL for the file upload. + """ + response = cls.httpx_client.put( + signed_url, + data=serialized_data, + headers={"content-type": "text/csv"}, + ) + cls._validate_response(response, "upload_to_gcs") + + @classmethod + def upload_train_set( + cls, + X_serialized: bytes, + y_serialized: bytes, + num_cells: int, + tabpfn_systems: list[str], + ): + """ + Upload a train set to server and return the train set UID if successful. + + Parameters + ---------- + X_serialized : bytes + The serialized training input samples. + y_serialized : bytes + The serialized target values. + num_cells : int + The number of cells in the train set. + tabpfn_systems : list[str] + The tabpfn systems to use for the fit. + + Returns + ------- + train_set_uid : str + The unique ID of the train set in the server. + """ + if num_cells > CELL_THRESHOLD_LARGE_DATASET: + # Upload train data to GCS + response_x = cls.generate_upload_url("x_train") + response_y = cls.generate_upload_url("y_train") + cls.upload_to_gcs(X_serialized, response_x["signed_url"]) + cls.upload_to_gcs(y_serialized, response_y["signed_url"]) + + # Call fit endpoint + response = cls.httpx_client.post( + url=cls.server_endpoints.fit.path, + params={ + "x_gcs_path": response_x["gcs_path"], + "y_gcs_path": response_y["gcs_path"], + "tabpfn_systems": json.dumps(tabpfn_systems), + }, + ) + else: + # Small dataset, so upload directly. + response = cls.httpx_client.post( + url=cls.server_endpoints.fit.path, + files=common_utils.to_httpx_post_file_format( + [ + ("x_file", "x_train_filename", X_serialized), + ("y_file", "y_train_filename", y_serialized), + ] + ), + params={"tabpfn_systems": json.dumps(tabpfn_systems)}, + ) + + cls._validate_response(response, "fit") + + train_set_uid = response.json()["train_set_uid"] + return train_set_uid + @classmethod def fit(cls, X, y, config=None) -> str: """ @@ -223,20 +327,10 @@ def fit(cls, X, y, config=None) -> str: if cached_dataset_uid: return cached_dataset_uid - response = cls.httpx_client.post( - url=cls.server_endpoints.fit.path, - files=common_utils.to_httpx_post_file_format( - [ - ("x_file", "x_train_filename", X_serialized), - ("y_file", "y_train_filename", y_serialized), - ] - ), - params={"tabpfn_systems": json.dumps(tabpfn_systems)}, + num_cells = X.shape[0] * (X.shape[1] + 1) + train_set_uid = cls.upload_train_set( + X_serialized, y_serialized, num_cells, tabpfn_systems ) - - cls._validate_response(response, "fit") - - train_set_uid = response.json()["train_set_uid"] cls.dataset_uid_cache_manager.add_dataset_uid(dataset_hash, train_set_uid) return train_set_uid @@ -269,6 +363,8 @@ def predict( x_test_serialized = common_utils.serialize_to_csv_formatted_bytes(x_test) + num_cells = x_test.shape[0] * (x_test.shape[1] + 1) + params = { "train_set_uid": train_set_uid, "task": task, @@ -303,7 +399,7 @@ def predict( for attempt in range(max_attempts): try: with cls._make_prediction_request( - cached_test_set_uid, x_test_serialized, params + cached_test_set_uid, x_test_serialized, params, num_cells ) as response: cls._validate_response(response, "predict") # Handle updates from server @@ -395,9 +491,24 @@ def run_progress(): return result @classmethod - def _make_prediction_request(cls, test_set_uid, x_test_serialized, params): + def _make_prediction_request( + cls, + test_set_uid: Optional[str], + x_test_serialized: bytes, + params: dict, + num_cells: int, + ): """ - Helper function to make the prediction request to the server. + Helper function to upload test set if required and make the prediction request to the server. + + Args: + test_set_uid: The unique ID of the train set in the server. + x_test_serialized: The serialized test set. + params: The parameters for the prediction request. + num_cells: The number of cells in the test set. + + Returns: + response: Streaming response from the server. """ if test_set_uid: params = params.copy() @@ -406,14 +517,31 @@ def _make_prediction_request(cls, test_set_uid, x_test_serialized, params): method="post", url=cls.server_endpoints.predict.path, params=params ) else: - response = cls.httpx_client.stream( - method="post", - url=cls.server_endpoints.predict.path, - params=params, - files=common_utils.to_httpx_post_file_format( - [("x_file", "x_test_filename", x_test_serialized)] - ), - ) + if num_cells > CELL_THRESHOLD_LARGE_DATASET: + # Upload to GCS. + url_response = cls.generate_upload_url("x_test") + cls.upload_to_gcs(x_test_serialized, url_response["signed_url"]) + # Make prediction request. + response = cls.httpx_client.stream( + method="post", + url=cls.server_endpoints.predict.path, + params={ + **params, + "x_gcs_path": url_response["gcs_path"], + }, + ) + else: + # Small dataset, so upload directly. + response = cls.httpx_client.stream( + method="post", + url=cls.server_endpoints.predict.path, + files=common_utils.to_httpx_post_file_format( + [ + ("x_file", "x_test_filename", x_test_serialized), + ] + ), + params=params, + ) return response @staticmethod diff --git a/tabpfn_client/constants.py b/tabpfn_client/constants.py index 1c5e0f7..30b01a7 100644 --- a/tabpfn_client/constants.py +++ b/tabpfn_client/constants.py @@ -4,3 +4,5 @@ from pathlib import Path CACHE_DIR = Path(__file__).parent.resolve() / ".tabpfn" + +CELL_THRESHOLD_LARGE_DATASET = 500000 diff --git a/tabpfn_client/server_config.yaml b/tabpfn_client/server_config.yaml index 37ddb13..03c974b 100644 --- a/tabpfn_client/server_config.yaml +++ b/tabpfn_client/server_config.yaml @@ -113,4 +113,9 @@ endpoints: path: "/get_api_usage/" methods: [ "POST" ] description: "Get prediction hits data for a given user" + + generate_upload_url: + path: "/generate-upload-url/" + methods: ["POST"] + description: "Generate a signed URL for direct dataset file upload to cloud storage"