Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 67 additions & 56 deletions app/lambda/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
import time
import logging

from typedb.driver import (
TypeDB,
from typedb_http_driver import (
TypeDBHttpDriver,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would be great if we could make the driver configurable, to make it easy to switch between different implementations (e.g., via an environment variable configured for the Lambda function).

Shouldn't be a blocker here, but would be a nice addition to simplify testing different drivers.. 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah to be honest I've been wondering why i can't reproduce this in any other way. I wonder if it's something with Docker not doing so well with HTTP streaming or something! I don't quite know yet, but will definitely be digging into this more. If i can get it ironed out i'd take out the HTTP driver to be honest -- not as elegant/easy to use!

Credentials,
DriverOptions,
TransactionType,
TransactionOptions,
driver,
Transaction,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -40,7 +42,7 @@ def wrapper(*args, **kwargs):
# Global driver instance for reuse across Lambda invocations
_global_driver = None
_driver_created_at = None
_driver_timeout = 300 # 5 minutes timeout for driver reuse
_driver_timeout = 600 # 10 minutes timeout for driver reuse (increased for HTTP driver)


def _transaction_options():
Expand All @@ -50,33 +52,54 @@ def _transaction_options():

def _cors_response(status_code, body):
"""Create a response with CORS headers"""
return {
response_body = json.dumps(body) if isinstance(body, (dict, list)) else str(body)

response = {
"statusCode": status_code,
"headers": {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": "Content-Type, Authorization",
},
"body": json.dumps(body) if isinstance(body, (dict, list)) else str(body),
"body": response_body,
}

# Log the response being sent to frontend
logger.info(f"📤 Sending response to frontend:")
logger.info(f" Status: {status_code}")
logger.info(f" Body length: {len(response_body)} chars")
logger.info(f" Body preview: {response_body[:200]}{'...' if len(response_body) > 200 else ''}")

return response


def handler(event, context):
handler_start = time.time()

# Log incoming request
method = event["httpMethod"]
path = event.get("path", "")
body = event.get("body", "")

logger.info(f"📥 Incoming request:")
logger.info(f" Method: {method}")
logger.info(f" Path: {path}")
logger.info(f" Body length: {len(body) if body else 0} chars")
if body:
logger.info(f" Body preview: {body[:200]}{'...' if len(body) > 200 else ''}")

logger.debug(f"Lambda invoked with event: {json.dumps(event, default=str)}")

_create_database_and_schema()

method = event["httpMethod"]
path = event.get("path", "")

# Handle CORS preflight requests
if method == "OPTIONS":
response = _cors_response(200, "")
else:
try:
response = handle_request(event, method, path)
except Exception as e:
logger.error(f"❌ Request failed: {str(e)}")
response = _cors_response(400, {"error": str(e)})

handler_duration = (time.time() - handler_start) * 1000
Expand Down Expand Up @@ -147,7 +170,7 @@ def handle_request(event, method, path):
result = reset_database()
return _cors_response(200, result)

logger.debug(f"No route found for {method} request to {path}")
logger.warning(f"🚫 No route found for {method} request to {path}")
return _cors_response(404, {"error": "Not found"})


Expand All @@ -171,9 +194,7 @@ def create_user(payload: dict):
profile_picture_uri = payload.get("profile_picture_uri", "")

try:
with _driver().transaction(
db_name, TransactionType.WRITE, _transaction_options()
) as tx:
with Transaction(_driver(), db_name, TransactionType.WRITE, _transaction_options()) as tx:
# Create user with username
query = f"insert $u isa user, has user-name '{username}'"

Expand All @@ -193,11 +214,13 @@ def create_user(payload: dict):
tx.query(query).resolve()
tx.commit()

return {
result = {
"message": "User created successfully",
"username": username,
"email": emails,
}
logger.info(f"✅ User created: {username} with {len(emails)} email(s)")
return result

except Exception as e:
error_msg = str(e)
Expand All @@ -217,16 +240,16 @@ def create_group(payload: dict):
group_name = payload["group_name"]

try:
with _driver().transaction(
db_name, TransactionType.WRITE, _transaction_options()
) as tx:
with Transaction(_driver(), db_name, TransactionType.WRITE, _transaction_options()) as tx:
# Create group with group name
query = f"insert $g isa group, has group-name '{group_name}';"

tx.query(query).resolve()
tx.commit()

return {"message": "Group created successfully", "group_name": group_name}
result = {"message": "Group created successfully", "group_name": group_name}
logger.info(f"✅ Group created: {group_name}")
return result

except Exception as e:
error_msg = str(e)
Expand All @@ -240,9 +263,7 @@ def create_group(payload: dict):
def list_users():
logger.debug("Listing users")

with _driver().transaction(
db_name, TransactionType.READ, _transaction_options()
) as tx:
with Transaction(_driver(), db_name, TransactionType.READ, _transaction_options()) as tx:
query_start = time.time()
result = (
tx.query(
Expand All @@ -262,22 +283,22 @@ def list_users():
list_start = time.time()
result = list(result)
list_duration = (time.time() - list_start) * 1000 # noqa
logger.info(f"📋 Listed {len(result)} users")

return result


@log_execution_time("list_groups")
def list_groups():
logger.debug("Listing groups")
with _driver().transaction(
db_name, TransactionType.READ, _transaction_options()
) as tx:
with Transaction(_driver(), db_name, TransactionType.READ, _transaction_options()) as tx:
result = (
tx.query('match $g isa group; fetch { "group_name": $g.group-name};')
.resolve()
.as_concept_documents()
)
result = list(result)
logger.info(f"📋 Listed {len(result)} groups")

return result

Expand All @@ -293,9 +314,7 @@ def add_member_to_group(group_name: str, payload: dict):
if "username" in payload and "group_name" in payload:
raise ValueError("Provide either 'username' or 'group_name', not both")

with _driver().transaction(
db_name, TransactionType.WRITE, _transaction_options()
) as tx:
with Transaction(_driver(), db_name, TransactionType.WRITE, _transaction_options()) as tx:
if "username" in payload:
# Adding a user to the group
username = payload["username"]
Expand Down Expand Up @@ -324,20 +343,20 @@ def add_member_to_group(group_name: str, payload: dict):
tx.query(query).resolve()
tx.commit()

return {
result = {
"message": f"{member_type.capitalize()} added to group successfully",
"group_name": group_name,
"member_type": member_type,
"member_name": member_name,
}
logger.info(f"✅ Added {member_type} '{member_name}' to group '{group_name}'")
return result


@log_execution_time("list_direct_group_members")
def list_direct_group_members(group_name: str):
logger.debug(f"Listing direct group members for {group_name}")
with _driver().transaction(
db_name, TransactionType.READ, _transaction_options()
) as tx:
with Transaction(_driver(), db_name, TransactionType.READ, _transaction_options()) as tx:
result = (
tx.query(
f"match "
Expand All @@ -361,19 +380,17 @@ def list_direct_group_members(group_name: str):
@log_execution_time("list_all_group_members")
def list_all_group_members(group_name: str):
logger.debug(f"Listing all group members for {group_name}")
with _driver().transaction(
db_name, TransactionType.READ, _transaction_options()
) as tx:
with Transaction(_driver(), db_name, TransactionType.READ, _transaction_options()) as tx:
# Use the group-members function from the schema to get all members recursively
result = (
tx.query(
f"match "
f' $group isa group, has group-name "{group_name}"; '
f" let $members in group-members($group); "
f" let $member in group-members($group); "
f" $member isa! $member-type; "
f"fetch {{"
f' "member_type": $member-type, '
f' "member_name": $members.name, '
f' "member_name": $member.name, '
f' "group_name": $group.group-name'
f"}};"
)
Expand All @@ -389,9 +406,7 @@ def list_all_group_members(group_name: str):
def list_principal_groups(principal_name: str, principal_type: str):
"""List direct groups for either a user or group principal"""
logger.debug(f"Listing direct groups for {principal_name} of type {principal_type}")
with _driver().transaction(
db_name, TransactionType.READ, _transaction_options()
) as tx:
with Transaction(_driver(), db_name, TransactionType.READ, _transaction_options()) as tx:
if principal_type == "user":
name_attr = "user-name"
else: # group
Expand Down Expand Up @@ -419,9 +434,7 @@ def list_principal_groups(principal_name: str, principal_type: str):
def list_all_principal_groups(principal_name: str, principal_type: str):
"""List all groups (transitive) for either a user or group principal"""
logger.debug(f"Listing all groups for {principal_name} of type {principal_type}")
with _driver().transaction(
db_name, TransactionType.READ, _transaction_options()
) as tx:
with Transaction(_driver(), db_name, TransactionType.READ, _transaction_options()) as tx:
if principal_type == "user":
name_attr = "user-name"
else: # group
Expand Down Expand Up @@ -452,8 +465,8 @@ def reset_database():

driver = _driver()
# Delete database if it exists
if driver.databases.contains(db_name):
driver.databases.get(db_name).delete()
if driver.database_exists(db_name):
driver.delete_database(db_name)
logger.debug(f"Database '{db_name}' deleted")

_create_database_and_schema()
Expand All @@ -465,14 +478,12 @@ def reset_database():
def _create_database_and_schema():
driver = _driver()
# Check if database exists, create only if it doesn't
if db_name not in [db.name for db in driver.databases.all()]:
driver.databases.create(db_name)
if not driver.database_exists(db_name):
driver.create_database(db_name)

# Check if schema already exists by looking for user type
schema_check_start = time.time()
with _driver().transaction(
db_name, TransactionType.READ, _transaction_options()
) as tx:
with Transaction(_driver(), db_name, TransactionType.READ, _transaction_options()) as tx:
check_start = time.time()
row = list(
tx.query("match entity $t; reduce $count = count;")
Expand All @@ -485,15 +496,15 @@ def _create_database_and_schema():
f"📋 Schema check transaction completed in {schema_check_duration:.2f}ms"
)

if row.get("count").get() == 0:
if row["data"]["count"]["value"] == 0:
logger.debug("Loading schema from file")
schema_load_start = time.time()
with _driver().transaction(
db_name, TransactionType.SCHEMA, _transaction_options()
) as schema_tx:
with Transaction(_driver(), db_name, TransactionType.SCHEMA, _transaction_options()) as schema_tx:
# Load schema from file
file_start = time.time()
with open("schema.tql", "r") as f:
import os
schema_path = os.path.join(os.path.dirname(__file__), "schema.tql")
with open(schema_path, "r") as f:
schema_content = f.read()
file_duration = (time.time() - file_start) * 1000 # noqa

Expand Down Expand Up @@ -526,12 +537,13 @@ def _driver():
logger.debug("♻️ Reusing existing driver")
return _global_driver
elif expired:
logger.debug("⏰ Driver expired, cleaning up")
_cleanup_driver()

# Create new driver or existing one is expired
driver_start = time.time()
try:
_global_driver = TypeDB.driver(
_global_driver = driver(
server_host,
Credentials("admin", "password"),
DriverOptions(is_tls_enabled=False),
Expand Down Expand Up @@ -559,8 +571,7 @@ def _cleanup_driver():
if _global_driver is not None:
logger.debug("🧹 Cleaning up global driver")
try:
# TypeDB drivers don't have explicit close methods in the Python API
# The connections will be cleaned up when the driver object is GC'd
_global_driver.close()
_global_driver = None
_driver_created_at = None
logger.debug("✅ Global driver cleaned up")
Expand Down
3 changes: 1 addition & 2 deletions app/lambda/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
--index-url=https://repo.typedb.com/public/public-snapshot/python/simple/
typedb-driver==0.0.0+bf3f4548451b471f9964c550dfcfe07723059482
requests>=2.25.0
Loading