Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 142 additions & 60 deletions notebooks/chatbot_score.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,32 @@
"outputs": [],
"source": [
"import os\n",
"from dotenv import load_dotenv\n",
"import warnings\n",
"import time\n",
"from dotenv import load_dotenv\n",
"from collections.abc import Callable\n",
"from typing import Any\n",
"import logging\n",
"\n",
"import weaviate\n",
"from weaviate.classes.query import Filter\n",
"\n",
"from langchain.chains.retrieval import create_retrieval_chain\n",
"from langchain.chains import create_retrieval_chain\n",
"from langchain.chains.combine_documents import create_stuff_documents_chain\n",
"from langchain.prompts.chat import (\n",
" ChatPromptTemplate,\n",
" HumanMessagePromptTemplate,\n",
" SystemMessagePromptTemplate,\n",
")\n",
"from langchain_core.prompts import MessagesPlaceholder\n",
"from langchain_openai import ChatOpenAI, OpenAIEmbeddings\n",
"from weaviate.classes.init import Auth\n",
"\n",
"\n",
"from langchain_core.documents.base import Document\n",
"from langchain_core.vectorstores.base import VectorStoreRetriever\n",
"from langchain_core.runnables import Runnable\n",
"from langchain_core.runnables import RunnableLambda\n",
"\n",
"from langchain_core.documents.base import Document\n",
"from langchain_openai import ChatOpenAI, OpenAIEmbeddings\n",
"from langchain_weaviate.vectorstores import WeaviateVectorStore\n",
"from typing import List, Callable"
"\n",
"import weaviate\n",
"from weaviate.classes.query import MetadataQuery, Filter\n",
"from weaviate.classes.init import Auth\n",
"from weaviate.classes.query import Filter\n",
"from weaviate.client import WeaviateClient"
]
},
{
Expand Down Expand Up @@ -89,44 +89,137 @@
{
"cell_type": "code",
"execution_count": null,
"id": "c4773d3f-aadf-4a35-81f9-0587f7d37a70",
"id": "3ebd68f4",
"metadata": {},
"outputs": [],
"source": [
"def configure_custom_retriever(client) -> Callable[[dict], List[Document]]:\n",
" vectorstore = WeaviateVectorStore(\n",
" client=client,\n",
" # index_name=\"LangChain_9787ec4b92d3438a8de3ff04ead7ead6\",\n",
" index_name=\"Ingestion_20250610\",\n",
" text_key=\"page_content\",\n",
" embedding=OpenAIEmbeddings(model=\"text-embedding-3-small\",\n",
" dimensions=1536\n",
" ),\n",
" attributes=[\"source\", \"source_key\"],\n",
" )\n",
"class CustomWeaviateVectorStore(WeaviateVectorStore):\n",
" \"\"\"Custom Vector Store overrides the similarity search function.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" client: Any,\n",
" index_name: str,\n",
" text_key: str,\n",
" embedding: Any,\n",
" attributes: list | None = None,\n",
" relevance_score_fn: Callable | None = None,\n",
" use_multi_tenancy: bool | None = None,\n",
" ) -> None:\n",
" \"\"\"Initialize the CustomWeaviateVectorStore class.\"\"\"\n",
" if use_multi_tenancy is None:\n",
" use_multi_tenancy = False\n",
"\n",
" sources = [\"github\", \"jira\", \"lsst_bib\", \"webpage\", \"discourse\"]\n",
" filters = Filter.by_property(\"source_key\").contains_any(sources)\n",
" self.client = client\n",
" self.index_name = index_name\n",
" self.text_key = text_key\n",
" self.embedding = embedding\n",
"\n",
" def retrieve(inputs: dict) -> List[Document]:\n",
" query = inputs[\"input\"]\n",
" super().__init__(\n",
" client=client,\n",
" index_name=index_name,\n",
" text_key=text_key,\n",
" embedding=embedding,\n",
" attributes=attributes,\n",
" relevance_score_fn=relevance_score_fn,\n",
" use_multi_tenancy=use_multi_tenancy,\n",
" )\n",
"\n",
" results = vectorstore.similarity_search_with_score(\n",
" def similarity_search(\n",
" self, query: str, k: int = 4, **kwargs: Any\n",
" ) -> list[Document]:\n",
" \"\"\"\n",
" Return list of documents most similar to the query text and their\n",
" score. A higher score means more similarity, with a max of 1.\n",
" \"\"\"\n",
" where_filter = kwargs.get(\"where_filter\")\n",
" collection = self.client.collections.get(self.index_name)\n",
" response = collection.query.hybrid(\n",
" query=query,\n",
" k=6,\n",
" filters=filters,\n",
" limit=k,\n",
" filters=where_filter,\n",
" alpha=1,\n",
" return_metadata=MetadataQuery(score=True, explain_score=True),\n",
" )\n",
" docs = []\n",
" for doc, score in results:\n",
" doc.metadata[\"similarity_score\"] = score\n",
" docs.append(doc)\n",
" return docs\n",
"\n",
" return retrieve\n",
" results = []\n",
" for obj in response.objects:\n",
" text = obj.properties.get(\"page_content\", \"\")\n",
" metadata = obj.properties.copy() if obj.properties else {}\n",
" metadata[\"score\"] = (\n",
" obj.metadata.score\n",
" ) # Inject the score into metadata\n",
" results.append(Document(page_content=text, metadata=metadata))\n",
" return results"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c4773d3f-aadf-4a35-81f9-0587f7d37a70",
"metadata": {},
"outputs": [],
"source": [
"def configure_client() -> WeaviateClient:\n",
" \"\"\"Configure the Weaviate client.\"\"\"\n",
" openai_api_key = os.getenv(\"OPENAI_API_KEY\")\n",
" weaviate_api_key = os.getenv(\"WEAVIATE_API_KEY\")\n",
" http_host = os.getenv(\"HTTP_HOST\")\n",
" grpc_host = os.getenv(\"GRPC_HOST\")\n",
"\n",
" if openai_api_key is None:\n",
" raise ValueError(\"OPENAI_API_KEY environment variable is not set\")\n",
" if weaviate_api_key is None:\n",
" raise ValueError(\"WEAVIATE_API_KEY environment variable is not set\")\n",
" if http_host is None:\n",
" raise ValueError(\"HTTP_HOST environment variable is not set\")\n",
" if grpc_host is None:\n",
" raise ValueError(\"GRPC_HOST environment variable is not set\")\n",
"\n",
" return weaviate.connect_to_custom(\n",
" http_host=http_host,\n",
" http_port=8080, # Database on port 80 in USDF\n",
" http_secure=False,\n",
" grpc_host=grpc_host,\n",
" grpc_port=50051,\n",
" grpc_secure=False,\n",
" auth_credentials=Auth.api_key(weaviate_api_key),\n",
" headers={\"X-OpenAI-Api-Key\": openai_api_key},\n",
" skip_init_checks=True,\n",
" )\n",
"\n",
"\n",
"def configure_retriever() -> VectorStoreRetriever:\n",
" \"\"\"Configure the Weaviate retriever.\"\"\"\n",
" selected_sources = [\n",
" \"github\", \"jira\", \"lsst_bib\", \"webpage\", \"discourse\"\n",
" ]\n",
" if selected_sources:\n",
" filters = Filter.by_property(\"source_key\").contains_any(\n",
" selected_sources\n",
" )\n",
"\n",
" search_kwargs = {\n",
" \"k\": 6,\n",
" \"where_filter\": filters\n",
" }\n",
"\n",
" return CustomWeaviateVectorStore(\n",
" client=configure_client(),\n",
" index_name=\"Ingestion_20250610\",\n",
" text_key=\"page_content\",\n",
" embedding=OpenAIEmbeddings(\n",
" model=\"text-embedding-3-small\", dimensions=1536\n",
" ),\n",
" attributes=[\"source\", \"source_key\"], # Metadata to fetch\n",
" ).as_retriever(\n",
" search_type=\"similarity\",\n",
" search_kwargs=search_kwargs,\n",
" )\n",
"\n",
"\n",
"def create_qa_chain(\n",
" input_retriever: Callable[[str], List[Document]],\n",
" input_retriever: Callable[[str], list[Document]],\n",
") -> Runnable:\n",
" \"\"\"Create a QA chain for the chatbot using a custom retriever.\"\"\"\n",
"\n",
Expand Down Expand Up @@ -177,35 +270,24 @@
" skip_init_checks=True,\n",
" )\n",
" \n",
" start = time.time()\n",
" retriever = configure_custom_retriever(client)\n",
" elapsed = time.time() - start\n",
" print(f\"Configure retriever took {elapsed:.2f}s\")\n",
" \n",
" start = time.time()\n",
" user_query = \"wWhat is LSST Cam?\"\n",
" retriever = configure_retriever()\n",
" qa_chain = create_qa_chain(retriever)\n",
" elapsed = time.time() - start\n",
" print(f\"Create QA chain took {elapsed:.2f}s\")\n",
"\n",
" query = \"instantiate lsst.daf.butler.Butler in Python tutorial?\"\n",
"\n",
" start = time.time()\n",
" result = qa_chain.invoke({\n",
" \"input\": query,\n",
" \"chat_history\": [],\n",
" })\n",
" elapsed = time.time() - start\n",
" print(f\"Invoke took {elapsed:.2f}s\")\n",
"\n",
" result = qa_chain.invoke(\n",
" {\n",
" \"input\": user_query,\n",
" \"chat_history\": []\n",
" }\n",
" )\n",
" print(\"Input:\\n\", result[\"input\"])\n",
" print(\"\\nAnswer:\\n\", result[\"answer\"])\n",
" \n",
"\n",
" print(\"\\nContext Documents:\")\n",
" for i, doc in enumerate(result[\"context\"], 1):\n",
" print(f\"\\n--- Document {i} ---\")\n",
" print(\"Similarity Score:\", doc.metadata.get(\"score\"))\n",
" print(\"Source:\", doc.metadata.get(\"source\"))\n",
" print(\"Repo:\", doc.metadata.get(\"repo\"))\n",
" print(\"Similarity Score:\", doc.metadata.get(\"similarity_score\"))\n",
" print(\"Content:\\n\", doc.page_content)\n",
"\n",
"\n",
Expand Down
Loading