Skip to content

Commit b5eb048

Browse files
authored
[RHDHPAI-978] Topic summary of initial query (#564)
* update list conversations API to include topic_summary Signed-off-by: Stephanie <[email protected]> * generate docs Signed-off-by: Stephanie <[email protected]> * fix format Signed-off-by: Stephanie <[email protected]> * ignore pylint error Signed-off-by: Stephanie <[email protected]> * fix skip_userid_check Signed-off-by: Stephanie <[email protected]> * add topic summary Signed-off-by: Stephanie <[email protected]> * fix pylint Signed-off-by: Stephanie <[email protected]> * fix follow-up queries issue on topic-summary Signed-off-by: Stephanie <[email protected]> * rebase Signed-off-by: Stephanie <[email protected]> --------- Signed-off-by: Stephanie <[email protected]>
1 parent 80c7938 commit b5eb048

21 files changed

+1455
-64
lines changed

docs/openapi.json

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,9 @@
744744
},
745745
"conversations": [
746746
{
747-
"conversation_id": "123e4567-e89b-12d3-a456-426614174000"
747+
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
748+
"topic_summary": "This is a topic summary",
749+
"last_message_timestamp": "2024-01-01T00:00:00Z"
748750
}
749751
]
750752
}
@@ -1419,6 +1421,37 @@
14191421
"title": "ConversationCacheConfiguration",
14201422
"description": "Conversation cache configuration."
14211423
},
1424+
"ConversationData": {
1425+
"properties": {
1426+
"conversation_id": {
1427+
"type": "string",
1428+
"title": "Conversation Id"
1429+
},
1430+
"topic_summary": {
1431+
"anyOf": [
1432+
{
1433+
"type": "string"
1434+
},
1435+
{
1436+
"type": "null"
1437+
}
1438+
],
1439+
"title": "Topic Summary"
1440+
},
1441+
"last_message_timestamp": {
1442+
"type": "number",
1443+
"title": "Last Message Timestamp"
1444+
}
1445+
},
1446+
"type": "object",
1447+
"required": [
1448+
"conversation_id",
1449+
"topic_summary",
1450+
"last_message_timestamp"
1451+
],
1452+
"title": "ConversationData",
1453+
"description": "Model representing conversation data returned by cache list operations.\n\nAttributes:\n conversation_id: The conversation ID\n topic_summary: The topic summary for the conversation (can be None)\n last_message_timestamp: The timestamp of the last message in the conversation"
1454+
},
14221455
"ConversationDeleteResponse": {
14231456
"properties": {
14241457
"conversation_id": {
@@ -1536,14 +1569,29 @@
15361569
"openai",
15371570
"gemini"
15381571
]
1572+
},
1573+
"topic_summary": {
1574+
"anyOf": [
1575+
{
1576+
"type": "string"
1577+
},
1578+
{
1579+
"type": "null"
1580+
}
1581+
],
1582+
"title": "Topic Summary",
1583+
"description": "Topic summary for the conversation",
1584+
"examples": [
1585+
"Openshift Microservices Deployment Strategies"
1586+
]
15391587
}
15401588
},
15411589
"type": "object",
15421590
"required": [
15431591
"conversation_id"
15441592
],
15451593
"title": "ConversationDetails",
1546-
"description": "Model representing the details of a user conversation.\n\nAttributes:\n conversation_id: The conversation ID (UUID).\n created_at: When the conversation was created.\n last_message_at: When the last message was sent.\n message_count: Number of user messages in the conversation.\n last_used_model: The last model used for the conversation.\n last_used_provider: The provider of the last used model.\n\nExample:\n ```python\n conversation = ConversationDetails(\n conversation_id=\"123e4567-e89b-12d3-a456-426614174000\"\n created_at=\"2024-01-01T00:00:00Z\",\n last_message_at=\"2024-01-01T00:05:00Z\",\n message_count=5,\n last_used_model=\"gemini/gemini-2.0-flash\",\n last_used_provider=\"gemini\",\n )\n ```"
1594+
"description": "Model representing the details of a user conversation.\n\nAttributes:\n conversation_id: The conversation ID (UUID).\n created_at: When the conversation was created.\n last_message_at: When the last message was sent.\n message_count: Number of user messages in the conversation.\n last_used_model: The last model used for the conversation.\n last_used_provider: The provider of the last used model.\n topic_summary: The topic summary for the conversation.\n\nExample:\n ```python\n conversation = ConversationDetails(\n conversation_id=\"123e4567-e89b-12d3-a456-426614174000\"\n created_at=\"2024-01-01T00:00:00Z\",\n last_message_at=\"2024-01-01T00:05:00Z\",\n message_count=5,\n last_used_model=\"gemini/gemini-2.0-flash\",\n last_used_provider=\"gemini\",\n topic_summary=\"Openshift Microservices Deployment Strategies\",\n )\n ```"
15471595
},
15481596
"ConversationResponse": {
15491597
"properties": {
@@ -1604,7 +1652,7 @@
16041652
"conversations"
16051653
],
16061654
"title": "ConversationsListResponse",
1607-
"description": "Model representing a response for listing conversations of a user.\n\nAttributes:\n conversations: List of conversation details associated with the user.\n\nExample:\n ```python\n conversations_list = ConversationsListResponse(\n conversations=[\n ConversationDetails(\n conversation_id=\"123e4567-e89b-12d3-a456-426614174000\",\n created_at=\"2024-01-01T00:00:00Z\",\n last_message_at=\"2024-01-01T00:05:00Z\",\n message_count=5,\n last_used_model=\"gemini/gemini-2.0-flash\",\n last_used_provider=\"gemini\",\n ),\n ConversationDetails(\n conversation_id=\"456e7890-e12b-34d5-a678-901234567890\"\n created_at=\"2024-01-01T01:00:00Z\",\n message_count=2,\n last_used_model=\"gemini/gemini-2.0-flash\",\n last_used_provider=\"gemini\",\n )\n ]\n )\n ```",
1655+
"description": "Model representing a response for listing conversations of a user.\n\nAttributes:\n conversations: List of conversation details associated with the user.\n\nExample:\n ```python\n conversations_list = ConversationsListResponse(\n conversations=[\n ConversationDetails(\n conversation_id=\"123e4567-e89b-12d3-a456-426614174000\",\n created_at=\"2024-01-01T00:00:00Z\",\n last_message_at=\"2024-01-01T00:05:00Z\",\n message_count=5,\n last_used_model=\"gemini/gemini-2.0-flash\",\n last_used_provider=\"gemini\",\n topic_summary=\"Openshift Microservices Deployment Strategies\",\n ),\n ConversationDetails(\n conversation_id=\"456e7890-e12b-34d5-a678-901234567890\"\n created_at=\"2024-01-01T01:00:00Z\",\n message_count=2,\n last_used_model=\"gemini/gemini-2.0-flash\",\n last_used_provider=\"gemini\",\n topic_summary=\"RHDH Purpose Summary\",\n )\n ]\n )\n ```",
16081656
"examples": [
16091657
{
16101658
"conversations": [
@@ -1614,14 +1662,16 @@
16141662
"last_message_at": "2024-01-01T00:05:00Z",
16151663
"last_used_model": "gemini/gemini-2.0-flash",
16161664
"last_used_provider": "gemini",
1617-
"message_count": 5
1665+
"message_count": 5,
1666+
"topic_summary": "Openshift Microservices Deployment Strategies"
16181667
},
16191668
{
16201669
"conversation_id": "456e7890-e12b-34d5-a678-901234567890",
16211670
"created_at": "2024-01-01T01:00:00Z",
16221671
"last_used_model": "gemini/gemini-2.5-flash",
16231672
"last_used_provider": "gemini",
1624-
"message_count": 2
1673+
"message_count": 2,
1674+
"topic_summary": "RHDH Purpose Summary"
16251675
}
16261676
]
16271677
}
@@ -1631,7 +1681,7 @@
16311681
"properties": {
16321682
"conversations": {
16331683
"items": {
1634-
"type": "string"
1684+
"$ref": "#/components/schemas/ConversationData"
16351685
},
16361686
"type": "array",
16371687
"title": "Conversations"
@@ -1642,7 +1692,7 @@
16421692
"conversations"
16431693
],
16441694
"title": "ConversationsListResponseV2",
1645-
"description": "Model representing a response for listing conversations of a user.\n\nAttributes:\n conversations: List of conversation IDs associated with the user."
1695+
"description": "Model representing a response for listing conversations of a user.\n\nAttributes:\n conversations: List of conversation data associated with the user."
16461696
},
16471697
"CustomProfile": {
16481698
"properties": {

src/app/endpoints/conversations.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ async def get_conversations_list_endpoint_handler(
214214
message_count=conv.message_count,
215215
last_used_model=conv.last_used_model,
216216
last_used_provider=conv.last_used_provider,
217+
topic_summary=conv.topic_summary,
217218
)
218219
for conv in user_conversations
219220
]

src/app/endpoints/conversations_v2.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@
8383
"conversations": [
8484
{
8585
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
86+
"topic_summary": "This is a topic summary",
87+
"last_message_timestamp": "2024-01-01T00:00:00Z",
8688
}
8789
]
8890
}
@@ -102,6 +104,8 @@ async def get_conversations_list_endpoint_handler(
102104

103105
logger.info("Retrieving conversations for user %s", user_id)
104106

107+
skip_userid_check = auth[2]
108+
105109
if configuration.conversation_cache is None:
106110
logger.warning("Converastion cache is not configured")
107111
raise HTTPException(
@@ -112,7 +116,7 @@ async def get_conversations_list_endpoint_handler(
112116
},
113117
)
114118

115-
conversations = configuration.conversation_cache.list(user_id, False)
119+
conversations = configuration.conversation_cache.list(user_id, skip_userid_check)
116120
logger.info("Conversations for user %s: %s", user_id, len(conversations))
117121

118122
return ConversationsListResponseV2(conversations=conversations)
@@ -132,6 +136,8 @@ async def get_conversation_endpoint_handler(
132136
user_id = auth[0]
133137
logger.info("Retrieving conversation %s for user %s", conversation_id, user_id)
134138

139+
skip_userid_check = auth[2]
140+
135141
if configuration.conversation_cache is None:
136142
logger.warning("Converastion cache is not configured")
137143
raise HTTPException(
@@ -144,7 +150,9 @@ async def get_conversation_endpoint_handler(
144150

145151
check_conversation_existence(user_id, conversation_id)
146152

147-
conversation = configuration.conversation_cache.get(user_id, conversation_id, False)
153+
conversation = configuration.conversation_cache.get(
154+
user_id, conversation_id, skip_userid_check
155+
)
148156
chat_history = [transform_chat_message(entry) for entry in conversation]
149157

150158
return ConversationResponse(
@@ -168,6 +176,8 @@ async def delete_conversation_endpoint_handler(
168176
user_id = auth[0]
169177
logger.info("Deleting conversation %s for user %s", conversation_id, user_id)
170178

179+
skip_userid_check = auth[2]
180+
171181
if configuration.conversation_cache is None:
172182
logger.warning("Converastion cache is not configured")
173183
raise HTTPException(
@@ -181,7 +191,9 @@ async def delete_conversation_endpoint_handler(
181191
check_conversation_existence(user_id, conversation_id)
182192

183193
logger.info("Deleting conversation %s for user %s", conversation_id, user_id)
184-
deleted = configuration.conversation_cache.delete(user_id, conversation_id, False)
194+
deleted = configuration.conversation_cache.delete(
195+
user_id, conversation_id, skip_userid_check
196+
)
185197

186198
if deleted:
187199
return ConversationDeleteResponse(
@@ -215,7 +227,8 @@ def check_conversation_existence(user_id: str, conversation_id: str) -> None:
215227
if configuration.conversation_cache is None:
216228
return
217229
conversations = configuration.conversation_cache.list(user_id, False)
218-
if conversation_id not in conversations:
230+
conversation_ids = [conv.conversation_id for conv in conversations]
231+
if conversation_id not in conversation_ids:
219232
logger.error("No conversation found for conversation ID %s", conversation_id)
220233
raise HTTPException(
221234
status_code=status.HTTP_404_NOT_FOUND,

src/app/endpoints/query.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
from utils.endpoints import (
4747
check_configuration_loaded,
4848
get_agent,
49+
get_topic_summary_system_prompt,
50+
get_temp_agent,
4951
get_system_prompt,
5052
store_conversation_into_cache,
5153
validate_conversation_ownership,
@@ -98,7 +100,11 @@ def is_transcripts_enabled() -> bool:
98100

99101

100102
def persist_user_conversation_details(
101-
user_id: str, conversation_id: str, model: str, provider_id: str
103+
user_id: str,
104+
conversation_id: str,
105+
model: str,
106+
provider_id: str,
107+
topic_summary: Optional[str],
102108
) -> None:
103109
"""Associate conversation to user in the database."""
104110
with get_session() as session:
@@ -112,6 +118,7 @@ def persist_user_conversation_details(
112118
user_id=user_id,
113119
last_used_model=model,
114120
last_used_provider=provider_id,
121+
topic_summary=topic_summary,
115122
message_count=1,
116123
)
117124
session.add(conversation)
@@ -169,9 +176,42 @@ def evaluate_model_hints(
169176
return model_id, provider_id
170177

171178

179+
async def get_topic_summary(
180+
question: str, client: AsyncLlamaStackClient, model_id: str
181+
) -> str:
182+
"""Get a topic summary for a question.
183+
184+
Args:
185+
question: The question to be validated.
186+
client: The AsyncLlamaStackClient to use for the request.
187+
model_id: The ID of the model to use.
188+
Returns:
189+
str: The topic summary for the question.
190+
"""
191+
topic_summary_system_prompt = get_topic_summary_system_prompt(configuration)
192+
agent, session_id, _ = await get_temp_agent(
193+
client, model_id, topic_summary_system_prompt
194+
)
195+
response = await agent.create_turn(
196+
messages=[UserMessage(role="user", content=question)],
197+
session_id=session_id,
198+
stream=False,
199+
toolgroups=None,
200+
)
201+
response = cast(Turn, response)
202+
return (
203+
interleaved_content_as_str(response.output_message.content)
204+
if (
205+
getattr(response, "output_message", None) is not None
206+
and getattr(response.output_message, "content", None) is not None
207+
)
208+
else ""
209+
)
210+
211+
172212
@router.post("/query", responses=query_response)
173213
@authorize(Action.QUERY)
174-
async def query_endpoint_handler(
214+
async def query_endpoint_handler( # pylint: disable=R0914
175215
request: Request,
176216
query_request: QueryRequest,
177217
auth: Annotated[AuthTuple, Depends(auth_dependency)],
@@ -200,7 +240,7 @@ async def query_endpoint_handler(
200240
# log Llama Stack configuration
201241
logger.info("Llama stack config: %s", configuration.llama_stack_configuration)
202242

203-
user_id, _, _, token = auth
243+
user_id, _, _skip_userid_check, token = auth
204244

205245
user_conversation: UserConversation | None = None
206246
if query_request.conversation_id:
@@ -251,6 +291,16 @@ async def query_endpoint_handler(
251291
# Update metrics for the LLM call
252292
metrics.llm_calls_total.labels(provider_id, model_id).inc()
253293

294+
# Get the initial topic summary for the conversation
295+
topic_summary = None
296+
with get_session() as session:
297+
existing_conversation = (
298+
session.query(UserConversation).filter_by(id=conversation_id).first()
299+
)
300+
if not existing_conversation:
301+
topic_summary = await get_topic_summary(
302+
query_request.query, client, model_id
303+
)
254304
# Convert RAG chunks to dictionary format once for reuse
255305
logger.info("Processing RAG chunks...")
256306
rag_chunks_dict = [chunk.model_dump() for chunk in summary.rag_chunks]
@@ -278,6 +328,7 @@ async def query_endpoint_handler(
278328
conversation_id=conversation_id,
279329
model=model_id,
280330
provider_id=provider_id,
331+
topic_summary=topic_summary,
281332
)
282333

283334
store_conversation_into_cache(
@@ -288,6 +339,8 @@ async def query_endpoint_handler(
288339
model_id,
289340
query_request.query,
290341
summary.llm_response,
342+
_skip_userid_check,
343+
topic_summary,
291344
)
292345

293346
# Convert tool calls to response format

0 commit comments

Comments
 (0)