diff --git a/backend/app/agent_types/tools_agent.py b/backend/app/agent_types/tools_agent.py index 0a061af17..a8c4c3f31 100644 --- a/backend/app/agent_types/tools_agent.py +++ b/backend/app/agent_types/tools_agent.py @@ -51,7 +51,7 @@ async def _get_messages(messages): def should_continue(messages): last_message = messages[-1] # If there is no function call, then we finish - if not last_message.tool_calls: + if not getattr(last_message, "tool_calls", None): return "end" # Otherwise if there is, we continue else: diff --git a/backend/app/api/assistants.py b/backend/app/api/assistants.py index dda15581d..8458b48c5 100644 --- a/backend/app/api/assistants.py +++ b/backend/app/api/assistants.py @@ -1,7 +1,7 @@ -from typing import Annotated, List, Optional +from typing import Annotated, List from uuid import uuid4 -from fastapi import APIRouter, HTTPException, Path, Query +from fastapi import APIRouter, HTTPException, Path from pydantic import BaseModel, Field import app.storage as storage @@ -10,8 +10,6 @@ router = APIRouter() -FEATURED_PUBLIC_ASSISTANTS = [] - class AssistantPayload(BaseModel): """Payload for creating an assistant.""" @@ -31,15 +29,9 @@ async def list_assistants(user: AuthedUser) -> List[Assistant]: @router.get("/public/") -async def list_public_assistants( - shared_id: Annotated[ - Optional[str], Query(description="ID of a publicly shared assistant.") - ] = None, -) -> List[Assistant]: +async def list_public_assistants() -> List[Assistant]: """List all public assistants.""" - return await storage.list_public_assistants( - FEATURED_PUBLIC_ASSISTANTS + ([shared_id] if shared_id else []) - ) + return await storage.list_public_assistants() @router.get("/{aid}") diff --git a/backend/app/chatbot.py b/backend/app/chatbot.py index fe19725c6..eeb5b7872 100644 --- a/backend/app/chatbot.py +++ b/backend/app/chatbot.py @@ -1,11 +1,12 @@ from typing import Annotated, List -from app.message_types import add_messages_liberal from langchain_core.language_models.base import LanguageModelLike from langchain_core.messages import BaseMessage, SystemMessage from langgraph.checkpoint import BaseCheckpointSaver from langgraph.graph.state import StateGraph +from app.message_types import add_messages_liberal + def get_chatbot_executor( llm: LanguageModelLike, diff --git a/backend/app/message_types.py b/backend/app/message_types.py index bd1f76956..9ceea94b2 100644 --- a/backend/app/message_types.py +++ b/backend/app/message_types.py @@ -6,7 +6,7 @@ MessageLikeRepresentation, ToolMessage, ) -from langgraph.graph.message import add_messages, Messages +from langgraph.graph.message import Messages, add_messages class LiberalFunctionMessage(FunctionMessage): diff --git a/backend/app/storage.py b/backend/app/storage.py index 1c6e8c355..eb6871bb0 100644 --- a/backend/app/storage.py +++ b/backend/app/storage.py @@ -19,23 +19,16 @@ async def get_assistant(user_id: str, assistant_id: str) -> Optional[Assistant]: """Get an assistant by ID.""" async with get_pg_pool().acquire() as conn: return await conn.fetchrow( - "SELECT * FROM assistant WHERE assistant_id = $1 AND (user_id = $2 OR public = true)", + "SELECT * FROM assistant WHERE assistant_id = $1 AND (user_id = $2 OR public IS true)", assistant_id, user_id, ) -async def list_public_assistants(assistant_ids: Sequence[str]) -> List[Assistant]: +async def list_public_assistants() -> List[Assistant]: """List all the public assistants.""" async with get_pg_pool().acquire() as conn: - return await conn.fetch( - ( - "SELECT * FROM assistant " - "WHERE assistant_id = ANY($1::uuid[]) " - "AND public = true;" - ), - assistant_ids, - ) + return await conn.fetch(("SELECT * FROM assistant WHERE public IS true;")) async def put_assistant( @@ -119,13 +112,22 @@ async def get_thread_state(*, user_id: str, thread_id: str, assistant_id: str): async def update_thread_state( config: RunnableConfig, - values: Union[Sequence[AnyMessage], dict[str, Any]], + values: Union[ + Sequence[AnyMessage], dict[str, Any] + ], # TODO: update to for StateGraphs? *, user_id: str, assistant_id: str, ): """Add state to a thread.""" assistant = await get_assistant(user_id, assistant_id) + as_node = None + # TODO: Somehow these checks don't get called but hten the aupdate_state doesn't + # fail due to an ambiguous node error ¯\(°_o)/¯ + if isinstance(values, dict): + as_node = "__start__" if values.get("type", "") == "human" else None + if isinstance(values, AnyMessage): + as_node = "__start__" if values.type == "human" else None await agent.aupdate_state( { "configurable": { @@ -135,6 +137,7 @@ async def update_thread_state( } }, values, + as_node=as_node, ) diff --git a/backend/app/upload.py b/backend/app/upload.py index bd09939b2..8c12bedb4 100644 --- a/backend/app/upload.py +++ b/backend/app/upload.py @@ -12,8 +12,8 @@ import os from typing import BinaryIO, List, Optional -from langchain_core.document_loaders.blob_loaders import Blob from langchain_community.vectorstores.pgvector import PGVector +from langchain_core.document_loaders.blob_loaders import Blob from langchain_core.runnables import ( ConfigurableField, RunnableConfig, diff --git a/frontend/src/components/Config.tsx b/frontend/src/components/Config.tsx index aeb039073..30dbc4ccb 100644 --- a/frontend/src/components/Config.tsx +++ b/frontend/src/components/Config.tsx @@ -412,11 +412,8 @@ function ToolSelectionField(props: { ); } -function PublicLink(props: { assistantId: string }) { - const currentLink = window.location.href; - const link = currentLink.includes(props.assistantId) - ? currentLink - : currentLink + "?shared_id=" + props.assistantId; +function PublicLink() { + const link = window.location.href; return (
) : ( - <> - {props.config?.public && ( - - )} - + <>{props.config?.public && } ); return (