Skip to content

Commit e2dea7f

Browse files
committed
Add database persistence layer and conversation tracking system
# Background / Context When the user queries lightspeed-stack, they get a conversation ID in the response that they can later use to retrieve the conversation history and continue the conversation. # Issue / Requirement / Reason for change We've put the burden of remembering the conversation ID on the user, which is not ideal. The user (e.g. a UI) has to store the conversation ID somewhere, e.g. browser local storage, but this is not ideal because it can easily get lost. It also complicates the UI implementation. Another issue is that a user could read another user's conversation if they only knew the ID, which is hard, but not impossible. It's not immediately obvious that the conversation ID is a secret that should not be shared with others. # Solution / Feature Overview Add a database persistence layer to lightspeed-stack to store user conversation IDs alongside the user ID, and a bit of metadata about the conversation. Allow users to retrieve their conversations by calling a new endpoint `/conversations` (without an ID. Calling with an ID is an existing endpoint that lists the message history for a conversation), which will return a list of conversations associated with the user, along with metadata like creation time, last message time, and message count. Verify that the user owns the conversation when querying or streaming a conversation or asking for its message history, and automatically associate the conversation with the user in the database when the user queries or streams a conversation. # Implementation Details **Database Layer (src/app/database.py):** - SQLAlchemy - Support for both SQLite and Postgres databases - PostgreSQL custom schema (table namespacing) support with automatic schema creation. This is useful because llama-stack itself needs a database for persisting conversation history, and we want to avoid conflicts with that when users use the same Postgres database as llama-stack. - Automatically enable SQL statement tracing when debug logging is enabled **Models (src/models/database):** - Base model class for shared SQLAlchemy configuration - UserConversation model with indexed user_id fields for efficient querying, automatic timestamp tracking (created_at, last_message_at) and message count tracking **Configuration Updates:** - Extended the configuration system to support database connection settings **API:** - Conversations endpoint without an ID now returns a list of conversations for the authenticated user, along with metadata - Query/streaming endpoint automatically associates conversations with the user in the database, and verifies ownership of conversations **Dependencies:** - Added SQLAlchemy dependency **Tests:** Some unit tests were added # Future Work This change does not include migrations, so any future changes to the databases require writing custom backwards compatibility code. We should probably add Alembic or something simpler in the near future.
1 parent c8175a8 commit e2dea7f

21 files changed

+984
-41
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dependencies = [
3333
"starlette>=0.47.1",
3434
"aiohttp>=3.12.14",
3535
"authlib>=1.6.0",
36+
"sqlalchemy>=2.0.42",
3637
]
3738

3839
[tool.pyright]

src/app/database.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
"""Database engine management."""
2+
3+
from pathlib import Path
4+
from typing import Any
5+
6+
from sqlalchemy import create_engine, text
7+
from sqlalchemy.engine.base import Engine
8+
from sqlalchemy.orm import sessionmaker, Session
9+
from log import get_logger, logging
10+
from configuration import configuration
11+
from models.database.base import Base
12+
from models.config import SQLiteDatabaseConfiguration, PostgreSQLDatabaseConfiguration
13+
14+
logger = get_logger(__name__)
15+
16+
engine: Engine | None = None
17+
SessionLocal: sessionmaker | None = None
18+
19+
20+
def get_engine() -> Engine:
21+
"""Get the database engine. Raises an error if not initialized."""
22+
if engine is None:
23+
raise RuntimeError(
24+
"Database engine not initialized. Call initialize_database() first."
25+
)
26+
return engine
27+
28+
29+
def create_tables() -> None:
30+
"""Create tables."""
31+
Base.metadata.create_all(get_engine())
32+
33+
34+
def get_session() -> Session:
35+
"""Get a database session. Raises an error if not initialized."""
36+
if SessionLocal is None:
37+
raise RuntimeError(
38+
"Database session not initialized. Call initialize_database() first."
39+
)
40+
return SessionLocal()
41+
42+
43+
def _create_sqlite_engine(config: SQLiteDatabaseConfiguration, **kwargs: Any) -> Engine:
44+
"""Create SQLite database engine."""
45+
if not Path(config.db_path).parent.exists():
46+
raise FileNotFoundError(
47+
f"SQLite database directory does not exist: {config.db_path}"
48+
)
49+
50+
try:
51+
return create_engine(f"sqlite:///{config.db_path}", **kwargs)
52+
except Exception as e:
53+
logger.exception("Failed to create SQLite engine")
54+
raise RuntimeError(f"SQLite engine creation failed: {e}") from e
55+
56+
57+
def _create_postgres_engine(
58+
config: PostgreSQLDatabaseConfiguration, **kwargs: Any
59+
) -> Engine:
60+
"""Create PostgreSQL database engine."""
61+
postgres_url = (
62+
f"postgresql://{config.user}:{config.password}@"
63+
f"{config.host}:{config.port}/{config.db}"
64+
f"?sslmode={config.ssl_mode}&gssencmode={config.gss_encmode}"
65+
)
66+
67+
is_custom_schema = config.namespace is not None and config.namespace != "public"
68+
69+
connect_args = {}
70+
if is_custom_schema:
71+
connect_args["options"] = f"-csearch_path={config.namespace}"
72+
73+
if config.ca_cert_path is not None:
74+
connect_args["sslrootcert"] = str(config.ca_cert_path)
75+
76+
try:
77+
postgres_engine = create_engine(
78+
postgres_url, connect_args=connect_args, **kwargs
79+
)
80+
except Exception as e:
81+
logger.exception("Failed to create PostgreSQL engine")
82+
raise RuntimeError(f"PostgreSQL engine creation failed: {e}") from e
83+
84+
if is_custom_schema:
85+
try:
86+
with postgres_engine.connect() as connection:
87+
connection.execute(
88+
text(f'CREATE SCHEMA IF NOT EXISTS "{config.namespace}"')
89+
)
90+
connection.commit()
91+
logger.info("Schema '%s' created or already exists", config.namespace)
92+
except Exception as e:
93+
logger.exception("Failed to create schema '%s'", config.namespace)
94+
raise RuntimeError(
95+
f"Schema creation failed for '{config.namespace}': {e}"
96+
) from e
97+
98+
return postgres_engine
99+
100+
101+
def initialize_database() -> None:
102+
"""Initialize the database engine."""
103+
db_config = configuration.database_configuration
104+
105+
global engine, SessionLocal # pylint: disable=global-statement
106+
107+
# Debug print all SQL statements if our logger is at-least DEBUG level
108+
echo = bool(logger.isEnabledFor(logging.DEBUG))
109+
110+
create_engine_kwargs = {
111+
"echo": echo,
112+
}
113+
114+
match db_config.db_type:
115+
case "sqlite":
116+
sqlite_config = db_config.config
117+
assert isinstance(sqlite_config, SQLiteDatabaseConfiguration)
118+
engine = _create_sqlite_engine(sqlite_config, **create_engine_kwargs)
119+
case "postgres":
120+
postgres_config = db_config.config
121+
assert isinstance(postgres_config, PostgreSQLDatabaseConfiguration)
122+
engine = _create_postgres_engine(postgres_config, **create_engine_kwargs)
123+
124+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

src/app/endpoints/conversations.py

Lines changed: 109 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,16 @@
99

1010
from client import AsyncLlamaStackClientHolder
1111
from configuration import configuration
12-
from models.responses import ConversationResponse, ConversationDeleteResponse
12+
from models.responses import (
13+
ConversationResponse,
14+
ConversationDeleteResponse,
15+
ConversationsListResponse,
16+
ConversationDetails,
17+
)
18+
from models.database.conversations import UserConversation
1319
from auth import get_auth_dependency
14-
from utils.endpoints import check_configuration_loaded
20+
from app.database import get_session
21+
from utils.endpoints import check_configuration_loaded, validate_conversation_ownership
1522
from utils.suid import check_suid
1623

1724
logger = logging.getLogger("app.endpoints.handlers")
@@ -66,6 +73,35 @@
6673
},
6774
}
6875

76+
conversations_list_responses: dict[int | str, dict[str, Any]] = {
77+
200: {
78+
"conversations": [
79+
{
80+
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
81+
"created_at": "2024-01-01T00:00:00Z",
82+
"last_message_at": "2024-01-01T00:05:00Z",
83+
"last_used_model": "gemini/gemini-1.5-flash",
84+
"last_used_provider": "gemini",
85+
"message_count": 5,
86+
},
87+
{
88+
"conversation_id": "456e7890-e12b-34d5-a678-901234567890",
89+
"created_at": "2024-01-01T01:00:00Z",
90+
"last_message_at": "2024-01-01T01:02:00Z",
91+
"last_used_model": "gemini/gemini-2.0-flash",
92+
"last_used_provider": "gemini",
93+
"message_count": 2,
94+
},
95+
]
96+
},
97+
503: {
98+
"detail": {
99+
"response": "Unable to connect to Llama Stack",
100+
"cause": "Connection error.",
101+
}
102+
},
103+
}
104+
69105

70106
def simplify_session_data(session_data: dict) -> list[dict[str, Any]]:
71107
"""Simplify session data to include only essential conversation information.
@@ -109,10 +145,64 @@ def simplify_session_data(session_data: dict) -> list[dict[str, Any]]:
109145
return chat_history
110146

111147

148+
@router.get("/conversations", responses=conversations_list_responses)
149+
def get_conversations_list_endpoint_handler(
150+
auth: Any = Depends(auth_dependency),
151+
) -> ConversationsListResponse:
152+
"""Handle request to retrieve all conversations for the authenticated user."""
153+
check_configuration_loaded(configuration)
154+
155+
user_id, _, _ = auth
156+
157+
logger.info("Retrieving conversations for user %s", user_id)
158+
159+
with get_session() as session:
160+
try:
161+
# Get all conversations for this user
162+
user_conversations = (
163+
session.query(UserConversation).filter_by(user_id=user_id).all()
164+
)
165+
166+
# Return conversation summaries with metadata
167+
conversations = [
168+
ConversationDetails(
169+
conversation_id=conv.id,
170+
created_at=conv.created_at.isoformat() if conv.created_at else None,
171+
last_message_at=(
172+
conv.last_message_at.isoformat()
173+
if conv.last_message_at
174+
else None
175+
),
176+
message_count=conv.message_count,
177+
last_used_model=conv.last_used_model,
178+
last_used_provider=conv.last_used_provider,
179+
)
180+
for conv in user_conversations
181+
]
182+
183+
logger.info(
184+
"Found %d conversations for user %s", len(conversations), user_id
185+
)
186+
187+
return ConversationsListResponse(conversations=conversations)
188+
189+
except Exception as e:
190+
logger.exception(
191+
"Error retrieving conversations for user %s: %s", user_id, e
192+
)
193+
raise HTTPException(
194+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
195+
detail={
196+
"response": "Unknown error",
197+
"cause": f"Unknown error while getting conversations for user {user_id}",
198+
},
199+
) from e
200+
201+
112202
@router.get("/conversations/{conversation_id}", responses=conversation_responses)
113203
async def get_conversation_endpoint_handler(
114204
conversation_id: str,
115-
_auth: Any = Depends(auth_dependency),
205+
auth: Any = Depends(auth_dependency),
116206
) -> ConversationResponse:
117207
"""Handle request to retrieve a conversation by ID."""
118208
check_configuration_loaded(configuration)
@@ -128,6 +218,13 @@ async def get_conversation_endpoint_handler(
128218
},
129219
)
130220

221+
user_id, _, _ = auth
222+
223+
validate_conversation_ownership(
224+
user_id=user_id,
225+
conversation_id=conversation_id,
226+
)
227+
131228
agent_id = conversation_id
132229
logger.info("Retrieving conversation %s", conversation_id)
133230

@@ -187,7 +284,7 @@ async def get_conversation_endpoint_handler(
187284
)
188285
async def delete_conversation_endpoint_handler(
189286
conversation_id: str,
190-
_auth: Any = Depends(auth_dependency),
287+
auth: Any = Depends(auth_dependency),
191288
) -> ConversationDeleteResponse:
192289
"""Handle request to delete a conversation by ID."""
193290
check_configuration_loaded(configuration)
@@ -202,6 +299,14 @@ async def delete_conversation_endpoint_handler(
202299
"cause": f"Conversation ID {conversation_id} is not a valid UUID",
203300
},
204301
)
302+
303+
user_id, _, _ = auth
304+
305+
validate_conversation_ownership(
306+
user_id=user_id,
307+
conversation_id=conversation_id,
308+
)
309+
205310
agent_id = conversation_id
206311
logger.info("Deleting conversation %s", conversation_id)
207312

0 commit comments

Comments
 (0)