diff --git a/tabpfn_client/client.py b/tabpfn_client/client.py index f29443e..0bb57ab 100644 --- a/tabpfn_client/client.py +++ b/tabpfn_client/client.py @@ -22,7 +22,7 @@ from tqdm import tqdm from tabpfn_client.tabpfn_common_utils import utils as common_utils -from tabpfn_client.constants import CACHE_DIR +from tabpfn_client.constants import CACHE_DIR, LARGE_DATASET_THRESHOLD from tabpfn_client.browser_auth import BrowserAuthHandler from tabpfn_client.tabpfn_common_utils.utils import Singleton @@ -223,16 +223,57 @@ def fit(cls, X, y, config=None) -> str: if cached_dataset_uid: return cached_dataset_uid - response = cls.httpx_client.post( - url=cls.server_endpoints.fit.path, - files=common_utils.to_httpx_post_file_format( - [ - ("x_file", "x_train_filename", X_serialized), - ("y_file", "y_train_filename", y_serialized), - ] - ), - params={"tabpfn_systems": json.dumps(tabpfn_systems)}, - ) + num_cells = X.shape[0] * (X.shape[1] + 1) + if num_cells > LARGE_DATASET_THRESHOLD: + # Generate Upload URLs + response_x = cls.httpx_client.post( + url=cls.server_endpoints.generate_upload_url.path, + json={"file_type": "x_train", "file_name": "x_train_filename"}, + ) + cls._validate_response(response_x, "fit") + response_x = response_x.json() + + response_y = cls.httpx_client.post( + url=cls.server_endpoints.generate_upload_url.path, + json={"file_type": "y_train", "file_name": "y_train_filename"}, + ) + cls._validate_response(response_y, "fit") + response_y = response_y.json() + + # Upload train and test set to GCS + response = cls.httpx_client.put( + response_x["signed_url"], + data=X_serialized, + headers={"content-type": "text/csv"}, + ) + cls._validate_response(response, "fit") + response = cls.httpx_client.put( + response_y["signed_url"], + data=y_serialized, + headers={"content-type": "text/csv"}, + ) + cls._validate_response(response, "fit") + # Call fit endpoint + response = cls.httpx_client.post( + url=cls.server_endpoints.fit.path, + params={ + "x_gcs_path": response_x["gcs_path"], + "y_gcs_path": response_y["gcs_path"], + "tabpfn_systems": json.dumps(tabpfn_systems), + }, + ) + else: + # Small dataset, so upload directly. + response = cls.httpx_client.post( + url=cls.server_endpoints.fit.path, + files=common_utils.to_httpx_post_file_format( + [ + ("x_file", "x_train_filename", X_serialized), + ("y_file", "y_train_filename", y_serialized), + ] + ), + params={"tabpfn_systems": json.dumps(tabpfn_systems)}, + ) cls._validate_response(response, "fit") @@ -269,6 +310,8 @@ def predict( x_test_serialized = common_utils.serialize_to_csv_formatted_bytes(x_test) + num_cells = x_test.shape[0] * (x_test.shape[1] + 1) + params = { "train_set_uid": train_set_uid, "task": task, @@ -303,7 +346,7 @@ def predict( for attempt in range(max_attempts): try: with cls._make_prediction_request( - cached_test_set_uid, x_test_serialized, params + cached_test_set_uid, x_test_serialized, params, num_cells ) as response: cls._validate_response(response, "predict") # Handle updates from server @@ -395,9 +438,11 @@ def run_progress(): return result @classmethod - def _make_prediction_request(cls, test_set_uid, x_test_serialized, params): + def _make_prediction_request( + cls, test_set_uid, x_test_serialized, params, num_cells + ): """ - Helper function to make the prediction request to the server. + Helper function to upload test set if required and make the prediction request to the server. """ if test_set_uid: params = params.copy() @@ -406,14 +451,45 @@ def _make_prediction_request(cls, test_set_uid, x_test_serialized, params): method="post", url=cls.server_endpoints.predict.path, params=params ) else: - response = cls.httpx_client.stream( - method="post", - url=cls.server_endpoints.predict.path, - params=params, - files=common_utils.to_httpx_post_file_format( - [("x_file", "x_test_filename", x_test_serialized)] - ), - ) + if num_cells > LARGE_DATASET_THRESHOLD: + # Generate upload URL + url_response = cls.httpx_client.post( + url=cls.server_endpoints.generate_upload_url.path, + json={ + "file_type": "x_test", + "file_name": "x_test_filename", + }, + ) + cls._validate_response(url_response, "predict") + url_response = url_response.json() + # Upload test set to GCS + response = cls.httpx_client.put( + url_response["signed_url"], + data=x_test_serialized, + headers={"content-type": "text/csv"}, + ) + cls._validate_response(response, "predict") + # Make prediction request + response = cls.httpx_client.stream( + method="post", + url=cls.server_endpoints.predict.path, + params={ + **params, + "x_gcs_path": url_response["gcs_path"], + }, + ) + else: + # Small dataset, so upload directly. + response = cls.httpx_client.stream( + method="post", + url=cls.server_endpoints.predict.path, + files=common_utils.to_httpx_post_file_format( + [ + ("x_file", "x_test_filename", x_test_serialized), + ] + ), + params=params, + ) return response @staticmethod diff --git a/tabpfn_client/constants.py b/tabpfn_client/constants.py index 1c5e0f7..8b945f9 100644 --- a/tabpfn_client/constants.py +++ b/tabpfn_client/constants.py @@ -4,3 +4,5 @@ from pathlib import Path CACHE_DIR = Path(__file__).parent.resolve() / ".tabpfn" + +LARGE_DATASET_THRESHOLD = 500000 diff --git a/tabpfn_client/server_config.yaml b/tabpfn_client/server_config.yaml index 37ddb13..03c974b 100644 --- a/tabpfn_client/server_config.yaml +++ b/tabpfn_client/server_config.yaml @@ -113,4 +113,9 @@ endpoints: path: "/get_api_usage/" methods: [ "POST" ] description: "Get prediction hits data for a given user" + + generate_upload_url: + path: "/generate-upload-url/" + methods: ["POST"] + description: "Generate a signed URL for direct dataset file upload to cloud storage"