Skip to content

Commit 2422408

Browse files
committed
Add database persistence layer and conversation tracking system
# Background / Context When the user queries lightspeed-stack, they get a conversation ID in the response that they can later use to retrieve the conversation history and continue the conversation. # Issue / Requirement / Reason for change We've put the burden of remembering the conversation ID on the user, which is not ideal. The user (e.g. a UI) has to store the conversation ID somewhere, e.g. browser local storage, but this is not ideal because it can easily get lost. It also complicated the UI implementation. Another issue is that a user could read another user's conversation if they only knew the ID, which is hard, but not ideal. It's not immediately obvious that the conversation ID is a secret that should not be shared with others. # Solution / Feature Overview Add a database persistence layer to lightspeed-stack to store user conversation IDs and a bit of metadata about the conversation. Allow users to retrieve their conversations by calling a new endpoint `/conversations` (without an ID - with an ID is an existing endpoint that lists the message history for a conversation), which will return a list of conversations associated with the user, along with metadata like creation time, last message time, and message count. Verify that the user owns the conversation when querying or streaming a conversation, and automatically associate the conversation with the user in the database when the user queries or streams a conversation. # Implementation Details **Database Layer (src/app/database.py):** - SQLAlchemy - Support for both SQLite and Postgre databases - PostgreSQL custom schema (table namespacing) support with automatic schema creation. This is useful because llama-stack itself needs a database for persisting conversation history, and we want to avoid conflicts with that when users use the same Postgres database as llama-stack. - Automatically enable SQL statement tracing when debug logging is enabled **Models (src/models/):** - Base model class for shared SQLAlchemy configuration - UserConversation model with indexed user_id fields for efficient querying, automatic timestamp tracking (created_at, last_message_at) and message count tracking **Configuration Updates:** - Extended configuration system to support database connection settings **API:** - Conversations endpoint without an ID now returns a list of conversations for the authenticated user, along with metadata - Query/streaming endpoint automatically associates conversations with the user in the database, and verifies ownership of conversations **Dependencies:** - Added SQLAlchemy dependency **Tests:** More tests will be added (and unit tests will be fixed) once I get some initial feedback on the implementation. # Future Work This change does not include migrations, so any future changes to the databases require writing custom backwards compatibility code. We should probably add Alembic or something simpler in the near future.
1 parent 75a0d1c commit 2422408

File tree

19 files changed

+761
-29
lines changed

19 files changed

+761
-29
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dependencies = [
3333
"starlette>=0.47.1",
3434
"aiohttp>=3.12.14",
3535
"authlib>=1.6.0",
36+
"sqlalchemy>=2.0.42",
3637
]
3738

3839
[tool.pyright]

src/app/database.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
"""Database engine management."""
2+
3+
from pathlib import Path
4+
from typing import Any
5+
6+
from sqlalchemy import create_engine, text
7+
from sqlalchemy.engine.base import Engine
8+
from sqlalchemy.orm import sessionmaker, Session
9+
from log import get_logger, logging
10+
from configuration import configuration
11+
from models.base import Base
12+
from models.config import SQLiteDatabaseConfiguration, PostgreSQLDatabaseConfiguration
13+
14+
logger = get_logger(__name__)
15+
16+
engine: Engine | None = None
17+
SessionLocal: sessionmaker | None = None
18+
19+
20+
def get_engine() -> Engine:
21+
"""Get the database engine. Raises an error if not initialized."""
22+
if engine is None:
23+
raise RuntimeError(
24+
"Database engine not initialized. Call initialize_database() first."
25+
)
26+
return engine
27+
28+
29+
def create_tables() -> None:
30+
"""Create tables."""
31+
Base.metadata.create_all(get_engine())
32+
33+
34+
def get_session() -> Session:
35+
"""Get a database session. Raises an error if not initialized."""
36+
if SessionLocal is None:
37+
raise RuntimeError(
38+
"Database session not initialized. Call initialize_database() first."
39+
)
40+
return SessionLocal()
41+
42+
43+
def _create_sqlite_engine(config: SQLiteDatabaseConfiguration, **kwargs: Any) -> Engine:
44+
"""Create SQLite database engine."""
45+
if not Path(config.db_path).parent.exists():
46+
raise FileNotFoundError(
47+
f"SQLite database directory does not exist: {config.db_path}"
48+
)
49+
50+
try:
51+
return create_engine(f"sqlite:///{config.db_path}", **kwargs)
52+
except Exception as e:
53+
logger.exception("Failed to create SQLite engine")
54+
raise RuntimeError(f"SQLite engine creation failed: {e}") from e
55+
56+
57+
def _create_postgres_engine(
58+
config: PostgreSQLDatabaseConfiguration, **kwargs: Any
59+
) -> Engine:
60+
"""Create PostgreSQL database engine."""
61+
postgres_url = (
62+
f"postgresql://{config.user}:{config.password}@"
63+
f"{config.host}:{config.port}/{config.db}"
64+
f"?sslmode={config.ssl_mode}&gssencmode={config.gss_encmode}"
65+
)
66+
67+
is_custom_schema = config.namespace is not None and config.namespace != "public"
68+
69+
connect_args = {}
70+
if is_custom_schema:
71+
connect_args["options"] = f"-csearch_path={config.namespace}"
72+
73+
if config.ca_cert_path is not None:
74+
connect_args["sslrootcert"] = str(config.ca_cert_path)
75+
76+
try:
77+
postgres_engine = create_engine(
78+
postgres_url, connect_args=connect_args, **kwargs
79+
)
80+
except Exception as e:
81+
logger.exception("Failed to create PostgreSQL engine")
82+
raise RuntimeError(f"PostgreSQL engine creation failed: {e}") from e
83+
84+
if is_custom_schema:
85+
try:
86+
with postgres_engine.connect() as connection:
87+
connection.execute(
88+
text(f'CREATE SCHEMA IF NOT EXISTS "{config.namespace}"')
89+
)
90+
connection.commit()
91+
logger.info("Schema '%s' created or already exists", config.namespace)
92+
except Exception as e:
93+
logger.exception("Failed to create schema '%s'", config.namespace)
94+
raise RuntimeError(
95+
f"Schema creation failed for '{config.namespace}': {e}"
96+
) from e
97+
98+
return postgres_engine
99+
100+
101+
def initialize_database() -> None:
102+
"""Initialize the database engine."""
103+
db_config = configuration.database_configuration
104+
105+
global engine, SessionLocal # pylint: disable=global-statement
106+
107+
# Debug print all SQL statements if our logger is at-least DEBUG level
108+
echo = bool(logger.isEnabledFor(logging.DEBUG))
109+
110+
create_engine_kwargs = {
111+
"echo": echo,
112+
}
113+
114+
match db_config.db_type:
115+
case "sqlite":
116+
sqlite_config = db_config.config
117+
assert isinstance(sqlite_config, SQLiteDatabaseConfiguration)
118+
engine = _create_sqlite_engine(sqlite_config, **create_engine_kwargs)
119+
case "postgres":
120+
postgres_config = db_config.config
121+
assert isinstance(postgres_config, PostgreSQLDatabaseConfiguration)
122+
engine = _create_postgres_engine(postgres_config, **create_engine_kwargs)
123+
124+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

src/app/endpoints/conversations.py

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,16 @@
99

1010
from client import AsyncLlamaStackClientHolder
1111
from configuration import configuration
12-
from models.responses import ConversationResponse, ConversationDeleteResponse
12+
from models.responses import (
13+
ConversationResponse,
14+
ConversationDeleteResponse,
15+
ConversationsListResponse,
16+
ConversationDetails,
17+
)
1318
from auth import get_auth_dependency
14-
from utils.endpoints import check_configuration_loaded
19+
from app.database import get_session
20+
from models.conversations import UserConversation # pylint: disable=ungrouped-imports
21+
from utils.endpoints import check_configuration_loaded, validate_conversation_ownership
1522
from utils.suid import check_suid
1623

1724
logger = logging.getLogger("app.endpoints.handlers")
@@ -66,6 +73,31 @@
6673
},
6774
}
6875

76+
conversations_list_responses: dict[int | str, dict[str, Any]] = {
77+
200: {
78+
"conversations": [
79+
{
80+
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
81+
"created_at": "2024-01-01T00:00:00Z",
82+
"last_message_at": "2024-01-01T00:05:00Z",
83+
"message_count": 5,
84+
},
85+
{
86+
"conversation_id": "456e7890-e12b-34d5-a678-901234567890",
87+
"created_at": "2024-01-01T01:00:00Z",
88+
"last_message_at": "2024-01-01T01:02:00Z",
89+
"message_count": 2,
90+
},
91+
]
92+
},
93+
503: {
94+
"detail": {
95+
"response": "Unable to connect to Llama Stack",
96+
"cause": "Connection error.",
97+
}
98+
},
99+
}
100+
69101

70102
def simplify_session_data(session_data: dict) -> list[dict[str, Any]]:
71103
"""Simplify session data to include only essential conversation information.
@@ -109,10 +141,63 @@ def simplify_session_data(session_data: dict) -> list[dict[str, Any]]:
109141
return chat_history
110142

111143

144+
@router.get("/conversations", responses=conversations_list_responses)
145+
def get_conversations_list_endpoint_handler(
146+
auth: Any = Depends(auth_dependency),
147+
) -> ConversationsListResponse:
148+
"""Handle request to retrieve all conversations for the authenticated user."""
149+
check_configuration_loaded(configuration)
150+
151+
user_id, _, _ = auth
152+
153+
logger.info("Retrieving conversations for user %s", user_id)
154+
155+
with get_session() as session:
156+
try:
157+
# Get all conversations for this user
158+
user_conversations = (
159+
session.query(UserConversation).filter_by(user_id=user_id).all()
160+
)
161+
162+
# Return conversation summaries with metadata
163+
conversations = [
164+
ConversationDetails(
165+
conversation_id=conv.id,
166+
created_at=conv.created_at.isoformat() if conv.created_at else None,
167+
last_message_at=(
168+
conv.last_message_at.isoformat()
169+
if conv.last_message_at
170+
else None
171+
),
172+
message_count=conv.message_count,
173+
model=conv.model,
174+
)
175+
for conv in user_conversations
176+
]
177+
178+
logger.info(
179+
"Found %d conversations for user %s", len(conversations), user_id
180+
)
181+
182+
return ConversationsListResponse(conversations=conversations)
183+
184+
except Exception as e:
185+
logger.exception(
186+
"Error retrieving conversations for user %s: %s", user_id, e
187+
)
188+
raise HTTPException(
189+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
190+
detail={
191+
"response": "Unknown error",
192+
"cause": f"Unknown error while getting conversations for user {user_id}",
193+
},
194+
) from e
195+
196+
112197
@router.get("/conversations/{conversation_id}", responses=conversation_responses)
113198
async def get_conversation_endpoint_handler(
114199
conversation_id: str,
115-
_auth: Any = Depends(auth_dependency),
200+
auth: Any = Depends(auth_dependency),
116201
) -> ConversationResponse:
117202
"""Handle request to retrieve a conversation by ID."""
118203
check_configuration_loaded(configuration)
@@ -128,6 +213,13 @@ async def get_conversation_endpoint_handler(
128213
},
129214
)
130215

216+
user_id, _, _ = auth
217+
218+
validate_conversation_ownership(
219+
user_id=user_id,
220+
conversation_id=conversation_id,
221+
)
222+
131223
agent_id = conversation_id
132224
logger.info("Retrieving conversation %s", conversation_id)
133225

@@ -187,7 +279,7 @@ async def get_conversation_endpoint_handler(
187279
)
188280
async def delete_conversation_endpoint_handler(
189281
conversation_id: str,
190-
_auth: Any = Depends(auth_dependency),
282+
auth: Any = Depends(auth_dependency),
191283
) -> ConversationDeleteResponse:
192284
"""Handle request to delete a conversation by ID."""
193285
check_configuration_loaded(configuration)
@@ -202,6 +294,14 @@ async def delete_conversation_endpoint_handler(
202294
"cause": f"Conversation ID {conversation_id} is not a valid UUID",
203295
},
204296
)
297+
298+
user_id, _, _ = auth
299+
300+
validate_conversation_ownership(
301+
user_id=user_id,
302+
conversation_id=conversation_id,
303+
)
304+
205305
agent_id = conversation_id
206306
logger.info("Deleting conversation %s", conversation_id)
207307

src/app/endpoints/query.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,18 @@
2222
from auth.interface import AuthTuple
2323
from client import AsyncLlamaStackClientHolder
2424
from configuration import configuration
25+
from app.database import get_session
2526
import metrics
27+
from models.conversations import UserConversation
2628
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
2729
from models.requests import QueryRequest, Attachment
2830
import constants
29-
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
31+
from utils.endpoints import (
32+
check_configuration_loaded,
33+
get_agent,
34+
get_system_prompt,
35+
validate_conversation_ownership,
36+
)
3037
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
3138
from utils.suid import get_suid
3239

@@ -65,6 +72,32 @@ def is_transcripts_enabled() -> bool:
6572
return configuration.user_data_collection_configuration.transcripts_enabled
6673

6774

75+
def persist_user_conversation_details(
76+
user_id: str, conversation_id: str, model: str
77+
) -> None:
78+
"""Associate conversation to user in the database."""
79+
with get_session() as session:
80+
existing_conversation = (
81+
session.query(UserConversation)
82+
.filter_by(id=conversation_id, user_id=user_id)
83+
.first()
84+
)
85+
86+
if not existing_conversation:
87+
conversation = UserConversation(
88+
id=conversation_id, user_id=user_id, model=model, message_count=1
89+
)
90+
session.add(conversation)
91+
logger.debug(
92+
"Associated conversation %s to user %s", conversation_id, user_id
93+
)
94+
else:
95+
existing_conversation.last_message_at = datetime.now(UTC)
96+
existing_conversation.message_count += 1
97+
98+
session.commit()
99+
100+
68101
@router.post("/query", responses=query_response)
69102
async def query_endpoint_handler(
70103
query_request: QueryRequest,
@@ -79,6 +112,11 @@ async def query_endpoint_handler(
79112

80113
user_id, _, token = auth
81114

115+
if query_request.conversation_id is not None:
116+
validate_conversation_ownership(
117+
user_id=user_id, conversation_id=query_request.conversation_id
118+
)
119+
82120
try:
83121
# try to get Llama Stack client
84122
client = AsyncLlamaStackClientHolder().get_client()
@@ -110,6 +148,10 @@ async def query_endpoint_handler(
110148
attachments=query_request.attachments or [],
111149
)
112150

151+
persist_user_conversation_details(
152+
user_id=user_id, conversation_id=conversation_id, model=model_id
153+
)
154+
113155
return QueryResponse(conversation_id=conversation_id, response=response)
114156

115157
# connection to Llama Stack server

0 commit comments

Comments
 (0)