diff --git a/Changelog.md b/Changelog.md index a9f9c850..7aec7582 100644 --- a/Changelog.md +++ b/Changelog.md @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - bugfix/28-cannot-auth-tools-list (2024-11-30) ### Changed + - feat/290-api-as-a-tool (2025-05-06) - feat/288-add-contact-create-to-oauth (2025-05-05) - feat/287-limit-input-buttons (2025-05-04) - feat/281-public-server-display-on-landing (2025-05-03) diff --git a/backend/.vscode/launch.json b/backend/.vscode/launch.json index ff7e64d5..418b64e0 100644 --- a/backend/.vscode/launch.json +++ b/backend/.vscode/launch.json @@ -61,6 +61,15 @@ "console": "integratedTerminal", "justMyCode": false, "envFile": "${workspaceFolder}/.env.test" + }, + { + "name": "Python: Run src.tools.api-new", + "type": "debugpy", + "request": "launch", + "module": "src.tools.api_new", + "console": "integratedTerminal", + "justMyCode": false, + "envFile": "${workspaceFolder}/.env" } ] } \ No newline at end of file diff --git a/backend/migrations/env.py b/backend/migrations/env.py index 84a09037..d8223f54 100644 --- a/backend/migrations/env.py +++ b/backend/migrations/env.py @@ -8,7 +8,7 @@ from sqlalchemy import engine_from_config from sqlalchemy import pool from alembic import context -from src.models import Base +from src.services.db import Base from src.constants import DB_URI config = context.config diff --git a/backend/migrations/versions/0010_add_tools_table.py b/backend/migrations/versions/0010_add_tools_table.py new file mode 100644 index 00000000..8c6e5a99 --- /dev/null +++ b/backend/migrations/versions/0010_add_tools_table.py @@ -0,0 +1,48 @@ +"""Add tools table + +Revision ID: 0010 +Revises: 0009 +Create Date: 2025-05-09T00:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '0010' +down_revision: Union[str, None] = '0009' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +def upgrade() -> None: + # Create tools table + op.create_table( + 'tools', + sa.Column('id', postgresql.UUID(as_uuid=True), server_default=sa.text('gen_random_uuid()'), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('description', sa.String(), nullable=True), + sa.Column('url', sa.String(), nullable=True), + sa.Column('spec', sa.JSON(), nullable=True), + sa.Column('headers', sa.JSON(), nullable=True), + sa.Column('tags', sa.JSON(), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + + # Create indexes + op.create_index('tools_user_id_idx', 'tools', ['user_id'], unique=False) + op.create_index('tools_name_idx', 'tools', ['name'], unique=False) + op.create_index('tools_url_idx', 'tools', ['url'], unique=False) + # op.create_index('tools_tags_idx', 'tools', ['tags'], unique=False) + # Create unique index on name and user_id + op.create_index('tools_name_user_id_idx', 'tools', ['name', 'user_id'], unique=True) + +def downgrade() -> None: + # Drop tools table + op.drop_table('tools') \ No newline at end of file diff --git a/backend/src/controllers/agent.py b/backend/src/controllers/agent.py index b69fb01d..060f8173 100644 --- a/backend/src/controllers/agent.py +++ b/backend/src/controllers/agent.py @@ -11,12 +11,13 @@ from src.repos.user_repo import UserRepo from src.utils.logger import logger from src.utils.a2a import process_a2a_streaming, process_a2a - +from src.repos.tool_repo import ToolRepo class AgentController: def __init__(self, db: AsyncSession, user_id: str = None, agent_id: str = None): # type: ignore self.user_repo = UserRepo(db=db, user_id=user_id) self.agent_repo = AgentRepo(db=db, user_id=user_id) + self.tool_repo = ToolRepo(db=db, user_id=user_id) self.agent_id = agent_id async def anew_thread( @@ -41,7 +42,7 @@ async def anew_thread( # else: # return await process_a2a(new_thread, thread_id) - agent = Agent(config=config, user_repo=self.user_repo) + agent = Agent(config=config, user_repo=self.user_repo, tool_repo=self.tool_repo) await agent.abuilder(tools=new_thread.tools, model_name=new_thread.model, mcp=new_thread.mcp, @@ -89,7 +90,7 @@ async def aexisting_thread( # else: # return await process_a2a(existing_thread, thread_id) - agent = Agent(config=config, user_repo=self.user_repo) + agent = Agent(config=config, user_repo=self.user_repo, tool_repo=self.tool_repo) await agent.abuilder(tools=existing_thread.tools, model_name=existing_thread.model, mcp=existing_thread.mcp, @@ -129,7 +130,7 @@ async def async_agent_thread( "system": settings.get("system") or None } - agent = Agent(config=config, user_repo=self.user_repo) + agent = Agent(config=config, user_repo=self.user_repo, tool_repo=self.tool_repo) await agent.abuilder(tools=settings.get("tools", []), model_name=settings.get("model"), mcp=settings.get("mcp", None)) if thread_id: messages = agent.existing_thread(query, settings.get("images")) diff --git a/backend/src/entities/__init__.py b/backend/src/entities/__init__.py index 50b954da..17358c59 100644 --- a/backend/src/entities/__init__.py +++ b/backend/src/entities/__init__.py @@ -138,4 +138,24 @@ class SearchKwargs(dict): k: int = 3 fetch_k: int = 2 lambda_mult: float = 0.5 - filter: str = None \ No newline at end of file + filter: str = None + +class AgentTool(BaseModel): + name: str + description: str + url: str + spec: dict | str = None + headers: dict = None + tags: list[str] = None + + model_config = { + "json_schema_extra": {"example": {"name": "tool1", "description": "tool1 description", "url": "http://localhost:8050/openapi.json", "spec": None, "headers": {}, "tags": ["tag1", "tag2"]}} + } + +class AgentToolList(BaseModel): + tools: list[AgentTool] = Field(...) + + model_config = { + "json_schema_extra": {"example": {"tools": [{"name": "tool1", "description": "tool1 description", "url": "https://tool1.com", "spec": None, "headers": {}, "tags": ["tag1", "tag2"]}]}} + } + \ No newline at end of file diff --git a/backend/src/models/__init__.py b/backend/src/models/__init__.py index 029776a3..0ce46fde 100644 --- a/backend/src/models/__init__.py +++ b/backend/src/models/__init__.py @@ -1,356 +1,7 @@ -import json -from typing import Optional -from datetime import datetime -import sqlalchemy as sa -from sqlalchemy import Column, String, DateTime, Text, ForeignKey, Boolean -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.sql import func -from sqlalchemy.dialects.postgresql import UUID -from passlib.context import CryptContext -from pydantic import BaseModel -from sqlalchemy.orm import relationship -from sqlalchemy.sql import text - -Base = sa.orm.declarative_base() -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - -class ProtectedUser(BaseModel): - id: str - username: str - email: str - name: str - created_at: datetime - updated_at: Optional[datetime] = None - -class User(Base): - __tablename__ = "users" - - id = Column( - UUID(as_uuid=True), - primary_key=True, - index=True, - server_default=sa.text("uuid_generate_v4()") - ) - username = Column(String, unique=True, index=True) - email = Column(String, unique=True, index=True) - name = Column(String) - hashed_password = Column(String) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), onupdate=func.now()) - - @staticmethod - def get_password_hash(password: str) -> str: - return pwd_context.hash(password) - - @staticmethod - def verify_password(plain_password: str, hashed_password: str) -> bool: - return pwd_context.verify(plain_password, hashed_password) - - def protected(self) -> ProtectedUser: - """Return a dictionary representation of user without sensitive data.""" - return ProtectedUser( - id=str(self.id), - username=self.username, - email=self.email, - name=self.name, - created_at=self.created_at, - updated_at=self.updated_at - ) - -class Token(Base): - __tablename__ = "tokens" - - user_id = Column( - UUID(as_uuid=True), - ForeignKey('users.id', ondelete='CASCADE'), - primary_key=True - ) - key = Column(String, primary_key=True) - value = Column(Text, nullable=False) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), onupdate=func.now()) - - @staticmethod - def encrypt_value(value: str, secret_key: str) -> str: - from cryptography.fernet import Fernet - f = Fernet(secret_key.encode()) - return f.encrypt(value.encode()).decode() - - @staticmethod - def decrypt_value(encrypted_value: str, secret_key: str) -> str: - from cryptography.fernet import Fernet - f = Fernet(secret_key.encode()) - return f.decrypt(encrypted_value.encode()).decode() - -class Settings(Base): - __tablename__ = "settings" - - id = Column( - UUID(as_uuid=True), - primary_key=True, - server_default=sa.text("gen_random_uuid()") - ) - user_id = Column( - UUID(as_uuid=True), - ForeignKey('users.id', ondelete='CASCADE'), - primary_key=True - ) - name = Column(String, nullable=False) - slug = Column(String, nullable=False, unique=True, index=True) - value = Column(sa.JSON, nullable=False) - created_at = Column(DateTime(timezone=True), server_default=func.now()) - updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) - - # Only relationship to Revisions, not directly to Agents - revisions = relationship("Revision", back_populates="setting") - - def to_dict(self) -> dict: - return { - "id": str(self.id), - "name": self.name, - "slug": self.slug, - "value": self.value - } - - @staticmethod - def generate_slug(name: str) -> str: - """Generate a URL-friendly slug from the name.""" - import re - # Convert to lowercase and replace spaces with dashes - slug = name.lower().strip().replace(' ', '-') - # Remove special characters - slug = re.sub(r'[^a-z0-9-]', '', slug) - # Replace multiple dashes with single dash - slug = re.sub(r'-+', '-', slug) - return slug - - def __init__(self, name: str, value: dict, **kwargs): - """Initialize a new setting with auto-generated slug.""" - super().__init__(**kwargs) - self.name = name - self.value = value - if 'slug' not in kwargs: - self.slug = self.generate_slug(name) - -class Agent(Base): - __tablename__ = "agents" - - id = Column(UUID(as_uuid=True), primary_key=True, server_default=text("gen_random_uuid()")) - user_id = Column( - UUID(as_uuid=True), - ForeignKey('users.id', ondelete='CASCADE'), - nullable=False - ) - name = Column(String, nullable=False) - slug = Column(String, nullable=False, unique=True) - description = Column(String, nullable=True) - public = Column(Boolean, nullable=False, server_default='false') - revision_number = Column(sa.Integer, nullable=False) - created_at = Column(DateTime(timezone=True), server_default=text("now()"), nullable=False) - updated_at = Column(DateTime(timezone=True), server_default=text("now()"), onupdate=text("now()"), nullable=False) - - # No direct relationship to Settings anymore - # Only a relationship to Revisions - revisions = relationship("Revision", back_populates="agent", cascade="all, delete-orphan") - - # New method to get the current settings via the active revision - @property - def active_revision(self): - # Find the revision that matches the current revision_number - for revision in self.revisions: - if revision.revision_number == self.revision_number: - return revision - return None - - @property - def settings(self): - revision = self.active_revision - if revision: - return revision.setting - return None - - def to_dict(self, include_setting: bool = True) -> dict: - result = { - "id": str(self.id), - "user_id": str(self.user_id), - "name": self.name, - "slug": self.slug, - "description": self.description, - "public": self.public, - "revision_number": self.revision_number, - "created_at": self.created_at.isoformat() if self.created_at else None, - "updated_at": self.updated_at.isoformat() if self.updated_at else None - } - - if include_setting and self.settings: - result["setting"] = self.settings.to_dict() +from src.models.server import Server +from src.models.user import User, Token, ProtectedUser +from src.models.setting import Settings +from src.models.thread import Thread +from src.models.agent import Agent, Revision - return result - - def to_json(self) -> str: - return json.dumps(self.to_dict()) - - @staticmethod - def generate_slug(name: str) -> str: - """Generate a URL-friendly slug from the name.""" - import re - # Convert to lowercase and replace spaces with dashes - slug = name.lower().strip().replace(' ', '-') - # Remove special characters - slug = re.sub(r'[^a-z0-9-]', '', slug) - # Replace multiple dashes with single dash - slug = re.sub(r'-+', '-', slug) - return slug - - def __init__(self, name: str, description: str, public: bool, **kwargs): - """Initialize a new setting with auto-generated slug.""" - super().__init__(**kwargs) - self.name = name - self.description = description - self.public = public - if 'slug' not in kwargs: - self.slug = self.generate_slug(name) - -class Revision(Base): - __tablename__ = "revisions" - - id = Column(UUID(as_uuid=True), primary_key=True, server_default=text("gen_random_uuid()")) - agent_id = Column(UUID(as_uuid=True), ForeignKey('agents.id', ondelete='CASCADE'), nullable=False) - user_id = Column(UUID(as_uuid=True), ForeignKey('users.id', ondelete='CASCADE'), nullable=False) - settings_id = Column(UUID(as_uuid=True), ForeignKey('settings.id', ondelete='CASCADE'), nullable=False) - revision_number = Column(sa.Integer, nullable=False) - name = Column(String, nullable=True) - description = Column(String, nullable=True) - created_at = Column(DateTime(timezone=True), server_default=text("now()"), nullable=False) - updated_at = Column(DateTime(timezone=True), server_default=text("now()"), onupdate=text("now()"), nullable=False) - - # Relationships - agent = relationship("Agent", back_populates="revisions") - setting = relationship("Settings", back_populates="revisions", foreign_keys=[settings_id]) - - def to_dict(self) -> dict: - return { - "id": str(self.id), - "agent_id": str(self.agent_id), - "user_id": str(self.user_id), - "settings_id": str(self.settings_id), - "revision_number": self.revision_number, - "name": self.name, - "description": self.description, - "created_at": self.created_at.isoformat() if self.created_at else None, - "updated_at": self.updated_at.isoformat() if self.updated_at else None - } - -class Thread(Base): - __tablename__ = "threads" - - user = Column( - UUID(as_uuid=True), - ForeignKey('users.id', ondelete='CASCADE'), - primary_key=True, - nullable=False - ) - thread = Column( - UUID(as_uuid=True), - primary_key=True, - nullable=False - ) - agent = Column( - UUID(as_uuid=True), - ForeignKey('agents.id', ondelete='CASCADE'), - nullable=True - ) - created_at = Column( - DateTime(timezone=True), - server_default=func.now(), - nullable=False - ) - - # Add relationships - user_relation = relationship("User", backref="threads") - agent_relation = relationship("Agent", backref="threads") - - def to_dict(self) -> dict: - return { - "user": str(self.user), - "thread": str(self.thread), - "agent": str(self.agent) if self.agent else None, - "created_at": self.created_at.isoformat() if self.created_at else None - } - -class Server(Base): - __tablename__ = "servers" - - id = Column( - UUID(as_uuid=True), - primary_key=True, - server_default=sa.text("gen_random_uuid()") - ) - user_id = Column( - UUID(as_uuid=True), - ForeignKey('users.id', ondelete='CASCADE'), - nullable=False - ) - name = Column(String, nullable=False) - slug = Column(String, nullable=False, unique=True, index=True) - description = Column(Text, nullable=True) - type = Column(String, nullable=False) # 'mcp' or 'a2a' - config = Column(sa.JSON, nullable=False) - documentation = Column(Text, nullable=True) # Markdown documentation - documentation_url = Column(String, nullable=True) # External documentation URL - public = Column(Boolean, nullable=False, server_default='false') - created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) - updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) - - # User relationship - user = relationship("User", backref="servers") - - def to_dict(self, include_config: bool = True) -> dict: - result = { - "id": str(self.id), - "user_id": str(self.user_id), - "name": self.name, - "slug": self.slug, - "description": self.description, - "type": self.type, - "public": self.public, - "documentation": self.documentation, - "documentation_url": self.documentation_url, - "created_at": self.created_at.isoformat() if self.created_at else None, - "updated_at": self.updated_at.isoformat() if self.updated_at else None - } - - if include_config: - result["config"] = self.config - - return result - - def to_json(self) -> str: - return json.dumps(self.to_dict()) - - @staticmethod - def generate_slug(name: str) -> str: - """Generate a URL-friendly slug from the name.""" - import re - # Convert to lowercase and replace spaces with dashes - slug = name.lower().strip().replace(' ', '-') - # Remove special characters - slug = re.sub(r'[^a-z0-9-]', '', slug) - # Replace multiple dashes with single dash - slug = re.sub(r'-+', '-', slug) - return slug - - def __init__(self, name: str, type: str, config: dict, description: str = None, - documentation: str = None, documentation_url: str = None, - public: bool = False, **kwargs): - """Initialize a new server with auto-generated slug.""" - super().__init__(**kwargs) - self.name = name - self.type = type - self.config = config - self.description = description - self.documentation = documentation - self.documentation_url = documentation_url - self.public = public - if 'slug' not in kwargs: - self.slug = self.generate_slug(name) \ No newline at end of file +__all__ = ["Server", "User", "Token", "ProtectedUser", "Settings", "Thread", "Agent", "Revision"] \ No newline at end of file diff --git a/backend/src/models/agent.py b/backend/src/models/agent.py new file mode 100644 index 00000000..ccfbe046 --- /dev/null +++ b/backend/src/models/agent.py @@ -0,0 +1,160 @@ +import sqlalchemy as sa +from sqlalchemy import JSON, Column, String, Boolean, DateTime, ForeignKey, UUID, text +from sqlalchemy.orm import relationship +import json + +from src.services.db import Base + +class Agent(Base): + __tablename__ = "agents" + + id = Column(UUID(as_uuid=True), primary_key=True, server_default=text("gen_random_uuid()")) + user_id = Column( + UUID(as_uuid=True), + ForeignKey('users.id', ondelete='CASCADE'), + nullable=False + ) + name = Column(String, nullable=False) + slug = Column(String, nullable=False, unique=True) + description = Column(String, nullable=True) + public = Column(Boolean, nullable=False, server_default='false') + revision_number = Column(sa.Integer, nullable=False) + created_at = Column(DateTime(timezone=True), server_default=text("now()"), nullable=False) + updated_at = Column(DateTime(timezone=True), server_default=text("now()"), onupdate=text("now()"), nullable=False) + + # No direct relationship to Settings anymore + # Only a relationship to Revisions + revisions = relationship("Revision", back_populates="agent", cascade="all, delete-orphan") + + # New method to get the current settings via the active revision + @property + def active_revision(self): + # Find the revision that matches the current revision_number + for revision in self.revisions: + if revision.revision_number == self.revision_number: + return revision + return None + + @property + def settings(self): + revision = self.active_revision + if revision: + return revision.setting + return None + + def to_dict(self, include_setting: bool = True) -> dict: + result = { + "id": str(self.id), + "user_id": str(self.user_id), + "name": self.name, + "slug": self.slug, + "description": self.description, + "public": self.public, + "revision_number": self.revision_number, + "created_at": self.created_at.isoformat() if self.created_at else None, + "updated_at": self.updated_at.isoformat() if self.updated_at else None + } + + if include_setting and self.settings: + result["setting"] = self.settings.to_dict() + + return result + + def to_json(self) -> str: + return json.dumps(self.to_dict()) + + @staticmethod + def generate_slug(name: str) -> str: + """Generate a URL-friendly slug from the name.""" + import re + # Convert to lowercase and replace spaces with dashes + slug = name.lower().strip().replace(' ', '-') + # Remove special characters + slug = re.sub(r'[^a-z0-9-]', '', slug) + # Replace multiple dashes with single dash + slug = re.sub(r'-+', '-', slug) + return slug + + def __init__(self, name: str, description: str, public: bool, **kwargs): + """Initialize a new setting with auto-generated slug.""" + super().__init__(**kwargs) + self.name = name + self.description = description + self.public = public + if 'slug' not in kwargs: + self.slug = self.generate_slug(name) + +class Revision(Base): + __tablename__ = "revisions" + + id = Column(UUID(as_uuid=True), primary_key=True, server_default=text("gen_random_uuid()")) + agent_id = Column(UUID(as_uuid=True), ForeignKey('agents.id', ondelete='CASCADE'), nullable=False) + user_id = Column(UUID(as_uuid=True), ForeignKey('users.id', ondelete='CASCADE'), nullable=False) + settings_id = Column(UUID(as_uuid=True), ForeignKey('settings.id', ondelete='CASCADE'), nullable=False) + revision_number = Column(sa.Integer, nullable=False) + name = Column(String, nullable=True) + description = Column(String, nullable=True) + created_at = Column(DateTime(timezone=True), server_default=text("now()"), nullable=False) + updated_at = Column(DateTime(timezone=True), server_default=text("now()"), onupdate=text("now()"), nullable=False) + + # Relationships + agent = relationship("Agent", back_populates="revisions") + setting = relationship("Settings", back_populates="revisions", foreign_keys=[settings_id]) + + def to_dict(self) -> dict: + return { + "id": str(self.id), + "agent_id": str(self.agent_id), + "user_id": str(self.user_id), + "settings_id": str(self.settings_id), + "revision_number": self.revision_number, + "name": self.name, + "description": self.description, + "created_at": self.created_at.isoformat() if self.created_at else None, + "updated_at": self.updated_at.isoformat() if self.updated_at else None + } + + +class Tool(Base): + __tablename__ = "tools" + + id = Column(UUID(as_uuid=True), primary_key=True, server_default=text("gen_random_uuid()")) + user_id = Column(UUID(as_uuid=True), ForeignKey('users.id', ondelete='CASCADE'), nullable=False) + name = Column(String, nullable=False) + url = Column(String, nullable=True) + spec = Column(sa.JSON, nullable=True) # Fallback if no URL is provided, could be YAML or JSON format + headers = Column(sa.JSON, nullable=True, info={'encrypted': True}) # Encrypted at rest + tags = Column(sa.JSON, nullable=True) + description = Column(String, nullable=True) + created_at = Column(DateTime(timezone=True), server_default=text("now()"), nullable=False) + updated_at = Column(DateTime(timezone=True), server_default=text("now()"), onupdate=text("now()"), nullable=False) + + # Relationships + # user = relationship("User", back_populates="tools") + + def to_dict(self) -> dict: + return { + "id": str(self.id), + # "user_id": str(self.user_id), + "name": self.name, + "slug": self.generate_slug(self.name), + "description": self.description, + "url": self.url, + "spec": self.spec, + "headers": self.headers if self.headers else {}, + "tags": self.tags if self.tags else [], + "created_at": self.created_at.isoformat() if self.created_at else None, + "updated_at": self.updated_at.isoformat() if self.updated_at else None + } + + @staticmethod + def generate_slug(name: str) -> str: + """Generate a URL-friendly slug from the name.""" + import re + # Convert to lowercase and replace spaces with underscores + slug = name.lower().strip().replace(' ', '_') + # Remove special characters + slug = re.sub(r'[^a-z0-9_]', '', slug) + # Replace multiple underscores with single underscore + slug = re.sub(r'_+', '_', slug) + return slug \ No newline at end of file diff --git a/backend/src/models/server.py b/backend/src/models/server.py new file mode 100644 index 00000000..736e096a --- /dev/null +++ b/backend/src/models/server.py @@ -0,0 +1,83 @@ +import sqlalchemy as sa +from sqlalchemy import Column, String, Text, Boolean, DateTime, func, ForeignKey, UUID +from sqlalchemy.orm import relationship +import json + +from src.services.db import Base + +class Server(Base): + __tablename__ = "servers" + + id = Column( + UUID(as_uuid=True), + primary_key=True, + server_default=sa.text("gen_random_uuid()") + ) + user_id = Column( + UUID(as_uuid=True), + ForeignKey('users.id', ondelete='CASCADE'), + nullable=False + ) + name = Column(String, nullable=False) + slug = Column(String, nullable=False, unique=True, index=True) + description = Column(Text, nullable=True) + type = Column(String, nullable=False) # 'mcp' or 'a2a' + config = Column(sa.JSON, nullable=False) + documentation = Column(Text, nullable=True) # Markdown documentation + documentation_url = Column(String, nullable=True) # External documentation URL + public = Column(Boolean, nullable=False, server_default='false') + created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) + updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) + + # User relationship + user = relationship("User", backref="servers") + + def to_dict(self, include_config: bool = True) -> dict: + result = { + "id": str(self.id), + "user_id": str(self.user_id), + "name": self.name, + "slug": self.slug, + "description": self.description, + "type": self.type, + "public": self.public, + "documentation": self.documentation, + "documentation_url": self.documentation_url, + "created_at": self.created_at.isoformat() if self.created_at else None, + "updated_at": self.updated_at.isoformat() if self.updated_at else None + } + + if include_config: + result["config"] = self.config + + return result + + def to_json(self) -> str: + return json.dumps(self.to_dict()) + + @staticmethod + def generate_slug(name: str) -> str: + """Generate a URL-friendly slug from the name.""" + import re + # Convert to lowercase and replace spaces with dashes + slug = name.lower().strip().replace(' ', '-') + # Remove special characters + slug = re.sub(r'[^a-z0-9-]', '', slug) + # Replace multiple dashes with single dash + slug = re.sub(r'-+', '-', slug) + return slug + + def __init__(self, name: str, type: str, config: dict, description: str = None, + documentation: str = None, documentation_url: str = None, + public: bool = False, **kwargs): + """Initialize a new server with auto-generated slug.""" + super().__init__(**kwargs) + self.name = name + self.type = type + self.config = config + self.description = description + self.documentation = documentation + self.documentation_url = documentation_url + self.public = public + if 'slug' not in kwargs: + self.slug = self.generate_slug(name) \ No newline at end of file diff --git a/backend/src/models/setting.py b/backend/src/models/setting.py new file mode 100644 index 00000000..1a4c6c6b --- /dev/null +++ b/backend/src/models/setting.py @@ -0,0 +1,54 @@ +import sqlalchemy as sa +from sqlalchemy import Column, String, DateTime, func, ForeignKey, UUID +from sqlalchemy.orm import relationship +from src.services.db import Base + +class Settings(Base): + __tablename__ = "settings" + + id = Column( + UUID(as_uuid=True), + primary_key=True, + server_default=sa.text("gen_random_uuid()") + ) + user_id = Column( + UUID(as_uuid=True), + ForeignKey('users.id', ondelete='CASCADE'), + primary_key=True + ) + name = Column(String, nullable=False) + slug = Column(String, nullable=False, unique=True, index=True) + value = Column(sa.JSON, nullable=False) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now()) + + # Only relationship to Revisions, not directly to Agents + revisions = relationship("Revision", back_populates="setting") + + def to_dict(self) -> dict: + return { + "id": str(self.id), + "name": self.name, + "slug": self.slug, + "value": self.value + } + + @staticmethod + def generate_slug(name: str) -> str: + """Generate a URL-friendly slug from the name.""" + import re + # Convert to lowercase and replace spaces with dashes + slug = name.lower().strip().replace(' ', '-') + # Remove special characters + slug = re.sub(r'[^a-z0-9-]', '', slug) + # Replace multiple dashes with single dash + slug = re.sub(r'-+', '-', slug) + return slug + + def __init__(self, name: str, value: dict, **kwargs): + """Initialize a new setting with auto-generated slug.""" + super().__init__(**kwargs) + self.name = name + self.value = value + if 'slug' not in kwargs: + self.slug = self.generate_slug(name) \ No newline at end of file diff --git a/backend/src/models/thread.py b/backend/src/models/thread.py new file mode 100644 index 00000000..07880e98 --- /dev/null +++ b/backend/src/models/thread.py @@ -0,0 +1,42 @@ +import sqlalchemy as sa +from sqlalchemy import Column, DateTime, func, ForeignKey, UUID +from sqlalchemy.orm import relationship +from src.services.db import Base + +class Thread(Base): + __tablename__ = "threads" + + user = Column( + UUID(as_uuid=True), + ForeignKey('users.id', ondelete='CASCADE'), + primary_key=True, + nullable=False + ) + thread = Column( + UUID(as_uuid=True), + primary_key=True, + nullable=False + ) + agent = Column( + UUID(as_uuid=True), + ForeignKey('agents.id', ondelete='CASCADE'), + nullable=True + ) + created_at = Column( + DateTime(timezone=True), + server_default=func.now(), + nullable=False + ) + + # Add relationships + user_relation = relationship("User", backref="threads") + agent_relation = relationship("Agent", backref="threads") + + def to_dict(self) -> dict: + return { + "user": str(self.user), + "thread": str(self.thread), + "agent": str(self.agent) if self.agent else None, + "created_at": self.created_at.isoformat() if self.created_at else None + } + \ No newline at end of file diff --git a/backend/src/models/user.py b/backend/src/models/user.py new file mode 100644 index 00000000..295bcbf8 --- /dev/null +++ b/backend/src/models/user.py @@ -0,0 +1,77 @@ +import sqlalchemy as sa +from typing import Optional +from sqlalchemy import Column, String, Text, DateTime, func, ForeignKey, UUID +from passlib.context import CryptContext +from pydantic import BaseModel +from datetime import datetime +from src.services.db import Base + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +class ProtectedUser(BaseModel): + id: str + username: str + email: str + name: str + created_at: datetime + updated_at: Optional[datetime] = None + +class User(Base): + __tablename__ = "users" + + id = Column( + UUID(as_uuid=True), + primary_key=True, + index=True, + server_default=sa.text("uuid_generate_v4()") + ) + username = Column(String, unique=True, index=True) + email = Column(String, unique=True, index=True) + name = Column(String) + hashed_password = Column(String) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + + @staticmethod + def get_password_hash(password: str) -> str: + return pwd_context.hash(password) + + @staticmethod + def verify_password(plain_password: str, hashed_password: str) -> bool: + return pwd_context.verify(plain_password, hashed_password) + + def protected(self) -> ProtectedUser: + """Return a dictionary representation of user without sensitive data.""" + return ProtectedUser( + id=str(self.id), + username=self.username, + email=self.email, + name=self.name, + created_at=self.created_at, + updated_at=self.updated_at + ) + +class Token(Base): + __tablename__ = "tokens" + + user_id = Column( + UUID(as_uuid=True), + ForeignKey('users.id', ondelete='CASCADE'), + primary_key=True + ) + key = Column(String, primary_key=True) + value = Column(Text, nullable=False) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + updated_at = Column(DateTime(timezone=True), onupdate=func.now()) + + @staticmethod + def encrypt_value(value: str, secret_key: str) -> str: + from cryptography.fernet import Fernet + f = Fernet(secret_key.encode()) + return f.encrypt(value.encode()).decode() + + @staticmethod + def decrypt_value(encrypted_value: str, secret_key: str) -> str: + from cryptography.fernet import Fernet + f = Fernet(secret_key.encode()) + return f.decrypt(encrypted_value.encode()).decode() \ No newline at end of file diff --git a/backend/src/repos/tool_repo.py b/backend/src/repos/tool_repo.py new file mode 100644 index 00000000..5dafa10e --- /dev/null +++ b/backend/src/repos/tool_repo.py @@ -0,0 +1,119 @@ + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.exc import IntegrityError +from src.tools import tools as STATIC_TOOLS, attach_tool_details +from src.models.agent import Tool as AgentTool +from src.tools.api import generate_tools_from_openapi_spec +from langchain_core.tools import StructuredTool +class ToolRepo: + _instance = None + + def __new__(cls, db: AsyncSession, user_id: str = None): + if cls._instance is None: + cls._instance = super(ToolRepo, cls).__new__(cls) + cls._instance.db = None + cls._instance.user_id = None + return cls._instance + + def __init__(self, db: AsyncSession, user_id: str = None): + # Update attributes if they've changed + self.db = db + self.user_id = user_id + + async def list_tool_with_details(self) -> list[AgentTool]: + tools: list[StructuredTool] = [] + if self.user_id: + user_tools = await self._list_user_tools() + tools.extend(user_tools) + + tools.extend(STATIC_TOOLS) + tool_details = [attach_tool_details({'id':tool.name, 'description':tool.description, 'args':tool.args, 'tags':tool.tags}) for tool in tools] + return tool_details + + async def _list_user_tools(self) -> list[AgentTool]: + res = await self.db.execute(select(AgentTool).where(AgentTool.user_id == self.user_id)) + user_tools = res.scalars().all() + tool_dicts = [] + for tool in user_tools: + api_tools = generate_tools_from_openapi_spec(openapi_url=tool.url, headers=tool.headers) + tool_dicts.extend(api_tools) + return tool_dicts + + async def list_tools(self) -> list[StructuredTool]: + tools = [] + if self.user_id: + user_tools = await self._list_user_tools() + tools.extend(user_tools) + tools.extend(STATIC_TOOLS) + return tools + + async def create( + self, + name: str, + description: str, + url: str, + spec: dict | str = None, + headers: dict = None, + tags: list[str] = None + ) -> AgentTool: + try: + tool = AgentTool( + name=name, + description=description, + user_id=self.user_id, + url=url, + spec=spec, + headers=headers, + tags=tags + ) + self.db.add(tool) + await self.db.commit() + await self.db.refresh(tool) + return tool + except IntegrityError as e: + await self.db.rollback() + raise ValueError(f"A tool with this name already exists") + except Exception as e: + await self.db.rollback() + raise e + + async def find_by_name(self, name: str) -> AgentTool: + tool = await self.db.execute(select(AgentTool).where(AgentTool.name == name, AgentTool.user_id == self.user_id)) + return tool.scalar_one_or_none() + + async def find_by_id(self, id: str) -> AgentTool: + tool = await self.db.execute(select(AgentTool).where(AgentTool.id == id, AgentTool.user_id == self.user_id)) + return tool.scalar_one_or_none() + + async def update( + self, + id: str, + name: str, + description: str, + url: str, + spec: dict | str = None, + headers: dict = None, + tags: list[str] = None + ) -> AgentTool: + tool = await self.find_by_id(id) + if tool: + tool.name = name + tool.description = description + tool.url = url + tool.spec = spec + tool.headers = headers + tool.tags = tags + self.db.commit() + self.db.refresh(tool) + return tool + else: + raise ValueError("Tool not found") + + async def delete(self, id: str) -> None: + tool = await self.find_by_id(id) + if tool: + await self.db.delete(tool) + await self.db.commit() + else: + raise ValueError("Tool not found") \ No newline at end of file diff --git a/backend/src/routes/v0/__init__.py b/backend/src/routes/v0/__init__.py index 1cbe64ef..f4e5af76 100644 --- a/backend/src/routes/v0/__init__.py +++ b/backend/src/routes/v0/__init__.py @@ -1,6 +1,7 @@ from .llm import router as llm from .thread import router as thread from .tool import router as tool +from .tool.custom import router as tool_custom from .retrieve import router as retrieve from .source import router as source from .info import router as info @@ -12,4 +13,4 @@ from .model import router as model from .server import router as server -__all__ = ["llm", "thread", "tool", "retrieve", "source", "info", "auth", "token", "storage", "settings", "agent", "model", "server"] \ No newline at end of file +__all__ = ["llm", "thread", "tool", "tool_custom", "retrieve", "source", "info", "auth", "token", "storage", "settings", "agent", "model", "server"] \ No newline at end of file diff --git a/backend/src/routes/v0/llm.py b/backend/src/routes/v0/llm.py index 45d62274..91a43b13 100644 --- a/backend/src/routes/v0/llm.py +++ b/backend/src/routes/v0/llm.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from langchain.chat_models import init_chat_model -from src.models import ProtectedUser, User +from src.models import ProtectedUser from src.entities import Answer, ChatInput, NewThread, ExistingThread from src.services.db import get_async_db from src.utils.auth import get_optional_user diff --git a/backend/src/routes/v0/tool.py b/backend/src/routes/v0/tool.py deleted file mode 100644 index be991b15..00000000 --- a/backend/src/routes/v0/tool.py +++ /dev/null @@ -1,294 +0,0 @@ -from typing import Dict, Any, List, Optional -from fastapi import status, Depends, APIRouter -from fastapi.responses import JSONResponse -from sqlalchemy.orm import Session - -from src.entities.a2a import A2AServer -from src.constants.examples import A2A_GET_AGENT_CARD_EXAMPLE, MCP_REQ_BODY_EXAMPLE -from src.constants import APP_LOG_LEVEL -from src.models import ProtectedUser -from src.repos.user_repo import UserRepo -from src.utils.auth import verify_credentials -from src.services.db import get_db -from src.services.mcp import McpService - -TAG = "Tool" -router = APIRouter(tags=[TAG]) - -################################################################################ -### List Tools -################################################################################ -from src.tools import tools, attach_tool_details -tool_names = [attach_tool_details({'id':tool.name, 'description':tool.description, 'args':tool.args, 'tags':tool.tags}) for tool in tools] -tools_response = {"tools": tool_names} -@router.get( - "/tools", - tags=[TAG], - responses={ - status.HTTP_200_OK: { - "description": "All tools.", - "content": { - "application/json": { - "example": tools_response - } - } - } - } -) -def list_tools( - # user: ProtectedUser = Depends(verify_credentials) -): - return JSONResponse( - content=tools_response, - status_code=status.HTTP_200_OK - ) - - -################################################################################ -### List MCP Info -################################################################################ -from pydantic import BaseModel, Field - -class MCPInfo(BaseModel): - mcp: Optional[Dict[str, Any]] = None - mcpServers: Optional[Dict[str, Any]] = None - - model_config = { - "json_schema_extra": {"example": MCP_REQ_BODY_EXAMPLE} - } - - -@router.post( - "/tools/mcp/info", - tags=[TAG], - responses={ - status.HTTP_200_OK: { - "description": "All tools.", - "content": { - "application/json": { - "example": [] - } - } - } - } -) -async def list_mcp_info( - config: MCPInfo -): - try: - agent_session = McpService() - mcp_config = config.mcpServers or config.mcp - if not mcp_config: - return JSONResponse( - content={'error': 'No MCP servers or MCP config found'}, - status_code=status.HTTP_400_BAD_REQUEST - ) - await agent_session.setup(mcp_config) - tools = agent_session.tools() - return JSONResponse( - content={'mcp': [ - {k: v for k, v in tool.model_dump().items() if k not in ['func', 'coroutine']} - for tool in tools - ]}, - status_code=status.HTTP_200_OK - ) - except Exception as e: - return JSONResponse( - content={'error': str(e)}, - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR - ) - finally: - await agent_session.cleanup() - -################################################################################ -### List A2A Info -################################################################################ -from src.utils.a2a import A2ACardResolver, A2AClient -from src.entities.a2a import A2AServers - -@router.post( - "/tools/a2a/info", - tags=[TAG], - responses={ - status.HTTP_200_OK: { - "description": "All capabilities.", - "content": { - "application/json": { - "example": A2A_GET_AGENT_CARD_EXAMPLE - } - } - } - } -) -async def get_a2a_agent_card( - body: A2AServers -): - try: - results = [] - - if not body.a2a: - return JSONResponse( - content={'error': 'No A2A servers or A2A config found'}, - status_code=status.HTTP_400_BAD_REQUEST - ) - - for server_name, server in body.a2a.items(): - try: - a2a_card_resolver = A2ACardResolver(server.base_url, server.agent_card_path) - agent_card = a2a_card_resolver.get_agent_card() - results.append(agent_card.model_dump()) - except Exception as server_error: - results.append({"error": str(server_error), "base_url": server.base_url}) - - return JSONResponse( - content={'agent_cards': results}, - status_code=status.HTTP_200_OK - ) - except Exception as e: - return JSONResponse( - content={'error': str(e)}, - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR - ) - -################################################################################ -### Invoke A2A Agent -################################################################################ -# @router.post( -# "/tools/a2a/invoke", -# tags=[TAG], -# responses={ -# status.HTTP_200_OK: { -# "description": "Invoke a agent.", -# "content": { -# "application/json": { -# "example": {} -# } -# } -# } -# } -# ) -# async def invoke_a2a_agent( -# body: dict[str, Any] -# ): -# try: -# a2a_card = A2ACardResolver(**body) -# a2a_client = A2AClient(a2a_card) -# response = a2a_client.invoke(body['task']) - -# return JSONResponse( -# content={'answer': 'Agent invoked'}, -# status_code=status.HTTP_200_OK -# ) -# except Exception as e: -# return JSONResponse( -# content={'error': str(e)}, -# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR -# ) - -################################################################################ -### Test Tool -################################################################################ -from pydantic import BaseModel -from typing import Dict, Any, Optional -import traceback - -class ToolRequest(BaseModel): - tool_id: str - input: Dict[str, Any] - metadata: Optional[Dict[str, Any]] = None - -################################################################################ -### Test Tool -################################################################################ -from pydantic import BaseModel -from typing import Dict, Any, Optional -import traceback -import os - -class ToolRequest(BaseModel): - args: Dict[str, Any] - # metadata: Optional[Dict[str, Any]] = None - -@router.post( - "/tools/{tool_id}/invoke", - tags=[TAG], - responses={ - status.HTTP_200_OK: { - "description": "Tool execution result.", - "content": { - "application/json": { - "example": { - "result": "Tool execution result", - "success": True - } - } - } - }, - status.HTTP_400_BAD_REQUEST: { - "description": "Invalid tool or arguments.", - "content": { - "application/json": { - "example": { - "error": "Tool not found or invalid arguments", - "success": False - } - } - } - }, - status.HTTP_500_INTERNAL_SERVER_ERROR: { - "description": "Error executing tool.", - "content": { - "application/json": { - "example": { - "error": "Internal server error", - "success": False - } - } - } - } - } -) -async def invoke_tool( - tool_id: str, - request: ToolRequest, - user: ProtectedUser = Depends(verify_credentials), - db: Session = Depends(get_db) -): - """ - Invoke a tool by executing it with the provided arguments. - """ - try: - # Find the tool by id - from src.tools import tools, dynamic_tools - selected_tool = next((tool for tool in tools if tool.name == tool_id), None) - - if not selected_tool: - return JSONResponse( - content={"error": f"Tool with id '{tool_id}' not found", "success": False}, - status_code=status.HTTP_400_BAD_REQUEST - ) - - user_repo = UserRepo(db, user.id) - # Use dynamic_tools to properly set metadata - tool_with_metadata = dynamic_tools([tool_id], {"user_repo": user_repo})[0] - - # Execute the tool with the provided arguments - output = tool_with_metadata.invoke(input=request.args) - - return JSONResponse( - content={"output": output, "success": True}, - status_code=status.HTTP_200_OK - ) - - except Exception as e: - error_message = str(e) - error_traceback = traceback.format_exc() - - return JSONResponse( - content={ - "error": error_message, - "traceback": error_traceback if "DEBUG" in APP_LOG_LEVEL else None, - "success": False - }, - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR - ) \ No newline at end of file diff --git a/backend/src/routes/v0/tool/__init__.py b/backend/src/routes/v0/tool/__init__.py new file mode 100644 index 00000000..ee740d3b --- /dev/null +++ b/backend/src/routes/v0/tool/__init__.py @@ -0,0 +1,18 @@ +from fastapi import status, Depends, APIRouter +from fastapi.responses import JSONResponse + +from src.models import ProtectedUser +from src.utils.auth import get_optional_user +from src.routes.v0.tool.custom import router as tool_custom +from src.routes.v0.tool.invoke import router as tool_invoke +from src.routes.v0.tool.create import router as tool_create +from src.routes.v0.tool.list import router as tool_list + +TAG = "Tool" +router = APIRouter(tags=[TAG]) + +## Attach custom router +router.include_router(tool_custom) +router.include_router(tool_invoke) +router.include_router(tool_create) +router.include_router(tool_list) \ No newline at end of file diff --git a/backend/src/routes/v0/tool/create.py b/backend/src/routes/v0/tool/create.py new file mode 100644 index 00000000..cf4a45cd --- /dev/null +++ b/backend/src/routes/v0/tool/create.py @@ -0,0 +1,45 @@ +from fastapi import status, Depends, APIRouter, Body, HTTPException +from fastapi.responses import JSONResponse +from sqlalchemy.ext.asyncio import AsyncSession +from src.services.db import get_async_db +from src.models import ProtectedUser +from src.utils.auth import verify_credentials +from src.repos.tool_repo import ToolRepo +from src.entities import AgentTool + +router = APIRouter() + +@router.post( + "/tools", + responses={ + status.HTTP_201_CREATED: { + "description": "Tool created successfully", + } + } +) +async def create_tool( + user: ProtectedUser = Depends(verify_credentials), + tool: AgentTool = Body(...), + db: AsyncSession = Depends(get_async_db) +): + try: + tool_repo = ToolRepo(db, user.id) + tool = await tool_repo.create( + name=tool.name, + description=tool.description, + url=tool.url, + spec=tool.spec, + headers=tool.headers, + tags=tool.tags + ) + return JSONResponse( + content={"tool": tool.to_dict()}, + status_code=status.HTTP_201_CREATED + ) + except ValueError as e: + if "A tool with this name already exists" in str(e): + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) + else: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) \ No newline at end of file diff --git a/backend/src/routes/v0/tool/custom.py b/backend/src/routes/v0/tool/custom.py new file mode 100644 index 00000000..a3966211 --- /dev/null +++ b/backend/src/routes/v0/tool/custom.py @@ -0,0 +1,141 @@ +from typing import Dict, Any, Optional +from pydantic import BaseModel, Field +from fastapi import status, Depends, APIRouter +from fastapi.responses import JSONResponse + +from src.constants.examples import MCP_REQ_BODY_EXAMPLE, A2A_GET_AGENT_CARD_EXAMPLE +from src.utils.a2a import A2ACardResolver +from src.entities.a2a import A2AServers +from src.services.mcp import McpService + +router = APIRouter() + +class MCPInfo(BaseModel): + mcp: Optional[Dict[str, Any]] = None + mcpServers: Optional[Dict[str, Any]] = None + + model_config = { + "json_schema_extra": {"example": MCP_REQ_BODY_EXAMPLE} + } + +@router.post( + "/tools/mcp/info", + responses={ + status.HTTP_200_OK: { + "description": "All tools.", + "content": { + "application/json": { + "example": [] + } + } + } + } +) +async def list_mcp_info( + config: MCPInfo +): + try: + agent_session = McpService() + mcp_config = config.mcpServers or config.mcp + if not mcp_config: + return JSONResponse( + content={'error': 'No MCP servers or MCP config found'}, + status_code=status.HTTP_400_BAD_REQUEST + ) + await agent_session.setup(mcp_config) + tools = agent_session.tools() + return JSONResponse( + content={'mcp': [ + {k: v for k, v in tool.model_dump().items() if k not in ['func', 'coroutine']} + for tool in tools + ]}, + status_code=status.HTTP_200_OK + ) + except Exception as e: + return JSONResponse( + content={'error': str(e)}, + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR + ) + finally: + await agent_session.cleanup() + +################################################################################ +### List A2A Info +################################################################################ +@router.post( + "/tools/a2a/info", + responses={ + status.HTTP_200_OK: { + "description": "All capabilities.", + "content": { + "application/json": { + "example": A2A_GET_AGENT_CARD_EXAMPLE + } + } + } + } +) +async def get_a2a_agent_card( + body: A2AServers +): + try: + results = [] + + if not body.a2a: + return JSONResponse( + content={'error': 'No A2A servers or A2A config found'}, + status_code=status.HTTP_400_BAD_REQUEST + ) + + for server_name, server in body.a2a.items(): + try: + a2a_card_resolver = A2ACardResolver(server.base_url, server.agent_card_path) + agent_card = a2a_card_resolver.get_agent_card() + results.append(agent_card.model_dump()) + except Exception as server_error: + results.append({"error": str(server_error), "base_url": server.base_url}) + + return JSONResponse( + content={'agent_cards': results}, + status_code=status.HTTP_200_OK + ) + except Exception as e: + return JSONResponse( + content={'error': str(e)}, + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR + ) + +################################################################################ +### Invoke A2A Agent +################################################################################ +# @router.post( +# "/tools/a2a/invoke", +# tags=[TAG], +# responses={ +# status.HTTP_200_OK: { +# "description": "Invoke a agent.", +# "content": { +# "application/json": { +# "example": {} +# } +# } +# } +# } +# ) +# async def invoke_a2a_agent( +# body: dict[str, Any] +# ): +# try: +# a2a_card = A2ACardResolver(**body) +# a2a_client = A2AClient(a2a_card) +# response = a2a_client.invoke(body['task']) + +# return JSONResponse( +# content={'answer': 'Agent invoked'}, +# status_code=status.HTTP_200_OK +# ) +# except Exception as e: +# return JSONResponse( +# content={'error': str(e)}, +# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR +# ) \ No newline at end of file diff --git a/backend/src/routes/v0/tool/invoke.py b/backend/src/routes/v0/tool/invoke.py new file mode 100644 index 00000000..62fa8135 --- /dev/null +++ b/backend/src/routes/v0/tool/invoke.py @@ -0,0 +1,102 @@ +import traceback +from typing import Dict, Any +from fastapi import status, Depends, APIRouter +from fastapi.responses import JSONResponse +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from src.repos.tool_repo import ToolRepo +from src.constants import APP_LOG_LEVEL +from src.models import ProtectedUser +from src.repos.user_repo import UserRepo +from src.utils.auth import verify_credentials +from src.services.db import get_db + +router = APIRouter() + +class ToolRequest(BaseModel): + args: Dict[str, Any] + +@router.post( + "/tools/{tool_id}/invoke", + responses={ + status.HTTP_200_OK: { + "description": "Tool execution result.", + "content": { + "application/json": { + "example": { + "result": "Tool execution result", + "success": True + } + } + } + }, + status.HTTP_400_BAD_REQUEST: { + "description": "Invalid tool or arguments.", + "content": { + "application/json": { + "example": { + "error": "Tool not found or invalid arguments", + "success": False + } + } + } + }, + status.HTTP_500_INTERNAL_SERVER_ERROR: { + "description": "Error executing tool.", + "content": { + "application/json": { + "example": { + "error": "Internal server error", + "success": False + } + } + } + } + } +) +async def invoke_tool( + tool_id: str, + request: ToolRequest, + user: ProtectedUser = Depends(verify_credentials), + db: Session = Depends(get_db) +): + """ + Invoke a tool by executing it with the provided arguments. + """ + from src.tools import dynamic_tools + try: + tool_repo = ToolRepo(db, user.id) + tools = await tool_repo.list_tools() + selected_tool = [tool for tool in tools if tool.name == tool_id] + + if not selected_tool: + return JSONResponse( + content={"error": f"Tool with id '{tool_id}' not found", "success": False}, + status_code=status.HTTP_400_BAD_REQUEST + ) + + user_repo = UserRepo(db, user.id) + # Use dynamic_tools to properly set metadata + tool_with_metadata = dynamic_tools([tool_id], {"user_repo": user_repo})[0] + + # Execute the tool with the provided arguments + output = tool_with_metadata.invoke(input=request.args) + + return JSONResponse( + content={"output": output, "success": True}, + status_code=status.HTTP_200_OK + ) + + except Exception as e: + error_message = str(e) + error_traceback = traceback.format_exc() + + return JSONResponse( + content={ + "error": error_message, + "traceback": error_traceback if "DEBUG" in APP_LOG_LEVEL else None, + "success": False + }, + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR + ) \ No newline at end of file diff --git a/backend/src/routes/v0/tool/list.py b/backend/src/routes/v0/tool/list.py new file mode 100644 index 00000000..948f960c --- /dev/null +++ b/backend/src/routes/v0/tool/list.py @@ -0,0 +1,36 @@ +from fastapi import status, Depends, APIRouter, Body, HTTPException, Query +from fastapi.responses import JSONResponse +from sqlalchemy.ext.asyncio import AsyncSession +from src.services.db import get_async_db +from src.models import ProtectedUser +from src.utils.auth import get_optional_user +from src.repos.tool_repo import ToolRepo + +router = APIRouter() + +@router.get( + "/tools", + responses={ + status.HTTP_200_OK: { + "description": "Tools listed successfully", + "content": { + "application/json": { + "example": {"tools": [{"id": "123", "name": "Tool 1", "description": "Tool 1 description", "url": "https://tool1.com", "spec": None, "headers": {}, "tags": ["tag1", "tag2"]}]} + } + } + } + } +) +async def list_tools( + user: ProtectedUser = Depends(get_optional_user), + db: AsyncSession = Depends(get_async_db) +): + try: + tool_repo = ToolRepo(db, user.id if user else None) + tools = await tool_repo.list_tool_with_details() + return JSONResponse( + content={"tools": tools}, + status_code=status.HTTP_200_OK + ) + except Exception as e: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) \ No newline at end of file diff --git a/backend/src/services/db.py b/backend/src/services/db.py index 2768de72..bd084190 100644 --- a/backend/src/services/db.py +++ b/backend/src/services/db.py @@ -2,6 +2,7 @@ import functools from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from typing import AsyncGenerator, Generator, Callable +import sqlalchemy as sa from sqlalchemy import create_engine from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker @@ -12,6 +13,7 @@ MAX_CONNECTION_POOL_SIZE = None # SQLAlchemy engines +Base = sa.orm.declarative_base() engine = create_engine(DB_URI) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) diff --git a/backend/src/tools/__init__.py b/backend/src/tools/__init__.py index f99f3357..b0bf179f 100644 --- a/backend/src/tools/__init__.py +++ b/backend/src/tools/__init__.py @@ -1,4 +1,5 @@ from langchain_mcp_adapters.client import MultiServerMCPClient +from langchain_core.tools import BaseTool from langgraph.prebuilt import ToolNode from src.tools.retrieval import retrieval_query, retrieval_add, retrieval_load @@ -28,8 +29,9 @@ def collect_tools(selected_tools: list[str]): return filtered_tools -def dynamic_tools(selected_tools: list[str], metadata: dict = None): +def dynamic_tools(selected_tools: list[str], metadata: dict = None, tool_selection: list[BaseTool] = None): # Filter tools by name + tools = tool_selection if tool_selection else tools filtered_tools = [tool for tool in tools if tool.name in selected_tools] if len(filtered_tools) == 0: raise ValueError(f"No tools found by the names: {selected_tools.join(', ')}") diff --git a/backend/src/tools/a2a.py b/backend/src/tools/a2a.py index 65446d41..523dedae 100644 --- a/backend/src/tools/a2a.py +++ b/backend/src/tools/a2a.py @@ -1,22 +1,39 @@ -from utils.a2a import A2ACardResolver, a2a_builder +""" +A2A (Agent to Agent) tools +""" from langchain_core.tools import StructuredTool -def a2a_tool( - base_url: str, - thread_id: str = None, -): - card = A2ACardResolver(base_url=base_url).get_agent_card() - async def send_task(query: str): - return await a2a_builder( - base_url=base_url, - query=query, - thread_id=thread_id +from src.entities.a2a import A2AServer +from src.utils.a2a import A2ACardResolver, a2a_builder + +def create_a2a_tools( + thread_id: str, + a2a: dict[str, A2AServer], +) -> list[StructuredTool]: + tools = [] + if not a2a: + return tools + + # Loop through each entry in the a2a dictionary + for key, config in a2a.items(): + card = A2ACardResolver( + base_url=config.base_url, + agent_card_path=config.agent_card_path + ).get_agent_card() + + async def send_task(query: str, config=config): + return await a2a_builder( + base_url=config.base_url, + query=query, + thread_id=thread_id + ) + send_task.__doc__ = ( + f"Part of {key} A2A (Agent to Agent) server. " + f"Send query to remote agent: {card.name}. " + f"Agent Card: {card.model_dump_json()}" ) - send_task.__doc__ = ( - f"Send query to remote agent: {card.name}. " - f"Agent Card: {card.model_dump_json()}" - ) - tool = StructuredTool.from_function(coroutine=send_task) - tool.name = card.name.lower().replace(" ", "_") - tool.description = card.description - return tool + tool = StructuredTool.from_function(coroutine=send_task) + tool.name = card.name.lower().replace(" ", "_") + tools.append(tool) + + return tools \ No newline at end of file diff --git a/backend/src/tools/api.py b/backend/src/tools/api.py new file mode 100644 index 00000000..f25b0532 --- /dev/null +++ b/backend/src/tools/api.py @@ -0,0 +1,454 @@ +from typing import Union, Dict, Any, List, Optional, Type +import re +import json +import httpx +import yaml +import asyncio + +from pydantic import BaseModel, create_model +from pydantic.fields import Field, FieldInfo +from langchain_community.agent_toolkits.openapi.toolkit import RequestsToolkit +from langchain_community.utilities.requests import TextRequestsWrapper, GenericRequestsWrapper +from langchain_core.tools import StructuredTool +from langchain_community.tools.requests.tool import BaseRequestsTool + +from src.utils.logger import logger + +def _get_schema(response_json: Union[dict, list]) -> dict: + if isinstance(response_json, list): + response_json = response_json[0] if response_json else {} + return {key: type(value).__name__ for key, value in response_json.items()} + +def _add_endpoint_to_spec( + base_url: str, + endpoint: str, + common_query_parameters: List[Dict[str, Any]], + openapi_spec: Dict[str, Any] +) -> None: + response = httpx.get(base_url + endpoint) + if response.status_code == 200: + schema = _get_schema(response.json()) + openapi_spec["paths"][endpoint] = { + "get": { + "summary": f"Get {endpoint[1:]}", + "parameters": common_query_parameters, + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": {"type": "object", "properties": schema} + } + }, + } + }, + } + } + +def get_api_spec( + title: str, + version: str, + base_url: str, + endpoints: List[str], + common_query_parameters: List[Dict[str, Any]], + description: str, + paths: Dict[str, Any] = {}, +) -> str: + openapi_spec: Dict[str, Any] = { + "openapi": "3.0.0", + "info": {"title": title, "version": version, "description": description}, + "servers": [{"url": base_url}], + "paths": paths, + } + # Iterate over the endpoints to construct the paths + for endpoint in endpoints: + _add_endpoint_to_spec(base_url, endpoint, common_query_parameters, openapi_spec) + return yaml.dump(openapi_spec, sort_keys=False) + +def openapi_from_url(url: str) -> str: + try: + response = httpx.get(url) + if response.status_code == 200: + return yaml.dump(response.json(), sort_keys=False) + else: + raise Exception(f"Failed to get OpenAPI spec from {url}") + except Exception as e: + raise Exception(f"Failed to get OpenAPI spec from {url}: {e}") + +################################################################################ +### Example Usage +################################################################################ +api_spec = get_api_spec( + title="JSONPlaceholder API", + version="1.0.0", + base_url="https://jsonplaceholder.typicode.com", + endpoints=["/posts", "/comments"], + common_query_parameters=[], + description="JSONPlaceholder API", +) +toolkit = RequestsToolkit( + requests_wrapper=TextRequestsWrapper(headers={}), + allow_dangerous_requests=True, +) +tools = toolkit.get_tools() +system_message = """You have access to an API to help answer user queries. +Here is documentation on the API: +{api_spec} +""".format(api_spec=api_spec) + +from langchain_community.utilities.requests import GenericRequestsWrapper +from langchain_core.tools import BaseTool + +def get_base_tool( + headers: Dict[str, Any] +) -> BaseTool: + requests_wrapper = GenericRequestsWrapper(headers=headers) + base_tool = BaseRequestsTool(requests_wrapper=requests_wrapper) + return base_tool + +def get_wrapper( + headers: Dict[str, Any] +) -> GenericRequestsWrapper: + requests_wrapper = GenericRequestsWrapper(headers=headers) + return requests_wrapper + +def construct_api_tool( + name: str, + description: str, + method: str, + headers: Dict[str, Any], + metadata: Dict[str, Any] = {}, + tags: List[str] = [], + verbose: bool = False, +) -> StructuredTool: + requests_wrapper = GenericRequestsWrapper(headers=headers) + if method == "GET": + async def api_request(url: str): + return await requests_wrapper.aget(url) + elif method == "POST": + async def api_request(url: str, data: Dict[str, Any]): + return await requests_wrapper.apost(url, data=data) + elif method == "PUT": + async def api_request(url: str, data: Dict[str, Any]): + return await requests_wrapper.aput(url, data=data) + elif method == "PATCH": + async def api_request(url: str, data: Dict[str, Any]): + return await requests_wrapper.apatch(url, data=data) + elif method == "DELETE": + async def api_request(url: str): + return await requests_wrapper.adelete(url) + + api_request.__doc__ = description + tool = StructuredTool.from_function(coroutine=api_request) + tool.name = re.sub(r'[^a-zA-Z0-9]', '_', name.lower()) + tool.description = description + tool.metadata = metadata + tool.tags = tags + tool.verbose = verbose + return tool + +################################################################################ +### API from OpenAPI spec +################################################################################ +async def fetch_openapi_spec( + url: str, + headers: Optional[Dict[str, Any]] = None, +) -> Dict[str, Any]: + """ + Fetches the OpenAPI specification JSON from a given URL asynchronously. + """ + wrapper = GenericRequestsWrapper(headers=headers or {}) + raw = await wrapper.aget(url) + return json.loads(raw) + + +def make_api_call_func( + method: str, + url: str, + headers: Dict[str, Any], + description: str, + data: Optional[Dict[str, Any]] = None, + path_params: Optional[Dict[str, Any]] = None, +) -> Any: + """ + Factory that creates an async function for the given HTTP method and URL. + + Args: + method: HTTP method (GET, POST, PUT, DELETE) + url: Base URL for the API endpoint + headers: HTTP headers to include in the request + description: Description of what the API call does + data: Optional request body data for POST/PUT requests + path_params: Optional dictionary of path parameters to format into the URL + """ + async def api_call(): + wrapper = GenericRequestsWrapper(headers=headers) + # Format URL with path parameters if provided + formatted_url = url.format(**(path_params or {})) + + if method == "GET": + return await wrapper.aget(formatted_url) + elif method == "POST": + return await wrapper.apost(formatted_url, data=data or {}) + elif method == "PUT": + return await wrapper.aput(formatted_url, data=data or {}) + elif method == "PATCH": + return await wrapper.apatch(formatted_url, data=data or {}) + elif method == "DELETE": + return await wrapper.adelete(formatted_url) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + + api_call.__doc__ = description + return api_call + + +def resolve_ref(spec: dict, ref: str) -> dict: + """ + Resolves a $ref string (e.g., '#/components/schemas/JobCreate') to the actual object in the OpenAPI spec. + + Args: + spec: The full OpenAPI spec as a dictionary. + ref: The $ref string to resolve. + + Returns: + The resolved object (e.g., the schema dict). + """ + if not ref.startswith("#/"): + raise ValueError("Only local refs are supported") + parts = ref.lstrip("#/").split("/") + obj = spec + for part in parts: + obj = obj[part] + return obj + +def resolve_ref_recursive(spec: dict, obj: dict) -> dict: + """ + Recursively resolves all $ref values in a schema object using the OpenAPI spec. + + Args: + spec: The full OpenAPI spec as a dictionary. + obj: The schema object (may contain $ref). + + Returns: + The schema object with all $ref values resolved. + """ + if isinstance(obj, dict): + if "$ref" in obj: + # Resolve the reference and recurse + resolved = resolve_ref(spec, obj["$ref"]) + return resolve_ref_recursive(spec, resolved) + else: + # Recurse into all dict values + return {k: resolve_ref_recursive(spec, v) for k, v in obj.items()} + elif isinstance(obj, list): + # Recurse into all list items + return [resolve_ref_recursive(spec, item) for item in obj] + else: + # Base case: return as is + return obj + + +def get_field_type(field_type: str) -> Any: + """ + Map field type strings to actual Python types. + """ + type_mapping = { + "string": str, + "integer": int, + "number": float, + "boolean": bool, + "object": Dict[str, Any], + "array": list + } + return type_mapping.get(field_type, Any) + +def create_schema(model_name: str, fields_json: Dict[str, Any]) -> Type[BaseModel]: + """ + Create a Pydantic model dynamically from a JSON object. + + :param model_name: The name of the model. + :param fields_json: A dictionary representing the fields from a JSON object. + :return: A dynamically created Pydantic model. + """ + + if fields_json.get("type") == "string": + fields = {model_name: (str, Field(description=fields_json.get("description", "")))} + return create_model(model_name, **fields) + elif fields_json.get("type") == "integer": + fields = {model_name: (int, Field(description=fields_json.get("description", "")))} + return create_model(model_name, **fields) + elif fields_json.get("type") == "number": + fields = {model_name: (float, Field(description=fields_json.get("description", "")))} + return create_model(model_name, **fields) + elif fields_json.get("type") == "boolean": + fields = {model_name: (bool, Field(description=fields_json.get("description", "")))} + return create_model(model_name, **fields) + elif fields_json.get("type") == "array": + fields = {model_name: (list, Field(description=fields_json.get("description", "")))} + return create_model(model_name, **fields) + elif fields_json.get("type") == "object": + fields = {model_name: (dict, Field(description=fields_json.get("description", "")))} + return create_model(model_name, **fields) + + fields = {} + for field_name, field_info in fields_json.items(): + if field_info.get("type") == "object" and "properties" in field_info: + field_type = create_schema(field_info.get('title'), field_info.get("properties")) + else: + field_type = get_field_type(field_info.get("type", "")) + + field_params = {"description": field_info.get("description", "")} + if field_info.get("required", False): + field_params["default"] = field_info.get('default', None) or ... + else: + field_params["default"] = field_info.get("default", None) + fields[field_name] = (field_type, Field(**field_params)) + + return create_model(model_name, **fields) + +def generate_tools_from_openapi_spec( + openapi_url: str, + headers: Optional[Dict[str, Any]] = None, + verbose: bool = False, +) -> List[StructuredTool]: + """ + Generates a list of StructuredTool instances for each endpoint defined in an OpenAPI spec. + + Args: + openapi_url: URL pointing to the OpenAPI JSON spec. + headers: Optional default headers for all requests. + verbose: Whether to set verbose=True on each tool. + + Returns: + A list of StructuredTool objects, one per (path, method) in the spec. + """ + # Load spec asynchronously + spec = asyncio.get_event_loop().run_until_complete( + fetch_openapi_spec(openapi_url, headers=headers) + ) + + # Determine base server URL + server_url = spec.get("servers", [{}])[0].get("url", "") or openapi_url.replace("/openapi.json", "") + + tools: List[StructuredTool] = [] + for path, operations in spec.get("paths", {}).items(): + for http_method, operation in operations.items(): + method = http_method.upper() + args_schema = get_args_schema(spec, operation) + name = re.sub(r'[^a-zA-Z0-9_]', '_', operation.get("summary").lower()) + + # Use summary or description from spec + description = ( + operation.get("summary") + or operation.get("description", f"{method} {path}") + ) + + # Construct full URL for this endpoint + full_url = server_url.rstrip("/") + path + + # Create the coroutine function for this endpoint + api_func = make_api_call_func( + method, + full_url, + headers or {}, + description, + ) + + # Build the tool + tool = StructuredTool.from_function(coroutine=api_func, args_schema=args_schema) + tool.name = name + tool.description = description + tool.tags = operation.get("tags", []) + tool.metadata = operation.get("x-metadata", {}) + tool.verbose = verbose + + tools.append(tool) + + return tools + +# helper to turn a FieldInfo into a (type, default-or-Field) tuple +def _to_field_def(fi: FieldInfo): + default = fi.get_default() + return fi.annotation, Field(default=default, description=fi.description) + +def merge_models(summary: str, reqBody=None, pathParams=None, queryParams=None): + all_defs = {} + if pathParams: + for name, fi in pathParams.model_fields.items(): + # Create a new Field with metadata instead of modifying the FieldInfo + field_annotation, field_obj = _to_field_def(fi) + # Add metadata to the new Field object + field_with_metadata = Field( + default=field_obj.default, + description=field_obj.description, + metadata={'in': 'path'} + ) + all_defs[name] = (field_annotation, field_with_metadata) + if queryParams: + for name, fi in queryParams.model_fields.items(): + field_annotation, field_obj = _to_field_def(fi) + field_with_metadata = Field( + default=field_obj.default, + description=field_obj.description, + metadata={'in': 'query'} + ) + all_defs[name] = (field_annotation, field_with_metadata) + if reqBody: + for name, fi in reqBody.model_fields.items(): + field_annotation, field_obj = _to_field_def(fi) + field_with_metadata = Field( + default=field_obj.default, + description=field_obj.description, + metadata={'in': 'body'} + ) + all_defs[name] = (field_annotation, field_with_metadata) + if all_defs: + model_name = summary.replace(" ", "") + # unpack all_defs into create_model + return create_model(model_name, **all_defs) + else: + return None + +def get_args_schema(spec: Dict[str, Any], operation: Dict[str, Any]) -> Dict[str, Any]: + try: + reqBody = None + pathParams = None + queryParams = None + if operation.get('requestBody'): + ref = operation.get('requestBody').get('content').get('application/json').get('schema').get('$ref') + if ref: + schema = resolve_ref(spec, ref) + fully_resolved_schema = resolve_ref_recursive(spec, schema) + reqBody = create_schema(fully_resolved_schema.get('title'), fully_resolved_schema.get('properties')) + + if operation.get('parameters'): + for param in operation.get('parameters'): + if param.get('in') == 'query': + queryParams = create_schema(param.get('name'), param.get('schema')) + elif param.get('in') == 'path': + pathParams = create_schema(param.get('name'), param.get('schema')) + + merged = merge_models(operation.get('summary'), reqBody, pathParams, queryParams) + return merged + except Exception as e: + logger.exception(f"Failed to get args schema for {operation.get('summary')}: {e}") + raise Exception(f"Failed to get args schema for {operation.get('summary')}: {e}") + +if __name__ == "__main__": + tool = construct_api_tool( + name="Get a post", + description="Get a post", + method="GET", + headers={}, + verbose=True + ) + print(tool) + + # Example usage + tools = generate_tools_from_openapi_spec( + "http://localhost:8050/openapi.json" + ) + for t in tools: + print(t.name, "-", t.description) + print(t.args_schema) \ No newline at end of file diff --git a/backend/src/tools/sql.py b/backend/src/tools/sql.py index 57fd08dd..15a9979b 100644 --- a/backend/src/tools/sql.py +++ b/backend/src/tools/sql.py @@ -10,6 +10,7 @@ from src.utils.logger import logger from src.utils.llm import LLMWrapper from src.constants.llm import ModelName + @tool def sql_query_read(question: str): """Execute a read-only query against a PostgreSQL database based on a natural language question. diff --git a/backend/src/utils/agent.py b/backend/src/utils/agent.py index 2faef2ad..dcc37e9c 100644 --- a/backend/src/utils/agent.py +++ b/backend/src/utils/agent.py @@ -6,9 +6,11 @@ from langgraph.checkpoint.postgres import PostgresSaver from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver from langgraph.prebuilt import create_react_agent -from langchain_core.tools import StructuredTool, tool from psycopg.connection_async import AsyncConnection +from pydantic import BaseModel +from src.repos.tool_repo import ToolRepo +from src.tools.a2a import create_a2a_tools from src.repos.thread_repo import ThreadRepo from src.entities.a2a import A2AServer from src.services.mcp import McpService @@ -20,13 +22,9 @@ from src.entities import Answer, Thread from src.utils.logger import logger from src.flows.chatbot import chatbot_builder -from langchain.chat_models import init_chat_model from src.services.db import create_async_pool, get_checkpoint_db -from pydantic import BaseModel from src.utils.format import get_base64_image -import sys - -from src.utils.a2a import A2ACardResolver, A2AClient, a2a_builder +from src.tools.api import generate_tools_from_openapi_spec class StreamContext(BaseModel): @@ -35,7 +33,7 @@ class StreamContext(BaseModel): event: str = '' class Agent: - def __init__(self, config: dict, user_repo: UserRepo = None): + def __init__(self, config: dict, user_repo: UserRepo = None, tool_repo: ToolRepo = None): self.connection_kwargs = { "autocommit": True, "prepare_threshold": 0, @@ -47,6 +45,7 @@ def __init__(self, config: dict, user_repo: UserRepo = None): self.graph = None self.pool: AsyncConnection = None # Don't create pool in constructor self.user_repo = user_repo + self.tool_repo = tool_repo self.model_name = config.get("model_name", None) self.llm: LLMWrapper = None self.tools = config.get("tools", []) @@ -333,9 +332,11 @@ async def abuilder( a2a: dict[str, A2AServer] = None, model_name: str = ModelName.ANTHROPIC_CLAUDE_3_7_SONNET_LATEST, checkpointer: AsyncPostgresSaver = None, - debug: bool = True if APP_LOG_LEVEL == "DEBUG" else False + debug: bool = True if APP_LOG_LEVEL == "DEBUG" else False, + name: str = "EnsoAgent" ): - self.tools = [] if len(tools) == 0 else dynamic_tools(selected_tools=tools, metadata={'user_repo': self.user_repo}) + tool_selection = await self.tool_repo.list_tools() + self.tools = [] if len(tools) == 0 else dynamic_tools(selected_tools=tools, metadata={'user_repo': self.user_repo}, tool_selection=tool_selection) self.llm = LLMWrapper(model_name=model_name, tools=self.tools, user_repo=self.user_repo) self.checkpointer = checkpointer system = self.config.get('configurable').get("system", None) @@ -343,33 +344,16 @@ async def abuilder( if mcp and len(mcp.keys()) > 0: await self.agent_session.setup(mcp) self.tools.extend(self.agent_session.tools()) - + + # Get A2A tools if provided if a2a and len(a2a.keys()) > 0: - # Check if a2a is a dictionary with multiple entries - if isinstance(a2a, dict): - # Loop through each entry in the a2a dictionary - for key, config in a2a.items(): - - card = A2ACardResolver( - base_url=config.base_url, - agent_card_path=config.agent_card_path - ).get_agent_card() - - async def send_task(query: str): - return await a2a_builder( - base_url=config.base_url, - query=query, - thread_id=self.thread_id - ) - send_task.__doc__ = ( - f"Send query to remote agent: {card.name}. " - f"Agent Card: {card.model_dump_json()}" - ) - tool = StructuredTool.from_function(coroutine=send_task) - tool.name = card.name.lower().replace(" ", "_") - # tool.name = key + "_" + card.name.lower().replace(" ", "_") - # tool.description = card.description - self.tools.append(tool) + a2a_tools = create_a2a_tools(thread_id=self.thread_id, a2a=a2a) + self.tools.extend(a2a_tools) + + # tools_from_spec = generate_tools_from_openapi_spec( + # "http://localhost:8050/openapi.json" + # ) + # self.tools.extend(tools_from_spec) if self.tools: graph = create_react_agent(self.llm, prompt=system, tools=self.tools, checkpointer=self.checkpointer) @@ -378,9 +362,9 @@ async def send_task(query: str): graph = builder.compile(checkpointer=self.checkpointer) if debug: - graph.debug = True + graph.debug = debug self.graph = graph - self.graph.name = "EnsoAgent" + self.graph.name = name return graph async def aprocess( diff --git a/backend/src/utils/auth.py b/backend/src/utils/auth.py index c97a78e1..1cff5955 100644 --- a/backend/src/utils/auth.py +++ b/backend/src/utils/auth.py @@ -79,7 +79,8 @@ async def verify_credentials( request.state.user_repo = UserRepo(db, request.state.user.id) return user.protected() - except JWTError: + except JWTError as e: + logger.error(f"JWTError: {e}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", diff --git a/docker-compose.yml b/docker-compose.yml index 3cef65c5..471df172 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -5,6 +5,7 @@ services: postgres: image: pgvector/pgvector:pg16 container_name: postgres + restart: always environment: POSTGRES_USER: admin POSTGRES_PASSWORD: test1234 @@ -18,6 +19,7 @@ services: pgadmin: image: dpage/pgadmin4 container_name: pgadmin + restart: always environment: PGADMIN_DEFAULT_EMAIL: admin@example.com PGADMIN_DEFAULT_PASSWORD: test1234