|
| 1 | +#!/usr/bin/env python |
| 2 | + |
| 3 | +# Copyright (c) 2024 Intel Corporation |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | + |
| 17 | +import argparse |
| 18 | +import os |
| 19 | +import uuid |
| 20 | +from operator import itemgetter |
| 21 | +from typing import Any, List, Mapping, Optional, Sequence |
| 22 | + |
| 23 | +from langchain.prompts import ChatPromptTemplate |
| 24 | +from langchain.schema.document import Document |
| 25 | +from langchain.schema.output_parser import StrOutputParser |
| 26 | +from langchain.schema.runnable.passthrough import RunnableAssign |
| 27 | +from langchain_benchmarks import clone_public_dataset, registry |
| 28 | +from langchain_benchmarks.rag import get_eval_config |
| 29 | +from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceHubEmbeddings |
| 30 | +from langchain_community.llms import HuggingFaceEndpoint |
| 31 | +from langchain_community.vectorstores import Redis |
| 32 | +from langchain_core.callbacks.manager import CallbackManagerForLLMRun |
| 33 | +from langchain_core.language_models.llms import LLM |
| 34 | +from langchain_core.prompt_values import ChatPromptValue |
| 35 | +from langchain_openai import ChatOpenAI |
| 36 | +from langsmith.client import Client |
| 37 | +from transformers import AutoTokenizer, LlamaForCausalLM |
| 38 | + |
| 39 | +# Parameters and settings |
| 40 | +ENDPOINT_URL_GAUDI2 = "http://localhost:8000" |
| 41 | +ENDPOINT_URL_VLLM = "http://localhost:8001/v1" |
| 42 | +TEI_ENDPOINT = "http://localhost:8002" |
| 43 | +LANG_CHAIN_DATASET = "<Dataset name to add>" |
| 44 | +HF_MODEL_NAME = "<Model name to add>" |
| 45 | +PROMPT_TOKENS_LEN = 214 # Magic number for prompt template tokens |
| 46 | +MAX_INPUT_TOKENS = 1024 |
| 47 | +MAX_OUTPUT_TOKENS = 128 |
| 48 | + |
| 49 | +# Generate a unique run ID for this experiment |
| 50 | +run_uid = uuid.uuid4().hex[:6] |
| 51 | + |
| 52 | +tokenizer = None |
| 53 | + |
| 54 | + |
| 55 | +def crop_tokens(prompt, max_len): |
| 56 | + inputs = tokenizer(prompt, return_tensors="pt") |
| 57 | + inputs_cropped = inputs["input_ids"][0:, 0:max_len] |
| 58 | + prompt_cropped = tokenizer.batch_decode( |
| 59 | + inputs_cropped, skip_special_tokens=True, clean_up_tokenization_spaces=False |
| 60 | + )[0] |
| 61 | + return prompt_cropped |
| 62 | + |
| 63 | + |
| 64 | +# After the retriever fetches documents, this |
| 65 | +# function formats them in a string to present for the LLM |
| 66 | +def format_docs(docs: Sequence[Document]) -> str: |
| 67 | + formatted_docs = [] |
| 68 | + for i, doc in enumerate(docs): |
| 69 | + doc_string = ( |
| 70 | + f"<document index='{i}'>\n" |
| 71 | + f"<source>{doc.metadata.get('source')}</source>\n" |
| 72 | + f"<doc_content>{doc.page_content[0:]}</doc_content>\n" |
| 73 | + "</document>" |
| 74 | + ) |
| 75 | + # Truncate the retrieval data based on the max tokens required |
| 76 | + cropped = crop_tokens(doc_string, MAX_INPUT_TOKENS - PROMPT_TOKENS_LEN) |
| 77 | + |
| 78 | + formatted_docs.append(cropped) # doc_string |
| 79 | + formatted_str = "\n".join(formatted_docs) |
| 80 | + return f"<documents>\n{formatted_str}\n</documents>" |
| 81 | + |
| 82 | + |
| 83 | +def ingest_dataset(args, langchain_docs): |
| 84 | + clone_public_dataset(langchain_docs.dataset_id, dataset_name=langchain_docs.name) |
| 85 | + docs = list(langchain_docs.get_docs()) |
| 86 | + embedder = HuggingFaceHubEmbeddings(model=args.embedding_endpoint_url) |
| 87 | + |
| 88 | + _ = Redis.from_texts( |
| 89 | + # appending this little bit can sometimes help with semantic retrieval |
| 90 | + # especially with multiple companies |
| 91 | + texts=[d.page_content for d in docs], |
| 92 | + metadatas=[d.metadata for d in docs], |
| 93 | + embedding=embedder, |
| 94 | + index_name=args.db_index, |
| 95 | + index_schema=args.db_schema, |
| 96 | + redis_url=args.db_url, |
| 97 | + ) |
| 98 | + |
| 99 | + |
| 100 | +def GetLangchainDataset(args): |
| 101 | + registry_retrieved = registry.filter(Type="RetrievalTask") |
| 102 | + langchain_docs = registry_retrieved[args.langchain_dataset] |
| 103 | + return langchain_docs |
| 104 | + |
| 105 | + |
| 106 | +def buildchain(args): |
| 107 | + embedder = HuggingFaceHubEmbeddings(model=args.embedding_endpoint_url) |
| 108 | + vectorstore = Redis.from_existing_index( |
| 109 | + embedding=embedder, index_name=args.db_index, schema=args.db_schema, redis_url=args.db_url |
| 110 | + ) |
| 111 | + retriever = vectorstore.as_retriever(search_kwargs={"k": 1}) |
| 112 | + prompt = ChatPromptTemplate.from_messages( |
| 113 | + [ |
| 114 | + ( |
| 115 | + "system", |
| 116 | + "You are an AI assistant answering questions about LangChain." |
| 117 | + "\n{context}\n" |
| 118 | + "Respond solely based on the document content.", |
| 119 | + ), |
| 120 | + ("human", "{question}"), |
| 121 | + ] |
| 122 | + ) |
| 123 | + |
| 124 | + llm = None |
| 125 | + match args.llm_service_api: |
| 126 | + case "tgi-gaudi": |
| 127 | + llm = HuggingFaceEndpoint( |
| 128 | + endpoint_url=args.llm_endpoint_url, |
| 129 | + max_new_tokens=MAX_OUTPUT_TOKENS, |
| 130 | + top_k=10, |
| 131 | + top_p=0.95, |
| 132 | + typical_p=0.95, |
| 133 | + temperature=1.0, |
| 134 | + repetition_penalty=1.03, |
| 135 | + streaming=False, |
| 136 | + truncate=1024, |
| 137 | + ) |
| 138 | + case "vllm-openai": |
| 139 | + llm = ChatOpenAI( |
| 140 | + model=args.model_name, |
| 141 | + openai_api_key="EMPTY", |
| 142 | + openai_api_base=args.llm_endpoint_url, |
| 143 | + max_tokens=MAX_OUTPUT_TOKENS, |
| 144 | + temperature=1.0, |
| 145 | + top_p=0.95, |
| 146 | + streaming=False, |
| 147 | + frequency_penalty=1.03, |
| 148 | + ) |
| 149 | + |
| 150 | + response_generator = (prompt | llm | StrOutputParser()).with_config( |
| 151 | + run_name="GenerateResponse", |
| 152 | + ) |
| 153 | + |
| 154 | + # This is the final response chain. |
| 155 | + # It fetches the "question" key from the input dict, |
| 156 | + # passes it to the retriever, then formats as a string. |
| 157 | + |
| 158 | + chain = ( |
| 159 | + RunnableAssign( |
| 160 | + {"context": (itemgetter("question") | retriever | format_docs).with_config(run_name="FormatDocs")} |
| 161 | + ) |
| 162 | + # The "RunnableAssign" above returns a dict with keys |
| 163 | + # question (from the original input) and |
| 164 | + # context: the string-formatted docs. |
| 165 | + # This is passed to the response_generator above |
| 166 | + | response_generator |
| 167 | + ) |
| 168 | + return chain |
| 169 | + |
| 170 | + |
| 171 | +def run_test(args, chain): |
| 172 | + client = Client() |
| 173 | + test_run = client.run_on_dataset( |
| 174 | + dataset_name=args.langchain_dataset, |
| 175 | + llm_or_chain_factory=chain, |
| 176 | + evaluation=None, |
| 177 | + project_name=f"{args.llm_service_api}-{args.model_name} op-{MAX_OUTPUT_TOKENS} cl-{args.concurrency} iter-{run_uid}", |
| 178 | + project_metadata={ |
| 179 | + "index_method": "basic", |
| 180 | + }, |
| 181 | + concurrency_level=args.concurrency, |
| 182 | + verbose=True, |
| 183 | + ) |
| 184 | + |
| 185 | + |
| 186 | +if __name__ == "__main__": |
| 187 | + parser = argparse.ArgumentParser() |
| 188 | + parser.add_argument( |
| 189 | + "-l", |
| 190 | + "--llm_endpoint_url", |
| 191 | + type=str, |
| 192 | + required=False, |
| 193 | + default=ENDPOINT_URL_GAUDI2, |
| 194 | + help="LLM Service Endpoint URL", |
| 195 | + ) |
| 196 | + parser.add_argument( |
| 197 | + "-e", |
| 198 | + "--embedding_endpoint_url", |
| 199 | + type=str, |
| 200 | + default=TEI_ENDPOINT, |
| 201 | + required=False, |
| 202 | + help="Embedding Service Endpoint URL", |
| 203 | + ) |
| 204 | + parser.add_argument("-m", "--model_name", type=str, default=HF_MODEL_NAME, required=False, help="Model Name") |
| 205 | + parser.add_argument("-ht", "--huggingface_token", type=str, required=True, help="Huggingface API token") |
| 206 | + parser.add_argument("-lt", "--langchain_token", type=str, required=True, help="langchain API token") |
| 207 | + parser.add_argument( |
| 208 | + "-d", |
| 209 | + "--langchain_dataset", |
| 210 | + type=str, |
| 211 | + required=True, |
| 212 | + help="langchain dataset name Refer: https://docs.smith.langchain.com/evaluation/quickstart ", |
| 213 | + ) |
| 214 | + |
| 215 | + parser.add_argument("-c", "--concurrency", type=int, default=16, required=False, help="Concurrency Level") |
| 216 | + |
| 217 | + parser.add_argument( |
| 218 | + "-lm", |
| 219 | + "--llm_service_api", |
| 220 | + type=str, |
| 221 | + default="tgi-gaudi", |
| 222 | + required=False, |
| 223 | + help='Choose between "tgi-gaudi" or "vllm-openai"', |
| 224 | + ) |
| 225 | + |
| 226 | + parser.add_argument( |
| 227 | + "-ig", "--ingest_dataset", type=bool, default=False, required=False, help='Set True to ingest dataset"' |
| 228 | + ) |
| 229 | + |
| 230 | + parser.add_argument("-dbu", "--db_url", type=str, required=True, help="Vector DB URL") |
| 231 | + |
| 232 | + parser.add_argument("-dbs", "--db_schema", type=str, required=True, help="Vector DB Schema") |
| 233 | + |
| 234 | + parser.add_argument("-dbi", "--db_index", type=str, required=True, help="Vector DB Index Name") |
| 235 | + |
| 236 | + args = parser.parse_args() |
| 237 | + |
| 238 | + if args.ingest_dataset: |
| 239 | + langchain_doc = GetLangchainDataset(args) |
| 240 | + ingest_dataset(args, langchain_doc) |
| 241 | + |
| 242 | + tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
| 243 | + os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com" |
| 244 | + os.environ["LANGCHAIN_API_KEY"] = args.langchain_token |
| 245 | + os.environ["HUGGINGFACEHUB_API_TOKEN"] = args.huggingface_token |
| 246 | + |
| 247 | + chain = buildchain(args) |
| 248 | + run_test(args, chain) |
0 commit comments