Skip to content

Commit 855fbfe

Browse files
ckhenedAntonyvance
andauthored
RAG end to end perf measurements using Langsmith (#60)
Co-authored-by: Antony Vance <[email protected]>
1 parent c4ba63e commit 855fbfe

File tree

3 files changed

+775
-0
lines changed

3 files changed

+775
-0
lines changed

ChatQnA/langchain/test/README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
## Performance measurement tests with langsmith
2+
3+
Pre-requisite: Signup in langsmith [https://www.langchain.com/langsmith] and get the api token <br />
4+
5+
### Steps to run perf measurements with tgi_gaudi.ipynb jupyter notebook
6+
7+
1. This dir is mounted at /test in qna-rag-redis-server
8+
2. Make sure redis container and LLM serving is up and running
9+
3. enter into qna-rag-redis-server container and start jupyter notebook server (can specify needed IP address and jupyter will run on port 8888)
10+
```
11+
docker exec -it qna-rag-redis-server bash
12+
cd /test
13+
jupyter notebook --allow-root --ip=X.X.X.X
14+
```
15+
4. Launch jupyter notebook in your browser and open the tgi_gaudi.ipynb notebook
16+
5. Update all the configuration parameters in the second cell of the notebook
17+
6. Clear all the cells and run all the cells
18+
7. The output of the last cell which calls client.run_on_dataset() will run the langchain Q&A test and captures measurements in the langsmith server. The URL to access the test result can be obtained from the output of the command
19+
<br/><br/>
20+
21+
### Steps to run perf measurements with end_to_end_rag_test.py python script
22+
23+
1. This dir is mounted at /test in qna-rag-redis-server
24+
2. Make sure redis container and LLM serving is up and running
25+
3. enter into qna-rag-redis-server container and run the python script
26+
```
27+
docker exec -it qna-rag-redis-server bash
28+
cd /test
29+
python end_to_end_rag_test.py -l "<LLM model serving - TGI or VLLM>" -e <TEI embedding model serving> -m <LLM model name> -ht "<huggingface token>" -lt <langsmith api key> -dbs "<path to schema>" -dbu "<redis server URL>" -dbi "<DB Index name>" -d "<langsmith dataset name>"
30+
```
31+
4. Check the results in langsmith server
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2024 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import argparse
18+
import os
19+
import uuid
20+
from operator import itemgetter
21+
from typing import Any, List, Mapping, Optional, Sequence
22+
23+
from langchain.prompts import ChatPromptTemplate
24+
from langchain.schema.document import Document
25+
from langchain.schema.output_parser import StrOutputParser
26+
from langchain.schema.runnable.passthrough import RunnableAssign
27+
from langchain_benchmarks import clone_public_dataset, registry
28+
from langchain_benchmarks.rag import get_eval_config
29+
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceHubEmbeddings
30+
from langchain_community.llms import HuggingFaceEndpoint
31+
from langchain_community.vectorstores import Redis
32+
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
33+
from langchain_core.language_models.llms import LLM
34+
from langchain_core.prompt_values import ChatPromptValue
35+
from langchain_openai import ChatOpenAI
36+
from langsmith.client import Client
37+
from transformers import AutoTokenizer, LlamaForCausalLM
38+
39+
# Parameters and settings
40+
ENDPOINT_URL_GAUDI2 = "http://localhost:8000"
41+
ENDPOINT_URL_VLLM = "http://localhost:8001/v1"
42+
TEI_ENDPOINT = "http://localhost:8002"
43+
LANG_CHAIN_DATASET = "<Dataset name to add>"
44+
HF_MODEL_NAME = "<Model name to add>"
45+
PROMPT_TOKENS_LEN = 214 # Magic number for prompt template tokens
46+
MAX_INPUT_TOKENS = 1024
47+
MAX_OUTPUT_TOKENS = 128
48+
49+
# Generate a unique run ID for this experiment
50+
run_uid = uuid.uuid4().hex[:6]
51+
52+
tokenizer = None
53+
54+
55+
def crop_tokens(prompt, max_len):
56+
inputs = tokenizer(prompt, return_tensors="pt")
57+
inputs_cropped = inputs["input_ids"][0:, 0:max_len]
58+
prompt_cropped = tokenizer.batch_decode(
59+
inputs_cropped, skip_special_tokens=True, clean_up_tokenization_spaces=False
60+
)[0]
61+
return prompt_cropped
62+
63+
64+
# After the retriever fetches documents, this
65+
# function formats them in a string to present for the LLM
66+
def format_docs(docs: Sequence[Document]) -> str:
67+
formatted_docs = []
68+
for i, doc in enumerate(docs):
69+
doc_string = (
70+
f"<document index='{i}'>\n"
71+
f"<source>{doc.metadata.get('source')}</source>\n"
72+
f"<doc_content>{doc.page_content[0:]}</doc_content>\n"
73+
"</document>"
74+
)
75+
# Truncate the retrieval data based on the max tokens required
76+
cropped = crop_tokens(doc_string, MAX_INPUT_TOKENS - PROMPT_TOKENS_LEN)
77+
78+
formatted_docs.append(cropped) # doc_string
79+
formatted_str = "\n".join(formatted_docs)
80+
return f"<documents>\n{formatted_str}\n</documents>"
81+
82+
83+
def ingest_dataset(args, langchain_docs):
84+
clone_public_dataset(langchain_docs.dataset_id, dataset_name=langchain_docs.name)
85+
docs = list(langchain_docs.get_docs())
86+
embedder = HuggingFaceHubEmbeddings(model=args.embedding_endpoint_url)
87+
88+
_ = Redis.from_texts(
89+
# appending this little bit can sometimes help with semantic retrieval
90+
# especially with multiple companies
91+
texts=[d.page_content for d in docs],
92+
metadatas=[d.metadata for d in docs],
93+
embedding=embedder,
94+
index_name=args.db_index,
95+
index_schema=args.db_schema,
96+
redis_url=args.db_url,
97+
)
98+
99+
100+
def GetLangchainDataset(args):
101+
registry_retrieved = registry.filter(Type="RetrievalTask")
102+
langchain_docs = registry_retrieved[args.langchain_dataset]
103+
return langchain_docs
104+
105+
106+
def buildchain(args):
107+
embedder = HuggingFaceHubEmbeddings(model=args.embedding_endpoint_url)
108+
vectorstore = Redis.from_existing_index(
109+
embedding=embedder, index_name=args.db_index, schema=args.db_schema, redis_url=args.db_url
110+
)
111+
retriever = vectorstore.as_retriever(search_kwargs={"k": 1})
112+
prompt = ChatPromptTemplate.from_messages(
113+
[
114+
(
115+
"system",
116+
"You are an AI assistant answering questions about LangChain."
117+
"\n{context}\n"
118+
"Respond solely based on the document content.",
119+
),
120+
("human", "{question}"),
121+
]
122+
)
123+
124+
llm = None
125+
match args.llm_service_api:
126+
case "tgi-gaudi":
127+
llm = HuggingFaceEndpoint(
128+
endpoint_url=args.llm_endpoint_url,
129+
max_new_tokens=MAX_OUTPUT_TOKENS,
130+
top_k=10,
131+
top_p=0.95,
132+
typical_p=0.95,
133+
temperature=1.0,
134+
repetition_penalty=1.03,
135+
streaming=False,
136+
truncate=1024,
137+
)
138+
case "vllm-openai":
139+
llm = ChatOpenAI(
140+
model=args.model_name,
141+
openai_api_key="EMPTY",
142+
openai_api_base=args.llm_endpoint_url,
143+
max_tokens=MAX_OUTPUT_TOKENS,
144+
temperature=1.0,
145+
top_p=0.95,
146+
streaming=False,
147+
frequency_penalty=1.03,
148+
)
149+
150+
response_generator = (prompt | llm | StrOutputParser()).with_config(
151+
run_name="GenerateResponse",
152+
)
153+
154+
# This is the final response chain.
155+
# It fetches the "question" key from the input dict,
156+
# passes it to the retriever, then formats as a string.
157+
158+
chain = (
159+
RunnableAssign(
160+
{"context": (itemgetter("question") | retriever | format_docs).with_config(run_name="FormatDocs")}
161+
)
162+
# The "RunnableAssign" above returns a dict with keys
163+
# question (from the original input) and
164+
# context: the string-formatted docs.
165+
# This is passed to the response_generator above
166+
| response_generator
167+
)
168+
return chain
169+
170+
171+
def run_test(args, chain):
172+
client = Client()
173+
test_run = client.run_on_dataset(
174+
dataset_name=args.langchain_dataset,
175+
llm_or_chain_factory=chain,
176+
evaluation=None,
177+
project_name=f"{args.llm_service_api}-{args.model_name} op-{MAX_OUTPUT_TOKENS} cl-{args.concurrency} iter-{run_uid}",
178+
project_metadata={
179+
"index_method": "basic",
180+
},
181+
concurrency_level=args.concurrency,
182+
verbose=True,
183+
)
184+
185+
186+
if __name__ == "__main__":
187+
parser = argparse.ArgumentParser()
188+
parser.add_argument(
189+
"-l",
190+
"--llm_endpoint_url",
191+
type=str,
192+
required=False,
193+
default=ENDPOINT_URL_GAUDI2,
194+
help="LLM Service Endpoint URL",
195+
)
196+
parser.add_argument(
197+
"-e",
198+
"--embedding_endpoint_url",
199+
type=str,
200+
default=TEI_ENDPOINT,
201+
required=False,
202+
help="Embedding Service Endpoint URL",
203+
)
204+
parser.add_argument("-m", "--model_name", type=str, default=HF_MODEL_NAME, required=False, help="Model Name")
205+
parser.add_argument("-ht", "--huggingface_token", type=str, required=True, help="Huggingface API token")
206+
parser.add_argument("-lt", "--langchain_token", type=str, required=True, help="langchain API token")
207+
parser.add_argument(
208+
"-d",
209+
"--langchain_dataset",
210+
type=str,
211+
required=True,
212+
help="langchain dataset name Refer: https://docs.smith.langchain.com/evaluation/quickstart ",
213+
)
214+
215+
parser.add_argument("-c", "--concurrency", type=int, default=16, required=False, help="Concurrency Level")
216+
217+
parser.add_argument(
218+
"-lm",
219+
"--llm_service_api",
220+
type=str,
221+
default="tgi-gaudi",
222+
required=False,
223+
help='Choose between "tgi-gaudi" or "vllm-openai"',
224+
)
225+
226+
parser.add_argument(
227+
"-ig", "--ingest_dataset", type=bool, default=False, required=False, help='Set True to ingest dataset"'
228+
)
229+
230+
parser.add_argument("-dbu", "--db_url", type=str, required=True, help="Vector DB URL")
231+
232+
parser.add_argument("-dbs", "--db_schema", type=str, required=True, help="Vector DB Schema")
233+
234+
parser.add_argument("-dbi", "--db_index", type=str, required=True, help="Vector DB Index Name")
235+
236+
args = parser.parse_args()
237+
238+
if args.ingest_dataset:
239+
langchain_doc = GetLangchainDataset(args)
240+
ingest_dataset(args, langchain_doc)
241+
242+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
243+
os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
244+
os.environ["LANGCHAIN_API_KEY"] = args.langchain_token
245+
os.environ["HUGGINGFACEHUB_API_TOKEN"] = args.huggingface_token
246+
247+
chain = buildchain(args)
248+
run_test(args, chain)

0 commit comments

Comments
 (0)