Skip to content

Commit ef0b934

Browse files
committed
Flatten referenced_documents in /v2/conversations response
Signed-off-by: Maysun J Faisal <[email protected]>
1 parent 88c6edf commit ef0b934

File tree

12 files changed

+198
-174
lines changed

12 files changed

+198
-174
lines changed

src/app/endpoints/conversations_v2.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,11 @@ def transform_chat_message(entry: CacheEntry) -> dict[str, Any]:
323323
"type": "assistant"
324324
}
325325

326-
# Check for additional_kwargs and add it to the assistant message if it exists
327-
if entry.additional_kwargs:
328-
assistant_message["additional_kwargs"] = entry.additional_kwargs.model_dump()
326+
# If referenced_documents exist on the entry, add them to the assistant message
327+
if entry.referenced_documents is not None:
328+
assistant_message["referenced_documents"] = [
329+
doc.model_dump(mode='json') for doc in entry.referenced_documents
330+
]
329331

330332
return {
331333
"provider": entry.provider,

src/app/endpoints/query.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from authorization.middleware import authorize
3333
from client import AsyncLlamaStackClientHolder
3434
from configuration import configuration
35-
from models.cache_entry import CacheEntry, AdditionalKwargs
35+
from models.cache_entry import CacheEntry
3636
from models.config import Action
3737
from models.database.conversations import UserConversation
3838
from models.requests import Attachment, QueryRequest
@@ -334,19 +334,14 @@ async def query_endpoint_handler( # pylint: disable=R0914
334334

335335
completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
336336

337-
additional_kwargs_obj = None
338-
if referenced_documents:
339-
additional_kwargs_obj = AdditionalKwargs(
340-
referenced_documents=referenced_documents
341-
)
342337
cache_entry = CacheEntry(
343338
query=query_request.query,
344339
response=summary.llm_response,
345340
provider=provider_id,
346341
model=model_id,
347342
started_at=started_at,
348343
completed_at=completed_at,
349-
additional_kwargs=additional_kwargs_obj
344+
referenced_documents=referenced_documents if referenced_documents else None
350345
)
351346

352347
store_conversation_into_cache(

src/app/endpoints/streaming_query.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,14 @@
4444
from constants import DEFAULT_RAG_TOOL, MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT
4545
import metrics
4646
from metrics.utils import update_llm_token_count_from_turn
47-
from models.cache_entry import CacheEntry, AdditionalKwargs
47+
from models.cache_entry import CacheEntry
4848
from models.config import Action
4949
from models.database.conversations import UserConversation
5050
from models.requests import QueryRequest
5151
from models.responses import ForbiddenResponse, UnauthorizedResponse, ReferencedDocument
5252
from utils.endpoints import (
5353
check_configuration_loaded,
54+
create_referenced_documents_with_metadata,
5455
create_rag_chunks_dict,
5556
get_agent,
5657
get_system_prompt,
@@ -868,19 +869,14 @@ async def response_generator(
868869

869870
referenced_documents = create_referenced_documents_with_metadata(summary, metadata_map)
870871

871-
additional_kwargs_obj = None
872-
if referenced_documents:
873-
additional_kwargs_obj = AdditionalKwargs(
874-
referenced_documents=referenced_documents
875-
)
876872
cache_entry = CacheEntry(
877873
query=query_request.query,
878874
response=summary.llm_response,
879875
provider=provider_id,
880876
model=model_id,
881877
started_at=started_at,
882878
completed_at=completed_at,
883-
additional_kwargs=additional_kwargs_obj
879+
referenced_documents=referenced_documents if referenced_documents else None
884880
)
885881

886882
store_conversation_into_cache(

src/cache/postgres_cache.py

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""PostgreSQL cache implementation."""
22

3+
import json
34
import psycopg2
45

56
from cache.cache import Cache
67
from cache.cache_error import CacheError
7-
from models.cache_entry import CacheEntry, AdditionalKwargs
8+
from models.cache_entry import CacheEntry
89
from models.config import PostgreSQLDatabaseConfiguration
9-
from models.responses import ConversationData
10+
from models.responses import ConversationData, ReferencedDocument
1011
from log import get_logger
1112
from utils.connection_decorator import connection
1213

@@ -19,17 +20,18 @@ class PostgresCache(Cache):
1920
The cache itself lives stored in following table:
2021
2122
```
22-
Column | Type | Nullable |
23-
-----------------+--------------------------------+----------+
24-
user_id | text | not null |
25-
conversation_id | text | not null |
26-
created_at | timestamp without time zone | not null |
27-
started_at | text | |
28-
completed_at | text | |
29-
query | text | |
30-
response | text | |
31-
provider | text | |
32-
model | text | |
23+
Column | Type | Nullable |
24+
-----------------------+--------------------------------+----------+
25+
user_id | text | not null |
26+
conversation_id | text | not null |
27+
created_at | timestamp without time zone | not null |
28+
started_at | text | |
29+
completed_at | text | |
30+
query | text | |
31+
response | text | |
32+
provider | text | |
33+
model | text | |
34+
referenced_documents | jsonb | |
3335
Indexes:
3436
"cache_pkey" PRIMARY KEY, btree (user_id, conversation_id, created_at)
3537
"timestamps" btree (created_at)
@@ -38,16 +40,16 @@ class PostgresCache(Cache):
3840

3941
CREATE_CACHE_TABLE = """
4042
CREATE TABLE IF NOT EXISTS cache (
41-
user_id text NOT NULL,
42-
conversation_id text NOT NULL,
43-
created_at timestamp NOT NULL,
44-
started_at text,
45-
completed_at text,
46-
query text,
47-
response text,
48-
provider text,
49-
model text,
50-
additional_kwargs jsonb,
43+
user_id text NOT NULL,
44+
conversation_id text NOT NULL,
45+
created_at timestamp NOT NULL,
46+
started_at text,
47+
completed_at text,
48+
query text,
49+
response text,
50+
provider text,
51+
model text,
52+
referenced_documents jsonb,
5153
PRIMARY KEY(user_id, conversation_id, created_at)
5254
);
5355
"""
@@ -68,15 +70,15 @@ class PostgresCache(Cache):
6870
"""
6971

7072
SELECT_CONVERSATION_HISTORY_STATEMENT = """
71-
SELECT query, response, provider, model, started_at, completed_at, additional_kwargs
73+
SELECT query, response, provider, model, started_at, completed_at, referenced_documents
7274
FROM cache
7375
WHERE user_id=%s AND conversation_id=%s
7476
ORDER BY created_at
7577
"""
7678

7779
INSERT_CONVERSATION_HISTORY_STATEMENT = """
7880
INSERT INTO cache(user_id, conversation_id, created_at, started_at, completed_at,
79-
query, response, provider, model, additional_kwargs)
81+
query, response, provider, model, referenced_documents)
8082
VALUES (%s, %s, CURRENT_TIMESTAMP, %s, %s, %s, %s, %s, %s, %s)
8183
"""
8284

@@ -214,18 +216,18 @@ def get(
214216
result = []
215217
for conversation_entry in conversation_entries:
216218
# Parse it back into an LLMResponse object
217-
additional_kwargs_data = conversation_entry[6]
218-
additional_kwargs_obj = None
219-
if additional_kwargs_data:
220-
additional_kwargs_obj = AdditionalKwargs.model_validate(additional_kwargs_data)
219+
docs_data = conversation_entry[6]
220+
docs_obj = None
221+
if docs_data:
222+
docs_obj = [ReferencedDocument.model_validate(doc) for doc in docs_data]
221223
cache_entry = CacheEntry(
222224
query=conversation_entry[0],
223225
response=conversation_entry[1],
224226
provider=conversation_entry[2],
225227
model=conversation_entry[3],
226228
started_at=conversation_entry[4],
227229
completed_at=conversation_entry[5],
228-
additional_kwargs=additional_kwargs_obj,
230+
referenced_documents=docs_obj,
229231
)
230232
result.append(cache_entry)
231233

@@ -253,10 +255,11 @@ def insert_or_append(
253255
raise CacheError("insert_or_append: cache is disconnected")
254256

255257
try:
256-
additional_kwargs_json = None
257-
if cache_entry.additional_kwargs:
258-
# Use exclude_none=True to keep JSON clean
259-
additional_kwargs_json = cache_entry.additional_kwargs.model_dump_json(exclude_none=True)
258+
referenced_documents_json = None
259+
if cache_entry.referenced_documents:
260+
docs_as_dicts = [doc.model_dump(mode='json') for doc in cache_entry.referenced_documents]
261+
referenced_documents_json = json.dumps(docs_as_dicts)
262+
260263
# the whole operation is run in one transaction
261264
with self.connection.cursor() as cursor:
262265
cursor.execute(
@@ -270,7 +273,7 @@ def insert_or_append(
270273
cache_entry.response,
271274
cache_entry.provider,
272275
cache_entry.model,
273-
additional_kwargs_json,
276+
referenced_documents_json,
274277
),
275278
)
276279

src/cache/sqlite_cache.py

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77

88
from cache.cache import Cache
99
from cache.cache_error import CacheError
10-
from models.cache_entry import CacheEntry, AdditionalKwargs
10+
from models.cache_entry import CacheEntry
1111
from models.config import SQLiteDatabaseConfiguration
12-
from models.responses import ConversationData
12+
from models.responses import ConversationData, ReferencedDocument
1313
from log import get_logger
1414
from utils.connection_decorator import connection
1515

@@ -22,17 +22,18 @@ class SQLiteCache(Cache):
2222
The cache itself is stored in following table:
2323
2424
```
25-
Column | Type | Nullable |
26-
-----------------+-----------------------------+----------+
27-
user_id | text | not null |
28-
conversation_id | text | not null |
29-
created_at | int | not null |
30-
started_at | text | |
31-
completed_at | text | |
32-
query | text | |
33-
response | text | |
34-
provider | text | |
35-
model | text | |
25+
Column | Type | Nullable |
26+
-----------------------+-----------------------------+----------+
27+
user_id | text | not null |
28+
conversation_id | text | not null |
29+
created_at | int | not null |
30+
started_at | text | |
31+
completed_at | text | |
32+
query | text | |
33+
response | text | |
34+
provider | text | |
35+
model | text | |
36+
referenced_documents | text | |
3637
Indexes:
3738
"cache_pkey" PRIMARY KEY, btree (user_id, conversation_id, created_at)
3839
"cache_key_key" UNIQUE CONSTRAINT, btree (key)
@@ -43,16 +44,16 @@ class SQLiteCache(Cache):
4344

4445
CREATE_CACHE_TABLE = """
4546
CREATE TABLE IF NOT EXISTS cache (
46-
user_id text NOT NULL,
47-
conversation_id text NOT NULL,
48-
created_at int NOT NULL,
49-
started_at text,
50-
completed_at text,
51-
query text,
52-
response text,
53-
provider text,
54-
model text,
55-
additional_kwargs text,
47+
user_id text NOT NULL,
48+
conversation_id text NOT NULL,
49+
created_at int NOT NULL,
50+
started_at text,
51+
completed_at text,
52+
query text,
53+
response text,
54+
provider text,
55+
model text,
56+
referenced_documents text,
5657
PRIMARY KEY(user_id, conversation_id, created_at)
5758
);
5859
"""
@@ -73,15 +74,15 @@ class SQLiteCache(Cache):
7374
"""
7475

7576
SELECT_CONVERSATION_HISTORY_STATEMENT = """
76-
SELECT query, response, provider, model, started_at, completed_at, additional_kwargs
77+
SELECT query, response, provider, model, started_at, completed_at, referenced_documents
7778
FROM cache
7879
WHERE user_id=? AND conversation_id=?
7980
ORDER BY created_at
8081
"""
8182

8283
INSERT_CONVERSATION_HISTORY_STATEMENT = """
8384
INSERT INTO cache(user_id, conversation_id, created_at, started_at, completed_at,
84-
query, response, provider, model, additional_kwargs)
85+
query, response, provider, model, referenced_documents)
8586
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
8687
"""
8788

@@ -212,18 +213,19 @@ def get(
212213

213214
result = []
214215
for conversation_entry in conversation_entries:
215-
additional_kwargs_json = conversation_entry[6]
216-
additional_kwargs_obj = None
217-
if additional_kwargs_json:
218-
additional_kwargs_obj = AdditionalKwargs.model_validate_json(additional_kwargs_json)
216+
docs_json_str = conversation_entry[6]
217+
docs_obj = None
218+
if docs_json_str:
219+
docs_data = json.loads(docs_json_str)
220+
docs_obj = [ReferencedDocument.model_validate(doc) for doc in docs_data]
219221
cache_entry = CacheEntry(
220222
query=conversation_entry[0],
221223
response=conversation_entry[1],
222224
provider=conversation_entry[2],
223225
model=conversation_entry[3],
224226
started_at=conversation_entry[4],
225227
completed_at=conversation_entry[5],
226-
additional_kwargs=additional_kwargs_obj,
228+
referenced_documents=docs_obj,
227229
)
228230
result.append(cache_entry)
229231

@@ -253,9 +255,10 @@ def insert_or_append(
253255
cursor = self.connection.cursor()
254256
current_time = time()
255257

256-
additional_kwargs_json = None
257-
if cache_entry.additional_kwargs:
258-
additional_kwargs_json = cache_entry.additional_kwargs.model_dump_json(exclude_none=True)
258+
referenced_documents_json = None
259+
if cache_entry.referenced_documents:
260+
docs_as_dicts = [doc.model_dump(mode='json') for doc in cache_entry.referenced_documents]
261+
referenced_documents_json = json.dumps(docs_as_dicts)
259262

260263
cursor.execute(
261264
self.INSERT_CONVERSATION_HISTORY_STATEMENT,
@@ -269,7 +272,7 @@ def insert_or_append(
269272
cache_entry.response,
270273
cache_entry.provider,
271274
cache_entry.model,
272-
additional_kwargs_json,
275+
referenced_documents_json,
273276
),
274277
)
275278

src/models/cache_entry.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
from typing import List
55
from models.responses import ReferencedDocument
66

7-
class AdditionalKwargs(BaseModel):
8-
"""A structured model for the 'additional_kwargs' dictionary."""
9-
referenced_documents: List[ReferencedDocument] = Field(default_factory=list)
10-
117

128
class CacheEntry(BaseModel):
139
"""Model representing a cache entry.
@@ -26,4 +22,4 @@ class CacheEntry(BaseModel):
2622
model: str
2723
started_at: str
2824
completed_at: str
29-
additional_kwargs: AdditionalKwargs | None = None
25+
referenced_documents: List[ReferencedDocument] | None = None

src/models/responses.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pydantic import AnyUrl, BaseModel, Field
88

99
from llama_stack_client.types import ProviderInfo
10-
from models.cache_entry import ConversationData
1110

1211

1312
class ModelsResponse(BaseModel):

0 commit comments

Comments
 (0)