diff --git a/app/lambda/handler.py b/app/lambda/handler.py index 3ec2860..de8c481 100644 --- a/app/lambda/handler.py +++ b/app/lambda/handler.py @@ -2,12 +2,14 @@ import time import logging -from typedb.driver import ( - TypeDB, +from typedb_http_driver import ( + TypeDBHttpDriver, Credentials, DriverOptions, TransactionType, TransactionOptions, + driver, + Transaction, ) logger = logging.getLogger(__name__) @@ -40,7 +42,7 @@ def wrapper(*args, **kwargs): # Global driver instance for reuse across Lambda invocations _global_driver = None _driver_created_at = None -_driver_timeout = 300 # 5 minutes timeout for driver reuse +_driver_timeout = 600 # 10 minutes timeout for driver reuse (increased for HTTP driver) def _transaction_options(): @@ -50,26 +52,46 @@ def _transaction_options(): def _cors_response(status_code, body): """Create a response with CORS headers""" - return { + response_body = json.dumps(body) if isinstance(body, (dict, list)) else str(body) + + response = { "statusCode": status_code, "headers": { "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS", "Access-Control-Allow-Headers": "Content-Type, Authorization", }, - "body": json.dumps(body) if isinstance(body, (dict, list)) else str(body), + "body": response_body, } + + # Log the response being sent to frontend + logger.info(f"๐ค Sending response to frontend:") + logger.info(f" Status: {status_code}") + logger.info(f" Body length: {len(response_body)} chars") + logger.info(f" Body preview: {response_body[:200]}{'...' if len(response_body) > 200 else ''}") + + return response def handler(event, context): handler_start = time.time() + + # Log incoming request + method = event["httpMethod"] + path = event.get("path", "") + body = event.get("body", "") + + logger.info(f"๐ฅ Incoming request:") + logger.info(f" Method: {method}") + logger.info(f" Path: {path}") + logger.info(f" Body length: {len(body) if body else 0} chars") + if body: + logger.info(f" Body preview: {body[:200]}{'...' if len(body) > 200 else ''}") + logger.debug(f"Lambda invoked with event: {json.dumps(event, default=str)}") _create_database_and_schema() - method = event["httpMethod"] - path = event.get("path", "") - # Handle CORS preflight requests if method == "OPTIONS": response = _cors_response(200, "") @@ -77,6 +99,7 @@ def handler(event, context): try: response = handle_request(event, method, path) except Exception as e: + logger.error(f"โ Request failed: {str(e)}") response = _cors_response(400, {"error": str(e)}) handler_duration = (time.time() - handler_start) * 1000 @@ -147,7 +170,7 @@ def handle_request(event, method, path): result = reset_database() return _cors_response(200, result) - logger.debug(f"No route found for {method} request to {path}") + logger.warning(f"๐ซ No route found for {method} request to {path}") return _cors_response(404, {"error": "Not found"}) @@ -171,9 +194,7 @@ def create_user(payload: dict): profile_picture_uri = payload.get("profile_picture_uri", "") try: - with _driver().transaction( - db_name, TransactionType.WRITE, _transaction_options() - ) as tx: + with Transaction(_driver(), db_name, TransactionType.WRITE, _transaction_options()) as tx: # Create user with username query = f"insert $u isa user, has user-name '{username}'" @@ -193,11 +214,13 @@ def create_user(payload: dict): tx.query(query).resolve() tx.commit() - return { + result = { "message": "User created successfully", "username": username, "email": emails, } + logger.info(f"โ User created: {username} with {len(emails)} email(s)") + return result except Exception as e: error_msg = str(e) @@ -217,16 +240,16 @@ def create_group(payload: dict): group_name = payload["group_name"] try: - with _driver().transaction( - db_name, TransactionType.WRITE, _transaction_options() - ) as tx: + with Transaction(_driver(), db_name, TransactionType.WRITE, _transaction_options()) as tx: # Create group with group name query = f"insert $g isa group, has group-name '{group_name}';" tx.query(query).resolve() tx.commit() - return {"message": "Group created successfully", "group_name": group_name} + result = {"message": "Group created successfully", "group_name": group_name} + logger.info(f"โ Group created: {group_name}") + return result except Exception as e: error_msg = str(e) @@ -240,9 +263,7 @@ def create_group(payload: dict): def list_users(): logger.debug("Listing users") - with _driver().transaction( - db_name, TransactionType.READ, _transaction_options() - ) as tx: + with Transaction(_driver(), db_name, TransactionType.READ, _transaction_options()) as tx: query_start = time.time() result = ( tx.query( @@ -262,6 +283,7 @@ def list_users(): list_start = time.time() result = list(result) list_duration = (time.time() - list_start) * 1000 # noqa + logger.info(f"๐ Listed {len(result)} users") return result @@ -269,15 +291,14 @@ def list_users(): @log_execution_time("list_groups") def list_groups(): logger.debug("Listing groups") - with _driver().transaction( - db_name, TransactionType.READ, _transaction_options() - ) as tx: + with Transaction(_driver(), db_name, TransactionType.READ, _transaction_options()) as tx: result = ( tx.query('match $g isa group; fetch { "group_name": $g.group-name};') .resolve() .as_concept_documents() ) result = list(result) + logger.info(f"๐ Listed {len(result)} groups") return result @@ -293,9 +314,7 @@ def add_member_to_group(group_name: str, payload: dict): if "username" in payload and "group_name" in payload: raise ValueError("Provide either 'username' or 'group_name', not both") - with _driver().transaction( - db_name, TransactionType.WRITE, _transaction_options() - ) as tx: + with Transaction(_driver(), db_name, TransactionType.WRITE, _transaction_options()) as tx: if "username" in payload: # Adding a user to the group username = payload["username"] @@ -324,20 +343,20 @@ def add_member_to_group(group_name: str, payload: dict): tx.query(query).resolve() tx.commit() - return { + result = { "message": f"{member_type.capitalize()} added to group successfully", "group_name": group_name, "member_type": member_type, "member_name": member_name, } + logger.info(f"โ Added {member_type} '{member_name}' to group '{group_name}'") + return result @log_execution_time("list_direct_group_members") def list_direct_group_members(group_name: str): logger.debug(f"Listing direct group members for {group_name}") - with _driver().transaction( - db_name, TransactionType.READ, _transaction_options() - ) as tx: + with Transaction(_driver(), db_name, TransactionType.READ, _transaction_options()) as tx: result = ( tx.query( f"match " @@ -361,19 +380,17 @@ def list_direct_group_members(group_name: str): @log_execution_time("list_all_group_members") def list_all_group_members(group_name: str): logger.debug(f"Listing all group members for {group_name}") - with _driver().transaction( - db_name, TransactionType.READ, _transaction_options() - ) as tx: + with Transaction(_driver(), db_name, TransactionType.READ, _transaction_options()) as tx: # Use the group-members function from the schema to get all members recursively result = ( tx.query( f"match " f' $group isa group, has group-name "{group_name}"; ' - f" let $members in group-members($group); " + f" let $member in group-members($group); " f" $member isa! $member-type; " f"fetch {{" f' "member_type": $member-type, ' - f' "member_name": $members.name, ' + f' "member_name": $member.name, ' f' "group_name": $group.group-name' f"}};" ) @@ -389,9 +406,7 @@ def list_all_group_members(group_name: str): def list_principal_groups(principal_name: str, principal_type: str): """List direct groups for either a user or group principal""" logger.debug(f"Listing direct groups for {principal_name} of type {principal_type}") - with _driver().transaction( - db_name, TransactionType.READ, _transaction_options() - ) as tx: + with Transaction(_driver(), db_name, TransactionType.READ, _transaction_options()) as tx: if principal_type == "user": name_attr = "user-name" else: # group @@ -419,9 +434,7 @@ def list_principal_groups(principal_name: str, principal_type: str): def list_all_principal_groups(principal_name: str, principal_type: str): """List all groups (transitive) for either a user or group principal""" logger.debug(f"Listing all groups for {principal_name} of type {principal_type}") - with _driver().transaction( - db_name, TransactionType.READ, _transaction_options() - ) as tx: + with Transaction(_driver(), db_name, TransactionType.READ, _transaction_options()) as tx: if principal_type == "user": name_attr = "user-name" else: # group @@ -452,8 +465,8 @@ def reset_database(): driver = _driver() # Delete database if it exists - if driver.databases.contains(db_name): - driver.databases.get(db_name).delete() + if driver.database_exists(db_name): + driver.delete_database(db_name) logger.debug(f"Database '{db_name}' deleted") _create_database_and_schema() @@ -465,14 +478,12 @@ def reset_database(): def _create_database_and_schema(): driver = _driver() # Check if database exists, create only if it doesn't - if db_name not in [db.name for db in driver.databases.all()]: - driver.databases.create(db_name) + if not driver.database_exists(db_name): + driver.create_database(db_name) # Check if schema already exists by looking for user type schema_check_start = time.time() - with _driver().transaction( - db_name, TransactionType.READ, _transaction_options() - ) as tx: + with Transaction(_driver(), db_name, TransactionType.READ, _transaction_options()) as tx: check_start = time.time() row = list( tx.query("match entity $t; reduce $count = count;") @@ -485,15 +496,15 @@ def _create_database_and_schema(): f"๐ Schema check transaction completed in {schema_check_duration:.2f}ms" ) - if row.get("count").get() == 0: + if row["data"]["count"]["value"] == 0: logger.debug("Loading schema from file") schema_load_start = time.time() - with _driver().transaction( - db_name, TransactionType.SCHEMA, _transaction_options() - ) as schema_tx: + with Transaction(_driver(), db_name, TransactionType.SCHEMA, _transaction_options()) as schema_tx: # Load schema from file file_start = time.time() - with open("schema.tql", "r") as f: + import os + schema_path = os.path.join(os.path.dirname(__file__), "schema.tql") + with open(schema_path, "r") as f: schema_content = f.read() file_duration = (time.time() - file_start) * 1000 # noqa @@ -526,12 +537,13 @@ def _driver(): logger.debug("โป๏ธ Reusing existing driver") return _global_driver elif expired: + logger.debug("โฐ Driver expired, cleaning up") _cleanup_driver() # Create new driver or existing one is expired driver_start = time.time() try: - _global_driver = TypeDB.driver( + _global_driver = driver( server_host, Credentials("admin", "password"), DriverOptions(is_tls_enabled=False), @@ -559,8 +571,7 @@ def _cleanup_driver(): if _global_driver is not None: logger.debug("๐งน Cleaning up global driver") try: - # TypeDB drivers don't have explicit close methods in the Python API - # The connections will be cleaned up when the driver object is GC'd + _global_driver.close() _global_driver = None _driver_created_at = None logger.debug("โ Global driver cleaned up") diff --git a/app/lambda/requirements.txt b/app/lambda/requirements.txt index 5d27ac2..eed6988 100644 --- a/app/lambda/requirements.txt +++ b/app/lambda/requirements.txt @@ -1,2 +1 @@ ---index-url=https://repo.typedb.com/public/public-snapshot/python/simple/ -typedb-driver==0.0.0+bf3f4548451b471f9964c550dfcfe07723059482 \ No newline at end of file +requests>=2.25.0 \ No newline at end of file diff --git a/app/lambda/typedb_http_driver.py b/app/lambda/typedb_http_driver.py new file mode 100644 index 0000000..d95f2ef --- /dev/null +++ b/app/lambda/typedb_http_driver.py @@ -0,0 +1,355 @@ +""" +HTTP-based TypeDB driver for Python +Ported from TypeScript implementation for lambda usage +""" + +import json +import logging +import time +from typing import Dict, List, Optional, Union, Any, Tuple +from dataclasses import dataclass +from contextlib import contextmanager +import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +logger = logging.getLogger(__name__) + + +@dataclass +class DriverParams: + """Driver connection parameters""" + username: str + password: str + addresses: List[str] + + +@dataclass +class TransactionOptions: + """Transaction configuration options""" + schema_lock_acquire_timeout_millis: Optional[int] = None + transaction_timeout_millis: Optional[int] = None + + +@dataclass +class QueryOptions: + """Query execution options""" + include_instance_types: Optional[bool] = None + answer_count_limit: Optional[int] = None + + +class TypeDBHttpError(Exception): + """Base exception for TypeDB HTTP driver errors""" + def __init__(self, message: str, code: Optional[str] = None, status_code: Optional[int] = None): + super().__init__(message) + self.code = code + self.status_code = status_code + + +class TypeDBHttpDriver: + """HTTP-based TypeDB driver""" + + def __init__(self, params: DriverParams): + self.params = params + self.token: Optional[str] = None + self.base_url = f"http://{params.addresses[0]}" + + # Setup HTTP session with retry strategy - optimized for Lambda + self.session = requests.Session() + + # Lighter retry strategy for Lambda (fewer retries) + retry_strategy = Retry( + total=2, # Reduced from 3 + backoff_factor=0.5, # Faster backoff + status_forcelist=[429, 500, 502, 503, 504], + ) + adapter = HTTPAdapter(max_retries=retry_strategy) + self.session.mount("http://", adapter) + self.session.mount("https://", adapter) + + # Pre-authenticate to reduce cold start time + self._get_token() + + def _get_token(self) -> str: + """Get authentication token, refreshing if necessary""" + if self.token: + return self.token + + return self._refresh_token() + + def _refresh_token(self) -> str: + """Refresh authentication token""" + url = f"{self.base_url}/v1/signin" + payload = { + "username": self.params.username, + "password": self.params.password + } + + try: + response = self.session.post(url, json=payload, timeout=10) + response.raise_for_status() + + data = response.json() + self.token = data["token"] + return self.token + + except requests.exceptions.RequestException as e: + raise TypeDBHttpError(f"Failed to authenticate: {str(e)}") + + def _make_request(self, method: str, path: str, data: Optional[Dict] = None, + params: Optional[Dict] = None) -> requests.Response: + """Make authenticated HTTP request with automatic token refresh""" + url = f"{self.base_url}{path}" + headers = { + "Authorization": f"Bearer {self._get_token()}", + "Content-Type": "application/json" + } + + try: + response = self.session.request( + method, url, json=data, params=params, headers=headers, timeout=30 + ) + + # Handle token expiration + if response.status_code == 401: + self.token = None # Clear expired token + headers["Authorization"] = f"Bearer {self._get_token()}" + response = self.session.request( + method, url, json=data, params=params, headers=headers, timeout=30 + ) + + return response + + except requests.exceptions.RequestException as e: + raise TypeDBHttpError(f"HTTP request failed: {str(e)}") + + def _handle_response(self, response: requests.Response) -> Dict[str, Any]: + """Handle HTTP response and extract JSON data""" + try: + if response.status_code >= 400: + error_data = response.json() if response.content else {} + error_msg = error_data.get("message", f"HTTP {response.status_code}") + error_code = error_data.get("code", None) + raise TypeDBHttpError(error_msg, code=error_code, status_code=response.status_code) + + # Handle empty responses + if not response.content: + return {} + + return response.json() + + except json.JSONDecodeError as e: + raise TypeDBHttpError(f"Invalid JSON response: {str(e)}") + + # Database operations + def get_databases(self) -> List[Dict[str, str]]: + """Get list of all databases""" + response = self._make_request("GET", "/v1/databases") + data = self._handle_response(response) + return data.get("databases", []) + + def create_database(self, name: str) -> None: + """Create a new database""" + response = self._make_request("POST", f"/v1/databases/{name}", {}) + self._handle_response(response) + + def delete_database(self, name: str) -> None: + """Delete a database""" + response = self._make_request("DELETE", f"/v1/databases/{name}") + self._handle_response(response) + + def database_exists(self, name: str) -> bool: + """Check if a database exists""" + databases = self.get_databases() + return any(db["name"] == name for db in databases) + + # Transaction operations + def open_transaction(self, database_name: str, transaction_type: str, + options: Optional[TransactionOptions] = None) -> str: + """Open a transaction and return transaction ID""" + payload = { + "databaseName": database_name, + "transactionType": transaction_type + } + + if options: + transaction_options = {} + if options.schema_lock_acquire_timeout_millis is not None: + transaction_options["schemaLockAcquireTimeoutMillis"] = options.schema_lock_acquire_timeout_millis + if options.transaction_timeout_millis is not None: + transaction_options["transactionTimeoutMillis"] = options.transaction_timeout_millis + payload["transactionOptions"] = transaction_options + + response = self._make_request("POST", "/v1/transactions/open", payload) + data = self._handle_response(response) + return data["transactionId"] + + def commit_transaction(self, transaction_id: str) -> None: + """Commit a transaction""" + response = self._make_request("POST", f"/v1/transactions/{transaction_id}/commit", {}) + self._handle_response(response) + + def close_transaction(self, transaction_id: str) -> None: + """Close a transaction""" + response = self._make_request("POST", f"/v1/transactions/{transaction_id}/close", {}) + self._handle_response(response) + + def rollback_transaction(self, transaction_id: str) -> None: + """Rollback a transaction""" + response = self._make_request("POST", f"/v1/transactions/{transaction_id}/rollback", {}) + self._handle_response(response) + + # Query operations + def query(self, transaction_id: str, query: str, + options: Optional[QueryOptions] = None) -> Dict[str, Any]: + """Execute a query in a transaction""" + payload = {"query": query} + + if options: + query_options = {} + if options.include_instance_types is not None: + query_options["includeInstanceTypes"] = options.include_instance_types + if options.answer_count_limit is not None: + query_options["answerCountLimit"] = options.answer_count_limit + payload["queryOptions"] = query_options + + response = self._make_request("POST", f"/v1/transactions/{transaction_id}/query", payload) + return self._handle_response(response) + + def one_shot_query(self, query: str, commit: bool, database_name: str, + transaction_type: str, transaction_options: Optional[TransactionOptions] = None, + query_options: Optional[QueryOptions] = None) -> Dict[str, Any]: + """Execute a one-shot query""" + payload = { + "query": query, + "commit": commit, + "databaseName": database_name, + "transactionType": transaction_type + } + + if transaction_options: + tx_opts = {} + if transaction_options.schema_lock_acquire_timeout_millis is not None: + tx_opts["schemaLockAcquireTimeoutMillis"] = transaction_options.schema_lock_acquire_timeout_millis + if transaction_options.transaction_timeout_millis is not None: + tx_opts["transactionTimeoutMillis"] = transaction_options.transaction_timeout_millis + payload["transactionOptions"] = tx_opts + + if query_options: + q_opts = {} + if query_options.include_instance_types is not None: + q_opts["includeInstanceTypes"] = query_options.include_instance_types + if query_options.answer_count_limit is not None: + q_opts["answerCountLimit"] = query_options.answer_count_limit + payload["queryOptions"] = q_opts + + response = self._make_request("POST", "/v1/query", payload) + return self._handle_response(response) + + def close(self): + """Close the driver and cleanup resources""" + if hasattr(self, 'session'): + self.session.close() + + +class Transaction: + """Transaction context manager""" + + def __init__(self, driver: TypeDBHttpDriver, database_name: str, + transaction_type: str, options: Optional[TransactionOptions] = None): + self.driver = driver + self.database_name = database_name + self.transaction_type = transaction_type + self.options = options + self.transaction_id: Optional[str] = None + + def __enter__(self): + self.transaction_id = self.driver.open_transaction( + self.database_name, self.transaction_type, self.options + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.transaction_id: + self.driver.close_transaction(self.transaction_id) + + def query(self, query: str, options: Optional[QueryOptions] = None) -> 'QueryResult': + """Execute a query in this transaction""" + if not self.transaction_id: + raise TypeDBHttpError("Transaction not open") + + response = self.driver.query(self.transaction_id, query, options) + return QueryResult(response) + + def commit(self): + """Commit this transaction""" + if not self.transaction_id: + raise TypeDBHttpError("Transaction not open") + + self.driver.commit_transaction(self.transaction_id) + + def rollback(self): + """Rollback this transaction""" + if not self.transaction_id: + raise TypeDBHttpError("Transaction not open") + + self.driver.rollback_transaction(self.transaction_id) + + +class QueryResult: + """Query result wrapper with convenience methods""" + + def __init__(self, response: Dict[str, Any]): + self.response = response + self.answer_type = response.get("answerType", "ok") + self.query_type = response.get("queryType", "read") + self.comment = response.get("comment") + self.query = response.get("query") + self.answers = response.get("answers", []) + + def resolve(self) -> 'QueryResult': + """Resolve the query result (for compatibility with GRPC driver)""" + return self + + def as_concept_documents(self) -> List[Dict[str, Any]]: + """Get results as concept documents""" + if self.answer_type != "conceptDocuments": + raise TypeDBHttpError(f"Cannot get concept documents from {self.answer_type} response") + return self.answers + + def as_concept_rows(self) -> List[Dict[str, Any]]: + """Get results as concept rows""" + if self.answer_type != "conceptRows": + raise TypeDBHttpError(f"Cannot get concept rows from {self.answer_type} response") + return self.answers + + +# Convenience functions for compatibility with GRPC driver +def driver(address: str, credentials: 'Credentials', options: Optional[Dict] = None) -> TypeDBHttpDriver: + """Create a TypeDB HTTP driver (compatibility function)""" + params = DriverParams( + username=credentials.username, + password=credentials.password, + addresses=[address] + ) + return TypeDBHttpDriver(params) + + +# Transaction type constants for compatibility +class TransactionType: + READ = "read" + WRITE = "write" + SCHEMA = "schema" + + +# Driver options for compatibility +class DriverOptions: + def __init__(self, is_tls_enabled: bool = False): + self.is_tls_enabled = is_tls_enabled + + +# Credentials for compatibility +class Credentials: + def __init__(self, username: str, password: str): + self.username = username + self.password = password diff --git a/app/web/index.html b/app/web/index.html index 92ea014..7b629e1 100644 --- a/app/web/index.html +++ b/app/web/index.html @@ -56,6 +56,12 @@