diff --git a/.github/workflows/a2a-tck-test.yml b/.github/workflows/a2a-tck-test.yml new file mode 100644 index 0000000000..6ebd15ad0c --- /dev/null +++ b/.github/workflows/a2a-tck-test.yml @@ -0,0 +1,112 @@ +name: A2A TCK Test + +on: + push: + paths: + - 'a2a/**' + pull_request: + paths: + - 'a2a/**' + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/develop' }} + +env: + JAVA_VERSION: 17 + JAVA_DISTRIBUTION: 'corretto' + +jobs: + a2a-tck-test: + runs-on: ubuntu-latest + timeout-minutes: 10 + + permissions: + contents: read + + steps: + - name: Configure Git + run: | + git config --global core.autocrlf input + + - uses: actions/checkout@v5 + + - name: Set up JDK ${{ env.JAVA_VERSION }} + uses: actions/setup-java@v5 + with: + java-version: ${{ env.JAVA_VERSION }} + distribution: ${{ env.JAVA_DISTRIBUTION }} + + - name: Setup Gradle + uses: gradle/actions/setup-gradle@v4 + + - name: Set up Python 3.12 + uses: actions/setup-python@v6 + with: + python-version: '3.12' + + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + version: "0.8.15" + enable-cache: true + + - name: Setup A2A TCK + working-directory: ./a2a/test-tck + run: ./setup_tck.sh + + - name: Build test server + run: ./gradlew :a2a:test-tck:a2a-test-server-tck:assemble + + - name: Start test server in background + working-directory: ./a2a/test-tck + run: | + ./run_sut.sh > server.log 2>&1 & + SERVER_PID=$! + echo "SERVER_PID=$SERVER_PID" >> $GITHUB_ENV + + # Wait for server to start (max 60 seconds) + timeout 60 bash -c ' + while ! grep -q "Responding at http://0.0.0.0:9999" server.log 2>/dev/null; do + sleep 1 + done + ' + + echo "Server started successfully" + + - name: Run TCK tests - Mandatory + working-directory: ./a2a/test-tck + run: ./run_tck.sh --sut-url http://localhost:9999/a2a --category mandatory --report + + - name: Run TCK tests - Capabilities + working-directory: ./a2a/test-tck + run: ./run_tck.sh --sut-url http://localhost:9999/a2a --category capabilities --report + + - name: Run TCK tests - Transport Equivalence + working-directory: ./a2a/test-tck + run: ./run_tck.sh --sut-url http://localhost:9999/a2a --category transport-equivalence --report + + - name: Run TCK tests - Quality + working-directory: ./a2a/test-tck + run: ./run_tck.sh --sut-url http://localhost:9999/a2a --category quality --report + + - name: Run TCK tests - Features + working-directory: ./a2a/test-tck + run: ./run_tck.sh --sut-url http://localhost:9999/a2a --category features --report + + - name: Stop test server + if: always() + run: | + if [ ! -z "$SERVER_PID" ]; then + kill $SERVER_PID || true + wait $SERVER_PID 2>/dev/null || true + fi + + - name: Upload TCK reports + if: always() + uses: actions/upload-artifact@v4 + with: + name: a2a-tck-reports + path: a2a/test-tck/a2a-tck/reports/*.html + if-no-files-found: warn diff --git a/.github/workflows/ollama-tests.yml b/.github/workflows/ollama-tests.yml index 5f9ea100b7..48d9a83cde 100644 --- a/.github/workflows/ollama-tests.yml +++ b/.github/workflows/ollama-tests.yml @@ -14,19 +14,74 @@ on: jobs: integration-tests: - + name: ${{ matrix.job-name }} runs-on: ${{ matrix.os }} permissions: contents: read strategy: matrix: - os: [ ubuntu-latest ] + include: + - job-name: "ollama-executor-tests" + test-group: "ai.koog.integration.tests.executor.OllamaExecutorIntegrationTest" + artifact-name: "ollama-executor-tests" + os: ubuntu-latest + - job-name: "ollama-agent-tests" + test-group: "ai.koog.integration.tests.agent.OllamaAgentIntegrationTest" + artifact-name: "ollama-agent-tests" + os: ubuntu-latest + - job-name: "ollama-simple-agent-tests" + test-group: "ai.koog.integration.tests.agent.OllamaSimpleAgentIntegrationTest" + artifact-name: "ollama-simple-agent-tests" + os: ubuntu-latest + fail-fast: false steps: - name: Configure Git + run: git config --global core.autocrlf input + + - name: Free up disk space run: | - git config --global core.autocrlf input + echo "=== Disk space before cleanup ===" + df -h + echo "=== Docker system info ===" + docker system df + + # Stop and remove all containers + docker ps -aq | xargs -r docker rm -f || true + + # Remove all Docker images, networks, and volumes + docker system prune -af --volumes + docker image prune -af + docker volume prune -af + docker network prune -f + + # Clean package caches + sudo apt-get clean + sudo apt-get autoclean + sudo apt-get autoremove -y + + # Remove large directories + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf /usr/local/share/boost + sudo rm -rf /opt/az + sudo rm -rf /usr/share/swift + sudo rm -rf /usr/local/lib/android + sudo rm -rf /usr/local/.ghcup + sudo rm -rf /home/runner/.dotnet + sudo rm -rf /opt/hostedtoolcache + + # Clean npm and yarn caches + sudo rm -rf /home/runner/.npm + sudo rm -rf /home/runner/.yarn + + # Clean Gradle cache from previous runs + sudo rm -rf /home/runner/.gradle + + echo "=== Disk space after cleanup ===" + df -h + - uses: actions/checkout@v5 - name: Set up JDK 17 uses: actions/setup-java@v5 @@ -38,16 +93,76 @@ jobs: # See: https://github.com/gradle/actions/blob/main/setup-gradle/README.md - name: Setup Gradle uses: gradle/actions/setup-gradle@v4 + with: + cache-disabled: true + + - name: Check disk space before tests + run: | + echo "=== Disk space ===" + df -h + echo "=== Available space in GB ===" + df -BG / | tail -1 | awk '{print "Available: " $4}' + echo "=== Docker info ===" + docker system df + echo "=== Memory info ===" + free -h - name: JvmOllamaTest with Gradle Wrapper env: OLLAMA_IMAGE_URL: ${{ vars.OLLAMA_IMAGE_URL }} - run: ./gradlew jvmOllamaTest --no-parallel --continue + GRADLE_OPTS: "-Dorg.gradle.daemon=false -Xmx1g -XX:MaxMetaspaceSize=512m" + run: | + echo "=== Starting tests with available disk space ===" + df -h / | tail -1 + + ./gradlew jvmOllamaTest \ + --tests "${{ matrix.test-group }}" \ + --no-parallel \ + --no-daemon \ + --no-build-cache \ + --no-configuration-cache \ + --continue \ + --stacktrace + + echo "=== Test completed, checking disk space ===" + df -h / | tail -1 + timeout-minutes: 60 + + - name: Cleanup Docker resources + if: always() + run: | + echo "=== Stopping all containers ===" + docker ps -q | xargs -r docker stop || true + + echo "=== Removing all containers ===" + docker ps -aq | xargs -r docker rm -f || true + + echo "=== Removing Ollama images ===" + docker images | grep ollama | awk '{print $3}' | xargs -r docker rmi -f || true + + echo "=== Pruning system ===" + docker system prune -af --volumes + + echo "=== Final disk space ===" + df -h + + - name: Check disk space after tests + if: always() + run: | + echo "=== Final disk space check ===" + df -h + echo "=== Available space in GB ===" + df -BG / | tail -1 | awk '{print "Available: " $4}' + echo "=== Docker system info ===" + docker system df + echo "=== Largest directories in /home/runner ===" + sudo du -sh /home/runner/* 2>/dev/null | sort -hr | head -10 || true - name: Collect reports if: always() uses: actions/upload-artifact@v4 with: - name: reports-${{ matrix.os }} + name: reports-${{ matrix.os }}-${{ matrix.artifact-name }} path: | **/build/reports/ + retention-days: 7 diff --git a/.gitignore b/.gitignore index e24ef2425c..5053c8f2fe 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,5 @@ local.properties docs/src/main/kotlin/*.kt **/.env .venv +.DS_Store +**/kotlin-js-store diff --git a/AGENT.md b/AGENT.md new file mode 100644 index 0000000000..6831b675ae --- /dev/null +++ b/AGENT.md @@ -0,0 +1,231 @@ +# Koog AI Agent Framework + +Koog is a Kotlin multiplatform framework for building AI agents with graph-based workflows. +It supports JVM and JS targets and integrates with multiple LLM providers +(OpenAI, Anthropic, Google, OpenRouter, Ollama) and Model Context Protocol (MCP). + +## Project Structure + +The project follows a modular architecture with a clear separation of concerns: + +``` +koog/ +├── agents/ +│ ├── agents-core/ # Core abstractions (AIAgent, AIAgentStrategy, AIAgentEnvironment) +│ ├── agents-tools/ # Tool infrastructure (Tool, ToolRegistry, AIAgentTool) +│ ├── agents-features-*/ # Feature implementations (memory, tracing, event handling) +│ ├── agents-mcp/ # Model Context Protocol integration +│ └── agents-test/ # Testing utilities and framework +├── prompt-*/ # LLM interaction layer (executors, models, structured data) +├── embeddings-*/ # Vector embedding support +├── examples/ # Reference implementations and usage patterns +└── build.gradle.kts # Root build configuration +``` + +## Build & Commands + +### Development Commands + +```bash +# Full build including tests +./gradlew build + +# Build without tests +./gradlew assemble + +# Run all JVM tests +./gradlew jvmTest + +# Run all JS tests +./gradlew jsTest + +# Test specific module +./gradlew :agents:agents-core:jvmTest + +# Run specific test class +./gradlew jvmTest --tests "ai.koog.agents.test.SimpleAgentMockedTest" + +# Run specific test method +./gradlew jvmTest --tests "ai.koog.agents.test.SimpleAgentMockedTest.test AIAgent doesn't call tools by default" + +# Compile test classes only (for faster iteration) +./gradlew jvmTestClasses jsTestClasses +``` + +### Development Environment + +- **JDK**: 17+ required for JVM target +- **Build System**: Gradle with version catalogs for dependency management +- **Targets**: JVM, JavaScript (Kotlin Multiplatform), WASM +- **IDE**: IntelliJ IDEA recommended with Kotlin plugin + +## Code Style + +- Follow [Kotlin Coding Conventions](https://kotlinlang.org/docs/coding-conventions.html) +- Use four spaces for indentation (consistent across all files) +- Name test functions as `testXxx` (no backticks for readability) +- Use descriptive variable and function names +- Prefer functional programming patterns where appropriate +- Use type-safe builders and DSLs for configuration +- Document public APIs with KDoc comments +- NEVER suppress compiler warnings without a good reason + +## Architecture + +### Core Framework Components + +**AIAgent** — Main orchestrator that executes strategies in coroutine scopes, manages tools via ToolRegistry, +runs features through AIAgentPipeline, and handles LLM communication via PromptExecutor. + +**AIAgentStrategy** — Graph-based execution logic that defines workflows as subgraphs with start/finish nodes, +manages tool selection strategy, and handles termination/error reporting. + +**ToolRegistry** — Centralized, type-safe tool management using a builder pattern: `ToolRegistry { tool(MyTool()) }`. +Supports registry merging with `+` operator. + +**AIAgentFeature** — Extensible capabilities installed into AIAgentPipeline with configuration. +Features have unique storage keys and can intercept agent lifecycle events. + +### Module Organization + +1. **agents-core**: Core abstractions (`AIAgent`, `AIAgentStrategy`, `AIAgentEnvironment`) +2. **agents-tools**: Tool infrastructure (`Tool`, `ToolRegistry`, `AIAgentTool`) +3. **agents-features-***: Feature implementations (memory, tracing, event handling) +4. **agents-mcp**: Model Context Protocol integration +5. **prompt-***: LLM interaction layer (executors, models, structured data) +6. **embeddings-***: Vector embedding support +7. **examples**: Reference implementations and usage patterns + +### Key Architectural Patterns + +- **State Machine Graphs**: Agents execute as node graphs with typed edges +- **Feature Pipeline**: Extensible behavior via installable features with lifecycle hooks +- **Environment Abstraction**: Safe tool execution context preventing direct tool calls +- **Type Safety**: Generics ensure compile-time correctness for tool arguments/results +- **Builder Patterns**: Fluent APIs for configuration throughout the framework + +## Testing + +The framework provides comprehensive testing utilities in `agents-test` module: + +### LLM Response Mocking +```kotlin +val mockLLMApi = getMockExecutor(toolRegistry, eventHandler) { + mockLLMAnswer("Hello!") onRequestContains "Hello" + mockLLMToolCall(CreateTool, CreateTool.Args("solve")) onRequestEquals "Solve task" + mockLLMAnswer("Default response").asDefaultResponse +} +``` + +### Tool Behavior Mocking +```kotlin +// Simple return value +mockTool(PositiveToneTool) alwaysReturns "The text has a positive tone." + +// With additional actions +mockTool(NegativeToneTool) alwaysTells { + println("Tool called") + "The text has a negative tone." +} + +// Conditional responses +mockTool(SearchTool) returns SearchTool.Result("Found") onArgumentsMatching { + args.query.contains("important") +} +``` + +### Graph Structure Testing +```kotlin +AIAgent(...) { + withTesting() + + testGraph("test") { + val firstSubgraph = assertSubgraphByName("first") + val secondSubgraph = assertSubgraphByName("second") + + assertEdges { + startNode() alwaysGoesTo firstSubgraph + firstSubgraph alwaysGoesTo secondSubgraph + } + + verifySubgraph(firstSubgraph) { + val askLLM = assertNodeByName("callLLM") + assertNodes { + askLLM withInput "Hello" outputs Message.Assistant("Hello!") + } + } + } +} +``` + +For comprehensive testing examples, see `agents/agents-test/TESTING.md`. + +## Security + +### API Key Management +- **NEVER** commit API keys or secrets to the repository +- Use environment variables for all sensitive configuration +- Store test API keys in a local environment only +- Required environment variables for integration tests: + - `ANTHROPIC_API_TEST_KEY` + - `OPEN_AI_API_TEST_KEY` + - `GEMINI_API_TEST_KEY` + - `OPEN_ROUTER_API_TEST_KEY` + - `OLLAMA_IMAGE_URL` + +### Tool Execution Safety +- Tools execute within controlled `AIAgentEnvironment` contexts +- Direct tool calls are prevented outside agent execution +- Use type-safe tool arguments to prevent injection attacks +- Validate all external inputs in tool implementations + +### Dependency Security +- Regularly update dependencies using Gradle version catalogs +- Use specific version ranges to avoid supply chain attacks +- Review dependencies for known vulnerabilities +- Follow the principle of the least privilege in tool implementations + +## Configuration + +### Environment Setup +Set environment variables for integration testing (never commit API keys): +```bash +# Export in your shell or IDE run configuration +export ANTHROPIC_API_TEST_KEY=your_key_here +export OPEN_AI_API_TEST_KEY=your_key_here +export GEMINI_API_TEST_KEY=your_key_here +export OPEN_ROUTER_API_TEST_KEY=your_key_here +export OLLAMA_IMAGE_URL=http://localhost:11434 + +# Or add to ~/.bashrc, ~/.zshrc, or IDE environment variables +``` + +### Gradle Configuration +- Uses version catalogs (`gradle/libs.versions.toml`) for dependency management +- Multiplatform configuration in `build.gradle.kts` +- Test configuration supports both JVM and JS targets + +### Development Environment Requirements +- **JDK**: 17+ (OpenJDK recommended) +- **IDE**: IntelliJ IDEA with Kotlin Multiplatform plugin +- **Optional**: Docker for Ollama local testing + +## Development Workflow + +### Branch Strategy +- **develop**: All development (features and bug fixes) +- **main**: Released versions only +- Base all PRs against `develop` branch +- Use descriptive branch names: `feature/agent-memory`, `fix/tool-registry-bug` + +### Code Quality +- **ALWAYS** run `./gradlew build` before submitting PRs +- Ensure all tests pass on JVM, JS, WASM targets +- Follow established patterns in existing code +- Add tests for new functionality +- Update documentation for API changes + +### Commit Guidelines +- Use conventional commit format: `feat:`, `fix:`, `docs:`, `test:` +- Include issue references where applicable +- Keep commits focused and atomic \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c36a97231..0e5fb59525 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,121 @@ +# 0.5.0 + +> Published 2 Oct 2025 + +## Major Features + +- **Full Agent-to-Agent (A2A) Protocol Support**: + - **Multiplatform Kotlin A2A SDK**: Including server and client with JSON-RPC HTTP support. + - **A2A Agent Feature**: seamlessly integrate A2A in your Koog agents +- **Non-Graph API for Strategies**: Introduced non-graph API for creating AI Agent strategies as Kotlin extension + functions with most of Koog's features supported (#560) +- **Agent Persistence and Checkpointing**: + - **Roll back Tool Side-Effects**: Add `RollbackToolRegistry` in the `Persistence` feature in order to roll back + tool calls with side effects when checkpointing. + - **State-Machine Persistence / Message History Switch**: Support switching between full state-machine persistence + and message history persistence (#856) +- **Tool API Improvements**: + - Make `ToolDescriptor` auto-generated for class-based tools (#791) + - Get rid of `ToolArgs` and `ToolResult` limitations for `Tool<*, *>` class (#791) +- **`subgraphWithTask` Simplification**: Get rid of required `finishTool` and support tools as functions in + `subgraphWithTask`, deduce final step automatically by data class (#791) +- **`AIAgentService` Introduced**: Make `AIAgent` state-manageable and single-run explicitly, introduce `AIAgentService` + to manage multiple uniform running agents. +- **New components**: + - Add LLM as a Judge component (#866) + - Tool Calling loop with Structured Output strategy (#829) + +## Improvements + +- Make Koog-based tools exportable via MCP server (KG-388) +- Add `additionalProperties` to LLM clients in order to support custom LLM configurations (#836) +- Allow adjusting context window sizes for Ollama dynamically (#883) +- Refactor streaming api to support tool calls (#747) +- Provide an ability to collect and send a list of nodes and edges out of `AIAgentStrategy` to the client when running + an agent (KG-160) +- Add `excludedProperties` to inline `createJsonStructure` too, update KDocs (#826) +- Refactor binary attachment handling and introduce Base64 serializer (#838) +- In `JsonStructuredData.defaultJson` instance rename class discriminator + from `#type` to `kind` to align with common practices (#772, KG-384) +- Make standard json generator default when creating `JsonStructuredData` + (it was basic before) (#772, KG-384) +- Add default audio configuration and modalities (#817) +- Add `GptAudio` model in OpenAI client (#818) +- Allow re-running of finished agents that have `Persistence` feature installed (#828, KG-193) +- Allow ideomatic node transformations with `.transform { ...}` lambda function (#684) +- Add ability to filter messages for every agent feature (KG-376) +- Add support for trace-level attributes in Langfuse integration (#860, KG-427) +- Keep all system messages when compressing message history of the agent(#857) +- Add support for Anthropic's Sonnet 4.5 model in Anthropic/Bedrock providers (#885) +- Refactored LLM client auto-configuration in Spring Boot integration, to modular provider-specific classes with + improved validation and security (#886) +- Add LLM Streaming agent events (KG-148) + +## Bug Fixes + +- Fix broken Anthropic models support via Amazon Bedrock (#789) +- Make `AIAgentStorageKey` in agent storage actually unique by removing `data` modifier (#825) +- Fix rerun for agents with Persistence (#828, KG-193) +- Update mcp version to `0.7.2` with fix for Android target (#835) +- Do not include an empty system message in Anthropic request (#887, KG-317) +- Use `maxTokens` from params in Google models (#734) +- Fix finishReason nullability (#771) + +## Deprecations + +- Rename agent interceptors in `EventHandler` and related feature events (KG-376) +- Deprecate concurrent unsafe `AIAgent.asTool` in favor of `AIAgentService.createAgentTool` (#873) +- Rename `Persistency` to `Persistence` everywhere (#896) +- Add `agentId` argument to all `Persistence` methods instead of `persistencyId` class field (#904) + +## Examples + +- Add a basic code-agent example (#808, KG-227) +- Add iOS and Web targets for demo-compose-app (#779, #780) + +# 0.4.2 + +> Published 15 Sep 2025 + +## Improvements + +- Make agents‑mcp support KMP targets to run across more platforms (#756). +- Add LLM client retry support to Spring Boot auto‑configuration to improve resilience on transient failures (#748). +- Add Claude Opus 4.1 model support to Anthropic client to unlock latest reasoning capabilities (#730). +- Add Gemini 2.5 Flash Lite model support to Google client to enable lower‑latency, cost‑efficient generations (#769). +- Add Java‑compatible non‑streaming Prompt Executor so Java apps can call Koog without + coroutines ([KG-312](https://youtrack.jetbrains.com/issue/KG-312), #715). +- Support excluding properties in JSON Schema generation to fine‑tune structured outputs (#638). +- Update AWS SDK to latest compatible version for Bedrock integrations. +- Introduce Postgres persistence provider to store agent state and artifacts (#705). +- Update Kotlin to 2.2.10 in dependency configuration for improved performance and language features (#764). +- Refactor executeStreaming to remove suspend for simpler interop and better call sites (#720). +- Add Java‑compatible prompt executor (non‑streaming) wiring and polish across + modules ([KG-312](https://youtrack.jetbrains.com/issue/KG-312), #715). +- Decouple FileSystemEntry from FileSystemProvider to simplify testing and enable alternative providers (#664). + +## Bug Fixes + +- Add missing tool calling support for Bedrock Nova models so agents can invoke functions when using + Nova ([KG-239](https://youtrack.jetbrains.com/issue/KG-239)). +- Add Android target support and migrate Android app to Kotlin Multiplatform to widen KMP + coverage ([KG-315](https://youtrack.jetbrains.com/issue/KG-315), #728, #767). +- Add Spring Boot Java example to jump‑start integration (#739). +- Add Java Spring auto‑config fixes: correct property binding and make Koog starter work out of the box (#698). +- Fix split package issues in OpenAI LLM clients to avoid classpath/load + errors ([KG-305](https://youtrack.jetbrains.com/issue/KG-305), #694). +- Ensure Anthropic tool schemas include the required "type" field in serialized request bodies to prevent validation + errors during tool calling (#582). +- Fix AbstractOpenAILLMClient to correctly handle plain‑text responses in capabilities flow; add integration tests to + prevent regressions (#564). +- Fix GraalVM native image build failure so projects can compile native binaries again (#774). +- Fix usages in OpenAI‑based data model to align with recent API changes (#688). +- Fix SpringBootStarters initialization and improve `RetryingClient` (#894) + +## CI and Build + +- Nightly build configuration and dependency submission workflow added (#695, #737). + # 0.4.1 > Published 28 Aug 2025 @@ -13,42 +131,58 @@ Fixed iOS target publication ## Major Features - **Integration with Observability Tools**: - - **Langfuse Integration**: Span adapters for Langfuse client, including open telemetry and graph visualisation ([KG-217](https://youtrack.jetbrains.com/issue/KG-217), [KG-223](https://youtrack.jetbrains.com/issue/KG-223)) - - **W&B Weave Integration**: Span adapters for W&B Weave open telemetry and observability ([KG-217](https://youtrack.jetbrains.com/issue/KG-217), [KG-218](https://youtrack.jetbrains.com/issue/KG-218)) -- **Ktor Integration**: First-class Ktor support via the "Koog" Ktor plugin to register and run agents in Ktor applications (#422). -- **iOS Target Support**: Multiplatform expanded with native iOS targets, enabling agents to run on Apple platforms (#512). -- **Upgraded Structured Output**: Refactored structured output API to be more flexible and add built-in/native provider support for OpenAI and Google, reducing prompt boilerplate and improving validation (#443). -- **GPT5 and Custom LLM Parameters Support**: Now GPT5 is available together with custom additional LLM parameters for OpenAI-compatible clients (#631, #517) + - **Langfuse Integration**: Span adapters for Langfuse client, including open telemetry and graph + visualisation ([KG-217](https://youtrack.jetbrains.com/issue/KG-217), [KG-223](https://youtrack.jetbrains.com/issue/KG-223)) + - **W&B Weave Integration**: Span adapters for W&B Weave open telemetry and + observability ([KG-217](https://youtrack.jetbrains.com/issue/KG-217), [KG-218](https://youtrack.jetbrains.com/issue/KG-218)) +- **Ktor Integration**: First-class Ktor support via the "Koog" Ktor plugin to register and run agents in Ktor + applications (#422). +- **iOS Target Support**: Multiplatform expanded with native iOS targets, enabling agents to run on Apple platforms ( + #512). +- **Upgraded Structured Output**: Refactored structured output API to be more flexible and add built-in/native provider + support for OpenAI and Google, reducing prompt boilerplate and improving validation (#443). +- **GPT5 and Custom LLM Parameters Support**: Now GPT5 is available together with custom additional LLM parameters for + OpenAI-compatible clients (#631, #517) - **Resilience and Retries**: - - **Retryable LLM Clients**: Introduce retry logic for LLM clients with sensible defaults to reduce transient failures (#592) - - **Retry Anything with LLM Feedback**: Add a feedback mechanism to the retry component (`subgraphWithRetry`) to observe and tune behavior (#459). + - **Retryable LLM Clients**: Introduce retry logic for LLM clients with sensible defaults to reduce transient + failures (#592) + - **Retry Anything with LLM Feedback**: Add a feedback mechanism to the retry component (`subgraphWithRetry`) to + observe and tune behavior (#459). ## Improvements - **OpenTelemetry and Observability**: - - Finish reason and unified attributes for inference/tool/message spans and events; extract event body fields to attributes for better querying ([KG-218](https://youtrack.jetbrains.com/issue/KG-218)). - - Mask sensitive data in events/attributes and introduce a “hidden-by-default” string type to keep secrets safe in logs ([KG-259](https://youtrack.jetbrains.com/issue/KG-259)). - - Include all messages into the inference span and add an index for ChoiceEvent to simplify analysis ([KG-172](https://youtrack.jetbrains.com/issue/KG-172)). - - Add tool arguments to `gen_ai.choice` and `gen_ai.assistant.message` events (#462). - - Allow setting a custom OpenTelemetry SDK instance in Koog ([KG-169](https://youtrack.jetbrains.com/issue/KG-169)). + - Finish reason and unified attributes for inference/tool/message spans and events; extract event body fields to + attributes for better querying ([KG-218](https://youtrack.jetbrains.com/issue/KG-218)). + - Mask sensitive data in events/attributes and introduce a “hidden-by-default” string type to keep secrets safe in + logs ([KG-259](https://youtrack.jetbrains.com/issue/KG-259)). + - Include all messages into the inference span and add an index for ChoiceEvent to simplify + analysis ([KG-172](https://youtrack.jetbrains.com/issue/KG-172)). + - Add tool arguments to `gen_ai.choice` and `gen_ai.assistant.message` events (#462). + - Allow setting a custom OpenTelemetry SDK instance in Koog ([KG-169](https://youtrack.jetbrains.com/issue/KG-169)). - **LLM and Providers**: - - Support Google’s “thinking” mode in generation config to improve reasoning quality (#414). - - Add responses API support for OpenAI (#645) - - AWS Bedrock: support Inference Profiles for simpler, consistent configuration (#506) and accept `AWS_SESSION_TOKEN` (#456). - - Add `maxTokens` as prompt parameters for finer control over generation length (#579). - - Add `contextLength` and `maxOutputTokens` to `LLModel` (#438, [KG-134](https://youtrack.jetbrains.com/issue/KG-134)) + - Support Google’s “thinking” mode in generation config to improve reasoning quality (#414). + - Add responses API support for OpenAI (#645) + - AWS Bedrock: support Inference Profiles for simpler, consistent configuration (#506) and accept + `AWS_SESSION_TOKEN` (#456). + - Add `maxTokens` as prompt parameters for finer control over generation length (#579). + - Add `contextLength` and `maxOutputTokens` to `LLModel` ( + #438, [KG-134](https://youtrack.jetbrains.com/issue/KG-134)) - **Agent Engine**: - - Add AIAgentPipeline interceptors to uniformly handle node errors; propagate `NodeExecutionError` across features ([KG-170](https://youtrack.jetbrains.com/issue/KG-170)). - - Include finish node processing in the pipeline to ensure finalizers run reliably (#598). + - Add AIAgentPipeline interceptors to uniformly handle node errors; propagate `NodeExecutionError` across + features ([KG-170](https://youtrack.jetbrains.com/issue/KG-170)). + - Include finish node processing in the pipeline to ensure finalizers run reliably (#598). - **File Tools and RAG**: - - Reworked FileSystemProvider with API cleanups and better ergonomics; moved blocking/suspendable operations to `Dispatchers.IO` for improved performance and responsiveness (#557, “Move suspendable operations to Dispatchers.IO”). - - Introduce `filterByRoot` helpers and allow custom path filters in `FilteredFileSystemProvider` for safer agent sandboxes (#494, #508). + - Reworked FileSystemProvider with API cleanups and better ergonomics; moved blocking/suspendable operations to + `Dispatchers.IO` for improved performance and responsiveness (#557, “Move suspendable operations to + Dispatchers.IO”). + - Introduce `filterByRoot` helpers and allow custom path filters in `FilteredFileSystemProvider` for safer agent + sandboxes (#494, #508). - Rename `PathFilter` to `TraversalFilter` and make its methods suspendable to support async checks. - Rename `fromAbsoluteString` to `fromAbsolutePathString` for clarity (#567). - Add `ReadFileTool` for reading local file contents where appropriate (#628). - Update kotlin-mcp dependency to v0.6.0 (#523) - ## Bug Fixes - Make `parts` field nullable in Google responses to handle missing content from Gemini models (#652). @@ -56,7 +190,8 @@ Fixed iOS target publication - Fix function calling for `gemini-2.5-flash` models to correctly route tool invocations (#586). - Restore OpenAI `responseFormat` option support in requests (#643). - Correct `o4-mini` vs `gpt-4o-mini` model mix-up in configuration (#573). -- Ensure event body for function calls is valid JSON for telemetry ingestion ([KG-268](https://youtrack.jetbrains.com/issue/KG-268)). +- Ensure event body for function calls is valid JSON for telemetry + ingestion ([KG-268](https://youtrack.jetbrains.com/issue/KG-268)). - Fix duplicated tool names resolution in `AIAgentSubgraphExt` to prevent conflicts (#493). - Fix Azure OpenAI client settings to generate valid endpoint URLs (#478). - Restore `llama3.2:latest` as the default for LLAMA_3_2 to match the provider expectations (#522). @@ -65,14 +200,17 @@ Fixed iOS target publication ## Removals / Breaking Changes -- Remove Google Gemini 1.5 Flash/Pro variants from the catalog ([KG-216](https://youtrack.jetbrains.com/issue/KG-216), #574). +- Remove Google Gemini 1.5 Flash/Pro variants from the catalog ([KG-216](https://youtrack.jetbrains.com/issue/KG-216), + #574). - Drop `execute` extensions for `PromptExecutor` in favor of the unified API (#591). -- File system API cleanup: removed deprecated FSProvider interfaces and methods; `PathFilter` renamed to `TraversalFilter` with suspendable operations; `fromAbsoluteString` renamed to `fromAbsolutePathString`. +- File system API cleanup: removed deprecated FSProvider interfaces and methods; `PathFilter` renamed to + `TraversalFilter` with suspendable operations; `fromAbsoluteString` renamed to `fromAbsolutePathString`. ## Examples - Add a web search agent (from Koog live stream 1) showcasing retrieval + summarization (#575). -- Add a trip planning agent example (from Koog live stream 2) demonstrating tools + planning + composite strategy (#595). +- Add a trip planning agent example (from Koog live stream 2) demonstrating tools + planning + composite strategy ( + #595). - Improve BestJokeAgent sample and fix NumberGuessingAgent example (#503, #445). # 0.3.0 @@ -81,12 +219,13 @@ Fixed iOS target publication ## Major Features -- **Agent Persistency and Checkpoints**: Save and restore agent state to local disk, memory, or easily integrate with +- **Agent Persistence and Checkpoints**: Save and restore agent state to local disk, memory, or easily integrate with any cloud storages or databases. Agents can now roll back to any prior state on demand or automatically restore from the latest checkpoint (#305) - **Vector Document Storage**: Store embeddings and documents in persistent storage for retrieval-augmented generation ( RAG), with in-memory and local file implementations (#272) -- **OpenTelemetry Support**: Native integration with OpenTelemetry for unified tracing logs across AI agents (#369, #401, +- **OpenTelemetry Support**: Native integration with OpenTelemetry for unified tracing logs across AI agents (#369, + #401, #423, #426) - **Content Moderation**: Built-in support for moderating models, enabling AI agents to automatically review and filter outputs for safety and compliance (#395) @@ -147,7 +286,7 @@ Fixed iOS target publication - Langfuse Tracing example - Moderation example: Moderating iterative joke-generation conversation - Parallel Nodes Execution example: Generating jokes using 3 different LLMs in parallel, and choosing the funniest one -- Snapshot and Persistency example: Taking agent snapshots and restoring its state example +- Snapshot and Persistence example: Taking agent snapshots and restoring its state example # 0.2.1 diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index 03c3901120..0000000000 --- a/CLAUDE.md +++ /dev/null @@ -1,93 +0,0 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Project Overview - -This repository contains the Koan Agents framework, a Kotlin multiplatform library for building AI agents. The framework enables creating intelligent agents that interact with tools, handle complex workflows, and maintain context across conversations. - -## Building and Testing - -### Basic Commands - -```bash -# Build the project -./gradlew assemble - -# Compile test classes -./gradlew jvmTestClasses jsTestClasses - -# Run all JVM tests -./gradlew jvmTest - -# Run all JS tests -./gradlew jsTest - -# Run a specific test class -./gradlew jvmTest --tests "fully.qualified.TestClassName" -# Example: -./gradlew jvmTest --tests "ai.koog.agents.test.SimpleAgentIntegrationTest" - -# Run a specific test method -./gradlew jvmTest --tests "fully.qualified.TestClassName.testMethodName" -# Example: -./gradlew jvmTest --tests "ai.koog.agents.test.SimpleAgentIntegrationTest.integration_simpleSingleRunAgentShouldNotCallToolsByDefault" -``` - -## Architecture - -### Key Modules - -1. **agents-core**: Core abstractions and interfaces - - AIAgent, AIAgentStrategy, event handling system, AIAgent, execution strategies, session management - -2. **agents-tools**: Tool infrastructure - - Tool, ToolRegistry, ToolDescriptor - -3. **agents-features**: Extensible agent capabilities - - Memory, tracing, and other features - -4. **prompt**: LLM interaction layer - - LLM executors, prompt construction, structured data processing - -### Core Concepts - -- **Agents**: State-machine graphs with nodes that process inputs and produce outputs -- **Tools**: Encapsulated actions with standardized interfaces -- **Strategies**: Define agent behavior and execution flow -- **Features**: Installable extensions that enhance agent capabilities -- **Event Handling**: System for intercepting and processing agent lifecycle events - -### Implementation Pattern - -1. Define tools that agents can use -2. Register tools in the ToolRegistry -3. Configure agent with strategy -4. Set up communication (if integrating with external systems) - -## Testing - -The project has extensive testing support: - -- **Mocking LLM responses**: - ```kotlin - val mockLLMApi = getMockExecutor(toolRegistry, eventHandler) { - mockLLMAnswer("Hello!") onRequestContains "Hello" - mockLLMToolCall(CreateTool, CreateTool.Args("solve")) onRequestEquals "Solve task" - } - ``` - -- **Mocking tool calls**: - ```kotlin - mockTool(PositiveToneTool) alwaysReturns "The text has a positive tone." - ``` - -- **Testing agent graph structure**: - ```kotlin - testGraph { - assertStagesOrder("first", "second") - // ... - } - ``` - -For detailed testing guidelines, refer to `agents/agents-test/TESTING.md`. \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7a11ff9d91..a97a6861aa 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -39,6 +39,28 @@ so do familiarize yourself with the following guidelines. name test functions as `testXxx`. Don't use backticks in test names. * Comment on the existing issue if you want to work on it. Ensure that the issue not only describes a problem but also describes a solution that has received positive feedback. Propose a solution if none has been suggested. +## Working with AI Code Agents + +This project includes some helpful guidelines to make AI coding assistants work better with codebase. + +### Agent Guidelines + +You'll find an [AGENT.md](AGENT.md) file in the repository root. +Think of it as a cheat sheet for AI assistants that explains: + +- **How the project works** — the overall architecture and main concepts +- **Development workflow** — which commands to run and how to build things +- **Testing patterns** — our approach to mocks and test structure +- **Code conventions** — the style we follow and why + +### How to use `AGENT.md` + +When you're pairing with an AI assistant on this project: + +1. Share the `AGENT.md` file with your code agent of choice (Junie, Claude Code, Cursor, Copilot, etc.) +2. The AI will understand our project structure and conventions better +3. You can even use it as a starting point to create custom configs for specific agents + ## Documentation The documentation is published on https://docs.koog.ai/. To propose changes or improvements to the documentation, go to the https://github.com/JetBrains/koog-docs repository. diff --git a/README.md b/README.md index 009c63129c..845efce6b1 100644 --- a/README.md +++ b/README.md @@ -87,7 +87,7 @@ Currently, the framework supports the JVM, JS, WasmJS and iOS targets. ``` dependencies { - implementation("ai.koog:koog-agents:0.4.2") + implementation("ai.koog:koog-agents:0.5.0") } ``` 2. Make sure that you have `mavenCentral()` in the list of repositories. @@ -97,7 +97,7 @@ Currently, the framework supports the JVM, JS, WasmJS and iOS targets. ``` dependencies { - implementation 'ai.koog:koog-agents:0.4.2' + implementation 'ai.koog:koog-agents:0.5.0' } ``` 2. Make sure that you have `mavenCentral()` in the list of repositories. @@ -109,7 +109,7 @@ Currently, the framework supports the JVM, JS, WasmJS and iOS targets. ai.koog koog-agents-jvm - 0.4.2 + 0.5.0 ``` 2. Make sure that you have `mavenCentral` in the list of repositories. diff --git a/a2a/CLAUDE.md b/a2a/CLAUDE.md new file mode 100644 index 0000000000..0f4d2ca415 --- /dev/null +++ b/a2a/CLAUDE.md @@ -0,0 +1,263 @@ +# A2A Module Development Guidelines + +## Module Overview + +The A2A (Agent-to-Agent) module is a **meta-module** within the larger Koog project that implements a comprehensive client and server library for the A2A protocol, a standardized communication protocol for AI agents based on the specification at https://a2a-protocol.org/latest/specification/. + +**Important**: Since this is a meta-module inside a bigger root project, the Gradle wrapper executable is located one directory up. All Gradle commands must use `../gradlew` when running from the a2a directory. + +### Purpose +- **Agent Discovery**: Through AgentCard manifests describing agent capabilities and interfaces +- **Message Exchange**: Standardized message format with support for different content types +- **Task Management**: Long-running operations with state tracking and lifecycle management +- **Push Notifications**: Asynchronous notifications for task updates +- **Multiple Transport Protocols**: JSON-RPC, HTTP+JSON/REST with extensibility for gRPC +- **Authentication & Security**: OpenAPI 3.0 compatible security schemes + +### Submodules Architecture + +1. **a2a-core**: Core abstractions, data models, and transport interfaces + - AgentCard, Message, Task, Event hierarchy + - Transport interfaces: ClientTransport, ServerTransport, RequestHandler + - No external dependencies beyond Kotlin stdlib and serialization + +2. **a2a-client**: High-level client library for communicating with A2A servers + - A2AClient wrapper, AgentCardResolver + - Capability validation, request/response handling + - Depends on: a2a-core, Ktor HTTP client + +3. **a2a-server**: Server-side implementation for hosting A2A agents + - A2AServer, AgentExecutor interface, session management + - Storage abstractions with in-memory implementations + - Depends on: a2a-core, coroutines, logging + +4. **a2a-transport**: Multiple transport protocol implementations + - a2a-transport-core-jsonrpc: JSON-RPC protocol base classes + - a2a-transport-client-jsonrpc-http: HTTP JSON-RPC client transport + - a2a-transport-server-jsonrpc-http: HTTP JSON-RPC server transport + - a2a-transport-*-rest: HTTP+JSON/REST protocol implementations + +## Technologies & Libraries + +### Core Dependencies (from gradle/libs.versions.toml) +- **Kotlin Multiplatform** +- **kotlinx-serialization**: JSON serialization for protocol messages +- **kotlinx-coroutines**: Async/concurrent programming, Flow APIs +- **kotlinx-datetime**: Timestamp handling in protocol messages +- **ktor3**: HTTP client/server for transport implementations +- **oshai-kotlin-logging**: Structured logging + +### Testing Libraries +- **kotlin-test**: Multiplatform test framework +- **kotest-assertions**: Rich assertion library for complex objects +- **kotlinx-coroutines-test**: Testing coroutines with runTest +- **testcontainers**: Docker-based integration testing +- **logback-classic**: Runtime logging for tests + +### Platform Support +- **JVM**: Full server and client support +- **JS (Browser)**: Client-only support +- **Future platforms**: Architecture ready for native support + +## Development Guidelines + +### Architecture Decisions +⚠️ **CRITICAL**: Never make design and architecture decisions independently. Always ask the user for confirmation before: +- Adding new transport protocols +- Changing storage interfaces +- Modifying core protocol message formats +- Adding new authentication schemes +- Changing session management behavior +- Doing other sorts of major architectural changes + +### Code Style Conventions + +#### API Visibility +- Use `explicitApi()` - all public APIs must have explicit visibility modifiers +- Prefer `public` declarations for APIs, `internal` for implementation details +- Use `@InternalA2AApi` annotation for APIs that are public but internal to the module + +#### Naming Conventions +- Classes: PascalCase (`A2AServer`, `SessionManager`, `AgentExecutor`) +- Interfaces: Same as classes, often ending in -or/-er for behaviors (`AgentExecutor`) +- Functions: camelCase with descriptive names (`onSendMessage`, `getByContext`) +- Properties: camelCase (`agentCard`, `sessionMutex`) + +#### Async Patterns +- Use `suspend` functions for all async operations +- Prefer `Flow` for streaming APIs (task events, message streams) +- Use `RWLock` pattern: `rwLock.withReadLock/withWriteLock` for concurrent access +- Use `Mutex` for single-threaded critical sections: `mutex.withLock` + +#### Error Handling +- Use domain-specific exceptions (`A2AInternalErrorException`, `TaskOperationException`) +- Propagate errors properly in async contexts +- Include contextual information in exception messages + +#### KDoc Documentation +- **Placement**: KDoc directly above declarations (classes, functions, properties) +- **Constructor properties**: Document using `@param` tags in class KDoc +- **Public class properties**: Document using `@property` tags in class KDoc +- **Cross-references**: Use `[ClassName]`, `[ClassName.propertyName]` syntax for linking components +- **Examples**: Include practical code examples for complex APIs +- **Validation rules**: Document constraints with bullet points and clear explanations +- **Exception documentation**: Use `@throws` with specific conditions +- **Required**: All public APIs must have KDoc (enforced by `explicitApi()`) + +## Testing Requirements + +### Mandatory Test Execution +⚠️ **ALWAYS** run the specific test suite using Gradle's jvmTest task to ensure changes work correctly: + +```bash +# Run all tests in a specific module +../gradlew :a2a:a2a-server:jvmTest + +# Run a specific test class +../gradlew :a2a:a2a-server:jvmTest --tests "ai.koog.a2a.server.tasks.InMemoryTaskStorageTest" + +# Run a specific test method +../gradlew :a2a:a2a-client:jvmTest --tests "ai.koog.a2a.client.A2AClientIntegrationTest.test get agent card" +``` + +### Testing Approach: Crucial Minimum +- **Focus on core functionality** - test essential behavior, not implementation details +- **Separate dedicated tests** - one test method per core scenario/edge case +- **Avoid verbose, unnecessary tests** - don't test trivial getters/setters +- **Test error conditions** - verify proper exception handling and error states + +#### Preferred Assertion Pattern +**✅ DO**: Assert on whole objects when possible (especially data classes): +```kotlin +assertEquals(expectedTask, actualTask) // Compares all properties +assertEquals(listOf(expectedArtifact), retrieved?.artifacts) +``` + +**❌ AVOID**: Bunch of individual field assertions unless necessary: +```kotlin +// Avoid this pattern: +assertEquals(expected.id, actual.id) +assertEquals(expected.contextId, actual.contextId) +assertEquals(expected.status, actual.status) +// ... many more individual assertions +``` + +### Integration Tests +- **Docker-based testing**: a2a-client uses testcontainers with Python A2A server +- **Docker build dependency**: Integration tests automatically build required containers +- **Test isolation**: Each test should be independent and clean up after itself + +### Test Structure Examples + +**Unit Tests** (InMemoryTaskStorageTest pattern): +```kotlin +class InMemoryTaskStorageTest { + private lateinit var storage: InMemoryTaskStorage + + @BeforeTest + fun setUp() { + storage = InMemoryTaskStorage() + } + + @Test + fun testStoreAndRetrieveTask() = runTest { + val task = createTask(id = "task-1", contextId = "context-1") + storage.update(task) + val retrieved = storage.get("task-1") + + assertNotNull(retrieved) + assertEquals(task, retrieved) // Whole object assertion + } +} +``` + +**Integration Tests** (A2AClientIntegrationTest pattern): +```kotlin +@Testcontainers +class A2AClientIntegrationTest { + @Container + val testA2AServer = GenericContainer("test-python-a2a-server") + .withExposedPorts(9999) + .waitingFor(Wait.forListeningPort()) + + @Test + fun `test get agent card`() = runTest { + val agentCard = client.getAgentCard() + val expectedAgentCard = AgentCard(...) + assertEquals(expectedAgentCard, agentCard) // Full object comparison + } +} +``` + +## Implementation Patterns + +### Transport Layer Abstraction +- Implement `ClientTransport` interface for new client protocols +- Implement `ServerTransport` and `RequestHandler` for new server protocols +- Keep protocol logic separate from transport mechanism +- Use suspend functions for all transport operations + +### Storage Interface Pattern +```kotlin +public interface SomeStorage { + public suspend fun save(item: Item) + public suspend fun get(id: String): Item? + public suspend fun delete(id: String) +} + +// In-memory implementation for testing/development +@OptIn(InternalA2AApi::class) +public class InMemorySomeStorage : SomeStorage { + private val rwLock = RWLock() + private val items = mutableMapOf() + + override suspend fun save(item: Item) = rwLock.withWriteLock { + items[item.id] = item + } +} +``` + +### Session Management Pattern +- Use `SessionEventProcessor` for event-driven session handling +- Implement proper cleanup with `Closeable` interface +- Use structured concurrency with `CoroutineScope` +- Handle session lifecycle properly (start → active → closed) + +### Agent Executor Implementation +```kotlin +public class MyAgentExecutor : AgentExecutor { + override suspend fun execute( + context: RequestContext, + eventProcessor: SessionEventProcessor + ) { + // Process request and send events through eventProcessor + eventProcessor.sendTaskEvent(/* task event */) + } +} +``` + +## Some Common Commands + +```bash +# Build specific modules (a2a is a meta-module, build individual modules) +../gradlew :a2a:a2a-core:assemble +../gradlew :a2a:a2a-client:assemble +../gradlew :a2a:a2a-server:assemble + +# Run JVM tests for specific modules +../gradlew :a2a:a2a-core:jvmTest +../gradlew :a2a:a2a-client:jvmTest +../gradlew :a2a:a2a-server:jvmTest +../gradlew :a2a:a2a-transport:a2a-transport-core-jsonrpc:jvmTest + +# Run JS tests for specific modules +../gradlew :a2a:a2a-core:jsTest +../gradlew :a2a:a2a-client:jsTest +../gradlew :a2a:a2a-server:jsTest + +# Build all non-transport a2a modules at once +../gradlew :a2a:a2a-core:assemble :a2a:a2a-client:assemble :a2a:a2a-server:assemble + +# Run all JVM tests across a2a modules +../gradlew :a2a:a2a-core:jvmTest :a2a:a2a-client:jvmTest :a2a:a2a-server:jvmTest +``` diff --git a/a2a/a2a-client/Module.md b/a2a/a2a-client/Module.md new file mode 100644 index 0000000000..f123824d72 --- /dev/null +++ b/a2a/a2a-client/Module.md @@ -0,0 +1,4 @@ +# Module a2a-client + +High-level client library for communicating with A2A protocol servers. Provides the A2AClient wrapper with capability validation, +request/response handling, and AgentCard resolution. Built on top of a2a-core abstractions with Ktor HTTP client support. diff --git a/a2a/a2a-client/build.gradle.kts b/a2a/a2a-client/build.gradle.kts new file mode 100644 index 0000000000..a0696d6b25 --- /dev/null +++ b/a2a/a2a-client/build.gradle.kts @@ -0,0 +1,92 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven +import org.gradle.internal.os.OperatingSystem +import org.gradle.kotlin.dsl.support.serviceOf +import java.io.ByteArrayOutputStream + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + sourceSets { + commonMain { + dependencies { + api(project(":a2a:a2a-core")) + api(libs.kotlinx.serialization.json) + api(libs.kotlinx.coroutines.core) + api(libs.ktor.client.core) + api(libs.ktor.client.content.negotiation) + api(libs.ktor.serialization.kotlinx.json) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + implementation(libs.kotest.assertions) + implementation(libs.kotlinx.coroutines.test) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + implementation(project(":a2a:a2a-test")) + implementation(project(":a2a:a2a-transport:a2a-transport-client-jsonrpc-http")) + implementation(project(":test-utils")) + + implementation(libs.ktor.client.cio) + implementation(libs.ktor.client.logging) + implementation(libs.testcontainers.junit) + runtimeOnly(libs.logback.classic) + } + } + + jsTest { + dependencies { + implementation(kotlin("test-js")) + } + } + } + + explicitApi() +} + +publishToMaven() + +tasks.register("dockerBuildTestPythonA2AServer") { + group = "docker" + description = "Build Python A2A test server image" + workingDir = file("../test-python-a2a-server") + commandLine = listOf("docker", "build", "-t", "test-python-a2a-server", ".") + + onlyIf { + // do not attempt to check for docker on windows + if (OperatingSystem.current().isWindows) { + return@onlyIf false + } + + try { + val buffer = ByteArrayOutputStream() + + serviceOf().exec { + commandLine = listOf("docker", "--version") + standardOutput = buffer + errorOutput = buffer + } + + true + } catch (_: Exception) { + logger.warn("Docker not available. Skipping task 'dockerBuildTestPythonA2AServer'") + + false + } + } +} +tasks.named("jvmTest") { + dependsOn("dockerBuildTestPythonA2AServer") +} diff --git a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt new file mode 100644 index 0000000000..7e4da3a6eb --- /dev/null +++ b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt @@ -0,0 +1,202 @@ +package ai.koog.a2a.client + +import ai.koog.a2a.exceptions.A2AException +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.Event +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskPushNotificationConfig +import ai.koog.a2a.model.TaskPushNotificationConfigParams +import ai.koog.a2a.model.TaskQueryParams +import ai.koog.a2a.transport.ClientCallContext +import ai.koog.a2a.transport.ClientTransport +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.Response +import kotlinx.coroutines.flow.Flow +import kotlin.concurrent.atomics.AtomicReference +import kotlin.concurrent.atomics.ExperimentalAtomicApi + +/** + * A2A client responsible for sending requests to A2A server. + */ +@OptIn(ExperimentalAtomicApi::class) +public open class A2AClient( + private val transport: ClientTransport, + private val agentCardResolver: AgentCardResolver, +) { + protected var agentCard: AtomicReference = AtomicReference(null) + + /** + * Performs initialization logic. + * Currently only retrieves the [AgentCard]. + */ + public open suspend fun connect() { + getAgentCard() + } + + /** + * Retrieves [AgentCard] by calling [AgentCardResolver.resolve]. + * Saves it to the cache. + */ + public open suspend fun getAgentCard(): AgentCard { + return agentCardResolver.resolve().also { + agentCard.exchange(it) + } + } + + /** + * Retrieves currently cached [AgentCard] + * + * @throws [IllegalStateException] if it's not initialized + */ + public open fun cachedAgentCard(): AgentCard { + return checkNotNull(agentCard.load()) { "Agent card is not initialized." } + } + + /** + * Calls [agent/getAuthenticatedExtendedCard](https://a2a-protocol.org/latest/specification/#710-agentgetauthenticatedextendedcard). + * Updates cached [AgentCard]. + * + * @throws A2AException if server returned an error. + */ + public suspend fun getAuthenticatedExtendedAgentCard( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response { + check(cachedAgentCard().supportsAuthenticatedExtendedCard == true) { + "Agent card reports that authenticated extended agent card is not supported." + } + + return transport.getAuthenticatedExtendedAgentCard(request, ctx).also { + agentCard.exchange(it.data) + } + } + + /** + * Calls [message/send](https://a2a-protocol.org/latest/specification/#71-messagesend). + * + * @throws A2AException if server returned an error. + */ + public suspend fun sendMessage( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response { + return transport.sendMessage(request, ctx) + } + + /** + * Calls [message/stream](https://a2a-protocol.org/latest/specification/#72-messagestream) + * + * @throws A2AException if server returned an error. + */ + public fun sendMessageStreaming( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Flow> { + check(cachedAgentCard().capabilities.streaming == true) { + "Agent card reports that streaming is not supported." + } + + return transport.sendMessageStreaming(request, ctx) + } + + /** + * Calls [tasks/get](https://a2a-protocol.org/latest/specification/#73-tasksget) + * + * @throws A2AException if server returned an error. + */ + public suspend fun getTask( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response { + return transport.getTask(request, ctx) + } + + /** + * Calls [tasks/cancel](https://a2a-protocol.org/latest/specification/#74-taskscancel) + * + * @throws A2AException if server returned an error. + */ + public suspend fun cancelTask( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response { + return transport.cancelTask(request, ctx) + } + + /** + * Calls [tasks/resubscribe](https://a2a-protocol.org/latest/specification/#79-tasksresubscribe) + * + * @throws A2AException if server returned an error. + */ + public fun resubscribeTask( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Flow> { + return transport.resubscribeTask(request, ctx) + } + + /** + * Calls [tasks/pushNotificationConfig/set](https://a2a-protocol.org/latest/specification/#75-taskspushnotificationconfigset) + * + * @throws A2AException if server returned an error. + */ + public suspend fun setTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response { + checkPushNotificationsSupported() + + return transport.setTaskPushNotificationConfig(request, ctx) + } + + /** + * Calls [tasks/pushNotificationConfig/get](https://a2a-protocol.org/latest/specification/#76-taskspushnotificationconfigget) + * + * @throws A2AException if server returned an error. + */ + public suspend fun getTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response { + checkPushNotificationsSupported() + + return transport.getTaskPushNotificationConfig(request, ctx) + } + + /** + * Calls [tasks/pushNotificationConfig/list](https://a2a-protocol.org/latest/specification/#77-taskspushnotificationconfiglist) + * + * @throws A2AException if server returned an error. + */ + public suspend fun listTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response> { + checkPushNotificationsSupported() + + return transport.listTaskPushNotificationConfig(request, ctx) + } + + /** + * Calls [tasks/pushNotificationConfig/delete](https://a2a-protocol.org/latest/specification/#78-taskspushnotificationconfigdelete) + * + * @throws A2AException if server returned an error. + */ + public suspend fun deleteTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response { + checkPushNotificationsSupported() + + return transport.deleteTaskPushNotificationConfig(request, ctx) + } + + protected fun checkPushNotificationsSupported() { + check(cachedAgentCard().capabilities.pushNotifications == true) { + "Agent card reports that push notifications are not supported." + } + } +} diff --git a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/AgentCardResolver.kt b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/AgentCardResolver.kt new file mode 100644 index 0000000000..7a90b89a5a --- /dev/null +++ b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/AgentCardResolver.kt @@ -0,0 +1,63 @@ +package ai.koog.a2a.client + +import ai.koog.a2a.consts.A2AConsts +import ai.koog.a2a.model.AgentCard +import io.ktor.client.HttpClient +import io.ktor.client.call.body +import io.ktor.client.plugins.contentnegotiation.ContentNegotiation +import io.ktor.client.plugins.defaultRequest +import io.ktor.client.request.get +import io.ktor.http.ContentType +import io.ktor.http.contentType +import io.ktor.serialization.kotlinx.json.json +import kotlinx.serialization.json.Json + +/** + * Represents a resolver capable of fetching an [AgentCard]. + * + * Implementations of this interface are responsible for providing the mechanism to retrieve + * the [AgentCard], which may include network requests, local lookups, or other means of resolution. + */ +public interface AgentCardResolver { + /** + * Resolves and retrieves an [AgentCard]. + */ + public suspend fun resolve(): AgentCard +} + +/** + * An [AgentCardResolver] that always returns the provided [agentCard]. + */ +public class ExplicitAgentCardResolver(public val agentCard: AgentCard) : AgentCardResolver { + override suspend fun resolve(): AgentCard = agentCard +} + +/** + * An [AgentCardResolver] that fetches the [AgentCard] from the provided [baseUrl] at [path]. + */ +public class UrlAgentCardResolver( + public val baseUrl: String, + public val path: String = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH, + baseHttpClient: HttpClient = HttpClient(), +) : AgentCardResolver { + private val httpClient: HttpClient = baseHttpClient.config { + defaultRequest { + url(baseUrl) + contentType(ContentType.Application.Json) + } + + install(ContentNegotiation) { + json( + Json { + ignoreUnknownKeys = true + } + ) + } + + expectSuccess = true + } + + override suspend fun resolve(): AgentCard { + return httpClient.get(path).body() + } +} diff --git a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt new file mode 100644 index 0000000000..2764d58f4a --- /dev/null +++ b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt @@ -0,0 +1,110 @@ +package ai.koog.a2a.client + +import ai.koog.a2a.test.BaseA2AProtocolTest +import ai.koog.a2a.transport.client.jsonrpc.http.HttpJSONRPCClientTransport +import ai.koog.test.utils.DockerAvailableCondition +import io.ktor.client.HttpClient +import io.ktor.client.plugins.logging.LogLevel +import io.ktor.client.plugins.logging.Logging +import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.TestInstance +import org.junit.jupiter.api.extension.ExtendWith +import org.junit.jupiter.api.parallel.Execution +import org.junit.jupiter.api.parallel.ExecutionMode +import org.testcontainers.containers.GenericContainer +import org.testcontainers.containers.wait.strategy.Wait +import org.testcontainers.junit.jupiter.Container +import org.testcontainers.junit.jupiter.Testcontainers +import kotlin.test.Test +import kotlin.time.Duration.Companion.seconds + +/** + * Integration test class for testing the JSON-RPC HTTP communication in the A2A client context. + * This class ensures the proper functioning and correctness of the A2A protocol over HTTP + * using the JSON-RPC standard. + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +@Testcontainers +@ExtendWith(DockerAvailableCondition::class) +@Execution(ExecutionMode.SAME_THREAD, reason = "Working with the same instance of test server.") +class A2AClientJsonRpcIntegrationTest : BaseA2AProtocolTest() { + companion object { + @Container + val testA2AServer: GenericContainer<*> = + GenericContainer("test-python-a2a-server") + .withExposedPorts(9999) + .waitingFor(Wait.forListeningPort()) + } + + override val testTimeout = 10.seconds + + private val httpClient = HttpClient { + install(Logging) { + level = LogLevel.BODY + } + } + + @Suppress("HttpUrlsUsage") + private val agentUrl by lazy { "http://${testA2AServer.host}:${testA2AServer.getMappedPort(9999)}" } + + private lateinit var transport: HttpJSONRPCClientTransport + + override lateinit var client: A2AClient + + @BeforeAll + fun setUp() = runTest { + transport = HttpJSONRPCClientTransport( + url = agentUrl, + baseHttpClient = httpClient + ) + + client = A2AClient( + transport = transport, + agentCardResolver = UrlAgentCardResolver( + baseUrl = agentUrl, + baseHttpClient = httpClient, + ), + ) + + client.connect() + } + + @AfterAll + fun tearDown() = runTest { + transport.close() + } + + @Test + override fun `test get agent card`() = + super.`test get agent card`() + + @Test + override fun `test get authenticated extended agent card`() = + super.`test get authenticated extended agent card`() + + @Test + override fun `test send message`() = + super.`test send message`() + + @Test + override fun `test send message streaming`() = + super.`test send message streaming`() + + @Test + override fun `test get task`() = + super.`test get task`() + + @Test + override fun `test cancel task`() = + super.`test cancel task`() + + @Test + override fun `test resubscribe task`() = + super.`test resubscribe task`() + + @Test + override fun `test push notification configs`() = + super.`test push notification configs`() +} diff --git a/a2a/a2a-core/Module.md b/a2a/a2a-core/Module.md new file mode 100644 index 0000000000..f574bc331b --- /dev/null +++ b/a2a/a2a-core/Module.md @@ -0,0 +1,4 @@ +# Module a2a-core + +Core abstractions and data models for the A2A (Agent-to-Agent) protocol implementation. This module provides the foundational +types and interfaces for agent communication, including protocol message structures, transport layer abstractions, and domain-specific exceptions. diff --git a/a2a/a2a-core/build.gradle.kts b/a2a/a2a-core/build.gradle.kts new file mode 100644 index 0000000000..7e99e5353e --- /dev/null +++ b/a2a/a2a-core/build.gradle.kts @@ -0,0 +1,50 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + jvm() + + js(IR) { + browser() + } + + sourceSets { + commonMain { + dependencies { + api(libs.kotlinx.serialization.json) + api(libs.kotlinx.coroutines.core) + api(libs.kotlinx.datetime) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + implementation(libs.kotlinx.coroutines.test) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + } + } + + jsTest { + dependencies { + implementation(kotlin("test-js")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/annotations/InternalA2AApi.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/annotations/InternalA2AApi.kt new file mode 100644 index 0000000000..d8aee42afd --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/annotations/InternalA2AApi.kt @@ -0,0 +1,10 @@ +package ai.koog.a2a.annotations + +/** + * Marks an API as internal to the a2a module. This annotation indicates that the + * marked API is not intended for public use and is subject to change or removal + * without prior notice. It should be used with caution and only within the intended + * internal scope. + */ +@RequiresOptIn("This API is internal in a2a and should not be used. It could be removed or changed without notice.") +public annotation class InternalA2AApi diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/consts/A2AConsts.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/consts/A2AConsts.kt new file mode 100644 index 0000000000..4863d3b801 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/consts/A2AConsts.kt @@ -0,0 +1,13 @@ +package ai.koog.a2a.consts + +/** + * Some common global A2A constants. + */ +public object A2AConsts { + /** + * The well-known agent card URL following RFC 8615. + * + * More info: https://a2a-protocol.org/latest/specification/#53-recommended-location + */ + public const val AGENT_CARD_WELL_KNOWN_PATH: String = "/.well-known/agent-card.json" +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt new file mode 100644 index 0000000000..7620ba56cf --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt @@ -0,0 +1,179 @@ +package ai.koog.a2a.exceptions + +import ai.koog.a2a.transport.RequestId + +/** + * Object containing all A2A error codes. + */ +@Suppress("MissingKDocForPublicAPI") +public object A2AErrorCodes { + public const val PARSE_ERROR: Int = -32700 + public const val INVALID_REQUEST: Int = -32600 + public const val METHOD_NOT_FOUND: Int = -32601 + public const val INVALID_PARAMS: Int = -32602 + public const val INTERNAL_ERROR: Int = -32603 + public const val TASK_NOT_FOUND: Int = -32001 + public const val TASK_NOT_CANCELABLE: Int = -32002 + public const val PUSH_NOTIFICATION_NOT_SUPPORTED: Int = -32003 + public const val UNSUPPORTED_OPERATION: Int = -32004 + public const val CONTENT_TYPE_NOT_SUPPORTED: Int = -32005 + public const val INVALID_AGENT_RESPONSE: Int = -32006 + public const val AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED: Int = -32007 +} + +/** + * Base class for all A2A exceptions. + */ +public sealed class A2AException( + public override val message: String, + public val errorCode: Int, + public val requestId: RequestId? = null, +) : Exception(message) + +/** + * Server received JSON that was not well-formed. + */ +public class A2AParseException( + message: String = "Invalid JSON payload", + requestId: RequestId? = null, +) : A2AException(message, A2AErrorCodes.PARSE_ERROR, requestId) + +/** + * The JSON payload was valid JSON, but not a valid JSON-RPC Request object. + */ +public class A2AInvalidRequestException( + message: String = "Invalid JSON-RPC Request", + requestId: RequestId? = null, +) : A2AException(message, A2AErrorCodes.INVALID_REQUEST, requestId) + +/** + * The requested A2A RPC method does not exist or is not supported. + */ +public class A2AMethodNotFoundException( + message: String = "Method not found", + requestId: RequestId? = null, +) : A2AException(message, A2AErrorCodes.METHOD_NOT_FOUND, requestId) + +/** + * The params provided for the method are invalid. + */ +public class A2AInvalidParamsException( + message: String = "Invalid method parameters", + requestId: RequestId? = null, +) : A2AException(message, A2AErrorCodes.INVALID_PARAMS, requestId) + +/** + * An unexpected error occurred on the server during processing. + */ +public class A2AInternalErrorException( + message: String = "Internal server error", + requestId: RequestId? = null, +) : A2AException(message, A2AErrorCodes.INTERNAL_ERROR, requestId) + +/** + * Reserved for implementation-defined server exceptions. A2A-specific exceptions use this range. + */ +public sealed class A2AServerException( + message: String, + errorCode: Int, + requestId: RequestId? = null, +) : A2AException(message, errorCode, requestId) { + init { + require(errorCode in -32099..-32000) { "Server error code must be in -32099..-32000" } + } +} + +/** + * The specified task id does not correspond to an existing or active task. + * It might be invalid, expired, or already completed and purged. + */ +public class A2ATaskNotFoundException( + message: String = "Task not found", + requestId: RequestId? = null, +) : A2AServerException(message, A2AErrorCodes.TASK_NOT_FOUND, requestId) + +/** + * An attempt was made to cancel a task that is not in a cancelable state. + * The task has already reached a terminal state like completed, failed, or canceled. + */ +public class A2ATaskNotCancelableException( + message: String = "Task cannot be canceled", + requestId: RequestId? = null, +) : A2AServerException(message, A2AErrorCodes.TASK_NOT_CANCELABLE, requestId) + +/** + * Client attempted to use push notification features but the server agent does not support them. + * The server's AgentCard.capabilities.pushNotifications is false. + */ +public class A2APushNotificationNotSupportedException( + message: String = "Push Notification is not supported", + requestId: RequestId? = null, +) : A2AServerException(message, A2AErrorCodes.PUSH_NOTIFICATION_NOT_SUPPORTED, requestId) + +/** + * The requested operation or a specific aspect of it is not supported by this server agent implementation. + * This is broader than just method not found. + */ +public class A2AUnsupportedOperationException( + message: String = "This operation is not supported", + requestId: RequestId? = null, +) : A2AServerException(message, A2AErrorCodes.UNSUPPORTED_OPERATION, requestId) + +/** + * A Media Type provided in the request's message.parts or implied for an artifact is not supported + * by the agent or the specific skill being invoked. + */ +public class A2AContentTypeNotSupportedException( + message: String = "Incompatible content types", + requestId: RequestId? = null, +) : A2AServerException(message, A2AErrorCodes.CONTENT_TYPE_NOT_SUPPORTED, requestId) + +/** + * Agent generated an invalid response for the requested method. + */ +public class A2AInvalidAgentResponseException( + message: String = "Invalid agent response type", + requestId: RequestId? = null, +) : A2AServerException(message, A2AErrorCodes.INVALID_AGENT_RESPONSE, requestId) + +/** + * The agent does not have an Authenticated Extended Card configured. + */ +public class A2AAuthenticatedExtendedCardNotConfiguredException( + message: String = "Authenticated Extended Card not configured", + requestId: RequestId? = null, +) : A2AServerException(message, A2AErrorCodes.AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED, requestId) + +/** + * Server returned some unknown error code. + */ +public class A2AUnknownException( + message: String, + errorCode: Int, + requestId: RequestId? = null, +) : A2AException(message, errorCode, requestId) + +/** + * Create appropriate [A2AException] based on the provided errorCode. + */ +public fun createA2AException( + message: String, + errorCode: Int, + requestId: RequestId?, +): A2AException { + return when (errorCode) { + A2AErrorCodes.PARSE_ERROR -> A2AParseException(message, requestId) + A2AErrorCodes.INVALID_REQUEST -> A2AInvalidRequestException(message, requestId) + A2AErrorCodes.METHOD_NOT_FOUND -> A2AMethodNotFoundException(message, requestId) + A2AErrorCodes.INVALID_PARAMS -> A2AInvalidParamsException(message, requestId) + A2AErrorCodes.INTERNAL_ERROR -> A2AInternalErrorException(message, requestId) + A2AErrorCodes.TASK_NOT_FOUND -> A2ATaskNotFoundException(message, requestId) + A2AErrorCodes.TASK_NOT_CANCELABLE -> A2ATaskNotCancelableException(message, requestId) + A2AErrorCodes.PUSH_NOTIFICATION_NOT_SUPPORTED -> A2APushNotificationNotSupportedException(message, requestId) + A2AErrorCodes.UNSUPPORTED_OPERATION -> A2AUnsupportedOperationException(message, requestId) + A2AErrorCodes.CONTENT_TYPE_NOT_SUPPORTED -> A2AContentTypeNotSupportedException(message, requestId) + A2AErrorCodes.INVALID_AGENT_RESPONSE -> A2AInvalidAgentResponseException(message, requestId) + A2AErrorCodes.AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED -> A2AAuthenticatedExtendedCardNotConfiguredException(message, requestId) + else -> A2AUnknownException(message, errorCode, requestId) + } +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt new file mode 100644 index 0000000000..18972d7a90 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt @@ -0,0 +1,478 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.EncodeDefault +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonElement +import kotlin.jvm.JvmInline + +/** + * The [AgentCard] is a self-describing manifest for an agent. It provides essential metadata including the agent's + * identity, capabilities, skills, supported communication methods, and security requirements. + * + * @property protocolVersion The version of the A2A protocol this agent supports. Default: "0.3.0". + * + * @property name A human-readable name for the agent. Examples: ["Recipe Agent"]. + * + * @property description A human-readable description of the agent, assisting users and other agents in understanding its purpose. + * + * Examples: ["Agent that helps users with recipes and cooking."]. + * + * @property url The preferred endpoint URL for interacting with the agent. This URL MUST support the transport specified by 'preferredTransport'. + * + * Examples: ["https://api.example.com/a2a/v1"]. + * + * @property preferredTransport The transport protocol for the preferred endpoint (the main 'url' field). If not specified, defaults to 'JSONRPC'. + * + * IMPORTANT: The transport specified here MUST be available at the main 'url'. This creates a binding between the main URL and its supported + * transport protocol. Clients should prefer this transport and URL combination when both are supported. + * + * Examples: ["JSONRPC", "GRPC", "HTTP+JSON"]. + * + * @property additionalInterfaces A list of additional supported interfaces (transport and URL combinations). This allows agents to expose multiple + * transports, potentially at different URLs. + * + * Best practices: + * - SHOULD include all supported transports + * - SHOULD include an entry matching the main 'url' and 'preferredTransport' + * - MAY reuse URLs if multiple transports are available at the same endpoint + * - MUST accurately declare the transport available at each URL. + * + * Clients can select any interface from this list based on their transport capabilities and preferences, enabling transport + * negotiation and fallback scenarios. + * + * @property iconUrl An optional URL to an icon for the agent. + * + * @property provider Information about the agent's service provider. + * + * @property version The agent's own version number. The format is defined by the provider. + * + * Examples: ["1.0.0"]. + * + * @property documentationUrl An optional URL to the agent's documentation. + * + * @property capabilities A declaration of optional capabilities supported by the agent. + * + * @property securitySchemes A declaration of the security schemes available to authorize requests. The key is the scheme name. + * Follows the OpenAPI 3.0 Security Scheme Object. + * + * @property security A list of security requirement objects that apply to all agent interactions. Each object lists security schemes that can be used. + * Follows the OpenAPI 3.0 Security Requirement Object. This list can be seen as an OR of ANDs. Each object in the list describes one possible set + * of security requirements that must be present on a request. This allows specifying, for example, "callers must either use OAuth OR an API Key AND mTLS." + * + * Examples: [{"oauth": ["read"]}, {"api-key": [], "mtls": []}]. + * + * @property defaultInputModes Default set of supported input MIME types for all skills, which can be overridden on a per-skill basis. + * + * @property defaultOutputModes Default set of supported output MIME types for all skills, which can be overridden on a per-skill basis. + * + * @property skills The set of skills, or distinct capabilities, that the agent can perform. + * + * @property supportsAuthenticatedExtendedCard If true, the agent can provide an extended agent card with additional details to authenticated users. Defaults to false. + * + * @property signatures JSON Web Signatures computed for this [AgentCard]. + */ +@Serializable +public data class AgentCard( + @EncodeDefault + public val protocolVersion: String = "0.3.0", + public val name: String, + public val description: String, + public val url: String, + @EncodeDefault + public val preferredTransport: TransportProtocol = TransportProtocol.JSONRPC, + public val additionalInterfaces: List? = null, + public val iconUrl: String? = null, + public val provider: AgentProvider? = null, + public val version: String, + public val documentationUrl: String? = null, + public val capabilities: AgentCapabilities, + public val securitySchemes: SecuritySchemes? = null, + public val security: Security? = null, + public val defaultInputModes: List, + public val defaultOutputModes: List, + public val skills: List, + public val supportsAuthenticatedExtendedCard: Boolean? = false, + public val signatures: List? = null +) { + init { + additionalInterfaces?.let { interfaces -> + requireNotNull(interfaces.find { it.url == url && it.transport == preferredTransport }) { + "If additionalInterfaces are specified, they must include an entry matching the main 'url' and 'preferredTransport'." + } + } + } +} + +/** + * The transport protocol for an agent. + */ +@JvmInline +@Serializable +public value class TransportProtocol(public val value: String) { + /** + * List of known transport protocols. + */ + public companion object { + /** + * JSON-RPC protocol. + */ + public val JSONRPC: TransportProtocol = TransportProtocol("JSONRPC") + + /** + * HTTP+JSON/REST protocol. + */ + public val HTTP_JSON_REST: TransportProtocol = TransportProtocol("HTTP+JSON/REST") + + /** + * GRPC protocol. + */ + public val GRPC: TransportProtocol = TransportProtocol("GRPC") + } +} + +/** + * Declares a combination of a target URL and a transport protocol for interacting with the agent. + * This allows agents to expose the same functionality over multiple transport mechanisms. + * + * @property url The URL where this interface is available. Must be a valid absolute HTTPS URL in production. + * + * Examples: ["https://api.example.com/a2a/v1", "https://grpc.example.com/a2a", "https://rest.example.com/v1"]. + * + * @property transport The transport protocol supported at this URL. + * + * Examples: ["JSONRPC", "GRPC", "HTTP+JSON"]. + */ +@Serializable +public data class AgentInterface( + public val url: String, + public val transport: TransportProtocol +) + +/** + * Represents the service provider of an agent. + * + * @property organization The name of the agent provider's organization. + * @property url A URL for the agent provider's website or relevant documentation. + */ +@Serializable +public data class AgentProvider( + public val organization: String, + public val url: String +) + +/** + * Defines optional capabilities supported by an agent. + * + * @property streaming Indicates if the agent supports Server-Sent Events (SSE) for streaming responses. + * + * @property pushNotifications Indicates if the agent supports sending push notifications for asynchronous task updates. + * + * @property stateTransitionHistory Indicates if the agent provides a history of state transitions for a task. + * + * TODO: it's not clear from the specification and official Python SDK, what does this field control. + * It's not [Task.history], since it always should be present. + * There are no further mentions or usages of this field in the official sources. + * So currently in our implementation it does not control anything. + * + * @property extensions A list of protocol extensions supported by the agent. + */ +@Serializable +public data class AgentCapabilities( + public val streaming: Boolean? = null, + public val pushNotifications: Boolean? = null, + public val stateTransitionHistory: Boolean? = null, + public val extensions: List? = null +) + +/** + * A declaration of a protocol extension supported by an Agent. + * + * @property uri The unique URI identifying the extension. + * @property description A human-readable description of how this agent uses the extension. + * @property required If true, the client must understand and comply with the extension's requirements to interact with the agent. + * @property params Optional, extension-specific configuration parameters. + */ +@Serializable +public data class AgentExtension( + public val uri: String, + public val description: String? = null, + public val required: Boolean? = null, + public val params: Map? = null +) + +/** + * A declaration of the security schemes available to authorize requests. The key is the scheme name. The value is the + * declaration of the security scheme object, which follows the OpenAPI 3.0 Security Scheme Object. + */ +public typealias SecuritySchemes = Map + +/** + * A list of alternative security requirements (a logical OR). To authorize a request, a client must satisfy one of the + * [SecurityRequirement]s in this list. + * + * For example, `[{"oauth": ["read"]}, {"apiKey": [], "mtls": []}]` means a client can use either OAuth with the "read" scope + * or both an API key and mTLS. + * + * @see [https://swagger.io/specification/#security-requirement-object] + */ +public typealias Security = List + +/** + * A set of security schemes that must be satisfied together (a logical AND). The key is a security scheme name, and the + * value is a list of required scopes. + * + * For example, `{"apiKey": [], "mtls": []}` requires both an API key and mTLS. + * + * @see [https://swagger.io/specification/#security-requirement-object] + */ +public typealias SecurityRequirement = Map> + +/** + * Defines a security scheme that can be used to secure an agent's endpoints. + * This is a discriminated union type based on the OpenAPI 3.0 Security Scheme Object. + * + * @see [https://swagger.io/specification/#security-scheme-object] + */ +@Serializable(with = SecuritySchemeSerializer::class) +public sealed interface SecurityScheme { + /** + * The type of the security scheme, used as discriminator. + */ + public val type: String +} + +/** + * Defines a security scheme using an API key. + * + * @property description An optional description for the security scheme. + * @property in The location of the API key. + * @property name The name of the header, query, or cookie parameter to be used. + */ +@Serializable +public data class APIKeySecurityScheme( + @SerialName("in") + public val `in`: In, + public val name: String, + public val description: String? = null, +) : SecurityScheme { + @EncodeDefault + override val type: String = "apiKey" +} + +/** + * The location of the API key. + */ +@Serializable +public enum class In { + @SerialName("cookie") + Cookie, + + @SerialName("header") + Header, + + @SerialName("query") + Query +} + +/** + * Defines a security scheme using HTTP authentication. + * + * @property scheme The name of the HTTP Authentication scheme to be used in the Authorization header, + * as defined in RFC7235 (e.g., "Bearer"). This value should be registered in the IANA Authentication Scheme registry. + * @property bearerFormat A hint to the client to identify how the bearer token is formatted (e.g., "JWT"). + * This is primarily for documentation purposes. + * @property description An optional description for the security scheme. + */ +@Serializable +public data class HTTPAuthSecurityScheme( + public val scheme: String, + public val bearerFormat: String? = null, + public val description: String? = null, +) : SecurityScheme { + @EncodeDefault + override val type: String = "http" +} + +/** + * Defines a security scheme using OAuth 2.0. + * + * @property flows An object containing configuration information for the supported OAuth 2.0 flows. + * @property oauth2MetadataUrl URL to the oauth2 authorization server metadata + * [RFC8414](https://datatracker.ietf.org/doc/html/rfc8414). TLS is required. + * @property description An optional description for the security scheme. + */ +@Serializable +public data class OAuth2SecurityScheme( + public val flows: OAuthFlows, + public val oauth2MetadataUrl: String? = null, + public val description: String? = null, +) : SecurityScheme { + @EncodeDefault + override val type: String = "oauth2" +} + +/** + * Defines the configuration for the supported OAuth 2.0 flows. + * + * @property authorizationCode Configuration for the OAuth Authorization Code flow. Previously called accessCode in OpenAPI 2.0. + * @property clientCredentials Configuration for the OAuth Client Credentials flow. Previously called application in OpenAPI 2.0. + * @property implicit Configuration for the OAuth Implicit flow. + * @property password Configuration for the OAuth Resource Owner Password flow. + */ +@Serializable +public data class OAuthFlows( + public val authorizationCode: AuthorizationCodeOAuthFlow? = null, + public val clientCredentials: ClientCredentialsOAuthFlow? = null, + public val implicit: ImplicitOAuthFlow? = null, + public val password: PasswordOAuthFlow? = null +) + +/** + * Common interface for OAuth 2.0 flows. + */ +@Serializable +public sealed interface OAuthFlow { + /** + * The available scopes for the OAuth2 security scheme. A map between the scope name and a short description for it. + */ + public val scopes: Map + + /** + * The URL to be used for obtaining refresh tokens. This MUST be a URL and use TLS. + */ + public val refreshUrl: String? +} + +/** + * Defines configuration details for the OAuth 2.0 Authorization Code flow. + * + * @property authorizationUrl The authorization URL to be used for this flow. This MUST be a URL and use TLS. + * @property tokenUrl The token URL to be used for this flow. This MUST be a URL and use TLS. + */ +@Serializable +public data class AuthorizationCodeOAuthFlow( + public val authorizationUrl: String, + public val tokenUrl: String, + override val scopes: Map, + override val refreshUrl: String? = null +) : OAuthFlow + +/** + * Defines configuration details for the OAuth 2.0 Client Credentials flow. + * + * @property tokenUrl The token URL to be used for this flow. This MUST be a URL. + */ +@Serializable +public data class ClientCredentialsOAuthFlow( + public val tokenUrl: String, + override val scopes: Map, + override val refreshUrl: String? = null +) : OAuthFlow + +/** + * Defines configuration details for the OAuth 2.0 Implicit flow. + * + * @property authorizationUrl The authorization URL to be used for this flow. This MUST be a URL. + */ +@Serializable +public data class ImplicitOAuthFlow( + public val authorizationUrl: String, + override val scopes: Map, + override val refreshUrl: String? = null +) : OAuthFlow + +/** + * Defines configuration details for the OAuth 2.0 Resource Owner Password flow. + * + * @property tokenUrl The token URL to be used for this flow. This MUST be a URL. + */ +@Serializable +public data class PasswordOAuthFlow( + public val tokenUrl: String, + override val scopes: Map, + override val refreshUrl: String? = null +) : OAuthFlow + +/** + * Defines a security scheme using OpenID Connect. + * + * @property openIdConnectUrl The OpenID Connect Discovery URL for the OIDC provider's metadata. + * @property description An optional description for the security scheme. + */ +@Serializable +public data class OpenIdConnectSecurityScheme( + public val openIdConnectUrl: String, + public val description: String? = null, +) : SecurityScheme { + @EncodeDefault + override val type: String = "openIdConnect" +} + +/** + * Defines a security scheme using mTLS authentication. + * + * @property description An optional description for the security scheme. + */ +@Serializable +public data class MutualTLSSecurityScheme( + public val description: String? = null, +) : SecurityScheme { + @EncodeDefault + override val type: String = "mutualTLS" +} + +/** + * Represents a distinct capability or function that an agent can perform. + * + * @property id A unique identifier for the agent's skill. + * + * @property name A human-readable name for the skill. + * + * @property description A detailed description of the skill, intended to help clients or users understand its purpose and functionality. + * + * @property tags A set of keywords describing the skill's capabilities. + * + * Examples: ["cooking", "customer support", "billing"]. + * + * @property examples Example prompts or scenarios that this skill can handle. Provides a hint to the client on how to use the skill. + * + * Examples: ["I need a recipe for bread"]. + * + * @property inputModes The set of supported input MIME types for this skill, overriding the agent's defaults. + * + * @property outputModes The set of supported output MIME types for this skill, overriding the agent's defaults. + * + * @property security Security schemes necessary for the agent to leverage this skill. As in the overall AgentCard.security, + * this list represents a logical OR of security requirement objects. Each object is a set of security schemes that must be + * used together (a logical AND). + * + * Examples: [{"google": ["oidc"]}]. + */ +@Serializable +public data class AgentSkill( + public val id: String, + public val name: String, + public val description: String, + public val tags: List, + public val examples: List? = null, + public val inputModes: List? = null, + public val outputModes: List? = null, + public val security: Security? = null +) + +/** + * AgentCardSignature represents a JWS signature of an AgentCard. + * This follows the JSON format of an RFC 7515 JSON Web Signature (JWS). + * + * @property protected The protected JWS header for the signature. This is a Base64url-encoded JSON object, as per RFC 7515. + * @property signature The computed signature, Base64url-encoded. + * @property header The unprotected JWS header values. + */ +@Serializable +public data class AgentCardSignature( + @SerialName("protected") + public val `protected`: String, + public val signature: String, + public val header: Map? = null +) diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Artifact.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Artifact.kt new file mode 100644 index 0000000000..74237903e9 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Artifact.kt @@ -0,0 +1,24 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject + +/** + * Represents a file, data structure, or other resource generated by an agent during a task. + * + * @property artifactId A unique identifier (e.g. UUID) for the artifact within the scope of the task. + * @property name An optional, human-readable name for the artifact. + * @property description An optional, human-readable description of the artifact. + * @property parts A list of content parts that make up the artifact. + * @property extensions Optional URIs of extensions that are relevant to this artifact. + * @property metadata Optional metadata for extensions. The key is an extension-specific identifier. + */ +@Serializable +public data class Artifact( + public val artifactId: String, + public val name: String? = null, + public val description: String? = null, + public val parts: List, + public val extensions: List? = null, + public val metadata: JsonObject? = null, +) diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Core.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Core.kt new file mode 100644 index 0000000000..ab3de34149 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Core.kt @@ -0,0 +1,36 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.Serializable + +/** + * Base interface for events. + */ +@Serializable(with = EventSerializer::class) +public sealed interface Event { + /** + * The type used as discriminator. + */ + public val kind: String +} + +/** + * Base interface for communication events, such as messages or tasks. + */ +@Serializable(with = CommunicationEventSerializer::class) +public sealed interface CommunicationEvent : Event + +/** + * Base interface for task events. + */ +@Serializable(with = TaskEventSerializer::class) +public sealed interface TaskEvent : Event { + /** + * The ID of the task associated with this event. + */ + public val taskId: String + + /** + * The ID of the context associated with this event. + */ + public val contextId: String +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Message.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Message.kt new file mode 100644 index 0000000000..b479a67b64 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Message.kt @@ -0,0 +1,46 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.EncodeDefault +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject + +/** + * Message role. + */ +@Serializable +public enum class Role { + @SerialName("user") + User, + + @SerialName("agent") + Agent +} + +/** + * Represents a single message in the conversation between a user and an agent. + * + * @property messageId A unique identifier for the message, typically a UUID, generated by the sender. + * @property role Identifies the sender of the message. `user` for the client, `agent` for the service. + * @property parts An array of content parts that form the message body. A message can be + * composed of multiple parts of different types (e.g., text and files). + * @property extensions The URIs of extensions that are relevant to this message. + * @property taskId The ID of the task this message is part of. Can be omitted for the first message of a new task. + * @property referenceTaskIds A list of other task IDs that this message references for additional context. + * @property contextId The context ID for this message, used to group related interactions. + * @property metadata Optional metadata for extensions. The key is an extension-specific identifier. + */ +@Serializable +public data class Message( + public val messageId: String, + public val role: Role, + public val parts: List, + public val extensions: List? = null, + public val taskId: String? = null, + public val referenceTaskIds: List? = null, + public val contextId: String? = null, + public val metadata: JsonObject? = null, +) : CommunicationEvent { + @EncodeDefault + override val kind: String = "message" +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/MessageSendParams.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/MessageSendParams.kt new file mode 100644 index 0000000000..2d72fc5e45 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/MessageSendParams.kt @@ -0,0 +1,35 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject + +/** + * Defines the parameters for a request to send a message to an agent. This can be used + * to create a new task, continue an existing one, or restart a task. + * + * @property message The message object being sent to the agent. + * @property configuration Optional configuration for the send request. + * @property metadata Optional metadata for extensions. + */ +@Serializable +public data class MessageSendParams( + public val message: Message, + public val configuration: MessageSendConfiguration? = null, + public val metadata: JsonObject? = null, +) + +/** + * Defines configuration options for a `message/send` or `message/stream` request. + * + * @property acceptedOutputModes A list of output MIME types the client is prepared to accept in the response. + * @property historyLength The number of most recent messages from the task's history to retrieve in the response. + * @property pushNotificationConfig Configuration for the agent to send push notifications for updates after the initial response. + * @property blocking If true, the client will wait for the task to complete. The server may reject this if the task is long-running. + */ +@Serializable +public data class MessageSendConfiguration( + public val blocking: Boolean? = null, + public val acceptedOutputModes: List? = null, + public val historyLength: Int? = null, + public val pushNotificationConfig: PushNotificationConfig? = null, +) diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Parts.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Parts.kt new file mode 100644 index 0000000000..abbe4ac047 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Parts.kt @@ -0,0 +1,103 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.EncodeDefault +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject + +/** + * Represents a part of a message or artifact. + */ +@Serializable(with = PartSerializer::class) +public sealed interface Part { + /** + * The type of the part, used as discriminator. + */ + public val kind: String + + /** + * Optional metadata associated with this part. + */ + public val metadata: JsonObject? +} + +/** + * Represents a text part. + * + * @property text The string content of the text part. + */ +@Serializable +public data class TextPart( + public val text: String, + override val metadata: JsonObject? = null, +) : Part { + @EncodeDefault + override val kind: String = "text" +} + +/** + * Represents a file part. The file content can be provided either directly as bytes or as a URI. + * + * @property file The file content. + */ +@Serializable +public data class FilePart( + public val file: File, + override val metadata: JsonObject? = null, +) : Part { + @EncodeDefault + override val kind: String = "file" +} + +/** + * Represents a file within a part. + */ +@Serializable(with = FileSerializer::class) +public sealed interface File { + /** + * An optional name for the file (e.g., "document.pdf"). + */ + public val name: String? + + /** + * An optional MIME type of the file (e.g., "application/pdf"). + */ + public val mimeType: String? +} + +/** + * Represents a file with its content provided directly as a base64-encoded string. + * + * @property bytes The base64-encoded content of the file. + */ +@Serializable +public data class FileWithBytes( + public val bytes: String, + override val name: String? = null, + override val mimeType: String? = null, +) : File + +/** + * Represents a file with its content located at a specific URI. + * + * @property uri A URL pointing to the file's content. + */ +@Serializable +public data class FileWithUri( + public val uri: String, + override val name: String? = null, + override val mimeType: String? = null, +) : File + +/** + * Represents a structured data part (e.g., JSON). + * + * @property data The structured data content. + */ +@Serializable +public data class DataPart( + public val data: JsonObject, + override val metadata: JsonObject? = null, +) : Part { + @EncodeDefault + override val kind: String = "data" +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt new file mode 100644 index 0000000000..78c72f911a --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt @@ -0,0 +1,92 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.DeserializationStrategy +import kotlinx.serialization.SerializationException +import kotlinx.serialization.json.JsonContentPolymorphicSerializer +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive + +internal object SecuritySchemeSerializer : JsonContentPolymorphicSerializer(SecurityScheme::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + val jsonObject = element.jsonObject + val type = jsonObject["type"]?.jsonPrimitive?.content ?: throw SerializationException("Missing 'type' field in SecurityScheme") + + return when (type) { + "apiKey" -> APIKeySecurityScheme.serializer() + "http" -> HTTPAuthSecurityScheme.serializer() + "oauth2" -> OAuth2SecurityScheme.serializer() + "openIdConnect" -> OpenIdConnectSecurityScheme.serializer() + "mutualTLS" -> MutualTLSSecurityScheme.serializer() + else -> throw SerializationException("Unknown SecurityScheme type: $type") + } + } +} + +internal object PartSerializer : JsonContentPolymorphicSerializer(Part::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + val jsonObject = element.jsonObject + val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: throw SerializationException("Missing 'kind' field in Part") + + return when (kind) { + "text" -> TextPart.serializer() + "file" -> FilePart.serializer() + "data" -> DataPart.serializer() + else -> throw SerializationException("Unknown Part kind: $kind") + } + } +} + +internal object FileSerializer : JsonContentPolymorphicSerializer(File::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + val jsonObject = element.jsonObject + + return when { + "bytes" in jsonObject -> FileWithBytes.serializer() + "uri" in jsonObject -> FileWithUri.serializer() + else -> throw SerializationException("Unknown File type") + } + } +} + +internal object EventSerializer : JsonContentPolymorphicSerializer(Event::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + val jsonObject = element.jsonObject + val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: throw SerializationException("Missing 'kind' field in Event") + + return when (kind) { + "status-update" -> TaskStatusUpdateEvent.serializer() + "artifact-update" -> TaskArtifactUpdateEvent.serializer() + "task" -> Task.serializer() + "message" -> Message.serializer() + else -> throw SerializationException("Unknown kind: $kind") + } + } +} + +internal object CommunicationEventSerializer : JsonContentPolymorphicSerializer(CommunicationEvent::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + val jsonObject = element.jsonObject + val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: throw SerializationException("Missing 'kind' field in CommunicationEvent") + + return when (kind) { + "task" -> Task.serializer() + "message" -> Message.serializer() + else -> throw SerializationException("Unknown kind: $kind") + } + } +} + +internal object TaskEventSerializer : JsonContentPolymorphicSerializer(TaskEvent::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + val jsonObject = element.jsonObject + val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: throw SerializationException("Missing 'kind' field in TaskEvent") + + return when (kind) { + "task" -> Task.serializer() + "status-update" -> TaskStatusUpdateEvent.serializer() + "artifact-update" -> TaskArtifactUpdateEvent.serializer() + else -> throw SerializationException("Unknown kind: $kind") + } + } +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt new file mode 100644 index 0000000000..aa007b68a7 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt @@ -0,0 +1,109 @@ +package ai.koog.a2a.model + +import kotlinx.datetime.Clock +import kotlinx.datetime.Instant +import kotlinx.serialization.EncodeDefault +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject + +/** + * Represents a single, stateful operation or conversation between a client and an agent. + * + * @property id A unique identifier (e.g. UUID) for the task, generated by the server for a new task. + * @property contextId A server-generated unique identifier (e.g. UUID) for maintaining context across multiple related tasks or interactions. + * @property status The current status of the task, including its state and a descriptive message. + * @property history An array of messages exchanged during the task, representing the conversation history. + * @property artifacts A collection of artifacts generated by the agent during the execution of the task. + * @property metadata Optional metadata for extensions. The key is an extension-specific identifier. + */ +@Serializable +public data class Task( + public val id: String, + override val contextId: String, + public val status: TaskStatus, + public val history: List? = null, + public val artifacts: List? = null, + public val metadata: JsonObject? = null, +) : CommunicationEvent, TaskEvent { + @EncodeDefault + override val kind: String = "task" + + override val taskId: String get() = id +} + +/** + * Represents the status of a task at a specific point in time. + * + * @property state The current state of the task's lifecycle. + * @property message An optional, human-readable message providing more details about the current status. + * @property timestamp Timestamp indicating when this status was recorded. + */ +@Serializable +public data class TaskStatus( + public val state: TaskState, + public val message: Message? = null, + public val timestamp: Instant? = Clock.System.now(), +) + +/** + * Defines the lifecycle states of a Task. + * + * @property terminal Indicates whether this is a terminal state. + */ +@Serializable +public enum class TaskState(public val terminal: Boolean) { + /** + * The task has been submitted and is awaiting execution. + */ + @SerialName("submitted") + Submitted(terminal = false), + + /** + * The agent is actively working on the task. + */ + @SerialName("working") + Working(terminal = false), + + /** + * The task is paused and waiting for input from the user. + */ + @SerialName("input-required") + InputRequired(terminal = false), + + /** + * The task has been successfully completed. + */ + @SerialName("completed") + Completed(terminal = true), + + /** + * The task has been canceled by the user. + */ + @SerialName("canceled") + Canceled(terminal = true), + + /** + * The task failed due to an error during execution. + */ + @SerialName("failed") + Failed(terminal = true), + + /** + * The task was rejected by the agent and was not started. + */ + @SerialName("rejected") + Rejected(terminal = true), + + /** + * The task requires authentication to proceed. + */ + @SerialName("auth-required") + AuthRequired(terminal = false), + + /* + * The task is in an unknown or indeterminate state. + */ + @SerialName("unknown") + Unknown(terminal = false), +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskEvents.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskEvents.kt new file mode 100644 index 0000000000..2187b90145 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskEvents.kt @@ -0,0 +1,51 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.EncodeDefault +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject + +/** + * An event sent by the agent to notify the client of a change in a task's status. + * This is typically used in streaming or subscription models. + * + * @property taskId The ID of the task that was updated. + * @property contextId The context ID associated with the task. + * @property status The new status of the task. + * @property final If true, this is the final event in the stream for this interaction. + * @property metadata Optional metadata for extensions. + */ +@Serializable +public data class TaskStatusUpdateEvent( + override val taskId: String, + override val contextId: String, + public val status: TaskStatus, + public val final: Boolean, + public val metadata: JsonObject? = null, +) : TaskEvent { + @EncodeDefault + override val kind: String = "status-update" +} + +/** + * An event sent by the agent to notify the client that an artifact has been + * generated or updated. This is typically used in streaming models. + * + * @property taskId The ID of the task this artifact belongs to. + * @property contextId The context ID associated with the task. + * @property artifact The artifact that was generated or updated. + * @property append If true, the content of this artifact should be appended to a previously sent artifact with the same ID. + * @property lastChunk If true, this is the final chunk of the artifact. + * @property metadata Optional metadata for extensions. + */ +@Serializable +public data class TaskArtifactUpdateEvent( + override val taskId: String, + override val contextId: String, + public val artifact: Artifact, + public val append: Boolean? = null, + public val lastChunk: Boolean? = null, + public val metadata: JsonObject? = null, +) : TaskEvent { + @EncodeDefault + override val kind: String = "artifact-update" +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskParams.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskParams.kt new file mode 100644 index 0000000000..606d70806c --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskParams.kt @@ -0,0 +1,44 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject + +/** + * Defines parameters containing a task ID, used for simple task operations. + * + * @property id The unique identifier (e.g. UUID) of the task. + * @property metadata Optional metadata associated with this request. + */ +@Serializable +public data class TaskIdParams( + public val id: String, + public val metadata: JsonObject? = null, +) + +/** + * Defines parameters for querying a task, with an option to limit history length. + * + * @property id The unique identifier (e.g. UUID) of the task. + * @property historyLength The number of most recent messages from the task's history to retrieve. + * @property metadata Optional metadata associated with this request. + */ +@Serializable +public data class TaskQueryParams( + public val id: String, + public val historyLength: Int? = null, + public val metadata: JsonObject? = null, +) + +/** + * Defines parameters for fetching a specific push notification configuration for a task. + * + * @property id The unique identifier (e.g. UUID) of the task. + * @property pushNotificationConfigId The ID of the push notification configuration to retrieve. + * @property metadata Optional metadata associated with this request. + */ +@Serializable +public data class TaskPushNotificationConfigParams( + public val id: String, + public val pushNotificationConfigId: String? = null, + public val metadata: JsonObject? = null, +) diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskPushNotificationConfig.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskPushNotificationConfig.kt new file mode 100644 index 0000000000..f41cbbfe6d --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskPushNotificationConfig.kt @@ -0,0 +1,43 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.Serializable + +/** + * A container associating a push notification configuration with a specific task. + * + * @property taskId The unique identifier (e.g. UUID) of the task. + * @property pushNotificationConfig The push notification configuration for this task. + */ +@Serializable +public data class TaskPushNotificationConfig( + public val taskId: String, + public val pushNotificationConfig: PushNotificationConfig, +) + +/** + * Defines the configuration for setting up push notifications for task updates. + * + * @property id A unique identifier (e.g. UUID) for the push notification configuration, set by the client to support multiple notification callbacks. + * @property url The callback URL where the agent should send push notifications. + * @property token A unique token for this task or session to validate incoming push notifications. + * @property authentication Optional authentication details for the agent to use when calling the notification URL. + */ +@Serializable +public data class PushNotificationConfig( + public val id: String? = null, + public val url: String, + public val token: String? = null, + public val authentication: PushNotificationAuthenticationInfo? = null, +) + +/** + * Defines authentication details for a push notification endpoint. + * + * @property schemes A list of supported authentication schemes (e.g., 'Basic', 'Bearer'). + * @property credentials Optional credentials required by the push notification endpoint. + */ +@Serializable +public data class PushNotificationAuthenticationInfo( + public val schemes: List, + public val credentials: String? = null, +) diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt new file mode 100644 index 0000000000..ca1821c373 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt @@ -0,0 +1,143 @@ +package ai.koog.a2a.transport + +import ai.koog.a2a.exceptions.A2AException +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.Event +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskPushNotificationConfig +import ai.koog.a2a.model.TaskPushNotificationConfigParams +import ai.koog.a2a.model.TaskQueryParams +import kotlinx.coroutines.flow.Flow +import kotlinx.serialization.Serializable +import kotlinx.serialization.SerializationException + +/** + * Client transport making requests to [A2A protocol methods](https://a2a-protocol.org/latest/specification/#7-protocol-rpc-methods) + * and handling responses from the server. + * + * Client transport must handle error responses from the server and convert them to appropriate [A2AException] + * (e.g. parsing error response data format like JSON error object and throwing corresponding [A2AException] based on the error code). + * It must preserve the [A2AException.errorCode] received from the [ServerTransport]. + * + * Client transport may throw exceptions other than [A2AException] for any transport-level errors (e.g. network failures, invalid responses, timeout), + * e.g. [SerializationException] + */ +public interface ClientTransport : AutoCloseable { + /** + * Calls [agent/getAuthenticatedExtendedCard](https://a2a-protocol.org/latest/specification/#710-agentgetauthenticatedextendedcard) + * + * @throws A2AException if server returned an error. + */ + public suspend fun getAuthenticatedExtendedAgentCard( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response + + /** + * Calls [message/send](https://a2a-protocol.org/latest/specification/#71-messagesend). + * + * @throws A2AException if server returned an error. + */ + public suspend fun sendMessage( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response + + /** + * Calls [message/stream](https://a2a-protocol.org/latest/specification/#72-messagestream) + * + * @throws A2AException if server returned an error. + */ + public fun sendMessageStreaming( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Flow> + + /** + * Calls [tasks/get](https://a2a-protocol.org/latest/specification/#73-tasksget) + * + * @throws A2AException if server returned an error. + */ + public suspend fun getTask( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response + + /** + * Calls [tasks/cancel](https://a2a-protocol.org/latest/specification/#74-taskscancel) + * + * @throws A2AException if server returned an error. + */ + public suspend fun cancelTask( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response + + /** + * Calls [tasks/resubscribe](https://a2a-protocol.org/latest/specification/#79-tasksresubscribe) + * + * @throws A2AException if server returned an error. + */ + public fun resubscribeTask( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Flow> + + /** + * Calls [tasks/pushNotificationConfig/set](https://a2a-protocol.org/latest/specification/#75-taskspushnotificationconfigset) + * + * @throws A2AException if server returned an error. + */ + public suspend fun setTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response + + /** + * Calls [tasks/pushNotificationConfig/get](https://a2a-protocol.org/latest/specification/#76-taskspushnotificationconfigget) + * + * @throws A2AException if server returned an error. + */ + public suspend fun getTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response + + /** + * Calls [tasks/pushNotificationConfig/list](https://a2a-protocol.org/latest/specification/#77-taskspushnotificationconfiglist) + * + * @throws A2AException if server returned an error. + */ + public suspend fun listTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response> + + /** + * Calls [tasks/pushNotificationConfig/delete](https://a2a-protocol.org/latest/specification/#78-taskspushnotificationconfigdelete) + * + * @throws A2AException if server returned an error. + */ + public suspend fun deleteTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response +} + +/** + * Represents the client context of a call. + * + * @property additionalHeaders Additional call-specific headers associated with the call. + */ +@Serializable +public data class ClientCallContext( + public val additionalHeaders: Map> = emptyMap(), +) { + @Suppress("MissingKDocForPublicAPI") + public companion object { + public val Default: ClientCallContext = ClientCallContext() + } +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Core.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Core.kt new file mode 100644 index 0000000000..ad4c2e9437 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Core.kt @@ -0,0 +1,47 @@ +package ai.koog.a2a.transport + +import kotlinx.serialization.Serializable +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +/** + * A uniquely identifying ID for a request. + */ +@Serializable(with = RequestIdSerializer::class) +public sealed interface RequestId { + /** + * A string representation of the ID. + */ + @Serializable + public data class StringId(val value: String) : RequestId + + /** + * A numeric representation of the ID. + */ + @Serializable + public data class NumberId(val value: Long) : RequestId +} + +/** + * Represents a request containing a unique identifier. + * + * @property id The unique identifier for the request. + * @property data The data payload of the request. + */ +@OptIn(ExperimentalUuidApi::class) +public class Request( + public val data: T, + public val id: RequestId = RequestId.StringId(Uuid.random().toString()), +) + +/** + * Represents a response associated with a request identifier. + * + * @property id The unique identifier for the request associated with this response. + * @property data The response data payload. + */ +@OptIn(ExperimentalUuidApi::class) +public class Response( + public val data: T, + public val id: RequestId = RequestId.StringId(Uuid.random().toString()), +) diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Serialization.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Serialization.kt new file mode 100644 index 0000000000..b55cec4339 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Serialization.kt @@ -0,0 +1,39 @@ +package ai.koog.a2a.transport + +import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerializationException +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.buildClassSerialDescriptor +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder +import kotlinx.serialization.json.JsonDecoder +import kotlinx.serialization.json.JsonEncoder +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.long +import kotlinx.serialization.json.longOrNull + +internal object RequestIdSerializer : KSerializer { + override val descriptor: SerialDescriptor = buildClassSerialDescriptor("RequestId") + + override fun deserialize(decoder: Decoder): RequestId { + val jsonDecoder = decoder as? JsonDecoder ?: error("Can only deserialize JSON") + + return when (val element = jsonDecoder.decodeJsonElement()) { + is JsonPrimitive -> when { + element.isString -> RequestId.StringId(element.content) + element.longOrNull != null -> RequestId.NumberId(element.long) + else -> throw SerializationException("Invalid RequestId type") + } + + else -> throw SerializationException("Invalid RequestId format") + } + } + + override fun serialize(encoder: Encoder, value: RequestId) { + val jsonEncoder = encoder as? JsonEncoder ?: throw SerializationException("Can only serialize JSON") + when (value) { + is RequestId.StringId -> jsonEncoder.encodeString(value.value) + is RequestId.NumberId -> jsonEncoder.encodeLong(value.value) + } + } +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt new file mode 100644 index 0000000000..416b0cd80f --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt @@ -0,0 +1,227 @@ +package ai.koog.a2a.transport + +import ai.koog.a2a.exceptions.A2AException +import ai.koog.a2a.exceptions.A2AInternalErrorException +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.Event +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskPushNotificationConfig +import ai.koog.a2a.model.TaskPushNotificationConfigParams +import ai.koog.a2a.model.TaskQueryParams +import kotlinx.coroutines.flow.Flow + +/** + * Server transport processing raw requests made to [A2A protocol methods](https://a2a-protocol.org/latest/specification/#7-protocol-rpc-methods) + * and delegating the processing to [RequestHandler]. + * + * Server transport must respond with appropriate [A2AException] in case of errors while processing the request + * (e.g. method not found or invalid method parameters). It must also handle [A2AException] thrown by the [RequestHandler] methods. + * In case non [A2AException] is thrown, it must be converted to [A2AInternalErrorException] with appropriate message. + * + * Server transport must convert [A2AException] to appropriate response data format (e.g. JSON error object), + * preserving the [A2AException.errorCode] so that it can be properly handled by the [ClientTransport]. + */ +public interface ServerTransport { + /** + * Handler responsible for processing parsed A2A requests. + */ + public val requestHandler: RequestHandler +} + +/** + * Handler responsible for processing parsed A2A requests, implementing + * [A2A protocol methods](https://a2a-protocol.org/latest/specification/#7-protocol-rpc-methods). + */ +public interface RequestHandler { + /** + * Handles [agent/getAuthenticatedExtendedCard](https://a2a-protocol.org/latest/specification/#710-agentgetauthenticatedextendedcard) + * + * @throws A2AException if there is an error with processsing the request. + */ + public suspend fun onGetAuthenticatedExtendedAgentCard( + request: Request, + ctx: ServerCallContext + ): Response + + /** + * Handles [message/send](https://a2a-protocol.org/latest/specification/#71-messagesend). + * + * @throws A2AException if there is an error with processsing the request. + */ + public suspend fun onSendMessage( + request: Request, + ctx: ServerCallContext + ): Response + + /** + * Handles [message/stream](https://a2a-protocol.org/latest/specification/#72-messagestream) + * + * @throws A2AException if there is an error with processsing the request. + */ + public fun onSendMessageStreaming( + request: Request, + ctx: ServerCallContext + ): Flow> + + /** + * Handles [tasks/get](https://a2a-protocol.org/latest/specification/#73-tasksget) + * + * @throws A2AException if there is an error with processsing the request. + */ + public suspend fun onGetTask( + request: Request, + ctx: ServerCallContext + ): Response + + /** + * Handles [tasks/resubscribe](https://a2a-protocol.org/latest/specification/#79-tasksresubscribe) + * + * @throws A2AException if there is an error with processsing the request. + */ + public fun onResubscribeTask( + request: Request, + ctx: ServerCallContext + ): Flow> + + /** + * Handles [tasks/cancel](https://a2a-protocol.org/latest/specification/#74-taskscancel) + * + * @throws A2AException if there is an error with processsing the request. + */ + public suspend fun onCancelTask( + request: Request, + ctx: ServerCallContext + ): Response + + /** + * Handles [tasks/pushNotificationConfig/set](https://a2a-protocol.org/latest/specification/#75-taskspushnotificationconfigset) + * + * @throws A2AException if there is an error with processsing the request. + */ + public suspend fun onSetTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response + + /** + * Handles [tasks/pushNotificationConfig/get](https://a2a-protocol.org/latest/specification/#76-taskspushnotificationconfigget) + * + * @throws A2AException if there is an error with processsing the request. + */ + public suspend fun onGetTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response + + /** + * Handles [tasks/pushNotificationConfig/list](https://a2a-protocol.org/latest/specification/#77-taskspushnotificationconfiglist) + * + * @throws A2AException if there is an error with processsing the request. + */ + public suspend fun onListTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response> + + /** + * Handles [tasks/pushNotificationConfig/delete](https://a2a-protocol.org/latest/specification/#78-taskspushnotificationconfigdelete) + * + * @throws A2AException if there is an error with processsing the request. + */ + public suspend fun onDeleteTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response +} + +/** + * Represents the server context of a call. + * + * This context has [state] associated with it, which is essentially an untyped map. It can be used to store arbitrary + * user-defined data. This is useful for extending the base logic with business-dependent logic, e.g., storing user + * information to authorize particular requests. This untyped [state] map has typed accessors for more convenient access, + * so it is recommended to use them when reading from state: [getFromState], [getFromStateOrNull]. + * + * **Note**: Make sure the types of [StateKey] and the value match when populating [state], otherwise [getFromState] + * and [getFromStateOrNull] will throw [IllegalStateException]. + * + * Example usage: + * ```kotlin + * // User-defined data class + * data class User(val id: String) + * + * // Collection of user-defined state keys + * object StateKeys { + * val USER_KEY = StateKey("42") + * } + * + * // On the handler side - copying supplied context and populating state + * override suspend fun onSendMessage( + * request: Request, + * ctx: ServerCallContext + * ): Response { + * val user = ctx.headers.getValue("user-id").let { User(it) } + * val newCtx = ctx.copy(state = ctx.state + (StateKeys.USER_KEY to user)) + * + * super.onSendMessage(request, newCtx) + * } + * + * // On the business logic side - retrieving user data from context + * val user = ctx.getFromState(StateKeys.USER_KEY) + * ``` + * + * @property headers Headers associated with the call. + * @property state State associated with the call, allows storing arbitrary values. To get typed value from the state, + * use [getFromState] or [getFromStateOrNull] with appropriate [StateKey]. + */ +public class ServerCallContext( + public val headers: Map> = emptyMap(), + public val state: Map, Any> = emptyMap() +) { + /** + * Retrieves a value of type [T] associated with the specified [key] from the [state] map. + * If the [key] is not found in the state, returns `null`. + * + * Performs unsafe cast under the hood, so make sure the value is of the expected type. + * + * @param key The state key for which the associated value needs to be retrieved. + */ + public inline fun getFromStateOrNull(key: StateKey): T? { + return state[key]?.let { + it as? T ?: throw IllegalStateException("State value for key $key is not of expected type ${T::class}") + } + } + + /** + * Retrieves a value of type [T] associated with the specified [key] from the [state] map. + * + * Performs unsafe cast under the hood, so make sure the value is of the expected type. + * + * @param key The state key for which the associated value needs to be retrieved. + * @throws NoSuchElementException if the [key] is not found in the state. + */ + public inline fun getFromState(key: StateKey): T { + return getFromStateOrNull(key) ?: throw NoSuchElementException("State key $key not found or null") + } + + /** + * Creates a copy of this [ServerCallContext]. + */ + public fun copy( + headers: Map> = this.headers, + state: Map, Any> = this.state, + ): ServerCallContext = ServerCallContext(headers, state) +} + +/** + * Helper class to be used with [ServerCallContext.state] to store and retrieve values associated with a key in a typed + * manner. + * + * @see ServerCallContext + */ +public class StateKey<@Suppress("unused") T>(public val name: String) { + override fun toString(): String = "${super.toString()}(name=$name)" +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/KeyedMutex.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/KeyedMutex.kt new file mode 100644 index 0000000000..73e13f7001 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/KeyedMutex.kt @@ -0,0 +1,146 @@ +package ai.koog.a2a.utils + +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract + +/** + * A keyed, suspend-friendly mutex for Kotlin Multiplatform. + * + * Guarantees mutual exclusion per key, without blocking threads. + * API mirrors kotlinx.coroutines Mutex: + * - lock(key, owner) + * - tryLock(key, owner) + * - unlock(key, owner) + * - withLock(key, owner) { ... } + * + * Internals: + * - A per-key entry holds a Mutex and a reference count (holders/waiters). + * - We increment refs before suspending for lock(key) so entries aren’t removed while waiting. + * - Cleanup removes the entry only when refs == 0 and the mutex is not locked. + */ +public class KeyedMutex { + private class Entry( + val mutex: Mutex = Mutex(), + var refs: Int = 0 + ) + + /** + * Protects access to [entries] and [Entry.refs] updates. + */ + private val mapMutex = Mutex() + private val entries = mutableMapOf() + + /** + * Suspends until the lock for [key] is acquired. + * Not re-entrant for the same key within the same coroutine. + */ + public suspend fun lock(key: K, owner: Any? = null) { + val entry = mapMutex.withLock { + val e = entries.getOrPut(key) { Entry() } + e.refs += 1 + e + } + + try { + entry.mutex.lock(owner) + } catch (t: Throwable) { + // If lock failed/cancelled before acquiring, roll back the ref and maybe cleanup. + mapMutex.withLock { + entry.refs -= 1 + if (entry.refs == 0 && !entry.mutex.isLocked && entries[key] == entry) { + entries.remove(key) + } + } + throw t + } + } + + /** + * Attempts to acquire the lock for [key] without suspension. + * Returns true if lock was acquired. + */ + public suspend fun tryLock(key: K, owner: Any? = null): Boolean { + return mapMutex.withLock { + val existing = entries[key] + if (existing != null) { + if (existing.mutex.tryLock(owner)) { + existing.refs += 1 + true + } else { + false + } + } else { + // Avoid inserting if we cannot lock; create a fresh entry and try immediately. + val e = Entry() + val locked = e.mutex.tryLock(owner) + if (locked) { + e.refs = 1 + entries[key] = e + true + } else { + // Unlikely with a fresh Mutex, but don't insert on failure. + false + } + } + } + } + + /** + * Releases the lock for [key]. + * Must be called exactly once per successful lock/tryLock. + * @throws IllegalStateException on misuse (mirrors Mutex behavior). + */ + public suspend fun unlock(key: K, owner: Any? = null) { + val entry = mapMutex.withLock { + entries[key] ?: throw IllegalStateException("Unlock requested for key without active entry") + } + + // Perform the actual unlock; may throw if not locked or wrong owner. + entry.mutex.unlock(owner) + + // Decrement refs and cleanup if safe. + mapMutex.withLock { + entry.refs -= 1 + if (entry.refs == 0 && !entry.mutex.isLocked && entries[key] == entry) { + entries.remove(key) + } + } + } + + /** + * Optional: observe whether a key appears locked. + * For diagnostics/metrics only (not a synchronization primitive). + */ + public suspend fun isLocked(key: K): Boolean = + mapMutex.withLock { entries[key]?.mutex?.isLocked == true } + + /** + * Checks whether this key is locked by the specified owner. + */ + public suspend fun holdsLock(key: K, owner: Any): Boolean = + mapMutex.withLock { entries[key]?.mutex?.holdsLock(owner) == true } +} + +/** + * Convenience function mirroring [Mutex.withLock] + */ +@OptIn(ExperimentalContracts::class) +public suspend inline fun KeyedMutex.withLock( + key: K, + owner: Any? = null, + action: suspend () -> T +): T { + contract { + callsInPlace(action, InvocationKind.EXACTLY_ONCE) + } + + lock(key, owner) + try { + return action() + } finally { + unlock(key, owner) + } +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/RWLock.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/RWLock.kt new file mode 100644 index 0000000000..6705868ffa --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/RWLock.kt @@ -0,0 +1,49 @@ +package ai.koog.a2a.utils + +import ai.koog.a2a.annotations.InternalA2AApi +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock + +// FIXME copied from agents-core module, because a2a does not depend on other Koog modules. +// Do we want to make a global utils module for cases like this? +/** + * A KMP read-write lock implementation that allows concurrent read access but ensures exclusive write access. + * + * This implementation uses [Mutex] to coordinate access for both readers and writers. + */ +@InternalA2AApi +public class RWLock { + private val writeMutex = Mutex() + private var readersCount = 0 + private val readersCountMutex = Mutex() + + /** + * Run the given [block] of code while holding the read lock. + */ + public suspend fun withReadLock(block: suspend () -> T): T { + readersCountMutex.withLock { + if (++readersCount == 1) { + writeMutex.lock() + } + } + + return try { + block() + } finally { + readersCountMutex.withLock { + if (--readersCount == 0) { + writeMutex.unlock() + } + } + } + } + + /** + * Run the given [block] of code while holding the write lock. + */ + public suspend fun withWriteLock(block: suspend () -> T): T { + writeMutex.withLock { + return block() + } + } +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/ResultUtils.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/ResultUtils.kt new file mode 100644 index 0000000000..ff73526be5 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/ResultUtils.kt @@ -0,0 +1,20 @@ +package ai.koog.a2a.utils + +import ai.koog.a2a.annotations.InternalA2AApi +import kotlinx.coroutines.CancellationException + +// FIXME copied from agents-core module, because a2a does not depend on other Koog modules. +// Do we want to make a global utils module for cases like this? +/** + * Same as [runCatching], but does not catch [CancellationException], throwing it instead, making it safe to use with coroutines. + */ +@InternalA2AApi +public inline fun runCatchingCancellable(block: () -> R): Result { + return try { + Result.success(block()) + } catch (ce: CancellationException) { + throw ce + } catch (e: Exception) { + Result.failure(e) + } +} diff --git a/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/model/AgentCardSerializationTest.kt b/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/model/AgentCardSerializationTest.kt new file mode 100644 index 0000000000..a0e410b2cc --- /dev/null +++ b/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/model/AgentCardSerializationTest.kt @@ -0,0 +1,569 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.json.Json +import kotlin.test.Test +import kotlin.test.assertEquals + +class AgentCardSerializationTest { + @Suppress("PrivatePropertyName") + private val TestJson = Json { + prettyPrint = true + encodeDefaults = false + } + + @Test + fun testMinimalAgentCardSerialization() { + val agentCard = AgentCard( + name = "Test Agent", + description = "A test agent", + url = "https://api.example.com/a2a", + version = "1.0.0", + capabilities = AgentCapabilities(), + defaultInputModes = listOf("text/plain"), + defaultOutputModes = listOf("text/plain"), + skills = listOf( + AgentSkill( + id = "test-skill", + name = "Test Skill", + description = "A test skill", + tags = listOf("test") + ) + ) + ) + + //language=JSON + val expectedJson = """ + { + "protocolVersion": "0.3.0", + "name": "Test Agent", + "description": "A test agent", + "url": "https://api.example.com/a2a", + "preferredTransport": "JSONRPC", + "version": "1.0.0", + "capabilities": {}, + "defaultInputModes": [ + "text/plain" + ], + "defaultOutputModes": [ + "text/plain" + ], + "skills": [ + { + "id": "test-skill", + "name": "Test Skill", + "description": "A test skill", + "tags": [ + "test" + ] + } + ] + } + """.trimIndent() + + val actualJson = TestJson.encodeToString(agentCard) + assertEquals(expectedJson, actualJson) + + // Test deserialization + val deserializedCard = TestJson.decodeFromString(actualJson) + assertEquals(agentCard, deserializedCard) + } + + @Test + fun testFullAgentCardSerialization() { + val agentCard = AgentCard( + name = "GeoSpatial Route Planner Agent", + description = "Provides advanced route planning, traffic analysis, and custom map generation services. This agent can calculate optimal routes, estimate travel times considering real-time traffic, and create personalized maps with points of interest.", + url = "https://georoute-agent.example.com/a2a/v1", + additionalInterfaces = listOf( + AgentInterface( + url = "https://georoute-agent.example.com/a2a/v1", + transport = TransportProtocol.JSONRPC + ), + AgentInterface( + url = "https://georoute-agent.example.com/a2a/grpc", + transport = TransportProtocol.GRPC + ), + AgentInterface( + url = "https://georoute-agent.example.com/a2a/json", + transport = TransportProtocol.HTTP_JSON_REST + ) + ), + provider = AgentProvider( + organization = "Example Geo Services Inc.", + url = "https://www.examplegeoservices.com" + ), + iconUrl = "https://georoute-agent.example.com/icon.png", + version = "1.2.0", + documentationUrl = "https://docs.examplegeoservices.com/georoute-agent/api", + capabilities = AgentCapabilities( + streaming = true, + pushNotifications = true, + stateTransitionHistory = false + ), + securitySchemes = mapOf( + "google" to OpenIdConnectSecurityScheme( + openIdConnectUrl = "https://accounts.google.com/.well-known/openid-configuration" + ) + ), + security = listOf( + mapOf("google" to listOf("openid", "profile", "email")) + ), + defaultInputModes = listOf("application/json", "text/plain"), + defaultOutputModes = listOf("application/json", "image/png"), + skills = listOf( + AgentSkill( + id = "route-optimizer-traffic", + name = "Traffic-Aware Route Optimizer", + description = "Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).", + tags = listOf("maps", "routing", "navigation", "directions", "traffic"), + examples = listOf( + "Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", + "{\"origin\": {\"lat\": 37.422, \"lng\": -122.084}, \"destination\": {\"lat\": 37.7749, \"lng\": -122.4194}, \"preferences\": [\"avoid_ferries\"]}" + ), + inputModes = listOf("application/json", "text/plain"), + outputModes = listOf( + "application/json", + "application/vnd.geo+json", + "text/html" + ) + ), + AgentSkill( + id = "custom-map-generator", + name = "Personalized Map Generator", + description = "Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.", + tags = listOf("maps", "customization", "visualization", "cartography"), + examples = listOf( + "Generate a map of my upcoming road trip with all planned stops highlighted.", + "Show me a map visualizing all coffee shops within a 1-mile radius of my current location." + ), + inputModes = listOf("application/json"), + outputModes = listOf( + "image/png", + "image/jpeg", + "application/json", + "text/html" + ) + ) + ), + supportsAuthenticatedExtendedCard = true, + signatures = listOf( + AgentCardSignature( + `protected` = "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpPU0UiLCJraWQiOiJrZXktMSIsImprdSI6Imh0dHBzOi8vZXhhbXBsZS5jb20vYWdlbnQvandrcy5qc29uIn0", + signature = "QFdkNLNszlGj3z3u0YQGt_T9LixY3qtdQpZmsTdDHDe3fXV9y9-B3m2-XgCpzuhiLt8E0tV6HXoZKHv4GtHgKQ" + ) + ) + ) + + //language=JSON + val expectedJson = """ + { + "protocolVersion": "0.3.0", + "name": "GeoSpatial Route Planner Agent", + "description": "Provides advanced route planning, traffic analysis, and custom map generation services. This agent can calculate optimal routes, estimate travel times considering real-time traffic, and create personalized maps with points of interest.", + "url": "https://georoute-agent.example.com/a2a/v1", + "preferredTransport": "JSONRPC", + "additionalInterfaces": [ + { + "url": "https://georoute-agent.example.com/a2a/v1", + "transport": "JSONRPC" + }, + { + "url": "https://georoute-agent.example.com/a2a/grpc", + "transport": "GRPC" + }, + { + "url": "https://georoute-agent.example.com/a2a/json", + "transport": "HTTP+JSON/REST" + } + ], + "iconUrl": "https://georoute-agent.example.com/icon.png", + "provider": { + "organization": "Example Geo Services Inc.", + "url": "https://www.examplegeoservices.com" + }, + "version": "1.2.0", + "documentationUrl": "https://docs.examplegeoservices.com/georoute-agent/api", + "capabilities": { + "streaming": true, + "pushNotifications": true, + "stateTransitionHistory": false + }, + "securitySchemes": { + "google": { + "openIdConnectUrl": "https://accounts.google.com/.well-known/openid-configuration", + "type": "openIdConnect" + } + }, + "security": [ + { + "google": [ + "openid", + "profile", + "email" + ] + } + ], + "defaultInputModes": [ + "application/json", + "text/plain" + ], + "defaultOutputModes": [ + "application/json", + "image/png" + ], + "skills": [ + { + "id": "route-optimizer-traffic", + "name": "Traffic-Aware Route Optimizer", + "description": "Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).", + "tags": [ + "maps", + "routing", + "navigation", + "directions", + "traffic" + ], + "examples": [ + "Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", + "{\"origin\": {\"lat\": 37.422, \"lng\": -122.084}, \"destination\": {\"lat\": 37.7749, \"lng\": -122.4194}, \"preferences\": [\"avoid_ferries\"]}" + ], + "inputModes": [ + "application/json", + "text/plain" + ], + "outputModes": [ + "application/json", + "application/vnd.geo+json", + "text/html" + ] + }, + { + "id": "custom-map-generator", + "name": "Personalized Map Generator", + "description": "Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.", + "tags": [ + "maps", + "customization", + "visualization", + "cartography" + ], + "examples": [ + "Generate a map of my upcoming road trip with all planned stops highlighted.", + "Show me a map visualizing all coffee shops within a 1-mile radius of my current location." + ], + "inputModes": [ + "application/json" + ], + "outputModes": [ + "image/png", + "image/jpeg", + "application/json", + "text/html" + ] + } + ], + "supportsAuthenticatedExtendedCard": true, + "signatures": [ + { + "protected": "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpPU0UiLCJraWQiOiJrZXktMSIsImprdSI6Imh0dHBzOi8vZXhhbXBsZS5jb20vYWdlbnQvandrcy5qc29uIn0", + "signature": "QFdkNLNszlGj3z3u0YQGt_T9LixY3qtdQpZmsTdDHDe3fXV9y9-B3m2-XgCpzuhiLt8E0tV6HXoZKHv4GtHgKQ" + } + ] + } + """.trimIndent() + + val actualJson = TestJson.encodeToString(agentCard) + + assertEquals(expectedJson, actualJson) + } + + @Test + fun testTransportProtocolSerialization() { + val jsonRpc = TransportProtocol.JSONRPC + val httpJson = TransportProtocol.HTTP_JSON_REST + val grpc = TransportProtocol.GRPC + val custom = TransportProtocol("CUSTOM") + + assertEquals("\"JSONRPC\"", TestJson.encodeToString(jsonRpc)) + assertEquals("\"HTTP+JSON/REST\"", TestJson.encodeToString(httpJson)) + assertEquals("\"GRPC\"", TestJson.encodeToString(grpc)) + assertEquals("\"CUSTOM\"", TestJson.encodeToString(custom)) + + // Test deserialization + assertEquals(jsonRpc, TestJson.decodeFromString("\"JSONRPC\"")) + assertEquals(httpJson, TestJson.decodeFromString("\"HTTP+JSON/REST\"")) + assertEquals(grpc, TestJson.decodeFromString("\"GRPC\"")) + assertEquals(custom, TestJson.decodeFromString("\"CUSTOM\"")) + } + + @Test + fun testSecuritySchemeSerialization() { + // API Key Security Scheme + val apiKeyScheme = APIKeySecurityScheme( + `in` = In.Header, + name = "Authorization", + description = "Bearer token" + ) + //language=JSON + val apiKeyJson = """ + { + "in": "header", + "name": "Authorization", + "description": "Bearer token", + "type": "apiKey" + } + """.trimIndent() + assertEquals(apiKeyJson, TestJson.encodeToString(apiKeyScheme)) + + // HTTP Auth Security Scheme + val httpScheme = HTTPAuthSecurityScheme( + scheme = "Bearer", + bearerFormat = "JWT", + description = "JWT Bearer token" + ) + //language=JSON + val httpJson = """ + { + "scheme": "Bearer", + "bearerFormat": "JWT", + "description": "JWT Bearer token", + "type": "http" + } + """.trimIndent() + assertEquals(httpJson, TestJson.encodeToString(httpScheme)) + + // OAuth2 Security Scheme + val oauth2Scheme = OAuth2SecurityScheme( + flows = OAuthFlows( + authorizationCode = AuthorizationCodeOAuthFlow( + authorizationUrl = "https://auth.example.com/oauth/authorize", + tokenUrl = "https://auth.example.com/oauth/token", + scopes = mapOf("read" to "Read access", "write" to "Write access") + ) + ), + description = "OAuth2 with authorization code flow" + ) + //language=JSON + val expectedOAuth2Json = """ + { + "flows": { + "authorizationCode": { + "authorizationUrl": "https://auth.example.com/oauth/authorize", + "tokenUrl": "https://auth.example.com/oauth/token", + "scopes": { + "read": "Read access", + "write": "Write access" + } + } + }, + "description": "OAuth2 with authorization code flow", + "type": "oauth2" + } + """.trimIndent() + + val oauth2Json = TestJson.encodeToString(oauth2Scheme) + assertEquals(expectedOAuth2Json, oauth2Json) + + // OpenID Connect Security Scheme + val oidcScheme = OpenIdConnectSecurityScheme( + openIdConnectUrl = "https://auth.example.com/.well-known/openid_configuration" + ) + //language=JSON + val oidcJson = """ + { + "openIdConnectUrl": "https://auth.example.com/.well-known/openid_configuration", + "type": "openIdConnect" + } + """.trimIndent() + assertEquals(oidcJson, TestJson.encodeToString(oidcScheme)) + + // Mutual TLS Security Scheme + val mtlsScheme = MutualTLSSecurityScheme( + description = "Client certificate authentication" + ) + //language=JSON + val mtlsJson = """ + { + "description": "Client certificate authentication", + "type": "mutualTLS" + } + """.trimIndent() + assertEquals(mtlsJson, TestJson.encodeToString(mtlsScheme)) + + // Test deserialization + assertEquals(apiKeyScheme, TestJson.decodeFromString(apiKeyJson)) + assertEquals(httpScheme, TestJson.decodeFromString(httpJson)) + assertEquals(oauth2Scheme, TestJson.decodeFromString(expectedOAuth2Json)) + assertEquals(oidcScheme, TestJson.decodeFromString(oidcJson)) + assertEquals(mtlsScheme, TestJson.decodeFromString(mtlsJson)) + } + + @Test + fun testOAuthFlowsSerialization() { + val flows = OAuthFlows( + authorizationCode = AuthorizationCodeOAuthFlow( + authorizationUrl = "https://auth.example.com/oauth/authorize", + tokenUrl = "https://auth.example.com/oauth/token", + scopes = mapOf("read" to "Read access"), + refreshUrl = "https://auth.example.com/oauth/refresh" + ), + clientCredentials = ClientCredentialsOAuthFlow( + tokenUrl = "https://auth.example.com/oauth/token", + scopes = mapOf("admin" to "Admin access") + ), + implicit = ImplicitOAuthFlow( + authorizationUrl = "https://auth.example.com/oauth/authorize", + scopes = mapOf("read" to "Read access") + ), + password = PasswordOAuthFlow( + tokenUrl = "https://auth.example.com/oauth/token", + scopes = mapOf("user" to "User access") + ) + ) + + val json = TestJson.encodeToString(flows) + val deserialized = TestJson.decodeFromString(json) + assertEquals(flows, deserialized) + } + + @Test + fun testAgentCapabilitiesSerialization() { + // Test default capabilities + val defaultCapabilities = AgentCapabilities() + //language=JSON + val defaultJson = """ + {} + """.trimIndent() + assertEquals(defaultJson, TestJson.encodeToString(defaultCapabilities)) + + // Test full capabilities + val fullCapabilities = AgentCapabilities( + streaming = true, + pushNotifications = true, + stateTransitionHistory = true, + extensions = listOf( + AgentExtension( + uri = "https://example.com/ext/v1", + description = "Test extension", + required = true + ) + ) + ) + //language=JSON + val expectedFullJson = """ + { + "streaming": true, + "pushNotifications": true, + "stateTransitionHistory": true, + "extensions": [ + { + "uri": "https://example.com/ext/v1", + "description": "Test extension", + "required": true + } + ] + } + """.trimIndent() + + val fullJson = TestJson.encodeToString(fullCapabilities) + assertEquals(expectedFullJson, fullJson) + + // Test deserialization + assertEquals(fullCapabilities, TestJson.decodeFromString(expectedFullJson)) + } + + @Test + fun testAgentSkillSerialization() { + val skill = AgentSkill( + id = "test-skill", + name = "Test Skill", + description = "A comprehensive test skill", + tags = listOf("test", "demo"), + examples = listOf("How to test?", "Show me a demo"), + inputModes = listOf("text/plain", "application/json"), + outputModes = listOf("text/plain"), + security = listOf( + mapOf("oauth" to listOf("read", "write")), + mapOf("api-key" to emptyList()) + ) + ) + + //language=JSON + val expectedJson = """ + { + "id": "test-skill", + "name": "Test Skill", + "description": "A comprehensive test skill", + "tags": [ + "test", + "demo" + ], + "examples": [ + "How to test?", + "Show me a demo" + ], + "inputModes": [ + "text/plain", + "application/json" + ], + "outputModes": [ + "text/plain" + ], + "security": [ + { + "oauth": [ + "read", + "write" + ] + }, + { + "api-key": [] + } + ] + } + """.trimIndent() + + val json = TestJson.encodeToString(skill) + assertEquals(expectedJson, json) + + val deserialized = TestJson.decodeFromString(json) + assertEquals(skill, deserialized) + } + + @Test + fun testEnumSerialization() { + assertEquals("\"cookie\"", TestJson.encodeToString(In.Cookie)) + assertEquals("\"header\"", TestJson.encodeToString(In.Header)) + assertEquals("\"query\"", TestJson.encodeToString(In.Query)) + + assertEquals(In.Cookie, TestJson.decodeFromString("\"cookie\"")) + assertEquals(In.Header, TestJson.decodeFromString("\"header\"")) + assertEquals(In.Query, TestJson.decodeFromString("\"query\"")) + } + + @Test + fun testAgentCardSignatureSerialization() { + val signature = AgentCardSignature( + `protected` = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9", + signature = "signature-data-here", + header = mapOf("kid" to kotlinx.serialization.json.JsonPrimitive("key-id-123")) + ) + + //language=JSON + val expectedJson = """ + { + "protected": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9", + "signature": "signature-data-here", + "header": { + "kid": "key-id-123" + } + } + """.trimIndent() + + val json = TestJson.encodeToString(signature) + assertEquals(expectedJson, json) + + val deserialized = TestJson.decodeFromString(json) + assertEquals(signature, deserialized) + } +} diff --git a/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/transport/TransportSerializationTest.kt b/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/transport/TransportSerializationTest.kt new file mode 100644 index 0000000000..903357f697 --- /dev/null +++ b/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/transport/TransportSerializationTest.kt @@ -0,0 +1,34 @@ +package ai.koog.a2a.transport + +import kotlinx.serialization.json.Json +import kotlin.test.Test +import kotlin.test.assertEquals + +class TransportSerializationTest { + @Suppress("PrivatePropertyName") + private val TestJson = Json + + @Test + fun testRequestIdStringId() { + val requestId = RequestId.StringId("test-id") + val requestIdJson = """"test-id"""" + + val serialized = TestJson.encodeToString(requestId) + assertEquals(requestIdJson, serialized) + + val deserialized = TestJson.decodeFromString(requestIdJson) + assertEquals(requestId, deserialized) + } + + @Test + fun testRequestIdNumberId() { + val requestId = RequestId.NumberId(123L) + val requestIdJson = """123""" + + val serialized = TestJson.encodeToString(requestId) + assertEquals(requestIdJson, serialized) + + val deserialized = TestJson.decodeFromString(requestIdJson) + assertEquals(requestId, deserialized) + } +} diff --git a/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/utils/KeyedMutexTest.kt b/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/utils/KeyedMutexTest.kt new file mode 100644 index 0000000000..0c23d6a8c6 --- /dev/null +++ b/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/utils/KeyedMutexTest.kt @@ -0,0 +1,416 @@ +package ai.koog.a2a.utils + +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.delay +import kotlinx.coroutines.sync.withLock +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.yield +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class KeyedMutexTest { + + @Test + fun basicLockUnlock() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + + assertFalse(mutex.isLocked(key), "Key should not be locked initially") + + mutex.lock(key) + assertTrue(mutex.isLocked(key), "Key should be locked after lock()") + + mutex.unlock(key) + assertFalse(mutex.isLocked(key), "Key should be unlocked after unlock()") + } + + @Test + fun basicTryLock() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + + assertTrue(mutex.tryLock(key), "tryLock should succeed on unlocked key") + assertTrue(mutex.isLocked(key), "Key should be locked after tryLock") + assertFalse(mutex.tryLock(key), "tryLock should fail on locked key") + + mutex.unlock(key) + assertFalse(mutex.isLocked(key), "Key should be unlocked after unlock") + } + + @Test + fun withLockConvenienceFunction() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + var executed = false + + val result = mutex.withLock(key) { + assertTrue(mutex.isLocked(key), "Key should be locked inside withLock block") + executed = true + "result" + } + + assertEquals("result", result, "withLock should return block result") + assertTrue(executed, "Block should have been executed") + assertFalse(mutex.isLocked(key), "Key should be unlocked after withLock") + } + + @Test + fun differentKeysCanBeLocked() = runTest { + val mutex = KeyedMutex() + val key1 = "key1" + val key2 = "key2" + + mutex.lock(key1) + mutex.lock(key2) + + assertTrue(mutex.isLocked(key1), "Key1 should be locked") + assertTrue(mutex.isLocked(key2), "Key2 should be locked") + + mutex.unlock(key1) + assertFalse(mutex.isLocked(key1), "Key1 should be unlocked") + assertTrue(mutex.isLocked(key2), "Key2 should still be locked") + + mutex.unlock(key2) + assertFalse(mutex.isLocked(key2), "Key2 should be unlocked") + } + + @Test + fun concurrentAccessSameKey() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + val results = mutableListOf() + val accessOrder = mutableListOf() + + val jobs = (1..3).map { i -> + this.async { + mutex.withLock(key) { + accessOrder.add(i) + delay(10) // Simulate some work + results.add(i) + } + } + } + + jobs.awaitAll() + + assertEquals(3, results.size, "All coroutines should complete") + assertEquals(accessOrder, results, "Access order should match completion order due to mutex") + assertFalse(mutex.isLocked(key), "Key should be unlocked after all operations") + } + + @Test + fun parallelAccessDifferentKeys() = runTest { + val mutex = KeyedMutex() + val results = mutableListOf() + val completionOrder = mutableListOf() + + // Use different keys with different work durations to verify parallelism + val keyWorkMap = mapOf( + "fast" to 10L, + "medium" to 15L, + "slow" to 25L + ) + + val jobs = keyWorkMap.map { (key, workDuration) -> + this.async { + mutex.withLock(key) { + results.add("$key-started") + delay(workDuration) + results.add("$key-finished") + completionOrder.add(key) + } + } + } + + jobs.awaitAll() + + assertEquals(6, results.size, "All operations should complete") + assertEquals(3, completionOrder.size, "All keys should complete") + + // If running in parallel, fast operations should finish before slow ones + // even if they started later. Check that "fast" completes before "slow" + val fastIndex = completionOrder.indexOf("fast") + val slowIndex = completionOrder.indexOf("slow") + assertTrue( + fastIndex < slowIndex, + "Fast operation should complete before slow operation if running in parallel. " + + "Completion order: $completionOrder" + ) + } + + @Test + fun unlockWithoutLockThrowsException() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + + assertFailsWith("Should throw when unlocking non-existent key") { + mutex.unlock(key) + } + } + + @Test + fun unlockAfterAlreadyUnlockedThrowsException() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + + mutex.lock(key) + mutex.unlock(key) + + assertFailsWith("Should throw when unlocking already unlocked key") { + mutex.unlock(key) + } + } + + @Test + fun ownerParameterEnforcement() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + val owner1 = "owner1" + val owner2 = "owner2" + + mutex.lock(key, owner1) + + assertFailsWith("Should throw when unlocking with wrong owner") { + mutex.unlock(key, owner2) + } + + // Should succeed with correct owner + mutex.unlock(key, owner1) + assertFalse(mutex.isLocked(key), "Key should be unlocked") + } + + @Test + fun tryLockWithOwnerParameterEnforcement() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + val owner1 = "owner1" + val owner2 = "owner2" + + assertTrue(mutex.tryLock(key, owner1), "First tryLock should succeed") + assertFalse(mutex.tryLock(key, owner2), "Second tryLock with different owner should fail") + + mutex.unlock(key, owner1) + assertTrue(mutex.tryLock(key, owner2), "tryLock should succeed after unlock") + mutex.unlock(key, owner2) + } + + @Test + fun exceptionDuringLockRollsBackRefCount() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + + // First, acquire the lock to create an entry + mutex.lock(key) + assertTrue(mutex.isLocked(key), "Key should be locked") + + // Start another coroutine that will try to lock and fail + val job = this.async { + try { + // This should wait since the key is already locked + mutex.lock(key, "owner") + // If we somehow get here, unlock to clean up + mutex.unlock(key, "owner") + } catch (e: Exception) { + // Expected to be cancelled + } + } + + yield() // Let the other coroutine start waiting + + // Cancel the waiting coroutine to simulate an exception during lock + job.cancel() + + // Unlock the original lock + mutex.unlock(key) + + // The entry should be cleaned up properly despite the cancelled waiter + assertFalse(mutex.isLocked(key), "Key should be unlocked and cleaned up") + + // Should be able to lock again normally + assertTrue(mutex.tryLock(key), "Should be able to lock after cleanup") + mutex.unlock(key) + } + + @Test + fun memoryLeakPrevention() = runTest { + val mutex = KeyedMutex() + val keys = (1..100).map { "key$it" } + + // Lock and unlock many keys + keys.forEach { key -> + mutex.withLock(key) { + // Do nothing, just acquire and release + } + } + + // All keys should be unlocked and entries cleaned up + keys.forEach { key -> + assertFalse(mutex.isLocked(key), "Key $key should be unlocked") + } + + // Test with concurrent access to same key + repeat(50) { + val key = "shared-key" + val jobs = (1..10).map { + async { + mutex.withLock(key) { + delay(1) + } + } + } + jobs.awaitAll() + assertFalse(mutex.isLocked(key), "Shared key should be unlocked after round $it") + } + } + + @Test + fun highConcurrencyStressTest() = runTest { + val mutex = KeyedMutex() + val sharedCounter = mutableMapOf() + val keys = listOf("key1", "key2", "key3") + + // Initialize counters + keys.forEach { key -> sharedCounter[key] = 0 } + + val jobs = (1..300).map { i -> + this.async { + val key = keys[i % keys.size] + mutex.withLock(key) { + val current = sharedCounter[key]!! + yield() // Force context switch to test race conditions + sharedCounter[key] = current + 1 + } + } + } + + jobs.awaitAll() + + // Verify that each counter was incremented exactly the right number of times + keys.forEach { key -> + val expected = 300 / keys.size + assertEquals(expected, sharedCounter[key], "Counter for $key should be exactly $expected") + assertFalse(mutex.isLocked(key), "Key $key should be unlocked") + } + } + + @Test + fun isLockedObservationOnly() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + + // isLocked should work for non-existent keys + assertFalse(mutex.isLocked(key), "Non-existent key should not be locked") + + mutex.lock(key) + assertTrue(mutex.isLocked(key), "Locked key should return true") + + // isLocked should be usable from concurrent contexts + val job = this.async { + mutex.isLocked(key) + } + + assertTrue(job.await(), "isLocked should work from concurrent context") + mutex.unlock(key) + } + + @Test + fun holdsLockOwnershipCheck() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + val owner1 = "owner1" + val owner2 = "owner2" + + // holdsLock should return false for non-existent keys + assertFalse(mutex.holdsLock(key, owner1), "Non-existent key should not be held by any owner") + + // Lock with owner1 + mutex.lock(key, owner1) + assertTrue(mutex.holdsLock(key, owner1), "owner1 should hold the lock") + assertFalse(mutex.holdsLock(key, owner2), "owner2 should not hold the lock") + + // After unlock, no one should hold the lock + mutex.unlock(key, owner1) + assertFalse(mutex.holdsLock(key, owner1), "owner1 should not hold the lock after unlock") + assertFalse(mutex.holdsLock(key, owner2), "owner2 should not hold the lock after unlock") + } + + @Test + fun holdsLockWithTryLock() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + val owner1 = "owner1" + val owner2 = "owner2" + + // Acquire lock with tryLock + assertTrue(mutex.tryLock(key, owner1), "tryLock should succeed") + assertTrue(mutex.holdsLock(key, owner1), "owner1 should hold the lock after tryLock") + assertFalse(mutex.holdsLock(key, owner2), "owner2 should not hold the lock") + + // Second tryLock with different owner should fail + assertFalse(mutex.tryLock(key, owner2), "tryLock with different owner should fail") + assertTrue(mutex.holdsLock(key, owner1), "owner1 should still hold the lock") + assertFalse(mutex.holdsLock(key, owner2), "owner2 should still not hold the lock") + + mutex.unlock(key, owner1) + assertFalse(mutex.holdsLock(key, owner1), "owner1 should not hold lock after unlock") + } + + @Test + fun holdsLockConcurrentAccess() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + val owner = "test-owner" + + mutex.lock(key, owner) + + // holdsLock should be usable from concurrent contexts + val job = this.async { + mutex.holdsLock(key, owner) + } + + assertTrue(job.await(), "holdsLock should work from concurrent context") + mutex.unlock(key, owner) + } + + @Test + fun reentrantBehaviorShouldDeadlock() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + var innerReached = false + + // This test verifies that the mutex is NOT re-entrant + // According to the documentation, it's "Not re-entrant for the same key within the same coroutine" + + val job = this.async { + mutex.withLock(key) { + try { + // This should deadlock since we're trying to lock the same key again + // in the same coroutine. We use a timeout to detect this. + mutex.withLock(key) { + innerReached = true + } + } catch (e: Exception) { + // May throw if timeout or cancellation occurs + } + } + } + + // Give it a short time to potentially complete + delay(100) + + // The job should still be running (deadlocked) + assertFalse(job.isCompleted, "Re-entrant lock should deadlock") + assertFalse(innerReached, "Inner block should not be reached") + + // Clean up by cancelling + job.cancel() + + // Verify the key gets cleaned up properly after cancellation + delay(10) + assertFalse(mutex.isLocked(key), "Key should be cleaned up after cancellation") + } +} diff --git a/a2a/a2a-server/Module.md b/a2a/a2a-server/Module.md new file mode 100644 index 0000000000..8c7536be5b --- /dev/null +++ b/a2a/a2a-server/Module.md @@ -0,0 +1,4 @@ +# Module a2a-server + +Server-side implementation for hosting A2A protocol agents. Provides the A2AServer class, AgentExecutor interface for implementing agent logic, +session management, and storage abstractions with for tasks, sessions, and push subscriptions. diff --git a/a2a/a2a-server/build.gradle.kts b/a2a/a2a-server/build.gradle.kts new file mode 100644 index 0000000000..5479619ece --- /dev/null +++ b/a2a/a2a-server/build.gradle.kts @@ -0,0 +1,59 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + sourceSets { + commonMain { + dependencies { + api(project(":a2a:a2a-core")) + api(libs.kotlinx.serialization.json) + api(libs.kotlinx.coroutines.core) + api(libs.ktor.client.core) + api(libs.ktor.client.content.negotiation) + api(libs.ktor.serialization.kotlinx.json) + implementation(libs.oshai.kotlin.logging) + } + } + + commonTest { + dependencies { + implementation(project(":a2a:a2a-test")) + implementation(kotlin("test")) + implementation(libs.kotlinx.coroutines.test) + implementation(libs.kotest.assertions) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + implementation(project(":a2a:a2a-test")) + implementation(project(":a2a:a2a-client")) + implementation(project(":a2a:a2a-transport:a2a-transport-client-jsonrpc-http")) + implementation(project(":a2a:a2a-transport:a2a-transport-server-jsonrpc-http")) + + implementation(libs.ktor.client.cio) + implementation(libs.ktor.client.logging) + implementation(libs.ktor.server.netty) + runtimeOnly(libs.logback.classic) + } + } + + jsTest { + dependencies { + implementation(kotlin("test-js")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt new file mode 100644 index 0000000000..75bc2d49e9 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt @@ -0,0 +1,680 @@ +package ai.koog.a2a.server + +import ai.koog.a2a.exceptions.A2AAuthenticatedExtendedCardNotConfiguredException +import ai.koog.a2a.exceptions.A2AInvalidParamsException +import ai.koog.a2a.exceptions.A2APushNotificationNotSupportedException +import ai.koog.a2a.exceptions.A2ATaskNotCancelableException +import ai.koog.a2a.exceptions.A2ATaskNotFoundException +import ai.koog.a2a.exceptions.A2AUnsupportedOperationException +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.Event +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskEvent +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskPushNotificationConfig +import ai.koog.a2a.model.TaskPushNotificationConfigParams +import ai.koog.a2a.model.TaskQueryParams +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.server.agent.AgentExecutor +import ai.koog.a2a.server.messages.ContextMessageStorage +import ai.koog.a2a.server.messages.InMemoryMessageStorage +import ai.koog.a2a.server.messages.MessageStorage +import ai.koog.a2a.server.notifications.PushNotificationConfigStorage +import ai.koog.a2a.server.notifications.PushNotificationSender +import ai.koog.a2a.server.session.IdGenerator +import ai.koog.a2a.server.session.LazySession +import ai.koog.a2a.server.session.RequestContext +import ai.koog.a2a.server.session.SessionEventProcessor +import ai.koog.a2a.server.session.SessionManager +import ai.koog.a2a.server.session.UuidIdGenerator +import ai.koog.a2a.server.tasks.ContextTaskStorage +import ai.koog.a2a.server.tasks.InMemoryTaskStorage +import ai.koog.a2a.server.tasks.TaskStorage +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.RequestHandler +import ai.koog.a2a.transport.Response +import ai.koog.a2a.transport.ServerCallContext +import ai.koog.a2a.utils.KeyedMutex +import ai.koog.a2a.utils.withLock +import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CompletableJob +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Job +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.channelFlow +import kotlinx.coroutines.flow.firstOrNull +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.lastOrNull +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.onStart +import kotlinx.coroutines.launch + +/** + * Default implementation of A2A server responsible for handling requests from A2A clients according to the + * [A2A protocol specification](https://a2a-protocol.org/latest/specification/). + * + * This class provides a complete implementation of all A2A protocol methods including message sending, task management, + * and push notifications. However, it **does not** provide any authorization, authentication, or custom validation + * logic. For production use, you should extend this class and add your own security and business logic. + * + * The A2AServer orchestrates the interaction between transport layer, agent executor, and storage components: + * - Receives requests from [RequestHandler] interface methods + * - Delegates actual agent logic to [AgentExecutor] + * - Delegates event processing and persisting to [SessionEventProcessor] + * - Delegates session management to [SessionManager] + * - Handles push notifications using [PushNotificationSender] + * + * ## Production usage with authorization + * + * For production deployments, extend this class to add authorization and custom validation. You can leverage + * [ServerCallContext.state] to pass user-defined data through the request pipeline to the [AgentExecutor]: + * + * ```kotlin + * // Define your user data and state keys + * data class AuthenticatedUser(val id: String, val permissions: Set) + * + * object AuthStateKeys { + * val USER = StateKey("authenticated_user") + * } + * + * // Extend A2AServer with authorization + * class AuthorizedA2AServer( + * agentExecutor: AgentExecutor, + * agentCard: AgentCard, + * agentCardExtended: AgentCard? = null, + * taskStorage: TaskStorage = InMemoryTaskStorage(), + * messageStorage: MessageStorage = InMemoryMessageStorage(), + * private val authService: AuthService, // Your auth service + * ) : A2AServer( + * agentExecutor = agentExecutor, + * agentCard = agentCard, + * agentCardExtended = agentCardExtended, + * taskStorage = taskStorage, + * messageStorage = messageStorage, + * ) { + * // Helper method for common auth pattern + * private suspend fun authenticateAndAuthorize( + * ctx: ServerCallContext, + * requiredPermission: String + * ): AuthenticatedUser { + * val token = ctx.headers["Authorization"]?.firstOrNull() + * ?: throw A2AInvalidParamsException("Missing authorization token") + * + * val user = authService.authenticate(token) + * ?: throw A2AInvalidParamsException("Invalid token") + * + * if (!user.permissions.contains(requiredPermission)) { + * throw A2AUnsupportedOperationException("Insufficient permissions") + * } + * + * return user + * } + * + * override suspend fun onSendMessage( + * request: Request, + * ctx: ServerCallContext + * ): Response { + * val user = authenticateAndAuthorize(ctx, requiredPermission = "send_message") + * + * // Pass user data to the agent executor via context state + * val enrichedCtx = ctx.copy( + * state = ctx.state + (AuthStateKeys.USER to user) + * ) + * + * // Delegate to parent implementation with enriched context + * return super.onSendMessage(request, enrichedCtx) + * } + * + * override suspend fun onGetTask( + * request: Request, + * ctx: ServerCallContext + * ): Response { + * val user = authenticateAndAuthorize(ctx, requiredPermission = "read_task") + * + * // Optionally validate task ownership + * val task = taskStorage.get(request.data.id, historyLength = 0, includeArtifacts = false) + * if (task?.metadata?.get("owner_id") != user.id) { + * throw A2AUnsupportedOperationException("Access denied to task ${request.data.id}") + * } + * + * val enrichedCtx = ctx.copy( + * state = ctx.state + (AuthStateKeys.USER to user) + * ) + * + * return super.onGetTask(request, enrichedCtx) + * } + * } + * ``` + * + * ## Accessing user data in AgentExecutor + * + * The authenticated user data passed through [ServerCallContext.state] can be accessed in your [AgentExecutor]: + * + * ```kotlin + * class MyAgentExecutor : AgentExecutor { + * override suspend fun execute( + * context: RequestContext, + * eventProcessor: SessionEventProcessor + * ) { + * // Retrieve authenticated user from the context + * val user = context.callContext.getFromState(AuthStateKeys.USER) + * + * // Use user information for personalized agent behavior + * eventProcessor.sendMessage( + * Message( + * role = Role.Agent, + * contextId = context.contextId, + * parts = listOf( + * TextPart("Hello ${user.id}, how can I help you today?") + * ) + * ) + * ) + * } + * + * override suspend fun cancel( + * context: RequestContext, + * eventProcessor: SessionEventProcessor, + * agentJob: Deferred?, + * ) { + * agentJob?.cancelAndJoin() + * // Access user data for audit logging + * val user = context.callContext.getFromStateOrNull(AuthStateKeys.USER) + * log.info("Task ${context.taskId} canceled by user ${user?.id}") + * } + * } + * ``` + * + * ## Complete server setup example + * + * Here's a complete example of setting up and running an A2A server from scratch: + * + * ```kotlin + * // 1. Create your agent executor with business logic + * val agentExecutor = object : AgentExecutor { + * override suspend fun execute( + * context: RequestContext, + * eventProcessor: SessionEventProcessor + * ) { + * val userMessage = context.params.message + * + * // Process the message and create a task + * // Send task creation event + * eventProcessor.sendTaskEvent( + * Task( + * id = context.taskId, + * contextId = context.contextId, + * status = TaskStatus( + * state = TaskState.Working, + * // Mark this message as belonging to the created task + * message = message.copy(taskId = context.taskId) + * timestamp = Clock.System.now() + * ), + * ) + * ) + * + * // Simulate some work + * delay(1000) + * + * // Mark task as completed + * eventProcessor.sendTaskEvent( + * TaskStatusUpdateEvent( + * taskId = context.taskId, + * contextId = context.contextId, + * status = TaskStatus( + * state = TaskState.Completed, + * message = Message( + * role = Role.Agent, + * contextId = context.contextId, + * taskId = context.taskId, + * parts = listOf( + * TextPart("Task completed successfully!") + * ) + * ), + * timestamp = Clock.System.now() + * ), + * final = true + * ) + * ) + * } + * } + * + * // 2. Define your agent card describing capabilities + * val agentCard = AgentCard(...) + * + * // 3. Create the A2AServer instance (or your extended version) + * val a2aServer = A2AServer(...) + * + * // 4. Create HTTP JSON-RPC transport + * val transport = HttpJSONRPCServerTransport( + * requestHandler = a2aServer + * ) + * + * // 5. Start the server + * transport.start( + * engineFactory = Netty, + * port = 8080, + * path = "/a2a", + * wait = true, + * agentCard = agentCard, + * agentCardPath = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH + * ) + * ``` + * + * ## Integration with existing Ktor application + * + * If you have an existing Ktor application, you can integrate the A2A server as a route: + * + * ```kotlin + * val agentCard = AgentCard(...) + * val a2aServer = A2AServer(...) + * val transport = HttpJSONRPCServerTransport(a2aServer) + * + * embeddedServer(Netty, port = 8080) { + * install(SSE) // Required for streaming support + * + * // To serve AgentCard instance + * install(ContentNegotiation) { + * json(Json) + * } + * + * routing { + * // Your existing routes... + * + * // Mount A2A JSON-RPC server transport + * transport.transportRoutes(this, "/a2a") + * + * // Serve agent card + * get("/a2a/agent-card.json") { + * call.respond(agentCard) + * } + * } + * }.start(wait = true) + * ``` + * + * @param agentExecutor The executor containing the core agent logic + * @param agentCard The agent card describing this agent's capabilities and metadata + * @param agentCardExtended Optional extended agent card for authenticated requests + * @param taskStorage Storage implementation for persisting tasks (defaults to [InMemoryTaskStorage]) + * @param messageStorage Storage implementation for persisting messages (defaults to [InMemoryMessageStorage]) + * @param pushConfigStorage Optional storage for push notification configurations (defaults to `null`) + * @param pushSender Optional push notification sender implementation (defaults to `null`) + * @param idGenerator Generator for new task and context IDs (defaults to [UuidIdGenerator]) + * @param coroutineScope Scope for managing all sessions, agent jobs, event processing, etc. + * + * @see AgentExecutor for implementing agent business logic + * @see TaskStorage for persisting tasks + * @see MessageStorage for persisting messages + * @see PushNotificationConfigStorage for persisting push notification configurations + * @see PushNotificationSender for sending push notifications + * @see ServerCallContext for passing custom state through the request pipeline + */ +public open class A2AServer( + protected val agentExecutor: AgentExecutor, + protected val agentCard: AgentCard, + protected val agentCardExtended: AgentCard? = null, + protected val taskStorage: TaskStorage = InMemoryTaskStorage(), + protected val messageStorage: MessageStorage = InMemoryMessageStorage(), + protected val pushConfigStorage: PushNotificationConfigStorage? = null, + protected val pushSender: PushNotificationSender? = null, + protected val idGenerator: IdGenerator = UuidIdGenerator, + protected val coroutineScope: CoroutineScope = CoroutineScope(SupervisorJob()), +) : RequestHandler { + private companion object { + private val logger = KotlinLogging.logger {} + } + + /** + * Mutex for locking specific tasks by their IDs. + */ + protected val tasksMutex: KeyedMutex = KeyedMutex() + + /** + * Special cancellation key for additional set of task cancellation locks. + */ + protected fun cancelKey(taskId: String): String = "cancel:$taskId" + + protected open val sessionManager: SessionManager = SessionManager( + coroutineScope = coroutineScope, + cancelKey = ::cancelKey, + tasksMutex = tasksMutex, + taskStorage = taskStorage, + pushConfigStorage = pushConfigStorage, + pushSender = pushSender, + ) + + override suspend fun onGetAuthenticatedExtendedAgentCard( + request: Request, + ctx: ServerCallContext + ): Response { + if (agentCard.supportsAuthenticatedExtendedCard != true) { + throw A2AAuthenticatedExtendedCardNotConfiguredException("Extended agent card is not supported") + } + + // Default server implementation does not provide authorization, return extended card directly if present + return Response( + data = agentCardExtended + ?: throw A2AAuthenticatedExtendedCardNotConfiguredException("Extended agent card is supported but not configured on the server"), + id = request.id + ) + } + + /** + * Common logic for handling incoming messages and starting the agent execution. + * Does all the setup and validation, creates event stream. + * + * @return A stream of events from the agent + */ + protected open fun onSendMessageCommon( + request: Request, + ctx: ServerCallContext + ): Flow> = channelFlow { + val message = request.data.message + + if (message.parts.isEmpty()) { + throw A2AInvalidParamsException("Empty message parts are not supported") + } + + val taskId = message.taskId ?: idGenerator.generateTaskId(message) + + val (session, monitoringStarted) = tasksMutex.withLock(taskId) { + // If there's a currently running session for the same task, wait for it to finish. + sessionManager.getSession(taskId)?.join() + + // Check if message links to a task. + val task: Task? = message.taskId?.let { taskId -> + // Check if the specified task exists + val task = taskStorage.get(taskId, historyLength = 0, includeArtifacts = false) + ?: throw A2ATaskNotFoundException("Task '$taskId' not found") + + task + } + + // Create event processor for the session based on the input data. + val eventProcessor = SessionEventProcessor( + contextId = task?.contextId + ?: message.contextId + ?: idGenerator.generateContextId(message), + taskId = taskId, + taskStorage = taskStorage, + ) + + // Create request context based on the request information. + val requestContext = RequestContext( + callContext = ctx, + params = request.data, + taskStorage = ContextTaskStorage(eventProcessor.contextId, taskStorage), + messageStorage = ContextMessageStorage(eventProcessor.contextId, messageStorage), + contextId = eventProcessor.contextId, + taskId = eventProcessor.taskId, + task = task, + ) + + LazySession( + coroutineScope = coroutineScope, + eventProcessor = eventProcessor, + ) { + agentExecutor.execute(requestContext, eventProcessor) + }.let { + it to sessionManager.addSession(it) + } + } + + // Signal that event collection is started + val eventCollectinStarted: CompletableJob = Job() + // Signal that all events have been collected + val eventCollectionFinished: CompletableJob = Job() + + // Subscribe to events stream and start emitting them. + launch { + session.events + .onStart { + eventCollectinStarted.complete() + } + .collect { event -> + send(Response(data = event, id = request.id)) + } + + eventCollectionFinished.complete() + } + + // Ensure event collection is setup to stream events in response. + eventCollectinStarted.join() + // Ensure monitoring is ready to monitor the session. + monitoringStarted.join() + + /* + Start the session to execute the agent and wait for it to finish. + Using await here to propagate any exceptions thrown by the agent execution. + */ + session.agentJob.await() + // Make sure all events have been collected and sent + eventCollectionFinished.join() + } + + override suspend fun onSendMessage( + request: Request, + ctx: ServerCallContext + ): Response { + val messageConfiguration = request.data.configuration + // Reusing streaming logic here, because it's essentially the same, only we need some particular event from the stream + val eventStream = onSendMessageCommon(request, ctx) + + val event = if (messageConfiguration?.blocking == true) { + // If blocking is requested, attempt to wait for the last event, until the current turn of the agent execution is finished. + eventStream.lastOrNull() + } else { + eventStream.firstOrNull() + } ?: throw IllegalStateException("Can't get response from the agent: event stream is empty") + + return when (val eventData = event.data) { + is Message -> Response(data = eventData, id = event.id) + is TaskEvent -> + taskStorage + .get( + eventData.taskId, + historyLength = messageConfiguration?.historyLength, + includeArtifacts = true + ) + ?.let { Response(data = it, id = event.id) } + ?: throw A2ATaskNotFoundException("Task '${eventData.taskId}' not found after the agent execution") + } + } + + override fun onSendMessageStreaming( + request: Request, + ctx: ServerCallContext + ): Flow> = flow { + checkStreamingSupport() + onSendMessageCommon(request, ctx).collect(this) + } + + override suspend fun onGetTask( + request: Request, + ctx: ServerCallContext + ): Response { + val taskParams = request.data + + return Response( + data = taskStorage.get(taskParams.id, historyLength = taskParams.historyLength, includeArtifacts = false) + ?: throw A2ATaskNotFoundException("Task '${taskParams.id}' not found"), + id = request.id, + ) + } + + override suspend fun onCancelTask( + request: Request, + ctx: ServerCallContext + ): Response { + val taskParams = request.data + val taskId = taskParams.id + + /* + Cancellation uses two lock levels. The first is the standard task lock. + If it’s already held by another request, ignore it because cancellation takes priority. + If it’s not held, acquire it to block new requests while the cancellation is in progress. + */ + val lockAcquired = tasksMutex.tryLock(taskId) + + return try { + /* + The second lock is a per-task cancellation lock. + It’s always taken during cancellation to serialize cancel operations and allow them to proceed even if the + regular task lock is held. It prevents overlapping cancels and delays session teardown so the event processor + isn’t closed immediately after the agent job is canceled. This allows the cancel handler to emit additional + cancellation events through the same processor and session, ensuring that existing subscribers receive all events. + */ + tasksMutex.withLock(cancelKey(taskId)) { + val session = sessionManager.getSession(taskParams.id) + + val task = taskStorage.get(taskParams.id, historyLength = 0, includeArtifacts = true) + ?: throw A2ATaskNotFoundException("Task '${taskParams.id}' not found") + + // Task is not running, check if it's already in a terminal state. + if (session == null && task.status.state.terminal) { + throw A2ATaskNotCancelableException("Task '${taskParams.id}' is already in terminal state ${task.status.state}") + } + + val eventProcessor = session?.eventProcessor ?: SessionEventProcessor( + contextId = task.contextId, + taskId = task.id, + taskStorage = taskStorage, + ) + + // Create request context based on the request information. + val requestContext = RequestContext( + callContext = ctx, + params = request.data, + taskStorage = ContextTaskStorage(eventProcessor.contextId, taskStorage), + messageStorage = ContextMessageStorage(eventProcessor.contextId, messageStorage), + contextId = eventProcessor.contextId, + taskId = eventProcessor.taskId, + task = task, + ) + + // Attempt to cancel the agent execution and wait until it's finished. + agentExecutor.cancel(requestContext, eventProcessor, session?.agentJob) + + // Return the final task state. + Response( + data = taskStorage.get(taskParams.id, historyLength = 0, includeArtifacts = true) + ?.also { + if (it.status.state != TaskState.Canceled) { + throw A2ATaskNotCancelableException("Task '${taskParams.id}' was not canceled successfully, current state is ${it.status.state}") + } + } + ?: throw A2ATaskNotFoundException("Task '${taskParams.id}' not found"), + id = request.id, + ) + } + } finally { + if (lockAcquired) { + tasksMutex.unlock(taskId) + } + } + } + + override fun onResubscribeTask( + request: Request, + ctx: ServerCallContext + ): Flow> = flow { + checkStreamingSupport() + + val taskParams = request.data + val session = sessionManager.getSession(taskParams.id) ?: return@flow + + session.events + .map { event -> Response(data = event, id = request.id) } + .collect(this) + } + + override suspend fun onSetTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response { + val pushStorage = storageIfPushNotificationSupported() + val taskPushConfig = request.data + + pushStorage.save(taskPushConfig.taskId, taskPushConfig.pushNotificationConfig) + + return Response(data = taskPushConfig, id = request.id) + } + + override suspend fun onGetTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response { + val pushStorage = storageIfPushNotificationSupported() + val pushConfigParams = request.data + + val pushConfig = pushStorage.get(pushConfigParams.id, pushConfigParams.pushNotificationConfigId) + ?: throw NoSuchElementException("Can't find push notification config with id '${pushConfigParams.pushNotificationConfigId}' for task '${pushConfigParams.id}'") + + return Response( + data = TaskPushNotificationConfig( + taskId = pushConfigParams.id, + pushNotificationConfig = pushConfig + ), + id = request.id + ) + } + + override suspend fun onListTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response> { + val pushStorage = storageIfPushNotificationSupported() + val taskParams = request.data + + return Response( + data = pushStorage + .getAll(taskParams.id) + .map { TaskPushNotificationConfig(taskId = taskParams.id, pushNotificationConfig = it) }, + id = request.id + ) + } + + override suspend fun onDeleteTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response { + val pushStorage = storageIfPushNotificationSupported() + val taskPushConfigParams = request.data + + pushStorage.delete(taskPushConfigParams.id, taskPushConfigParams.pushNotificationConfigId) + + return Response(data = null, id = request.id) + } + + protected open fun checkStreamingSupport() { + if (agentCard.capabilities.streaming != true) { + throw A2AUnsupportedOperationException("Streaming is not supported by the server") + } + } + + protected open fun storageIfPushNotificationSupported(): PushNotificationConfigStorage { + if (agentCard.capabilities.pushNotifications != true) { + throw A2APushNotificationNotSupportedException("Push notifications are not supported by the server") + } + + if (pushConfigStorage == null) { + throw A2APushNotificationNotSupportedException("Push notifications are supported, but not configured on the server") + } + + return pushConfigStorage + } + + /** + * Cancels [coroutineScope] associated with this server, essentially cancelling all running jobs and sessions. + * + * @param cause Optional cause of the cancellation + */ + public open fun cancel(cause: CancellationException? = null) { + coroutineScope.cancel(cause) + } +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt new file mode 100644 index 0000000000..bd6e9c4a69 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt @@ -0,0 +1,137 @@ +package ai.koog.a2a.server.agent + +import ai.koog.a2a.exceptions.A2AContentTypeNotSupportedException +import ai.koog.a2a.exceptions.A2ATaskNotCancelableException +import ai.koog.a2a.exceptions.A2AUnsupportedOperationException +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.TaskEvent +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.server.session.RequestContext +import ai.koog.a2a.server.session.SessionEventProcessor +import kotlinx.coroutines.Deferred + +/** + * Implementations of this interface contain the core logic of the agent, + * executing actions based on requests and publishing updates to an event processor. + */ +public interface AgentExecutor { + /** + * Execute the agent's logic for a given request context. + * + * The agent should read necessary information from the [context] and publish [TaskEvent] or [Message] events to + * the [eventProcessor]. This method should return once the agent's execution for this request is complete or + * yields control (e.g., enters a [TaskState.InputRequired] state). + * + * All events must have context id from [RequestContext.contextId] and for task events task id from [RequestContext.taskId]. + * + * Can throw an exception if the input is invalid or the agent fails to execute the request. + * + * Example implementation: + * ```kotlin + * val userMessage = context.params.message + * + * // Process the message and create a task + * // Send task creation event + * eventProcessor.sendTaskEvent( + * Task( + * id = context.taskId, + * contextId = context.contextId, + * status = TaskStatus( + * state = TaskState.Working, + * // Mark this message as belonging to the created task + * message = userMessage.copy(taskId = context.taskId) + * timestamp = Clock.System.now() + * ), + * ) + * ) + * + * // Simulate some work + * delay(1000) + * + * // Mark task as completed + * eventProcessor.sendTaskEvent( + * TaskStatusUpdateEvent( + * taskId = context.taskId, + * contextId = context.contextId, + * status = TaskStatus( + * state = TaskState.Completed, + * message = Message( + * role = Role.Agent, + * contextId = context.contextId, + * taskId = context.taskId, + * parts = listOf( + * TextPart("Task completed successfully!") + * ) + * ), + * timestamp = Clock.System.now() + * ), + * final = true + * ) + * ) + * ``` + * + * @param context The context containing the necessary information and accessors for executing the agent. + * @param eventProcessor The event processor to publish events to. + * @throws Exception if something goes wrong during execution. Should prefer more specific exceptions when possible, + * e.g., [A2AContentTypeNotSupportedException], [A2AUnsupportedOperationException], etc. See full list of available + * A2A exceptions in [ai.koog.a2a.exceptions]. + */ + public suspend fun execute(context: RequestContext, eventProcessor: SessionEventProcessor) + + /** + * Request to cancel a task. + * + * Must throw an exception if the cancellation fails or is impossible. The executor should attempt to stop the task + * identified by the task id in the [context] or throw an exception if cancellation is not supported or not possible, + * e.g., [A2ATaskNotCancelableException]. + * + * Can also publish [TaskEvent]s to the [eventProcessor] to update the task state. Must ensure that the final + * task state will be [TaskState.Canceled], otherwise the task will not be considered canceled, and the requester will + * get [A2ATaskNotCancelableException]. + * + * **IMPORTANT**: This should execute quickly as it runs synchronously with the request. + * + * Default implementation throws [A2ATaskNotCancelableException], meaning cancellation is not supported by default. + * + * Example implementation: + * ```kotlin + * // Cancel agent execution job, if the agent is currently running, to terminate it. + * agentJob?.cancelAndJoin() + * // Send task cancellation event with custom message to event processor + * eventProcessor.sendTaskEvent( + * TaskStatusUpdateEvent( + * taskId = context.taskId, + * contextId = context.contextId, + * status = TaskStatus( + * state = TaskState.Canceled, + * message = Message( + * role = Role.Agent, + * taskId = context.taskId, + * contextId = context.contextId, + * parts = listOf( + * TextPart("Task was canceled by the user") + * ) + * ) + * ), + * final = true, + * ) + * ) + * ``` + * + * @param context The context containing the necessary information and accessors for executing the agent. + * @param eventProcessor The event processor to publish events to. + * @param agentJob Optional job executing the agent logic, if the agent is currently running. + * @throws Exception if something goes wrong during execution or the cancellation is impossible. Should prefer more + * specific exceptions if possible, e.g., [A2ATaskNotCancelableException], [A2AUnsupportedOperationException], etc. + * See full list of available A2A exceptions in [ai.koog.a2a.exceptions]. + */ + public suspend fun cancel( + context: RequestContext, + eventProcessor: SessionEventProcessor, + agentJob: Deferred?, + ) { + throw A2ATaskNotCancelableException("Cancellation is not supported") + } +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/exceptions/Exceptions.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/exceptions/Exceptions.kt new file mode 100644 index 0000000000..62d54a0a6a --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/exceptions/Exceptions.kt @@ -0,0 +1,26 @@ +package ai.koog.a2a.server.exceptions + +/** + * Indicates an error with task-related operations. + */ +public class TaskOperationException(message: String, cause: Throwable? = null) : Exception(message, cause) + +/** + * Indicates an error with message-related operations. + */ +public class MessageOperationException(message: String, cause: Throwable? = null) : Exception(message, cause) + +/** + * Indicates a failure in sending an event because it was invalid. + */ +public class InvalidEventException(message: String, cause: Throwable? = null) : Exception(message, cause) + +/** + * Indicates errors occurring during push notification operations. + */ +public class PushNotificationException(message: String, cause: Throwable? = null) : Exception(message, cause) + +/** + * Indicates a session is not in the active state. + */ +public class SessionNotActiveException(message: String, cause: Throwable? = null) : Exception(message, cause) diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/messages/InMemoryMessageStorage.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/messages/InMemoryMessageStorage.kt new file mode 100644 index 0000000000..6960c19e15 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/messages/InMemoryMessageStorage.kt @@ -0,0 +1,47 @@ +package ai.koog.a2a.server.messages + +import ai.koog.a2a.annotations.InternalA2AApi +import ai.koog.a2a.model.Message +import ai.koog.a2a.server.exceptions.MessageOperationException +import ai.koog.a2a.utils.RWLock + +/** + * In-memory implementation of [MessageStorage] using a thread-safe map. + * + * This implementation stores messages in memory grouped by context ID and provides + * concurrency safety through mutex locks. + */ +@OptIn(InternalA2AApi::class) +public class InMemoryMessageStorage : MessageStorage { + private val messagesByContext = mutableMapOf>() + private val rwLock = RWLock() + + override suspend fun save(message: Message): Unit = rwLock.withWriteLock { + val contextId = message.contextId + ?: throw MessageOperationException("Message must have a contextId to be saved") + + messagesByContext.getOrPut(contextId) { mutableListOf() }.add(message) + } + + override suspend fun getByContext(contextId: String): List = rwLock.withReadLock { + messagesByContext[contextId]?.toList() ?: emptyList() + } + + override suspend fun deleteByContext(contextId: String): Unit = rwLock.withReadLock { + messagesByContext -= contextId + } + + override suspend fun replaceByContext(contextId: String, messages: List): Unit = rwLock.withWriteLock { + // Validate that all messages have the correct contextId + val invalidMessages = messages.filter { it.contextId != contextId } + if (invalidMessages.isNotEmpty()) { + throw MessageOperationException( + "All messages must have contextId '$contextId', but found messages with different contextIds: " + + invalidMessages.map { it.contextId }.distinct().joinToString() + ) + } + + // Replace all messages for the context + messagesByContext[contextId] = messages.toMutableList() + } +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/messages/MessageStorage.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/messages/MessageStorage.kt new file mode 100644 index 0000000000..d722f13b38 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/messages/MessageStorage.kt @@ -0,0 +1,101 @@ +package ai.koog.a2a.server.messages + +import ai.koog.a2a.model.Message +import ai.koog.a2a.server.exceptions.MessageOperationException + +/** + * Storage interface for messages associated with a particular context. + * Can be used to keep track of a conversation history. + * + * Implementations must ensure concurrency safety. + */ +public interface MessageStorage { + /** + * Saves a message to the storage. + * + * @param message the message to save + * @throws MessageOperationException if the message cannot be saved + */ + public suspend fun save(message: Message) + + /** + * Retrieves all messages associated with the given context. + * + * @param contextId the context identifier + */ + public suspend fun getByContext(contextId: String): List + + /** + * Deletes all messages associated with the given context. + * + * @param contextId the context identifier + * @throws MessageOperationException if some messages cannot be deleted + */ + public suspend fun deleteByContext(contextId: String) + + /** + * Replaces all messages associated with the given context. + * + * @param contextId the context identifier + * @throws MessageOperationException if context cannot be replaced + */ + public suspend fun replaceByContext(contextId: String, messages: List) +} + +/** + * Wrapper class around [MessageStorage] for interacting with a particular context. + * Provides convenience methods and verification for context ID. + * + * @param contextId the context identifier + * @param messageStorage the underlying [MessageStorage] implementation + */ +public class ContextMessageStorage( + private val contextId: String, + private val messageStorage: MessageStorage, +) { + /** + * Saves a message to the storage. + * + * @param message the message to save + * @see [MessageStorage.save] + */ + public suspend fun save(message: Message) { + require(message.contextId == contextId) { + "contextId of message must be same as current contextId" + } + + messageStorage.save(message) + } + + /** + * Retrieves all messages associated with the current context. + * + * @see [MessageStorage.getByContext] + */ + public suspend fun getAll(): List { + return messageStorage.getByContext(contextId) + } + + /** + * Deletes all messages associated with the current context. + * + * @see [MessageStorage.deleteByContext] + */ + public suspend fun deleteAll() { + messageStorage.deleteByContext(contextId) + } + + /** + * Replaces all messages associated with the current context. + * + * @param messages the list of messages to replace + * @see [MessageStorage.replaceByContext] + */ + public suspend fun replaceAll(messages: List) { + require(messages.all { it.contextId == contextId }) { + "contextId of messages must be same as current contextId" + } + + messageStorage.replaceByContext(contextId, messages) + } +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/InMemoryPushNotificationConfigStorage.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/InMemoryPushNotificationConfigStorage.kt new file mode 100644 index 0000000000..ca197fc5e3 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/InMemoryPushNotificationConfigStorage.kt @@ -0,0 +1,50 @@ +package ai.koog.a2a.server.notifications + +import ai.koog.a2a.annotations.InternalA2AApi +import ai.koog.a2a.model.PushNotificationConfig +import ai.koog.a2a.utils.RWLock + +/** + * In-memory implementation of [PushNotificationConfigStorage] using a thread-safe map. + * + * This implementation stores push notification configurations in memory grouped by task ID + * and provides concurrency safety through read-write locks. + */ +@OptIn(InternalA2AApi::class) +public class InMemoryPushNotificationConfigStorage : PushNotificationConfigStorage { + private val configsByTaskId = mutableMapOf>() + private val rwLock = RWLock() + + override suspend fun save(taskId: String, pushNotificationConfig: PushNotificationConfig): Unit = + rwLock.withWriteLock { + val configId = pushNotificationConfig.id + val taskConfigs = configsByTaskId.getOrPut(taskId) { mutableMapOf() } + + taskConfigs[configId] = pushNotificationConfig + } + + override suspend fun getAll(taskId: String): List = rwLock.withReadLock { + configsByTaskId[taskId]?.values?.toList() ?: emptyList() + } + + override suspend fun get(taskId: String, configId: String?): PushNotificationConfig? = rwLock.withReadLock { + configsByTaskId[taskId]?.get(configId) + } + + override suspend fun delete(taskId: String, configId: String?): Unit = rwLock.withWriteLock { + if (configId == null) { + // Delete all configurations for the task + configsByTaskId.remove(taskId) + } else { + configsByTaskId[taskId]?.let { taskConfigs -> + // Delete specific configuration + taskConfigs.remove(configId) + + // Clean up empty task entry + if (taskConfigs.isEmpty()) { + configsByTaskId.remove(taskId) + } + } + } + } +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/PushNotificationConfigStorage.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/PushNotificationConfigStorage.kt new file mode 100644 index 0000000000..4f5ea18053 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/PushNotificationConfigStorage.kt @@ -0,0 +1,43 @@ +package ai.koog.a2a.server.notifications + +import ai.koog.a2a.model.PushNotificationConfig +import ai.koog.a2a.server.exceptions.PushNotificationException + +/** + * Interface for managing the storage of push notification configurations associated with task updates. + * + * Implementations must ensure concurrency safety. + */ +public interface PushNotificationConfigStorage { + /** + * Saves a push notification configuration for a specified task ID. + * + * @param taskId Task ID for which to save the configration. + * @param pushNotificationConfig Config instance containing the details of the notification setup. + * @throws PushNotificationException if the configuration cannot be saved. + */ + public suspend fun save(taskId: String, pushNotificationConfig: PushNotificationConfig) + + /** + * Retrieves a push notification configuration for a specified task ID and configuration ID. + * @param taskId Task ID for which to retrieve the configuration. + * @param configId Configuration ID for which to retrieve the configuration. + */ + public suspend fun get(taskId: String, configId: String?): PushNotificationConfig? + + /** + * Retrieves all push notification configurations associated with the given task ID. + * + * @param taskId Task ID for which to retrieve the configurations. + */ + public suspend fun getAll(taskId: String): List + + /** + * Deletes all push notification configurations for a specified task ID, optionally deleting a specific configuration instead. + * + * @param taskId Task ID for which to delete the configurations. + * @param configId Optional configuration ID to delete. Defaults to `null`, meaning all configurations for the task will be deleted. + * @throws PushNotificationException if the configuration cannot be deleted. + */ + public suspend fun delete(taskId: String, configId: String? = null) +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/PushNotificationSender.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/PushNotificationSender.kt new file mode 100644 index 0000000000..2be5fcbd74 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/PushNotificationSender.kt @@ -0,0 +1,26 @@ +package ai.koog.a2a.server.notifications + +import ai.koog.a2a.model.PushNotificationConfig +import ai.koog.a2a.model.Task + +/** + * Interface for sending push notifications. + * + * [More info on push notifications in specification](https://a2a-protocol.org/latest/specification/#95-push-notification-setup-and-usage) + */ +public interface PushNotificationSender { + public companion object { + /** + * Represents a custom optional HTTP header used to include a token for authenticating A2A notifications. + */ + public const val A2A_NOTIFICATION_TOKEN_HEADER: String = "X-A2A-Notification-Token" + } + + /** + * Sends a push notification. + * + * @param config Push notification configuration. + * @param task Task object to send in the notification. + */ + public suspend fun send(config: PushNotificationConfig, task: Task) +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/SimplePushNotificationSender.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/SimplePushNotificationSender.kt new file mode 100644 index 0000000000..2593b453c9 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/SimplePushNotificationSender.kt @@ -0,0 +1,67 @@ +package ai.koog.a2a.server.notifications + +import ai.koog.a2a.model.PushNotificationConfig +import ai.koog.a2a.model.Task +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.client.HttpClient +import io.ktor.client.plugins.contentnegotiation.ContentNegotiation +import io.ktor.client.request.post +import io.ktor.client.request.setBody +import io.ktor.http.HttpHeaders +import io.ktor.serialization.kotlinx.json.json +import io.ktor.utils.io.core.Closeable +import kotlinx.serialization.json.Json + +/** + * Simple implementation of a notification sender. + * Doesn't perform any configuration validation. + * Always takes the first authentication scheme provided in [PushNotificationConfig.authentication] + */ +public class SimplePushNotificationSender( + baseHttpClient: HttpClient, + json: Json = Json, +) : PushNotificationSender, Closeable { + private companion object { + private val logger = KotlinLogging.logger {} + } + + private val httpClient = baseHttpClient.config { + install(ContentNegotiation) { + json(json) + } + + expectSuccess = true + } + + override suspend fun send(config: PushNotificationConfig, task: Task) { + try { + logger.debug { "Sending push notification configId='${config.id} for taskId='${task.id}'" } + + httpClient.post(config.url) { + config.authentication?.let { auth -> + // Simple sender always takes the first scheme from the list + val schema = auth.schemes.firstOrNull() + val credentials = auth.credentials + + if (schema != null && credentials != null) { + headers[HttpHeaders.Authorization] = "$schema $credentials" + } + } + + config.token?.let { token -> + headers[PushNotificationSender.A2A_NOTIFICATION_TOKEN_HEADER] = token + } + + setBody(task) + } + + logger.debug { "Sent push notification successfully configId='${config.id} for taskId='${task.id}'" } + } catch (e: Exception) { + logger.warn(e) { "Failed to send push notification configId='${config.id} for taskId='${task.id}'" } + } + } + + override fun close() { + httpClient.close() + } +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/IdGenerator.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/IdGenerator.kt new file mode 100644 index 0000000000..e773699356 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/IdGenerator.kt @@ -0,0 +1,57 @@ +package ai.koog.a2a.server.session + +import ai.koog.a2a.model.Message +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +/** + * Interface for generating unique IDs for new contexts and tasks. + * + * Called by the server to provide unique IDs when messages lack contextId or taskId, + * preventing race conditions during concurrent agent execution. These generated IDs + * are enforced when agents reply in newly created contexts or create tasks based on + * incoming messages without existing task IDs. + */ +public interface IdGenerator { + /** + * Generates a unique context ID based on the given message. + * + * @param message The message for which the context ID is being generated. + * @return A unique string representing the context ID. + */ + public suspend fun generateContextId(message: Message): String + + /** + * Generates a unique task ID based on the given message. + * + * @param message The message for which the task ID is being generated. + * @return A unique string representing the task ID. + */ + public suspend fun generateTaskId(message: Message): String +} + +/** + * Implementation of the [IdGenerator] interface that generates unique identifiers using UUIDs. + * + * This generator ensures that each context ID and task ID is uniquely identified, leveraging UUIDs + * for randomness and collision resistance. IDs are generated only if the relevant existing ID + * (contextId or taskId) is null in the incoming message. + */ +@OptIn(ExperimentalUuidApi::class) +public object UuidIdGenerator : IdGenerator { + override suspend fun generateContextId(message: Message): String { + require(message.contextId == null) { + "Can't generate contextId for message with existing contextId" + } + + return Uuid.random().toString() + } + + override suspend fun generateTaskId(message: Message): String { + require(message.taskId == null) { + "Can't generate taskId for message with existing taskId" + } + + return Uuid.random().toString() + } +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/RequestContext.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/RequestContext.kt new file mode 100644 index 0000000000..ae51c604b2 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/RequestContext.kt @@ -0,0 +1,33 @@ +package ai.koog.a2a.server.session + +import ai.koog.a2a.model.Task +import ai.koog.a2a.server.messages.ContextMessageStorage +import ai.koog.a2a.server.tasks.ContextTaskStorage +import ai.koog.a2a.transport.ServerCallContext + +/** + * Request context associated with each A2A agent-related request, providing essential information and repositories to + * the agent executor. + * + * @property callContext [ServerCallContext] associated with the request. + * @property params Parameters associated with the request. + * @property taskStorage [ContextTaskStorage] associated with the request. + * @property messageStorage [ContextMessageStorage] associated with the request. + * @property contextId The context ID representing either an existing context from the incoming request in [params], + * or a newly generated ID that the agent must use if it decides to reply. + * @property taskId The task ID representing either an existing task from the incoming request in [params], + * or a newly generated ID that the agent must use if it decides to create a new task. + * @property task Optional shallow version of the current task (without message history or artifacts) + * providing lightweight access to general task state information. Present only if the incoming request + * was associated with an existing task. For detailed task information or referenced tasks, + * query [taskStorage] directly. + */ +public class RequestContext( + public val callContext: ServerCallContext, + public val params: T, + public val taskStorage: ContextTaskStorage, + public val messageStorage: ContextMessageStorage, + public val contextId: String, + public val taskId: String, + public val task: Task?, +) diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt new file mode 100644 index 0000000000..9dd8451f5e --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt @@ -0,0 +1,88 @@ +package ai.koog.a2a.server.session + +import ai.koog.a2a.model.Event +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.CoroutineStart +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.async +import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.collect + +/** + * Represents a session with lifecycle management. + * + * @property eventProcessor The session event processor + * @property agentJob The execution process associated with this session's execution + */ +public class Session( + public val eventProcessor: SessionEventProcessor, + public val agentJob: Deferred +) { + /** + * Context ID associated with this session. + */ + public val contextId: String = eventProcessor.contextId + + /** + * Task ID associated with this session. + */ + public val taskId: String = eventProcessor.taskId + + /** + * A stream of events associated with this session. + */ + public val events: Flow = eventProcessor.events + + /** + * Starts the [agentJob], if it hasn't already been started. + */ + public fun start() { + agentJob.start() + } + + /** + * Suspends until the session, i.e., event stream and agent job, complete. + * Waits for the event stream to finish first, to avoid triggering the agent job prematurely. + * Assumes that by the time event stream is finished, agent job will already be completed or canceled. + */ + public suspend fun join() { + events.collect() + agentJob.join() + } + + /** + * [start] and then [join] the session. + */ + public suspend fun startAndJoin() { + start() + join() + } + + /** + * Cancels the execution process, waiting for it to complete, and then closes event processor. + */ + public suspend fun cancelAndJoin() { + agentJob.cancelAndJoin() + eventProcessor.close() + } +} + +/** + * Creates an instance of [Session] with lazily started [Session.agentJob] + * + * @param coroutineScope The coroutine scope to use for running the [block] + * @param eventProcessor The session event processor + * @param block The block to be executed + */ +@Suppress("ktlint:standard:function-naming", "FunctionName") +public fun LazySession( + coroutineScope: CoroutineScope, + eventProcessor: SessionEventProcessor, + block: suspend CoroutineScope.() -> Unit +): Session { + return Session( + eventProcessor = eventProcessor, + agentJob = coroutineScope.async(start = CoroutineStart.LAZY, block = block) + ) +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt new file mode 100644 index 0000000000..2e4fc7dbe3 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt @@ -0,0 +1,162 @@ +package ai.koog.a2a.server.session + +import ai.koog.a2a.model.Event +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskEvent +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.server.exceptions.InvalidEventException +import ai.koog.a2a.server.exceptions.SessionNotActiveException +import ai.koog.a2a.server.tasks.TaskStorage +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.flow.filterIsInstance +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.onSubscription +import kotlinx.coroutines.flow.takeWhile +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlin.concurrent.atomics.AtomicBoolean +import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.jvm.JvmInline + +/** + * A session processor responsible for handling session events. + * It validates the events, writes them to [taskStorage] and emits them to the subscribers via [events]. + * + * Validation logic attempts to perform basic verification that events follow what is expected from a proper A2A server implementation. + * These are the main rules: + * + * - **Session type exclusivity**: A session can only handle either [Message] event or [TaskEvent] events, never both. + * - **Context ID validation**: All events must have the same context id as provided [contextId]. + * - **Single message limit**: Only one [Message] can be sent per session, after which the processor closes. + * - **Task ID consistency**: [TaskEvent] events must have task ids equal to provided [taskId]. + * - **Final event enforcement**: After a [TaskStatusUpdateEvent] with `final=true` is sent, processor closes. + * - **Terminal state closure**: When the event with the terminal state is sent, processor closes. + * + * @param taskStorage The storage where task events will be saved. + * @property contextId The contextId associated with this session, representing either an existing context + * from the incoming request or a newly generated ID that must be used for all events in this session. + * @property taskId The taskId associated with this session, representing either an existing task + * from the incoming request or a newly generated ID that must be used if creating a new task. + * Note: This taskId might not correspond to an actually existing task initially - it serves as the + * identifier that will be validated against all [TaskEvent] in this session. + */ +@OptIn(ExperimentalAtomicApi::class) +public class SessionEventProcessor( + public val contextId: String, + public val taskId: String, + private val taskStorage: TaskStorage, +) { + private companion object { + private const val SESSION_CLOSED = "Session event processor is closed, can't send events" + + private const val INVALID_CONTEXT_ID = "Event contextId must be same as provided contextId" + + private const val INVALID_TASK_ID = "Event taskId must be same as provided taskId" + + private const val TASK_EVENT_SENT = + "Task has already been initialized in this session, only TaskEvent's with the same taskId can be sent from now on" + } + + private val _isOpen: AtomicBoolean = AtomicBoolean(true) + + /** + * Whether the session is open. + */ + public val isOpen: Boolean get() = _isOpen.load() + + /** + * Tracks whether a task event was sent in this session, meaning we have to reject [Message] events now. + */ + private var isTaskEventSent: Boolean = false + + private val sessionMutex = Mutex() + + /** + * Helper interface to send actual events or termination signal to close current event stream subscribers on session closure. + */ + private sealed interface FlowEvent { + @JvmInline + value class Data(val data: Event) : FlowEvent + object Close : FlowEvent + } + + private val _events = MutableSharedFlow() + + /** + * A hot flow of events in this session that can be subscribed to. + */ + public val events: Flow = _events + .onSubscription { if (!_isOpen.load()) emit(FlowEvent.Close) } + .takeWhile { it !is FlowEvent.Close } + .filterIsInstance() + .map { it.data } + + /** + * Sends a [Message] to the session event processor. Validates the message against the session context and updates + * the session state accordingly. + * + * @param message The message to be sent. + * @throws [InvalidEventException] for invalid events. + * @see SessionEventProcessor + */ + public suspend fun sendMessage(message: Message): Unit = sessionMutex.withLock { + if (_isOpen.load()) { + if (isTaskEventSent) { + throw InvalidEventException(TASK_EVENT_SENT) + } + + if (message.contextId != this.contextId) { + throw InvalidEventException(INVALID_CONTEXT_ID) + } + + _events.emit(FlowEvent.Data(message)) + _isOpen.store(false) + } else { + throw SessionNotActiveException(SESSION_CLOSED) + } + } + + /** + * Sends a [TaskEvent] to the session event processor. + * Validates the event against the session context and updates [taskStorage]. + * + * @param event The event to be sent. + * @throws [InvalidEventException] for invalid events. + * @see SessionEventProcessor + */ + public suspend fun sendTaskEvent(event: TaskEvent): Unit = sessionMutex.withLock { + if (_isOpen.load()) { + isTaskEventSent = true + + if (event.contextId != this.contextId) { + throw InvalidEventException(INVALID_CONTEXT_ID) + } + + if (event.taskId != this.taskId) { + throw InvalidEventException(INVALID_TASK_ID) + } + + taskStorage.update(event) + _events.emit(FlowEvent.Data(event)) + + val isFinalEvent = (event is TaskStatusUpdateEvent && (event.status.state.terminal || event.final)) || + (event is Task && event.status.state.terminal) + + if (isFinalEvent) { + _isOpen.store(false) + } + } else { + throw SessionNotActiveException(SESSION_CLOSED) + } + } + + /** + * Closes the session event processor, also closing event stream. + */ + public suspend fun close(): Unit = sessionMutex.withLock { + _isOpen.store(false) + _events.emit(FlowEvent.Close) + } +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt new file mode 100644 index 0000000000..53b1550c26 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt @@ -0,0 +1,129 @@ +package ai.koog.a2a.server.session + +import ai.koog.a2a.annotations.InternalA2AApi +import ai.koog.a2a.model.TaskEvent +import ai.koog.a2a.server.notifications.PushNotificationConfigStorage +import ai.koog.a2a.server.notifications.PushNotificationSender +import ai.koog.a2a.server.tasks.TaskStorage +import ai.koog.a2a.utils.KeyedMutex +import ai.koog.a2a.utils.RWLock +import ai.koog.a2a.utils.withLock +import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.coroutines.CompletableJob +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Job +import kotlinx.coroutines.flow.firstOrNull +import kotlinx.coroutines.flow.onStart +import kotlinx.coroutines.launch + +/** + * Manages a set of active instances of [LazySession], sends push notifications if configured after each session completes. + * Automatically closes and removes the session when agent job is completed (whether successfully or not). + * + * Additionally, if push notifications are configured, after each task session completes, push notifications are sent with + * the current task state. + * + * Provides the ability to lock a task id. + * + * @param coroutineScope The scope in which the monitoring jobs will be launched. + * @param tasksMutex The mutex for locking specific task ids. + * @param taskStorage The storage for tasks. + * @param pushConfigStorage The storage for push notification configurations. + * @param pushSender The push notification sender. + */ +@OptIn(InternalA2AApi::class) +public class SessionManager( + private val coroutineScope: CoroutineScope, + private val tasksMutex: KeyedMutex, + private val cancelKey: (String) -> String, + private val taskStorage: TaskStorage, + private val pushConfigStorage: PushNotificationConfigStorage? = null, + private val pushSender: PushNotificationSender? = null, +) { + private companion object { + private val logger = KotlinLogging.logger {} + } + + /** + * Map of task id to session. All sessions have task id associated with them, even if the task won't be created. + */ + private val sessions = mutableMapOf() + private val sessionsRwLock = RWLock() + + /** + * Adds a session to a set of active sessions. + * Handles cleanup by closing and removing the session when it is completed (whether successfully or not). + * Sends push notifications if configured after each session completes. + * + * @param session The session to add. + * @return A [CompletableJob] indicating when the monitoring coroutine is started and ready to monitor the session. + * It is crucial to start agent execution only after this job completes, to ensure the monitoring won't skip any events. + * @throws IllegalArgumentException if a session for the same task id already exists. + */ + public suspend fun addSession(session: Session): CompletableJob { + sessionsRwLock.withWriteLock { + check(session.taskId !in sessions) { + "Session for taskId '${session.taskId}' already runs." + } + + sessions[session.taskId] = session + } + + // Signal to indicate the monitoring is started. + val monitoringStarted = Job() + + // Monitor for agent job completion to send push notifications and remove session from the map. + coroutineScope.launch { + val firstEvent = session.events + .onStart { monitoringStarted.complete() } + .firstOrNull() + + // Wait for the agent job to finish + session.agentJob.join() + + /* + Check and wait if there's a cancellation request for this task running now and still publishing some events. + Then remove it from the session map. + */ + tasksMutex.withLock(cancelKey(session.taskId)) { + sessionsRwLock.withWriteLock { + sessions -= session.taskId + session.cancelAndJoin() + } + } + + // Send push notifications with the current state of the task, after the session completion, if configured. + coroutineScope.launch { + if (firstEvent is TaskEvent && pushSender != null && pushConfigStorage != null) { + val task = taskStorage.get(session.taskId, historyLength = 0, includeArtifacts = false) + + if (task != null) { + pushConfigStorage.getAll(session.taskId).forEach { config -> + try { + pushSender.send(config, task) + } catch (e: Exception) { + // TODO log error + } + } + } + } + } + } + + return monitoringStarted + } + + /** + * Returns the session for the given task id, if it exists. + */ + public suspend fun getSession(taskId: String): Session? = sessionsRwLock.withReadLock { + sessions[taskId] + } + + /** + * Returns the number of active sessions. + */ + public suspend fun activeSessions(): Int = sessionsRwLock.withReadLock { + sessions.size + } +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorage.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorage.kt new file mode 100644 index 0000000000..cfa9a8d8dc --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorage.kt @@ -0,0 +1,161 @@ +package ai.koog.a2a.server.tasks + +import ai.koog.a2a.annotations.InternalA2AApi +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskArtifactUpdateEvent +import ai.koog.a2a.model.TaskEvent +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.server.exceptions.TaskOperationException +import ai.koog.a2a.utils.RWLock +import kotlinx.serialization.json.JsonObject + +/** + * In-memory implementation of [TaskStorage] using a thread-safe map. + * + * This implementation stores tasks in memory and provides concurrency safety through mutex locks. + * Tasks are indexed by both ID and context ID for efficient retrieval. + */ +@OptIn(InternalA2AApi::class) +public class InMemoryTaskStorage : TaskStorage { + private val tasks = mutableMapOf() + private val tasksByContext = mutableMapOf>() + private val rwLock = RWLock() + + override suspend fun get( + taskId: String, + historyLength: Int?, + includeArtifacts: Boolean + ): Task? = rwLock.withReadLock { + historyLength?.let { + require(it >= 0) { "historyLength must be non-negative" } + } + + val task = tasks[taskId] ?: return@withReadLock null + // Need to modify the original full task object to remove some information + val isModificationNeeded = historyLength != null || !includeArtifacts + + if (isModificationNeeded) { + task.copy( + history = if (historyLength != null) { + task.history?.takeLast(historyLength) + } else { + task.history + }, + artifacts = if (includeArtifacts) { + task.artifacts + } else { + null + } + ) + } else { + task + } + } + + override suspend fun getAll( + taskIds: List, + historyLength: Int?, + includeArtifacts: Boolean + ): List = rwLock.withReadLock { + taskIds.mapNotNull { taskId -> + get(taskId, historyLength, includeArtifacts) + } + } + + override suspend fun getByContext( + contextId: String, + historyLength: Int?, + includeArtifacts: Boolean + ): List = rwLock.withReadLock { + val contextTaskIds = tasksByContext[contextId] ?: emptySet() + contextTaskIds.mapNotNull { taskId -> + get(taskId, historyLength, includeArtifacts) + } + } + + override suspend fun update(event: TaskEvent): Unit = rwLock.withWriteLock { + when (event) { + is Task -> { + val oldTask = tasks[event.id] + + if (oldTask != null && event.contextId != oldTask.contextId) { + throw TaskOperationException("Cannot change context for existing task: ${event.id}") + } + + // Store or replace the task + tasks[event.id] = event + + // Update context index + tasksByContext.getOrPut(event.contextId) { mutableSetOf() }.add(event.id) + } + + is TaskStatusUpdateEvent -> { + val existingTask = tasks[event.taskId] + ?: throw TaskOperationException("Cannot update status for non-existing task: ${event.taskId}") + + val updatedTask = existingTask.copy( + status = event.status, + history = existingTask.status.message + ?.let { existingTask.history.orEmpty() + it } + ?: existingTask.history, + metadata = existingTask.metadata + ?.let { JsonObject(it + event.metadata.orEmpty()) } + ?: event.metadata, + ) + + tasks[event.taskId] = updatedTask + } + + is TaskArtifactUpdateEvent -> { + val existingTask = tasks[event.taskId] + ?: throw TaskOperationException("Cannot update artifact for non-existing task: ${event.taskId}") + + val currentArtifacts = existingTask.artifacts?.toMutableList() ?: mutableListOf() + val existingArtifactIndex = currentArtifacts.indexOfFirst { it.artifactId == event.artifact.artifactId } + + if (existingArtifactIndex >= 0) { + val existingArtifact = currentArtifacts[existingArtifactIndex] + + currentArtifacts[existingArtifactIndex] = if (event.append == true) { + existingArtifact.copy(parts = existingArtifact.parts + event.artifact.parts) + } else { + event.artifact + } + } else { + currentArtifacts.add(event.artifact) + } + + val updatedTask = existingTask.copy( + artifacts = currentArtifacts, + metadata = existingTask.metadata + ?.let { JsonObject(it + event.metadata.orEmpty()) } + ?: event.metadata + ) + + tasks[event.taskId] = updatedTask + } + } + } + + override suspend fun delete(taskId: String): Unit = rwLock.withWriteLock { + tasks.remove(taskId)?.let { task -> + // Remove from context index + tasksByContext[task.contextId]?.remove(taskId) + if (tasksByContext[task.contextId]?.isEmpty() == true) { + tasksByContext.remove(task.contextId) + } + } + } + + override suspend fun deleteAll(taskIds: List): Unit = rwLock.withWriteLock { + taskIds.forEach { taskId -> + tasks.remove(taskId)?.let { task -> + // Remove from context index + tasksByContext[task.contextId]?.remove(taskId) + if (tasksByContext[task.contextId]?.isEmpty() == true) { + tasksByContext.remove(task.contextId) + } + } + } + } +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/TaskStorage.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/TaskStorage.kt new file mode 100644 index 0000000000..b6ac235bbb --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/TaskStorage.kt @@ -0,0 +1,161 @@ +package ai.koog.a2a.server.tasks + +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskEvent +import ai.koog.a2a.server.exceptions.TaskOperationException + +/** + * Storage interface for managing tasks and their lifecycle events. + * + * Implementations must ensure concurrency safety. + */ +public interface TaskStorage { + /** + * Retrieves a task by ID. + * + * @param taskId the unique task identifier + * @param historyLength the maximum number of messages in conversation history to include in the response + * Set to `null` to include all messages. Defaults to `0`. + * @param includeArtifacts whether to include artifacts in the response. Default is `false`. + * + * @return the task or null if not found + */ + public suspend fun get( + taskId: String, + historyLength: Int? = 0, + includeArtifacts: Boolean = false + ): Task? + + /** + * Retrieves multiple tasks by their IDs. + * + * @param taskIds list of task identifiers + * @param historyLength the maximum number of messages in conversation history to include in the response + * Set to `null` to include all messages. Defaults to `0`. + * @param includeArtifacts whether to include artifacts in the response. Default is `false`. + * @return list of found tasks (may be fewer than requested) + */ + public suspend fun getAll( + taskIds: List, + historyLength: Int? = 0, + includeArtifacts: Boolean = false + ): List + + /** + * Retrieves all tasks associated with a specific context ID. + * + * @param contextId context identifier + * @param historyLength the maximum number of messages in conversation history to include in the response + * Set to `null` to include all messages. Defaults to `0`. + * @param includeArtifacts whether to include artifacts in the response. Default is `false`. + * @return A list of tasks that match the specified context ID. + */ + public suspend fun getByContext( + contextId: String, + historyLength: Int? = 0, + includeArtifacts: Boolean = false + ): List + + /** + * Updates task state based on an event, or creates the task if it doesn't exist. + * + * When the event is a [Task], it will be stored as a new task or replace an existing one. + * When the event is a status/artifact update, it modifies the existing task state. + * All tasks must be created before they can be updated with the [Task] event first. + * Attempts to send a non-[Task] [TaskEvent] for a non-existing task will result in error. + * + * @param event the update event to apply (creation or modification) + * @throws TaskOperationException if the task cannot be created or updated, e.g. [TaskEvent] that is not [Task] is sent for non-existing task id + */ + public suspend fun update(event: TaskEvent) + + /** + * Deletes a task by ID. + * + * @param taskId the task identifier to delete + * @throws TaskOperationException if the task cannot be deleted, e.g. it doesn't exist + */ + public suspend fun delete(taskId: String) + + /** + * Deletes multiple tasks by their IDs. + * + * @param taskIds list of task identifiers to delete + * @throws TaskOperationException if some tasks cannot be deleted, e.g. they don't exist + */ + public suspend fun deleteAll(taskIds: List) +} + +/** + * A specialized wrapper around [TaskStorage] for providing access to the tasks within a specific context. + * + * @param contextId the unique identifier for the current context + * @param taskStorage the underlying task storage implementation + * @see [TaskStorage] + */ +public class ContextTaskStorage( + private val contextId: String, + private val taskStorage: TaskStorage, +) { + /** + * Retrieves a task by ID. + * + * @see [TaskStorage.get] + */ + public suspend fun get( + taskId: String, + historyLength: Int? = 0, + includeArtifacts: Boolean = false + ): Task? = taskStorage.get(taskId, historyLength, includeArtifacts) + + /** + * Retrieves multiple tasks by their IDs. + * + * @see [TaskStorage.getAll] + */ + public suspend fun getAll( + taskIds: List, + historyLength: Int? = 0, + includeArtifacts: Boolean = false + ): List = taskStorage.getAll(taskIds, historyLength, includeArtifacts) + + /** + * Retrieves all tasks associated with the current context ID. + * + * @see [TaskStorage.getByContext] + */ + public suspend fun getByContext( + historyLength: Int? = 0, + includeArtifacts: Boolean = false + ): List = taskStorage.getByContext(contextId, historyLength, includeArtifacts) + + /** + * Deletes a task by ID, checking that it belongs to the current context ID. + * + * @see [TaskStorage.delete] + */ + public suspend fun delete(taskId: String) { + get(taskId, historyLength = 0, includeArtifacts = false)?.let { task -> + require(task.contextId == contextId) { + "contextId of the task requested to be deleted must be same as current contextId" + } + + taskStorage.delete(taskId) + } + } + + /** + * Deletes multiple tasks by their IDs, checking that they belong to the current context ID. + * + * @see [TaskStorage.deleteAll] + */ + public suspend fun deleteAll(taskIds: List) { + getAll(taskIds, historyLength = 0, includeArtifacts = false).let { tasks -> + require(tasks.all { it.contextId == contextId }) { + "contextId of the tasks requested to be deleted must be same as current contextId" + } + + taskStorage.deleteAll(taskIds) + } + } +} diff --git a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/messages/InMemoryMessageStorageTest.kt b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/messages/InMemoryMessageStorageTest.kt new file mode 100644 index 0000000000..5395731022 --- /dev/null +++ b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/messages/InMemoryMessageStorageTest.kt @@ -0,0 +1,99 @@ +package ai.koog.a2a.server.messages + +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.server.exceptions.MessageOperationException +import kotlinx.coroutines.test.runTest +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue + +class InMemoryMessageStorageTest { + private lateinit var storage: InMemoryMessageStorage + + @BeforeTest + fun setUp() { + storage = InMemoryMessageStorage() + } + + @Test + fun testSaveMessage() = runTest { + val message = createMessage("msg-1", "context-1", "Hello, world!") + + storage.save(message) + + val messages = storage.getByContext("context-1") + + assertEquals(1, messages.size) + assertEquals(message, messages[0]) + } + + @Test + fun testSaveMessageWithoutContextId() = runTest { + val message = createMessage("msg-1", null, "Hello, world!") + + assertFailsWith { + storage.save(message) + } + } + + @Test + fun testSaveMultipleMessages() = runTest { + val message1 = createMessage("msg-1", "context-1", "First message") + val message2 = createMessage("msg-2", "context-1", "Second message") + val message3 = createMessage("msg-3", "context-2", "Different context") + + storage.save(message1) + storage.save(message2) + storage.save(message3) + + val context1Messages = storage.getByContext("context-1") + assertEquals(2, context1Messages.size) + assertEquals(message1, context1Messages[0]) + assertEquals(message2, context1Messages[1]) + + val context2Messages = storage.getByContext("context-2") + assertEquals(1, context2Messages.size) + assertEquals(message3, context2Messages[0]) + } + + @Test + fun testGetByNonExistentContext() = runTest { + val messages = storage.getByContext("non-existent-context") + assertTrue(messages.isEmpty()) + } + + @Test + fun testDeleteByContext() = runTest { + val message1 = createMessage("msg-1", "context-1", "Message 1") + val message2 = createMessage("msg-2", "context-1", "Message 2") + val message3 = createMessage("msg-3", "context-2", "Different context") + + storage.save(message1) + storage.save(message2) + storage.save(message3) + + storage.deleteByContext("context-1") + + val context1Messages = storage.getByContext("context-1") + assertTrue(context1Messages.isEmpty()) + + val context2Messages = storage.getByContext("context-2") + assertEquals(1, context2Messages.size) + assertEquals(message3, context2Messages[0]) + } + + private fun createMessage( + messageId: String, + contextId: String?, + content: String = "test content" + ) = Message( + messageId = messageId, + role = Role.User, + parts = listOf(TextPart(content)), + contextId = contextId + ) +} diff --git a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/notifications/InMemoryPushNotificationConfigStorageTest.kt b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/notifications/InMemoryPushNotificationConfigStorageTest.kt new file mode 100644 index 0000000000..38e3ac8abf --- /dev/null +++ b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/notifications/InMemoryPushNotificationConfigStorageTest.kt @@ -0,0 +1,92 @@ +package ai.koog.a2a.server.notifications + +import ai.koog.a2a.model.PushNotificationConfig +import kotlinx.coroutines.test.runTest +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class InMemoryPushNotificationConfigStorageTest { + private lateinit var storage: InMemoryPushNotificationConfigStorage + + @BeforeTest + fun setUp() { + storage = InMemoryPushNotificationConfigStorage() + } + + @Test + fun testSaveMultipleConfigsWithMixedIds() = runTest { + val configWithId = PushNotificationConfig( + id = "config-1", + url = "https://webhook1.example.com" + ) + val configWithoutId = PushNotificationConfig( + id = null, + url = "https://webhook2.example.com" + ) + + storage.save("task-1", configWithId) + storage.save("task-1", configWithoutId) + + val retrieved = storage.getAll("task-1") + assertEquals(setOf(configWithId, configWithoutId), retrieved.toSet()) + } + + @Test + fun testOverwriteExistingConfig() = runTest { + val originalConfig = PushNotificationConfig( + id = "config-1", + url = "https://webhook1.example.com" + ) + val updatedConfig = PushNotificationConfig( + id = "config-1", + url = "https://webhook-updated.example.com" + ) + + storage.save("task-1", originalConfig) + storage.save("task-1", updatedConfig) + + val retrieved = storage.getAll("task-1") + assertEquals(setOf(updatedConfig), retrieved.toSet()) + } + + @Test + fun testDeleteAllConfigsForTask() = runTest { + val config = PushNotificationConfig( + id = "config-1", + url = "https://webhook.example.com" + ) + + storage.save("task-1", config) + storage.delete("task-1", null) + + val remaining = storage.getAll("task-1") + assertTrue(remaining.isEmpty()) + } + + @Test + fun testDeleteSpecificConfig() = runTest { + val config1 = PushNotificationConfig( + id = "config-1", + url = "https://webhook1.example.com" + ) + val config2 = PushNotificationConfig( + id = "config-2", + url = "https://webhook2.example.com" + ) + + storage.save("task-1", config1) + storage.save("task-1", config2) + storage.delete("task-1", "config-1") + + val remaining = storage.getAll("task-1") + assertEquals(setOf(config2), remaining.toSet()) + } + + @Test + fun testGetAllForNonExistentTask() = runTest { + val configs = storage.getAll("non-existent-task") + assertTrue(configs.isEmpty()) + } +} diff --git a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionEventProcessorTest.kt b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionEventProcessorTest.kt new file mode 100644 index 0000000000..462b158f65 --- /dev/null +++ b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionEventProcessorTest.kt @@ -0,0 +1,377 @@ +package ai.koog.a2a.server.session + +import ai.koog.a2a.model.Artifact +import ai.koog.a2a.model.Event +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskArtifactUpdateEvent +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatus +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.server.exceptions.InvalidEventException +import ai.koog.a2a.server.exceptions.SessionNotActiveException +import ai.koog.a2a.server.tasks.InMemoryTaskStorage +import kotlinx.coroutines.flow.lastOrNull +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.yield +import kotlinx.datetime.Instant +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertNull +import kotlin.time.Duration.Companion.seconds + +class SessionEventProcessorTest { + private companion object { + private val TEST_TIMEOUT = 5.seconds + } + + private lateinit var taskStorage: InMemoryTaskStorage + private val contextId = "test-context-1" + private val taskId = "task-1" + + @BeforeTest + fun setUp() { + taskStorage = InMemoryTaskStorage() + } + + private fun createMessage( + messageId: String, + contextId: String, + content: String + ) = Message( + messageId = messageId, + role = Role.User, + parts = listOf(TextPart(content)), + contextId = contextId + ) + + private fun createTask( + id: String, + contextId: String, + state: TaskState = TaskState.Submitted + ) = Task( + id = id, + contextId = contextId, + status = TaskStatus( + state = state, + timestamp = Instant.parse("2023-01-01T10:00:00Z") + ) + ) + + private fun createProcessor( + contextId: String, + taskId: String, + ): SessionEventProcessor = SessionEventProcessor( + contextId = contextId, + taskId = taskId, + taskStorage = taskStorage, + ) + + @Test + fun message_testSendMessageWithInvalidContextId() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val message = createMessage("msg-1", "different-context", "Hello") + + assertFailsWith { + processor.sendMessage(message) + } + + processor.close() + } + + @Test + fun message_testSendSecondMessageFails() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val message1 = createMessage("msg-1", contextId, "Hello") + val message2 = createMessage("msg-2", contextId, "World") + + processor.sendMessage(message1) + + assertFailsWith { + processor.sendMessage(message2) + } + assertFalse(processor.isOpen) + } + + @Test + fun message_testSendTaskEventAfterMessageFails() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val message = createMessage("msg-1", contextId, "Hello") + val task = createTask(taskId, contextId) + + processor.sendMessage(message) + + assertFailsWith { + processor.sendTaskEvent(task) + } + assertFalse(processor.isOpen) + } + + // Task session tests + + @Test + fun task_testSendTaskEventNewTask() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val task = createTask(taskId, contextId) + + // Start collecting events before sending + val events = mutableListOf() + val eventsJob = launch { + processor.events.collect { + events.add(it) + } + } + + // Let the eventJob job actually start + yield() + + processor.sendTaskEvent(task) + processor.close() + + // Wait for event collection to complete + eventsJob.join() + + // Verify task was stored and collected + val storedTask = taskStorage.get(taskId) + + assertEquals(task, storedTask) + assertEquals(listOf(task), events.toList()) + } + + @Test + fun task_testSendTaskEventWithExistingTask() = runTest(timeout = TEST_TIMEOUT) { + // Store a task first + val existingTask = createTask(taskId, contextId) + taskStorage.update(existingTask) + + // Create processor with existing task + val processor = createProcessor(contextId, taskId) + + val statusUpdate = TaskStatusUpdateEvent( + taskId = taskId, + contextId = contextId, + status = TaskStatus(state = TaskState.Working), + final = false + ) + + processor.sendTaskEvent(statusUpdate) + processor.close() + + // Verify event was processed + val updatedTask = taskStorage.get(taskId) + assertEquals(statusUpdate.status, updatedTask?.status) + } + + @Test + fun task_testSendTaskEventWithInvalidContextId() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val task = createTask(taskId, "different-context") + + assertFailsWith { + processor.sendTaskEvent(task) + } + + processor.close() + } + + @Test + fun task_testSendEventAfterFinalEventFails() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val task = createTask(taskId, contextId) + val finalStatusUpdate = TaskStatusUpdateEvent( + taskId = taskId, + contextId = contextId, + status = TaskStatus(state = TaskState.Completed), + final = true + ) + val anotherEvent = TaskArtifactUpdateEvent( + taskId = taskId, + contextId = contextId, + artifact = Artifact( + artifactId = "artifact-1", + parts = listOf(TextPart("content")) + ), + append = false + ) + + processor.sendTaskEvent(task) + processor.sendTaskEvent(finalStatusUpdate) + + assertFailsWith { + processor.sendTaskEvent(anotherEvent) + } + assertFalse(processor.isOpen) + } + + @Test + fun task_testSendMessageAfterTaskEventFails() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val task = createTask(taskId, contextId) + val message = createMessage("msg-1", contextId, "Hello") + + processor.sendTaskEvent(task) + + assertFailsWith { + processor.sendMessage(message) + } + + processor.close() + } + + @Test + fun task_testTaskArtifactUpdateEvent() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val task = createTask(taskId, contextId) + val artifact = Artifact( + artifactId = "artifact-1", + parts = listOf(TextPart("content")) + ) + val artifactEvent = TaskArtifactUpdateEvent( + taskId = taskId, + contextId = contextId, + artifact = artifact, + append = false + ) + + processor.sendTaskEvent(task) + processor.sendTaskEvent(artifactEvent) + processor.close() + + // Verify artifact was stored + val storedTask = taskStorage.get("task-1", includeArtifacts = true) + assertEquals(listOf(artifact), storedTask?.artifacts) + } + + @Test + fun task_testCompleteTaskLifecycle() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + + // Create task + val task = createTask(taskId, contextId) + processor.sendTaskEvent(task) + + // Update status to working + val workingUpdate = TaskStatusUpdateEvent( + taskId = taskId, + contextId = contextId, + status = TaskStatus(state = TaskState.Working), + final = false + ) + processor.sendTaskEvent(workingUpdate) + + // Add artifact + val artifact = Artifact( + artifactId = "artifact-1", + parts = listOf(TextPart("result")) + ) + val artifactEvent = TaskArtifactUpdateEvent( + taskId = taskId, + contextId = contextId, + artifact = artifact, + append = false + ) + processor.sendTaskEvent(artifactEvent) + + // Complete task + val completedUpdate = TaskStatusUpdateEvent( + taskId = taskId, + contextId = contextId, + status = TaskStatus(state = TaskState.Completed), + final = true + ) + processor.sendTaskEvent(completedUpdate) + + processor.close() + + // Verify final state + val finalTask = taskStorage.get(taskId, includeArtifacts = true) + assertEquals(TaskState.Completed, finalTask?.status?.state) + assertEquals(listOf(artifact), finalTask?.artifacts) + } + + // Concurrent scenarios + + @Test + fun concurrent_message_testSendMessageBroadcast() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val message = createMessage("msg-1", contextId, "Hello") + + fun collectionJob(events: MutableList) = launch { + processor.events.collect { + events.add(it) + } + } + + // Set up two collectors to test that events are broadcasted properly + val eventsOne = mutableListOf() + val eventsJobOne = collectionJob(eventsOne) + + val eventsTwo = mutableListOf() + val eventsJobTwo = collectionJob(eventsTwo) + + // Let event jobs actually start + yield() + + processor.sendMessage(message) + processor.close() + + eventsJobOne.join() + eventsJobTwo.join() + + assertEquals(listOf(message), eventsOne.toList(), "First collector should collect the message") + assertEquals(listOf(message), eventsTwo.toList(), "Second collector should collect the message") + } + + @Test + fun concurrent_message_testClosedProcessorSendMessageFailsAndEventStreamIsEmpty() = + runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val message1 = createMessage("msg-1", contextId, "Hello") + val message2 = createMessage("msg-2", contextId, "World") + + // Send first message + processor.sendMessage(message1) + + // Close processor and then attempt to send more events + processor.close() + + assertFailsWith("Should not be possible to send events to closed session") { + processor.sendMessage(message2) + } + + assertNull(processor.events.lastOrNull(), "Events stream should be empty after closing") + assertFalse(processor.isOpen) + } + + @Test + fun concurrent_task_testClosedProcessorSendTaskEventFailsAndEventStreamIsEmpty() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val task = createTask(taskId, contextId) + + // Send first task event + processor.sendTaskEvent(task) + + // Close processor and then attempt to send more events + processor.close() + + val workingUpdate = TaskStatusUpdateEvent( + taskId = taskId, + contextId = contextId, + status = TaskStatus(state = TaskState.Working), + final = false + ) + + assertFailsWith("Should not be possible to send events to closed session") { + processor.sendTaskEvent(workingUpdate) + } + + assertNull(processor.events.lastOrNull(), "Events stream should be empty after closing") + assertFalse(processor.isOpen) + } +} diff --git a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionManagerTest.kt b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionManagerTest.kt new file mode 100644 index 0000000000..326ea114ae --- /dev/null +++ b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionManagerTest.kt @@ -0,0 +1,282 @@ +package ai.koog.a2a.server.session + +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.PushNotificationConfig +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatus +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.server.notifications.InMemoryPushNotificationConfigStorage +import ai.koog.a2a.server.notifications.PushNotificationSender +import ai.koog.a2a.server.tasks.InMemoryTaskStorage +import ai.koog.a2a.utils.KeyedMutex +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.currentCoroutineContext +import kotlinx.coroutines.delay +import kotlinx.coroutines.job +import kotlinx.coroutines.joinAll +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.yield +import kotlinx.datetime.Instant +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNull +import kotlin.time.Duration.Companion.seconds + +class SessionManagerTest { + private companion object Companion { + private val TEST_TIMEOUT = 5.seconds + } + + private lateinit var taskStorage: InMemoryTaskStorage + private lateinit var pushConfigStorage: InMemoryPushNotificationConfigStorage + private lateinit var pushSender: MockPushNotificationSender + + private val contextId = "context-1" + private val taskId = "task-1" + + private class MockPushNotificationSender : PushNotificationSender { + val sentNotifications = mutableListOf>() + + override suspend fun send(config: PushNotificationConfig, task: Task) { + sentNotifications.add(config to task) + } + } + + @BeforeTest + fun setUp() { + taskStorage = InMemoryTaskStorage() + pushConfigStorage = InMemoryPushNotificationConfigStorage() + pushSender = MockPushNotificationSender() + } + + private fun createMessage( + messageId: String, + contextId: String, + content: String + ) = Message( + messageId = messageId, + role = Role.User, + parts = listOf(TextPart(content)), + contextId = contextId + ) + + private fun createTask( + id: String, + contextId: String, + state: TaskState = TaskState.Submitted + ) = Task( + id = id, + contextId = contextId, + status = TaskStatus( + state = state, + timestamp = Instant.parse("2023-01-01T10:00:00Z") + ) + ) + + private fun createProcessor( + contextId: String, + taskId: String, + ) = SessionEventProcessor( + contextId = contextId, + taskId = taskId, + taskStorage = taskStorage, + ) + + private fun createManager( + coroutineScope: CoroutineScope, + ) = SessionManager( + coroutineScope = coroutineScope, + cancelKey = { "cancel:$it" }, + tasksMutex = KeyedMutex(), + taskStorage = taskStorage, + pushConfigStorage = pushConfigStorage, + pushSender = pushSender, + ) + + @Test + fun testSessionManagerCreation() = runTest(timeout = TEST_TIMEOUT) { + val sessionManager = createManager(this) + + assertEquals(0, sessionManager.activeSessions()) + assertNull(sessionManager.getSession("any-task-id")) + } + + @Test + fun testAddMessageSession() = runTest(timeout = TEST_TIMEOUT) { + val sessionManager = createManager(this) + val eventProcessor = createProcessor(contextId, taskId) + + val message = createMessage("msg-1", contextId, "Hello") + + val session = LazySession( + coroutineScope = this, + eventProcessor = eventProcessor + ) { + eventProcessor.sendMessage(message) + } + + // Start session and wait for completion + sessionManager.addSession(session) + session.startAndJoin() + + // Let the session manager process it + yield() + + // Session should be automatically cleaned up after completion + assertEquals(0, sessionManager.activeSessions()) + } + + @Test + fun testAddTaskSession() = runTest(timeout = TEST_TIMEOUT) { + val sessionManager = createManager(this) + val eventProcessor = createProcessor(contextId, taskId) + + val session = LazySession( + coroutineScope = this, + eventProcessor = eventProcessor + ) { + val task = createTask(taskId, contextId) + eventProcessor.sendTaskEvent(task) + + // Simulate work + delay(400) + + val statusUpdate = TaskStatusUpdateEvent( + taskId = taskId, + contextId = contextId, + status = TaskStatus(state = TaskState.Completed), + final = true + ) + eventProcessor.sendTaskEvent(statusUpdate) + } + + sessionManager.addSession(session) + session.start() + + // Let the session manager process it + yield() + + assertEquals(session, sessionManager.getSession(taskId)) + + session.startAndJoin() + + // Let the session manager process it + yield() + + // Session should be automatically cleaned up after completion + assertEquals(0, sessionManager.activeSessions()) + } + + @Test + fun testMultipleSessions() = runTest(timeout = TEST_TIMEOUT) { + val sessionManager = createManager(this) + + // Create two task sessions + val eventProcessor1 = createProcessor("context-1", "task-1") + val eventProcessor2 = createProcessor("context-2", "task-2") + + val session1 = LazySession( + coroutineScope = this, + eventProcessor = eventProcessor1 + ) { + val task = createTask("task-1", "context-1") + eventProcessor1.sendTaskEvent(task) + + // Simulate work + delay(150) + + val statusUpdate = TaskStatusUpdateEvent( + taskId = "task-1", + contextId = "context-1", + status = TaskStatus(state = TaskState.Completed), + final = true + ) + eventProcessor1.sendTaskEvent(statusUpdate) + } + + val session2 = LazySession( + coroutineScope = this, + eventProcessor = eventProcessor2 + ) { + val task = createTask("task-2", "context-2") + eventProcessor2.sendTaskEvent(task) + + // Simulate work + delay(150) + + val statusUpdate = TaskStatusUpdateEvent( + taskId = "task-2", + contextId = "context-2", + status = TaskStatus(state = TaskState.Completed), + final = true + ) + eventProcessor2.sendTaskEvent(statusUpdate) + } + + sessionManager.addSession(session1) + sessionManager.addSession(session2) + session1.start() + session2.start() + + // Let the session manager process it + yield() + + assertEquals(session1, sessionManager.getSession("task-1")) + assertEquals(session2, sessionManager.getSession("task-2")) + + session1.join() + session2.join() + + // Let the session manager process it + yield() + + // All sessions should be automatically cleaned up + assertEquals(0, sessionManager.activeSessions()) + } + + @Test + fun testSessionWithPushNotifications() = runTest(timeout = TEST_TIMEOUT) { + val sessionManager = createManager(this) + val eventProcessor = createProcessor(contextId, taskId) + + // Configure push notification + val config = PushNotificationConfig( + id = "config-1", + url = "https://example.com/webhook" + ) + pushConfigStorage.save("task-1", config) + + val task = createTask("task-1", contextId) + + val session = LazySession( + coroutineScope = this, + eventProcessor = eventProcessor + ) { + eventProcessor.sendTaskEvent(task) + + val statusUpdate = TaskStatusUpdateEvent( + taskId = "task-1", + contextId = contextId, + status = TaskStatus(state = TaskState.Completed), + final = true + ) + eventProcessor.sendTaskEvent(statusUpdate) + } + + sessionManager.addSession(session) + session.startAndJoin() + + // Let all coroutines finish, so that the push notifications background job is definitely completed. + currentCoroutineContext().job.children.toList().joinAll() + + // Verify push notification was sent + assertEquals(1, pushSender.sentNotifications.size) + val (sentConfig, sentTask) = pushSender.sentNotifications[0] + assertEquals(config, sentConfig) + assertEquals(TaskState.Completed, sentTask.status.state) + } +} diff --git a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorageTest.kt b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorageTest.kt new file mode 100644 index 0000000000..258ca53eda --- /dev/null +++ b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorageTest.kt @@ -0,0 +1,351 @@ +package ai.koog.a2a.server.tasks + +import ai.koog.a2a.model.Artifact +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskArtifactUpdateEvent +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatus +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.server.exceptions.TaskOperationException +import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Instant +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue + +class InMemoryTaskStorageTest { + private lateinit var storage: InMemoryTaskStorage + + @BeforeTest + fun setUp() { + storage = InMemoryTaskStorage() + } + + @Test + fun testGetNonExistentTask() = runTest { + val result = storage.get("non-existent-id") + assertNull(result) + } + + @Test + fun testStoreAndRetrieveTask() = runTest { + val task = createTask(id = "task-1", contextId = "context-1") + + storage.update(task) + + val retrieved = storage.get("task-1") + + assertNotNull(retrieved) + assertEquals(task.id, retrieved.id) + assertEquals(task.contextId, retrieved.contextId) + } + + @Test + fun testDeleteTask() = runTest { + val task = createTask(id = "task-1", contextId = "context-1") + storage.update(task) + + storage.delete("task-1") + + val retrieved = storage.get("task-1") + assertNull(retrieved) + } + + @Test + fun testGetByContext() = runTest { + val task1 = createTask(id = "task-1", contextId = "context-1") + val task2 = createTask(id = "task-2", contextId = "context-1") + val task3 = createTask(id = "task-3", contextId = "context-2") + + storage.update(task1) + storage.update(task2) + storage.update(task3) + + val result = storage.getByContext("context-1") + + assertEquals(2, result.size) + assertTrue(result.all { it.contextId == "context-1" }) + assertTrue(result.any { it.id == "task-1" }) + assertTrue(result.any { it.id == "task-2" }) + } + + @Test + fun testTaskStatusUpdateEvent() = runTest { + // Create and store initial task with metadata + val initialMetadata = buildJsonObject { + put("initialKey", JsonPrimitive("initialValue")) + put("sharedKey", JsonPrimitive("originalValue")) + } + val task = createTask(id = "task-1", contextId = "context-1", metadata = initialMetadata) + storage.update(task) + + // Create a status update event with additional metadata + val updateMetadata = buildJsonObject { + put("newKey", JsonPrimitive("newValue")) + put("sharedKey", JsonPrimitive("updatedValue")) + } + val newMessage = createUserMessage("status-msg", "context-1", "Task completed successfully") + val newStatus = TaskStatus( + state = TaskState.Completed, + message = newMessage, + timestamp = Instant.parse("2023-01-01T12:00:00Z") + ) + val statusUpdateEvent = TaskStatusUpdateEvent( + taskId = "task-1", + contextId = "context-1", + status = newStatus, + metadata = updateMetadata, + final = true + ) + + // Update task status + storage.update(statusUpdateEvent) + + // Verify the status was updated and metadata was merged + val retrieved = storage.get("task-1") + assertEquals(newStatus, retrieved?.status) + + // Verify metadata merging: original + new with updates overriding + val expectedMetadata = buildJsonObject { + put("initialKey", JsonPrimitive("initialValue")) // preserved from original + put("sharedKey", JsonPrimitive("updatedValue")) // updated from event + put("newKey", JsonPrimitive("newValue")) // added from event + } + assertEquals(expectedMetadata, retrieved?.metadata) + } + + @Test + fun testTaskStatusUpdateEventForNonExistentTask() = runTest { + val statusUpdateEvent = TaskStatusUpdateEvent( + taskId = "non-existent", + contextId = "context-1", + status = TaskStatus(state = TaskState.Completed), + final = true + ) + + assertFailsWith { + storage.update(statusUpdateEvent) + } + } + + @Test + fun testTaskArtifactUpdateEventNewArtifact() = runTest { + // Create and store initial task + val task = createTask(id = "task-1", contextId = "context-1") + storage.update(task) + + // Create artifact update event with new artifact + val artifact = Artifact( + artifactId = "artifact-1", + parts = listOf(TextPart("Initial content")) + ) + val artifactUpdateEvent = TaskArtifactUpdateEvent( + taskId = "task-1", + contextId = "context-1", + artifact = artifact, + append = false + ) + + // Update task with artifact + storage.update(artifactUpdateEvent) + + // Verify the artifact was added + val retrieved = storage.get("task-1", includeArtifacts = true) + assertEquals(listOf(artifact), retrieved?.artifacts) + } + + @Test + fun testTaskArtifactUpdateEventReplaceExisting() = runTest { + // Create and store initial task with artifact + val initialArtifact = Artifact( + artifactId = "artifact-1", + parts = listOf(TextPart("Initial content")) + ) + val task = createTask(id = "task-1", contextId = "context-1", artifacts = listOf(initialArtifact)) + storage.update(task) + + // Create artifact update event to replace existing artifact + val newArtifact = Artifact( + artifactId = "artifact-1", + parts = listOf(TextPart("Replaced content")) + ) + val artifactUpdateEvent = TaskArtifactUpdateEvent( + taskId = "task-1", + contextId = "context-1", + artifact = newArtifact, + append = false + ) + + // Update task with new artifact + storage.update(artifactUpdateEvent) + + // Verify the artifact was replaced + val retrieved = storage.get("task-1", includeArtifacts = true) + assertEquals(listOf(newArtifact), retrieved?.artifacts) + } + + @Test + fun testTaskArtifactUpdateEventAppendToExisting() = runTest { + // Create and store initial task with artifact + val initialArtifact = Artifact( + artifactId = "artifact-1", + parts = listOf(TextPart("Initial content")) + ) + val task = createTask(id = "task-1", contextId = "context-1", artifacts = listOf(initialArtifact)) + storage.update(task) + + // Create artifact update event to append to existing artifact + val appendArtifact = Artifact( + artifactId = "artifact-1", + parts = listOf(TextPart(" Additional content")) + ) + val artifactUpdateEvent = TaskArtifactUpdateEvent( + taskId = "task-1", + contextId = "context-1", + artifact = appendArtifact, + append = true + ) + + val resultingArtifact = initialArtifact.copy(parts = initialArtifact.parts + appendArtifact.parts) + + // Update task with appended artifact + storage.update(artifactUpdateEvent) + + // Verify the content was appended + val retrieved = storage.get("task-1", includeArtifacts = true) + assertEquals(listOf(resultingArtifact), retrieved?.artifacts) + } + + @Test + fun testTaskArtifactUpdateEventForNonExistentTask() = runTest { + val artifact = Artifact( + artifactId = "artifact-1", + parts = listOf(TextPart("Content")) + ) + val artifactUpdateEvent = TaskArtifactUpdateEvent( + taskId = "non-existent", + contextId = "context-1", + artifact = artifact, + append = false + ) + + assertFailsWith { + storage.update(artifactUpdateEvent) + } + } + + @Test + fun testTaskStatusHistoryPreservation() = runTest { + // Initial task with no message in status + val initialTask = createTask( + id = "task-1", + contextId = "context-1" + ) + storage.update(initialTask) + + // Verify initial task - should have the initial status message but no history yet + val initialTaskFromStorage = storage.get("task-1", historyLength = null) + assertNotNull(initialTaskFromStorage) + assertNull(initialTaskFromStorage.history) + + // Update status with a new message - history is still empty + val firstUpdateMessage = createUserMessage("update-msg-1", "context-1", "Making progress") + val firstUpdateStatus = TaskStatus( + state = TaskState.Working, + message = firstUpdateMessage, + timestamp = Instant.parse("2023-01-01T11:00:00Z") + ) + val firstUpdate = TaskStatusUpdateEvent( + taskId = "task-1", + contextId = "context-1", + status = firstUpdateStatus, + final = false + ) + + storage.update(firstUpdate) + + val afterFirstUpdate = storage.get("task-1", historyLength = null) // Get full history + assertNotNull(afterFirstUpdate) + assertEquals(firstUpdateStatus, afterFirstUpdate.status) + assertNull(afterFirstUpdate.history) + + // Second status update - this should add the previous status message to history + val secondUpdateMessage = createUserMessage("update-msg-2", "context-1", "Almost done") + val secondUpdateStatus = TaskStatus( + state = TaskState.Working, + message = secondUpdateMessage, + timestamp = Instant.parse("2023-01-01T12:00:00Z") + ) + val secondUpdate = TaskStatusUpdateEvent( + taskId = "task-1", + contextId = "context-1", + status = secondUpdateStatus, + final = false + ) + storage.update(secondUpdate) + + // Verify second update + val afterSecondUpdate = storage.get("task-1", historyLength = null) + assertNotNull(afterSecondUpdate) + assertEquals(secondUpdateStatus, afterSecondUpdate.status) + assertEquals(listOf(firstUpdateMessage), afterSecondUpdate.history) + + // Final status update + val completionMessage = createUserMessage("completion-msg", "context-1", "Task completed successfully") + val completionStatus = TaskStatus( + state = TaskState.Completed, + message = completionMessage, + timestamp = Instant.parse("2023-01-01T13:00:00Z") + ) + val completionUpdate = TaskStatusUpdateEvent( + taskId = "task-1", + contextId = "context-1", + status = completionStatus, + final = true + ) + storage.update(completionUpdate) + + // Verify final update + val finalTask = storage.get("task-1", historyLength = null) + assertNotNull(finalTask) + assertEquals(completionStatus, finalTask.status) + assertEquals(listOf(firstUpdateMessage, secondUpdateMessage), finalTask.history) + } + + private fun createUserMessage( + messageId: String, + contextId: String, + content: String + ) = Message( + messageId = messageId, + role = Role.User, + parts = listOf(TextPart(content)), + contextId = contextId + ) + + private fun createTask( + id: String, + contextId: String, + status: TaskStatus = TaskStatus(state = TaskState.Submitted), + history: List? = null, + artifacts: List? = null, + metadata: JsonObject? = null + ) = Task( + id = id, + contextId = contextId, + status = status, + history = history, + artifacts = artifacts, + metadata = metadata + ) +} diff --git a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/.gitkeep b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/TestAgentExecutor.kt b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/TestAgentExecutor.kt new file mode 100644 index 0000000000..398d2afe53 --- /dev/null +++ b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/TestAgentExecutor.kt @@ -0,0 +1,219 @@ +@file:OptIn(ExperimentalUuidApi::class) + +package ai.koog.a2a.server + +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatus +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.server.agent.AgentExecutor +import ai.koog.a2a.server.session.RequestContext +import ai.koog.a2a.server.session.SessionEventProcessor +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.delay +import kotlinx.datetime.Clock +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +private suspend fun sayHello( + context: RequestContext, + eventProcessor: SessionEventProcessor, +) { + eventProcessor.sendMessage( + Message( + messageId = Uuid.random().toString(), + role = Role.Agent, + parts = listOf(TextPart("Hello World")), + contextId = context.contextId, + taskId = context.taskId + ) + ) +} + +private suspend fun doTask( + context: RequestContext, + eventProcessor: SessionEventProcessor, +) { + val task = Task( + id = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = TaskState.Submitted, + timestamp = Clock.System.now() + ), + history = listOf(context.params.message) + ) + + // Send initial task event + eventProcessor.sendTaskEvent(task) + + // Send task working status update + eventProcessor.sendTaskEvent( + TaskStatusUpdateEvent( + contextId = context.contextId, + taskId = context.taskId, + status = TaskStatus( + state = TaskState.Working, + message = Message( + messageId = Uuid.random().toString(), + role = Role.Agent, + parts = listOf(TextPart("Working on task")), + contextId = context.contextId, + taskId = context.taskId + ), + timestamp = Clock.System.now() + ), + final = false + ) + ) + + // Send task completion status update + eventProcessor.sendTaskEvent( + TaskStatusUpdateEvent( + contextId = context.contextId, + taskId = context.taskId, + status = TaskStatus( + state = TaskState.Completed, + message = Message( + messageId = Uuid.random().toString(), + role = Role.Agent, + parts = listOf(TextPart("Task completed")), + contextId = context.contextId, + taskId = context.taskId + ), + timestamp = Clock.System.now() + ), + final = true + ) + ) +} + +private suspend fun doCancelableTask( + context: RequestContext, + eventProcessor: SessionEventProcessor, +) { + val task = Task( + id = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = TaskState.Submitted, + timestamp = Clock.System.now() + ), + history = listOf(context.params.message) + ) + + eventProcessor.sendTaskEvent(task) +} + +private suspend fun doLongRunningTask( + context: RequestContext, + eventProcessor: SessionEventProcessor, +) { + val task = Task( + id = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = TaskState.Submitted, + timestamp = Clock.System.now() + ), + history = listOf(context.params.message) + ) + + eventProcessor.sendTaskEvent(task) + + // Simulate long-running task + repeat(4) { + delay(200) + + eventProcessor.sendTaskEvent( + TaskStatusUpdateEvent( + taskId = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = TaskState.Working, + message = Message( + messageId = Uuid.random().toString(), + role = Role.Agent, + parts = listOf(TextPart("Still working $it")), + contextId = context.contextId, + taskId = context.taskId + ), + timestamp = Clock.System.now() + ), + final = false + ) + ) + } +} + +class TestAgentExecutor : AgentExecutor { + override suspend fun execute(context: RequestContext, eventProcessor: SessionEventProcessor) { + val userMessage = context.params.message + val userInput = userMessage.parts.filterIsInstance() + .joinToString(" ") { it.text } + .lowercase() + + // Test scenarios to test various aspects of A2A + when (userInput) { + "hello world" -> { + sayHello(context, eventProcessor) + } + + "do task" -> { + doTask(context, eventProcessor) + } + + "do cancelable task" -> { + doCancelableTask(context, eventProcessor) + } + + "do long-running task" -> { + doLongRunningTask(context, eventProcessor) + } + + else -> { + eventProcessor.sendMessage( + Message( + messageId = Uuid.random().toString(), + role = Role.Agent, + parts = listOf(TextPart("Sorry, I don't understand you")), + contextId = context.contextId + ) + ) + } + } + } + + override suspend fun cancel( + context: RequestContext, + eventProcessor: SessionEventProcessor, + agentJob: Deferred? + ) { + agentJob?.cancelAndJoin() + + eventProcessor.sendTaskEvent( + TaskStatusUpdateEvent( + contextId = context.contextId, + taskId = context.taskId, + status = TaskStatus( + state = TaskState.Canceled, + message = Message( + messageId = Uuid.random().toString(), + role = Role.Agent, + parts = listOf(TextPart("Task canceled")), + contextId = context.contextId, + taskId = context.taskId + ), + timestamp = Clock.System.now() + ), + final = true + ) + ) + } +} diff --git a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/jsonrpc/A2AServerJsonRpcIntegrationTest.kt b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/jsonrpc/A2AServerJsonRpcIntegrationTest.kt new file mode 100644 index 0000000000..fb67bc6abf --- /dev/null +++ b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/jsonrpc/A2AServerJsonRpcIntegrationTest.kt @@ -0,0 +1,250 @@ +package ai.koog.a2a.server.jsonrpc + +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendConfiguration +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.transport.Request +import io.kotest.inspectors.shouldForAll +import io.kotest.inspectors.shouldForAtLeastOne +import io.kotest.matchers.nulls.shouldNotBeNull +import io.kotest.matchers.should +import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldStartWith +import io.kotest.matchers.types.shouldBeInstanceOf +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.joinAll +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.TestInstance +import org.junit.jupiter.api.parallel.Execution +import org.junit.jupiter.api.parallel.ExecutionMode +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.time.Duration.Companion.seconds +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +/** + * Integration test class for testing the JSON-RPC HTTP communication in the A2A server context. + * This class ensures the proper functioning and correctness of the A2A protocol over HTTP + * using the JSON-RPC standard. + */ +@OptIn(ExperimentalUuidApi::class) +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +@Execution(ExecutionMode.SAME_THREAD, reason = "Working with the same instance of test server.") +class A2AServerJsonRpcIntegrationTest : BaseA2AServerJsonRpcTest() { + override val testTimeout = 10.seconds + + @BeforeAll + override fun setup() { + super.setup() + } + + @BeforeTest + override fun initClient() { + super.initClient() + } + + @AfterAll + override fun tearDown() { + super.tearDown() + } + + @Test + override fun `test get agent card`() = + super.`test get agent card`() + + @Test + override fun `test get authenticated extended agent card`() = + super.`test get authenticated extended agent card`() + + @Test + override fun `test send message`() = + super.`test send message`() + + @Test + override fun `test send message streaming`() = + super.`test send message streaming`() + + @Test + override fun `test get task`() = + super.`test get task`() + + @Test + override fun `test cancel task`() = + super.`test cancel task`() + + @Test + override fun `test resubscribe task`() = + super.`test resubscribe task`() + + @Test + override fun `test push notification configs`() = + super.`test push notification configs`() + + /** + * Extended test that wouldn't work with Python A2A SDK server, because their implementation has some problems. + * It doesn't send events emitted in the `cancel` method in AgentExecutor to the subscribers of message/stream or tasks/resubscribe. + * But our server implementation should handle it properly. + */ + @Test + fun `test cancel task cancellation events received`() = runTest(timeout = testTimeout) { + // Need real time for this test + withContext(Dispatchers.Default) { + val createTaskRequest = Request( + data = MessageSendParams( + message = Message( + messageId = Uuid.Companion.random().toString(), + role = Role.User, + parts = listOf( + TextPart("do long-running task"), + ), + contextId = "test-context", + ), + ), + ) + + val taskId = (client.sendMessage(createTaskRequest).data as Task).id + + joinAll( + launch { + val resubscribeTaskRequest = Request( + data = TaskIdParams( + id = taskId, + ) + ) + + val events = client + .resubscribeTask(resubscribeTaskRequest) + .toList() + .map { it.data } + + // All the same task and context + events.shouldForAll { + it.shouldBeInstanceOf { + it.taskId shouldBe taskId + it.contextId shouldBe "test-context" + } + } + + // Has events from `execute` - task is working + events.shouldForAtLeastOne { + it.shouldBeInstanceOf { + it.status.state shouldBe TaskState.Working + it.status.message shouldNotBeNull { + role shouldBe Role.Agent + + parts.shouldForAll { + it.shouldBeInstanceOf { + it.text shouldStartWith "Still working" + } + } + } + } + } + + // Has events from `cancel` - task is canceled + events.shouldForAtLeastOne { + it.shouldBeInstanceOf { + it.status.state shouldBe TaskState.Canceled + it.status.message shouldNotBeNull { + role shouldBe Role.Agent + parts shouldBe listOf(TextPart("Task canceled")) + } + } + } + }, + launch { + // Let the task run for a while + delay(400) + + val cancelTaskRequest = Request( + data = TaskIdParams( + id = taskId, + ) + ) + + val response = client.cancelTask(cancelTaskRequest) + response.data should { + it.id shouldBe taskId + it.contextId shouldBe "test-context" + it.status should { + it.state shouldBe TaskState.Canceled + it.message shouldNotBeNull { + role shouldBe Role.Agent + parts shouldBe listOf(TextPart("Task canceled")) + } + } + } + } + ) + } + } + + /** + * Another test that doesn't work with Python A2A SDK server because of its implementation problems. + * It's taken from TCK. Follow-up messages to the running task should be supported. + * In case the task is still running, request should wait for a chance to be processed when the task is done. + */ + @Test + fun `test task send follow-up message`() = runTest(timeout = testTimeout) { + fun createRequest( + taskId: String?, + blocking: Boolean, + ) = Request( + data = MessageSendParams( + message = Message( + messageId = Uuid.Companion.random().toString(), + role = Role.User, + parts = listOf( + TextPart("do long-running task"), + ), + taskId = taskId, + contextId = "test-context" + ), + configuration = MessageSendConfiguration( + blocking = blocking + ) + ) + ) + + // Create a long-running task and return without waiting + val initialRequest = createRequest(taskId = null, blocking = false) + val initialResponse = client.sendMessage(initialRequest) + + val taskId = initialResponse.data.shouldBeInstanceOf().taskId + + // Immediately send a follow-up message to the same task and wait for the response + val followupRequest = createRequest(taskId = taskId, blocking = true) + val followupResponse = client.sendMessage(followupRequest) + + followupResponse.data.shouldBeInstanceOf { + it.taskId shouldBe taskId + it.contextId shouldBe "test-context" + + it.status should { + it.state shouldBe TaskState.Working + it.message shouldNotBeNull { + role shouldBe Role.Agent + + parts.shouldForAll { + it.shouldBeInstanceOf { + it.text shouldStartWith "Still working" + } + } + } + } + } + } +} diff --git a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/jsonrpc/BaseA2AServerJsonRpcTest.kt b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/jsonrpc/BaseA2AServerJsonRpcTest.kt new file mode 100644 index 0000000000..d023683889 --- /dev/null +++ b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/jsonrpc/BaseA2AServerJsonRpcTest.kt @@ -0,0 +1,184 @@ +package ai.koog.a2a.server.jsonrpc + +import ai.koog.a2a.client.A2AClient +import ai.koog.a2a.client.UrlAgentCardResolver +import ai.koog.a2a.consts.A2AConsts +import ai.koog.a2a.model.AgentCapabilities +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.AgentSkill +import ai.koog.a2a.model.TransportProtocol +import ai.koog.a2a.server.A2AServer +import ai.koog.a2a.server.TestAgentExecutor +import ai.koog.a2a.server.notifications.InMemoryPushNotificationConfigStorage +import ai.koog.a2a.test.BaseA2AProtocolTest +import ai.koog.a2a.transport.client.jsonrpc.http.HttpJSONRPCClientTransport +import ai.koog.a2a.transport.server.jsonrpc.http.HttpJSONRPCServerTransport +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.HttpTimeout +import io.ktor.client.plugins.logging.LogLevel +import io.ktor.client.plugins.logging.Logging +import io.ktor.server.netty.Netty +import kotlinx.coroutines.runBlocking +import java.net.ServerSocket + +abstract class BaseA2AServerJsonRpcTest : BaseA2AProtocolTest() { + protected var testPort: Int? = null + protected val testPath = "/a2a" + protected lateinit var serverUrl: String + + protected lateinit var serverTransport: HttpJSONRPCServerTransport + protected lateinit var clientTransport: HttpJSONRPCClientTransport + protected lateinit var httpClient: HttpClient + + override lateinit var client: A2AClient + + open fun setup(): Unit = runBlocking { + // Discover and take any free port + testPort = ServerSocket(0).use { it.localPort } + serverUrl = "http://localhost:$testPort$testPath" + + // Create agent cards + val agentCard = createAgentCard() + val agentCardExtended = createExtendedAgentCard() + + // Create test agent executor + val testAgentExecutor = TestAgentExecutor() + + // Create A2A server + val a2aServer = A2AServer( + agentExecutor = testAgentExecutor, + agentCard = agentCard, + agentCardExtended = agentCardExtended, + pushConfigStorage = InMemoryPushNotificationConfigStorage() + ) + + // Create server transport + serverTransport = HttpJSONRPCServerTransport(a2aServer) + + // Start server + serverTransport.start( + engineFactory = Netty, + port = testPort!!, + path = testPath, + wait = false, + agentCard = agentCard, + agentCardPath = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH, + ) + + // Create client transport + httpClient = HttpClient(CIO) { + install(Logging) { + level = LogLevel.ALL + } + + install(HttpTimeout) { + requestTimeoutMillis = testTimeout.inWholeMilliseconds + } + } + + clientTransport = HttpJSONRPCClientTransport(serverUrl, httpClient) + + client = A2AClient( + transport = clientTransport, + agentCardResolver = UrlAgentCardResolver( + baseUrl = serverUrl, + path = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH, + baseHttpClient = httpClient, + ) + ) + } + + open fun initClient(): Unit = runBlocking { + client.connect() + } + + open fun tearDown() = runBlocking { + clientTransport.close() + serverTransport.stop() + } + + private fun createAgentCard(): AgentCard = AgentCard( + protocolVersion = "0.3.0", + name = "Hello World Agent", + description = "Just a hello world agent", + url = "http://localhost:9999/", + preferredTransport = TransportProtocol.JSONRPC, + additionalInterfaces = null, + iconUrl = null, + provider = null, + version = "1.0.0", + documentationUrl = null, + capabilities = AgentCapabilities( + streaming = true, + pushNotifications = true, + stateTransitionHistory = null, + extensions = null + ), + securitySchemes = null, + security = null, + defaultInputModes = listOf("text"), + defaultOutputModes = listOf("text"), + skills = listOf( + AgentSkill( + id = "hello_world", + name = "Returns hello world", + description = "just returns hello world", + tags = listOf("hello world"), + examples = listOf("hi", "hello world"), + inputModes = null, + outputModes = null, + security = null + ) + ), + supportsAuthenticatedExtendedCard = true, + signatures = null + ) + + private fun createExtendedAgentCard(): AgentCard = AgentCard( + protocolVersion = "0.3.0", + name = "Hello World Agent - Extended Edition", + description = "The full-featured hello world agent for authenticated users.", + url = "http://localhost:9999/", + preferredTransport = TransportProtocol.Companion.JSONRPC, + additionalInterfaces = null, + iconUrl = null, + provider = null, + version = "1.0.1", + documentationUrl = null, + capabilities = AgentCapabilities( + streaming = true, + pushNotifications = true, + stateTransitionHistory = null, + extensions = null + ), + securitySchemes = null, + security = null, + defaultInputModes = listOf("text"), + defaultOutputModes = listOf("text"), + skills = listOf( + AgentSkill( + id = "hello_world", + name = "Returns hello world", + description = "just returns hello world", + tags = listOf("hello world"), + examples = listOf("hi", "hello world"), + inputModes = null, + outputModes = null, + security = null + ), + AgentSkill( + id = "super_hello_world", + name = "Returns a SUPER Hello World", + description = "A more enthusiastic greeting, only for authenticated users.", + tags = listOf("hello world", "super", "extended"), + examples = listOf("super hi", "give me a super hello"), + inputModes = null, + outputModes = null, + security = null + ) + ), + supportsAuthenticatedExtendedCard = true, + signatures = null + ) +} diff --git a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/jsonrpc/StressA2AServerJsonRpcIntegrationTest.kt b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/jsonrpc/StressA2AServerJsonRpcIntegrationTest.kt new file mode 100644 index 0000000000..dc9b1ea2c1 --- /dev/null +++ b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/jsonrpc/StressA2AServerJsonRpcIntegrationTest.kt @@ -0,0 +1,47 @@ +package ai.koog.a2a.server.jsonrpc + +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.Disabled +import org.junit.jupiter.api.RepeatedTest +import org.junit.jupiter.api.TestInstance +import org.junit.jupiter.api.parallel.Execution +import org.junit.jupiter.api.parallel.ExecutionMode +import kotlin.test.BeforeTest +import kotlin.time.Duration.Companion.seconds + +/** + * Stress-testing event-stream related requests to check that events are processed correctly under a high load. + * Also more samples help with finding some flaky behavior, e.g. race conditions. + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +@Execution(ExecutionMode.SAME_THREAD, reason = "Working with the same instance of test server.") +@Disabled("TODO add stress tests in heavy tests only") // TODO add in heavy tests +class StressA2AServerJsonRpcIntegrationTest : BaseA2AServerJsonRpcTest() { + override val testTimeout = 10.seconds + + @BeforeAll + override fun setup() { + super.setup() + } + + @BeforeTest + override fun initClient() { + super.initClient() + } + + @AfterAll + override fun tearDown() { + super.tearDown() + } + + @RepeatedTest(300, name = "{currentRepetition}/{totalRepetitions}") + fun `stress test cancel task`() = super.`test cancel task`() + + @RepeatedTest(300, name = "{currentRepetition}/{totalRepetitions}") + fun `stress test send message`() = super.`test send message`() + + // Long test, lower repetitions + @RepeatedTest(10, name = "{currentRepetition}/{totalRepetitions}") + fun `stress test resubscribe task`() = super.`test resubscribe task`() +} diff --git a/a2a/a2a-server/src/jvmTest/resources/logback.xml b/a2a/a2a-server/src/jvmTest/resources/logback.xml new file mode 100644 index 0000000000..24a99c370b --- /dev/null +++ b/a2a/a2a-server/src/jvmTest/resources/logback.xml @@ -0,0 +1,11 @@ + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + diff --git a/a2a/a2a-test/Module.md b/a2a/a2a-test/Module.md new file mode 100644 index 0000000000..c1b5c0d71e --- /dev/null +++ b/a2a/a2a-test/Module.md @@ -0,0 +1,4 @@ +# Module a2a-test + +Testing utilities and helpers for A2A protocol implementations. Provides common test fixtures, utilities, and integration +test support for both client and server components. diff --git a/a2a/a2a-test/build.gradle.kts b/a2a/a2a-test/build.gradle.kts new file mode 100644 index 0000000000..69bfd6d21c --- /dev/null +++ b/a2a/a2a-test/build.gradle.kts @@ -0,0 +1,34 @@ +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + sourceSets { + commonMain { + dependencies { + api(project(":a2a:a2a-core")) + api(project(":a2a:a2a-client")) + api(kotlin("test")) + api(kotlin("test-annotations-common")) + api(libs.kotlinx.coroutines.core) + api(libs.kotlinx.coroutines.test) + api(libs.kotlinx.serialization.json) + implementation(libs.kotest.assertions) + implementation(libs.kotlinx.coroutines.test) + } + } + + jvmMain { + dependencies { + api(kotlin("test-junit5")) + } + } + + jsMain { + dependencies { + api(kotlin("test-js")) + } + } + } +} diff --git a/a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/Stub.kt b/a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/Stub.kt new file mode 100644 index 0000000000..60504608b9 --- /dev/null +++ b/a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/Stub.kt @@ -0,0 +1,7 @@ +package ai.koog.a2a.test + +/** + * This class is required for publishing iOS target when there's no commonMain set. + */ +@Suppress("unused") +private class Stub diff --git a/a2a/a2a-test/src/jvmMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt b/a2a/a2a-test/src/jvmMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt new file mode 100644 index 0000000000..ec04b6769f --- /dev/null +++ b/a2a/a2a-test/src/jvmMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt @@ -0,0 +1,437 @@ +package ai.koog.a2a.test + +import ai.koog.a2a.client.A2AClient +import ai.koog.a2a.exceptions.A2AInternalErrorException +import ai.koog.a2a.model.AgentCapabilities +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.AgentSkill +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendConfiguration +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.PushNotificationAuthenticationInfo +import ai.koog.a2a.model.PushNotificationConfig +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskPushNotificationConfig +import ai.koog.a2a.model.TaskPushNotificationConfigParams +import ai.koog.a2a.model.TaskQueryParams +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.model.TransportProtocol +import ai.koog.a2a.transport.Request +import io.kotest.assertions.throwables.shouldThrowExactly +import io.kotest.inspectors.shouldForAll +import io.kotest.matchers.collections.shouldHaveSize +import io.kotest.matchers.collections.shouldNotBeEmpty +import io.kotest.matchers.nulls.shouldNotBeNull +import io.kotest.matchers.should +import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldStartWith +import io.kotest.matchers.types.shouldBeInstanceOf +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.test.runTest +import kotlin.time.Duration +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +/** + * Abstract base class containing transport-agnostic A2A protocol compliance tests. + * + * Concrete test classes should inherit from this class and provide the [client] property + * to run the same test suite against different A2A implementations. + * + * @property client The A2A client instance to test against. Should be connected and ready to use. + */ +@OptIn(ExperimentalUuidApi::class) +@Suppress("FunctionName") +abstract class BaseA2AProtocolTest { + protected abstract val testTimeout: Duration + + /** + * The A2A client instance to test. Must be connected and ready to use. + */ + protected abstract var client: A2AClient + + open fun `test get agent card`() = runTest(timeout = testTimeout) { + val agentCard = client.getAgentCard() + + // Assert on the full AgentCard structure + val expectedAgentCard = AgentCard( + protocolVersion = "0.3.0", + name = "Hello World Agent", + description = "Just a hello world agent", + url = "http://localhost:9999/", + preferredTransport = TransportProtocol.JSONRPC, + additionalInterfaces = null, + iconUrl = null, + provider = null, + version = "1.0.0", + documentationUrl = null, + capabilities = AgentCapabilities( + streaming = true, + pushNotifications = true, + stateTransitionHistory = null, + extensions = null + ), + securitySchemes = null, + security = null, + defaultInputModes = listOf("text"), + defaultOutputModes = listOf("text"), + skills = listOf( + AgentSkill( + id = "hello_world", + name = "Returns hello world", + description = "just returns hello world", + tags = listOf("hello world"), + examples = listOf("hi", "hello world"), + inputModes = null, + outputModes = null, + security = null + ) + ), + supportsAuthenticatedExtendedCard = true, + signatures = null + ) + + agentCard shouldBe expectedAgentCard + } + + open fun `test get authenticated extended agent card`() = runTest(timeout = testTimeout) { + val request = Request(data = null) + + val response = client.getAuthenticatedExtendedAgentCard(request) + + // Assert on the extended agent card structure + val expectedExtendedAgentCard = AgentCard( + protocolVersion = "0.3.0", + name = "Hello World Agent - Extended Edition", + description = "The full-featured hello world agent for authenticated users.", + url = "http://localhost:9999/", + preferredTransport = TransportProtocol.JSONRPC, + additionalInterfaces = null, + iconUrl = null, + provider = null, + version = "1.0.1", + documentationUrl = null, + capabilities = AgentCapabilities( + streaming = true, + pushNotifications = true, + stateTransitionHistory = null, + extensions = null + ), + securitySchemes = null, + security = null, + defaultInputModes = listOf("text"), + defaultOutputModes = listOf("text"), + skills = listOf( + AgentSkill( + id = "hello_world", + name = "Returns hello world", + description = "just returns hello world", + tags = listOf("hello world"), + examples = listOf("hi", "hello world"), + inputModes = null, + outputModes = null, + security = null + ), + AgentSkill( + id = "super_hello_world", + name = "Returns a SUPER Hello World", + description = "A more enthusiastic greeting, only for authenticated users.", + tags = listOf("hello world", "super", "extended"), + examples = listOf("super hi", "give me a super hello"), + inputModes = null, + outputModes = null, + security = null + ) + ), + supportsAuthenticatedExtendedCard = true, + signatures = null + ) + + response.data shouldBe expectedExtendedAgentCard + } + + open fun `test send message`() = runTest(timeout = testTimeout) { + val request = Request( + data = MessageSendParams( + message = Message( + messageId = Uuid.random().toString(), + role = Role.User, + parts = listOf( + TextPart("hello world"), + ), + contextId = "test-context" + ), + ) + ) + + val response = client.sendMessage(request) + + response should { + it.id shouldBe request.id + + it.data.shouldBeInstanceOf { + it.role shouldBe Role.Agent + it.parts shouldBe listOf(TextPart("Hello World")) + it.contextId shouldBe "test-context" + } + } + } + + open fun `test send message streaming`() = runTest(timeout = testTimeout) { + val createTaskRequest = Request( + data = MessageSendParams( + message = Message( + messageId = Uuid.random().toString(), + role = Role.User, + parts = listOf( + TextPart("do task"), + ), + contextId = "test-context" + ), + ), + ) + + val events = client + .sendMessageStreaming(createTaskRequest) + .toList() + .map { it.data } + + events shouldHaveSize 3 + events[0].shouldBeInstanceOf { + it.contextId shouldBe "test-context" + it.status should { + it.state shouldBe TaskState.Submitted + } + + it.history shouldNotBeNull { + this shouldHaveSize 1 + + this[0] should { + it.role shouldBe Role.User + it.parts shouldBe listOf(TextPart("do task")) + } + } + } + + events[1].shouldBeInstanceOf { + it.contextId shouldBe "test-context" + + it.status should { + it.state shouldBe TaskState.Working + it.message shouldNotBeNull { + role shouldBe Role.Agent + parts shouldBe listOf(TextPart("Working on task")) + } + } + } + + events[2].shouldBeInstanceOf { + it.contextId shouldBe "test-context" + + it.status should { + it.state shouldBe TaskState.Completed + it.message shouldNotBeNull { + role shouldBe Role.Agent + parts shouldBe listOf(TextPart("Task completed")) + } + } + } + } + + open fun `test get task`() = runTest(timeout = testTimeout) { + val createTaskRequest = Request( + data = MessageSendParams( + message = Message( + messageId = Uuid.random().toString(), + role = Role.User, + parts = listOf( + TextPart("do task"), + ), + contextId = "test-context" + ), + ), + ) + + val taskId = (client.sendMessage(createTaskRequest).data as Task).id + + val getTaskRequest = Request( + data = TaskQueryParams( + id = taskId, + historyLength = 1 + ) + ) + + val response = client.getTask(getTaskRequest) + + response.data should { + it.id shouldBe taskId + it.contextId shouldBe "test-context" + it.status should { + it.state shouldBe TaskState.Completed + } + } + } + + open fun `test cancel task`() = runTest(timeout = testTimeout) { + val createTaskRequest = Request( + data = MessageSendParams( + message = Message( + messageId = Uuid.random().toString(), + role = Role.User, + parts = listOf( + TextPart("do cancelable task"), + ), + contextId = "test-context" + ), + ), + ) + + val taskId = (client.sendMessage(createTaskRequest).data as Task).id + + val cancelTaskRequest = Request( + data = TaskIdParams( + id = taskId, + ) + ) + + val response = client.cancelTask(cancelTaskRequest) + + response.data should { + it.id shouldBe taskId + it.contextId shouldBe "test-context" + it.status should { + it.state shouldBe TaskState.Canceled + it.message shouldNotBeNull { + role shouldBe Role.Agent + parts shouldBe listOf(TextPart("Task canceled")) + } + } + } + } + + open fun `test resubscribe task`() = runTest(timeout = testTimeout) { + val createTaskRequest = Request( + data = MessageSendParams( + message = Message( + messageId = Uuid.random().toString(), + role = Role.User, + parts = listOf( + TextPart("do long-running task"), + ), + contextId = "test-context" + ), + configuration = MessageSendConfiguration( + blocking = false + ) + ), + ) + + val taskId = (client.sendMessage(createTaskRequest).data as Task).id + + val resubscribeTaskRequest = Request( + data = TaskIdParams( + id = taskId, + ) + ) + + val events = client + .resubscribeTask(resubscribeTaskRequest) + .toList() + .map { it.data } + + events.shouldNotBeEmpty() + + events.shouldForAll { + it.shouldBeInstanceOf { + it.taskId shouldBe taskId + it.contextId shouldBe "test-context" + + it.status should { + it.state shouldBe TaskState.Working + it.message shouldNotBeNull { + role shouldBe Role.Agent + + parts.shouldForAll { + it.shouldBeInstanceOf { + it.text shouldStartWith "Still working" + } + } + } + } + } + } + } + + open fun `test push notification configs`() = runTest(timeout = testTimeout) { + val createTaskRequest = Request( + data = MessageSendParams( + message = Message( + messageId = Uuid.random().toString(), + role = Role.User, + parts = listOf( + TextPart("do long-running task"), + ), + contextId = "test-context" + ), + ), + ) + + val taskId = (client.sendMessage(createTaskRequest).data as Task).id + + val pushConfig = TaskPushNotificationConfig( + taskId = taskId, + pushNotificationConfig = PushNotificationConfig( + id = "push-id", + url = "https://localhost:3000", + token = "push-token", + authentication = PushNotificationAuthenticationInfo( + schemes = listOf("bearer"), + credentials = "very-secret-credential" + ) + ) + ) + + val request = Request( + data = pushConfig + ) + + val setPushConfigResponse = client.setTaskPushNotificationConfig(request) + setPushConfigResponse.data shouldBe pushConfig + + val getPushConfigRequest = Request( + data = TaskPushNotificationConfigParams( + id = taskId, + pushNotificationConfigId = pushConfig.pushNotificationConfig.id, + ) + ) + + val getPushConfigResponse = client.getTaskPushNotificationConfig(getPushConfigRequest) + getPushConfigResponse.data shouldBe pushConfig + + val listPushConfigRequest = Request( + data = TaskIdParams( + id = taskId, + ) + ) + + val listPushConfigResponse = client.listTaskPushNotificationConfig(listPushConfigRequest) + listPushConfigResponse.data shouldBe listOf(pushConfig) + + val deletePushConfigRequest = Request( + data = TaskPushNotificationConfigParams( + id = taskId, + pushNotificationConfigId = pushConfig.pushNotificationConfig.id, + ) + ) + + client.deleteTaskPushNotificationConfig(deletePushConfigRequest) + + shouldThrowExactly { + client.getTaskPushNotificationConfig(getPushConfigRequest) + } + } +} diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/Module.md b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/Module.md new file mode 100644 index 0000000000..adaaf5c980 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/Module.md @@ -0,0 +1,3 @@ +# Module a2a-transport-client-jsonrpc-http + +HTTP-based JSON-RPC client transport implementation for A2A protocol. Implements the ClientTransport interface using Ktor HTTP client to communicate with JSON-RPC servers over HTTP. diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts new file mode 100644 index 0000000000..0dfa2d09db --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts @@ -0,0 +1,52 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + sourceSets { + commonMain { + dependencies { + api(project(":a2a:a2a-transport:a2a-transport-core-jsonrpc")) + api(libs.kotlinx.serialization.json) + api(libs.kotlinx.coroutines.core) + api(libs.ktor.client.core) + api(libs.ktor.client.content.negotiation) + api(libs.ktor.serialization.kotlinx.json) + implementation(libs.oshai.kotlin.logging) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + implementation(libs.kotlinx.coroutines.test) + implementation(libs.ktor.client.mock) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + implementation(libs.mokksy.a2a) + implementation(libs.ktor.client.cio) + runtimeOnly(libs.logback.classic) + } + } + + jsTest { + dependencies { + implementation(kotlin("test-js")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransport.kt b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransport.kt new file mode 100644 index 0000000000..11ca3edc76 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransport.kt @@ -0,0 +1,99 @@ +package ai.koog.a2a.transport.client.jsonrpc.http + +import ai.koog.a2a.transport.ClientCallContext +import ai.koog.a2a.transport.jsonrpc.JSONRPCClientTransport +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCJson +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCRequest +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCResponse +import io.ktor.client.HttpClient +import io.ktor.client.call.body +import io.ktor.client.plugins.contentnegotiation.ContentNegotiation +import io.ktor.client.plugins.defaultRequest +import io.ktor.client.plugins.sse.SSE +import io.ktor.client.plugins.sse.sse +import io.ktor.client.request.headers +import io.ktor.client.request.post +import io.ktor.client.request.setBody +import io.ktor.http.ContentType +import io.ktor.http.HttpMethod +import io.ktor.http.contentType +import io.ktor.serialization.kotlinx.json.json +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.map + +/** + * Implementation of a JSON-RPC client transport using HTTP as the underlying communication protocol. + * + * This transport sends JSON-RPC requests over HTTP and processes the responses. It also supports + * both standard requests and Server-Sent Events (SSE) for streaming responses. + * + * @param url The URL of the JSON-RPC server endpoint. + * @param baseHttpClient The base [HttpClient] instance, which will be configured internally. + */ +public class HttpJSONRPCClientTransport( + url: String, + baseHttpClient: HttpClient = HttpClient() +) : JSONRPCClientTransport() { + private val httpClient: HttpClient = baseHttpClient.config { + defaultRequest { + url(url) + contentType(ContentType.Application.Json) + } + + install(ContentNegotiation) { + json(JSONRPCJson) + } + + install(SSE) + + expectSuccess = true + } + + override suspend fun request( + request: JSONRPCRequest, + ctx: ClientCallContext + ): JSONRPCResponse { + val response = httpClient.post { + headers { + ctx.additionalHeaders.forEach { (key, values) -> + appendAll(key, values) + } + } + + setBody(request) + } + + return response.body() + } + + override fun requestStreaming( + request: JSONRPCRequest, + ctx: ClientCallContext + ): Flow = flow { + httpClient.sse( + request = { + method = HttpMethod.Post + + headers { + ctx.additionalHeaders.forEach { (key, values) -> + appendAll(key, values) + } + } + + setBody(request) + } + ) { + incoming + .map { event -> + requireNotNull(event.data) { "SSE data must not be null" } + .let { data -> JSONRPCJson.decodeFromString(data) } + } + .collect(this@flow) + } + } + + override fun close() { + httpClient.close() + } +} diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportTest.kt b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportTest.kt new file mode 100644 index 0000000000..9a033d8c3a --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportTest.kt @@ -0,0 +1,457 @@ +package ai.koog.a2a.transport.client.jsonrpc.http + +import ai.koog.a2a.exceptions.A2AErrorCodes +import ai.koog.a2a.exceptions.A2AInvalidParamsException +import ai.koog.a2a.model.AgentCapabilities +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.AgentSkill +import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.PushNotificationConfig +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskPushNotificationConfig +import ai.koog.a2a.model.TaskPushNotificationConfigParams +import ai.koog.a2a.model.TaskQueryParams +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatus +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.transport.ClientTransport +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.RequestId +import ai.koog.a2a.transport.Response +import ai.koog.a2a.transport.jsonrpc.A2AMethod +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCError +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCErrorResponse +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCJson +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCRequest +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCSuccessResponse +import ai.koog.a2a.transport.jsonrpc.model.JSONRPC_VERSION +import io.ktor.client.HttpClient +import io.ktor.client.engine.mock.MockEngine +import io.ktor.client.engine.mock.respond +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpMethod +import io.ktor.http.content.TextContent +import io.ktor.http.headersOf +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.json.decodeFromJsonElement +import kotlinx.serialization.json.encodeToJsonElement +import kotlin.test.Ignore +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.fail +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +@OptIn(ExperimentalUuidApi::class) +class HttpJSONRPCClientTransportTest { + + private val json = JSONRPCJson + + private suspend inline fun testAPIMethod( + method: A2AMethod, + request: Request, + expectedResponse: Response, + noinline invoke: suspend ClientTransport.(Request) -> Response, + ) { + val mockEngine = MockEngine { receivedRequest -> + assertEquals(HttpMethod.Post, receivedRequest.method) + assertEquals(ContentType.Application.Json, receivedRequest.body.contentType) + + val requestBodyText = (receivedRequest.body as TextContent).text + val jsonRpcRequest = json.decodeFromString(requestBodyText) + + assertEquals(method.value, jsonRpcRequest.method) + assertEquals(request.id, jsonRpcRequest.id) + assertEquals(request.data, json.decodeFromJsonElement(jsonRpcRequest.params)) + + val jsonRpcResponse = JSONRPCSuccessResponse( + id = expectedResponse.id, + result = json.encodeToJsonElement(expectedResponse.data), + jsonrpc = JSONRPC_VERSION, + ) + + respond( + content = json.encodeToString(jsonRpcResponse), + headers = headersOf(HttpHeaders.ContentType, ContentType.Application.Json.toString()), + ) + } + + val httpClient = HttpClient(mockEngine) + val transport = HttpJSONRPCClientTransport("https://api.example.com/a2a", httpClient) + + val actualResponse = transport.invoke(request) + + assertEquals(expectedResponse.id, actualResponse.id) + assertEquals(expectedResponse.data, actualResponse.data) + + transport.close() + } + + @Test + fun testGetAuthenticatedExtendedAgentCard() = runTest { + val id = RequestId.StringId("test-1") + + val request = Request( + id = id, + data = null, + ) + + val expectedResponse = Response( + id = id, + data = AgentCard( + name = "Test Agent", + description = "A test agent", + url = "https://api.example.com/a2a", + version = "1.0.0", + capabilities = AgentCapabilities(), + defaultInputModes = listOf("text/plain"), + defaultOutputModes = listOf("text/plain"), + skills = listOf( + AgentSkill( + id = "test-skill", + name = "Test Skill", + description = "A test skill", + tags = listOf("test") + ) + ) + ) + ) + + testAPIMethod( + method = A2AMethod.GetAuthenticatedExtendedAgentCard, + request = request, + expectedResponse = expectedResponse, + invoke = { getAuthenticatedExtendedAgentCard(it) } + ) + } + + @Test + fun testSendMessage() = runTest { + val id = RequestId.StringId("test-2") + + val testMessage = Message( + messageId = Uuid.random().toString(), + role = Role.User, + parts = listOf(TextPart("Hello, agent!")), + taskId = "task-123" + ) + + val messageSendParams = MessageSendParams( + message = testMessage + ) + + val request = Request( + id = id, + data = messageSendParams, + ) + + val expectedResponse: Response = Response( + id = id, + data = Message( + messageId = "msg-456", + role = Role.Agent, + parts = listOf(TextPart("Hello, user! How can I help you?")), + taskId = "task-123" + ) + ) + + testAPIMethod( + method = A2AMethod.SendMessage, + request = request, + expectedResponse = expectedResponse, + invoke = { sendMessage(it) } + ) + } + + @Ignore + @Test + fun testSendMessageStreaming() = runTest { + // FIXME Can't test it, MockEngine doesn't support SSE capability + } + + @Test + fun testGetTask() = runTest { + val id = RequestId.StringId("test-3") + + val taskQueryParams = TaskQueryParams( + id = "task-123", + historyLength = 10 + ) + + val request = Request( + id = id, + data = taskQueryParams, + ) + + val expectedResponse = Response( + id = id, + data = Task( + id = "task-123", + contextId = "context-456", + status = TaskStatus( + state = TaskState.Working, + message = Message( + messageId = Uuid.random().toString(), + role = Role.Agent, + parts = listOf(TextPart("Working on your request...")) + ) + ), + history = listOf( + Message( + messageId = Uuid.random().toString(), + role = Role.User, + parts = listOf(TextPart("Hello, agent!")), + taskId = "task-123" + ) + ) + ) + ) + + testAPIMethod( + method = A2AMethod.GetTask, + request = request, + expectedResponse = expectedResponse, + invoke = { getTask(it) } + ) + } + + @Test + fun testCancelTask() = runTest { + val id = RequestId.StringId("test-4") + + val taskIdParams = TaskIdParams(id = "task-123") + + val request = Request( + id = id, + data = taskIdParams, + ) + + val expectedResponse = Response( + id = id, + data = Task( + id = "task-123", + contextId = "context-456", + status = TaskStatus( + state = TaskState.Canceled, + message = Message( + messageId = Uuid.random().toString(), + role = Role.Agent, + parts = listOf(TextPart("Task has been canceled.")) + ) + ) + ) + ) + + testAPIMethod( + method = A2AMethod.CancelTask, + request = request, + expectedResponse = expectedResponse, + invoke = { cancelTask(it) } + ) + } + + @Ignore + @Test + fun testResubscribeTask() = runTest { + // FIXME Can't test it, MockEngine doesn't support SSE capability + } + + @Test + fun testSetTaskPushNotificationConfig() = runTest { + val id = RequestId.StringId("test-5") + + val pushNotificationConfig = TaskPushNotificationConfig( + taskId = "task-123", + pushNotificationConfig = PushNotificationConfig( + id = "notification-config-1", + url = "https://webhook.example.com/notifications", + token = "webhook-token-123" + ) + ) + + val request = Request( + id = id, + data = pushNotificationConfig, + ) + + val expectedResponse = Response( + id = id, + data = pushNotificationConfig + ) + + testAPIMethod( + method = A2AMethod.SetTaskPushNotificationConfig, + request = request, + expectedResponse = expectedResponse, + invoke = { setTaskPushNotificationConfig(it) } + ) + } + + @Test + fun testGetTaskPushNotificationConfig() = runTest { + val id = RequestId.StringId("test-6") + + val configParams = TaskPushNotificationConfigParams( + id = "task-123", + pushNotificationConfigId = "notification-config-1" + ) + + val request = Request( + id = id, + data = configParams, + ) + + val expectedResponse = Response( + id = id, + data = TaskPushNotificationConfig( + taskId = "task-123", + pushNotificationConfig = PushNotificationConfig( + id = "notification-config-1", + url = "https://webhook.example.com/notifications", + token = "webhook-token-123" + ) + ) + ) + + testAPIMethod( + method = A2AMethod.GetTaskPushNotificationConfig, + request = request, + expectedResponse = expectedResponse, + invoke = { getTaskPushNotificationConfig(it) } + ) + } + + @Test + fun testListTaskPushNotificationConfig() = runTest { + val id = RequestId.StringId("test-7") + + val taskIdParams = TaskIdParams(id = "task-123") + + val request = Request( + id = id, + data = taskIdParams, + ) + + val expectedResponse = Response( + id = id, + data = listOf( + TaskPushNotificationConfig( + taskId = "task-123", + pushNotificationConfig = PushNotificationConfig( + id = "notification-config-1", + url = "https://webhook.example.com/notifications", + token = "webhook-token-123" + ) + ), + TaskPushNotificationConfig( + taskId = "task-123", + pushNotificationConfig = PushNotificationConfig( + id = "notification-config-2", + url = "https://webhook2.example.com/notifications", + token = "webhook-token-456" + ) + ) + ) + ) + + testAPIMethod( + method = A2AMethod.ListTaskPushNotificationConfig, + request = request, + expectedResponse = expectedResponse, + invoke = { listTaskPushNotificationConfig(it) } + ) + } + + @Test + fun testDeleteTaskPushNotificationConfig() = runTest { + val id = RequestId.StringId("test-8") + + val configParams = TaskPushNotificationConfigParams( + id = "task-123", + pushNotificationConfigId = "notification-config-1" + ) + + val request = Request( + id = id, + data = configParams, + ) + + val expectedResponse = Response( + id = id, + data = null + ) + + testAPIMethod( + method = A2AMethod.DeleteTaskPushNotificationConfig, + request = request, + expectedResponse = expectedResponse, + invoke = { deleteTaskPushNotificationConfig(it) } + ) + } + + @Test + fun testSendMessageError() = runTest { + val id = RequestId.StringId("test-error-1") + + val testMessage = Message( + messageId = Uuid.random().toString(), + role = Role.User, + parts = listOf(TextPart("Hello, agent!")), + taskId = "invalid-task-id" + ) + + val messageSendParams = MessageSendParams( + message = testMessage + ) + + val request = Request( + id = id, + data = messageSendParams, + ) + + val mockEngine = MockEngine { receivedRequest -> + assertEquals(HttpMethod.Post, receivedRequest.method) + assertEquals(ContentType.Application.Json, receivedRequest.body.contentType) + + val requestBodyText = (receivedRequest.body as TextContent).text + val jsonRpcRequest = json.decodeFromString(requestBodyText) + + assertEquals(A2AMethod.SendMessage.value, jsonRpcRequest.method) + assertEquals(request.id, jsonRpcRequest.id) + assertEquals(request.data, json.decodeFromJsonElement(jsonRpcRequest.params)) + + val jsonRpcErrorResponse = JSONRPCErrorResponse( + id = id, + error = JSONRPCError( + code = A2AErrorCodes.INVALID_PARAMS, + message = "Invalid method parameters", + data = json.encodeToJsonElement("The message parameters are invalid") + ), + jsonrpc = JSONRPC_VERSION, + ) + + respond( + content = json.encodeToString(jsonRpcErrorResponse), + headers = headersOf(HttpHeaders.ContentType, ContentType.Application.Json.toString()), + ) + } + + val httpClient = HttpClient(mockEngine) + val transport = HttpJSONRPCClientTransport("https://api.example.com/a2a", httpClient) + + try { + transport.sendMessage(request) + fail("Expected A2AInvalidParamsException to be thrown") + } catch (e: A2AInvalidParamsException) { + assertEquals("Invalid method parameters", e.message) + assertEquals(-32602, e.errorCode) + } + + transport.close() + } +} diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/Module.md b/a2a/a2a-transport/a2a-transport-core-jsonrpc/Module.md new file mode 100644 index 0000000000..a9edb95769 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/Module.md @@ -0,0 +1,3 @@ +# Module a2a-transport-core-jsonrpc + +Core JSON-RPC protocol implementation for A2A communication. Provides base classes and utilities for implementing JSON-RPC based transport layers, including request/response handling and error mapping. diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/build.gradle.kts b/a2a/a2a-transport/a2a-transport-core-jsonrpc/build.gradle.kts new file mode 100644 index 0000000000..dbc55803c5 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/build.gradle.kts @@ -0,0 +1,47 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + sourceSets { + commonMain { + dependencies { + api(project(":a2a:a2a-core")) + + api(libs.kotlinx.serialization.json) + api(libs.kotlinx.coroutines.core) + + implementation(libs.oshai.kotlin.logging) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + runtimeOnly(libs.logback.classic) + } + } + + jsTest { + dependencies { + implementation(kotlin("test-js")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/A2AMethod.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/A2AMethod.kt new file mode 100644 index 0000000000..c7b02861b5 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/A2AMethod.kt @@ -0,0 +1,20 @@ +package ai.koog.a2a.transport.jsonrpc + +/** + * A2A JSON-RPC methods. + */ +public enum class A2AMethod( + public val value: String, + public val streaming: Boolean = false +) { + GetAuthenticatedExtendedAgentCard("agent/getAuthenticatedExtendedCard"), + SendMessage("message/send"), + SendMessageStreaming("message/stream", streaming = true), + GetTask("tasks/get"), + CancelTask("tasks/cancel"), + ResubscribeTask("tasks/resubscribe", streaming = true), + SetTaskPushNotificationConfig("tasks/pushNotificationConfig/set"), + GetTaskPushNotificationConfig("tasks/pushNotificationConfig/get"), + ListTaskPushNotificationConfig("tasks/pushNotificationConfig/list"), + DeleteTaskPushNotificationConfig("tasks/pushNotificationConfig/delete"), +} diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt new file mode 100644 index 0000000000..13ffc238d0 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt @@ -0,0 +1,181 @@ +package ai.koog.a2a.transport.jsonrpc + +import ai.koog.a2a.exceptions.A2AException +import ai.koog.a2a.exceptions.createA2AException +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.Event +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskPushNotificationConfig +import ai.koog.a2a.model.TaskPushNotificationConfigParams +import ai.koog.a2a.model.TaskQueryParams +import ai.koog.a2a.transport.ClientCallContext +import ai.koog.a2a.transport.ClientTransport +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.RequestId +import ai.koog.a2a.transport.Response +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCError +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCErrorResponse +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCJson +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCRequest +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCResponse +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCSuccessResponse +import ai.koog.a2a.transport.jsonrpc.model.JSONRPC_VERSION +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.onCompletion +import kotlinx.serialization.json.decodeFromJsonElement +import kotlinx.serialization.json.encodeToJsonElement + +/** + * Abstract transport implementation for JSON-RPC-based client communication. + * Handles sending JSON-RPC requests, processing responses, and mapping them to expected types. + */ +public abstract class JSONRPCClientTransport : ClientTransport { + /** + * Sends a JSON-RPC request and returns the corresponding response. + */ + protected abstract suspend fun request( + request: JSONRPCRequest, + ctx: ClientCallContext, + ): JSONRPCResponse + + /** + * Sends a JSON-RPC request and returns the corresponding response stream. + */ + protected abstract fun requestStreaming( + request: JSONRPCRequest, + ctx: ClientCallContext, + ): Flow + + /** + * Convert generic [Request] to [JSONRPCRequest]. + */ + protected inline fun Request.toJSONRPCRequest(method: A2AMethod): JSONRPCRequest { + return JSONRPCRequest( + id = id, + method = method.value, + params = JSONRPCJson.encodeToJsonElement(data), + jsonrpc = JSONRPC_VERSION, + ) + } + + /** + * Convert [JSONRPCResponse] to generic [Response]. + * + * @throws A2AException if server returned an error. + */ + protected inline fun JSONRPCResponse.toResponse(): Response { + return when (this) { + is JSONRPCSuccessResponse -> Response( + id = id, + data = JSONRPCJson.decodeFromJsonElement(result), + ) + + is JSONRPCErrorResponse -> { + throw error.toA2AException(id) + } + } + } + + protected fun JSONRPCError.toA2AException(id: RequestId?): A2AException { + return createA2AException(message, code, id) + } + + /** + * Generic request processing. + */ + protected suspend inline fun request( + method: A2AMethod, + request: Request, + ctx: ClientCallContext + ): Response { + val jsonrpcRequest = request.toJSONRPCRequest(method) + val jsonrpcResponse = request(jsonrpcRequest, ctx) + + return jsonrpcResponse.toResponse() + } + + /** + * Generic streaming request processing. + */ + protected inline fun requestStreaming( + method: A2AMethod, + request: Request, + ctx: ClientCallContext + ): Flow> { + val jsonrpcRequest = request.toJSONRPCRequest(method) + val jsonrpcResponse = requestStreaming(jsonrpcRequest, ctx) + + return jsonrpcResponse + .map { it.toResponse() } + .onCompletion { thr -> + // Do not let wrap A2A exceptions, propagate them directly + if (thr?.cause is A2AException) { + throw thr.cause!! + } + } + } + + override suspend fun getAuthenticatedExtendedAgentCard( + request: Request, + ctx: ClientCallContext + ): Response = + request(A2AMethod.GetAuthenticatedExtendedAgentCard, request, ctx) + + override suspend fun sendMessage( + request: Request, + ctx: ClientCallContext + ): Response = + request(A2AMethod.SendMessage, request, ctx) + + override fun sendMessageStreaming( + request: Request, + ctx: ClientCallContext + ): Flow> = + requestStreaming(A2AMethod.SendMessageStreaming, request, ctx) + + override suspend fun getTask( + request: Request, + ctx: ClientCallContext + ): Response = + request(A2AMethod.GetTask, request, ctx) + + override suspend fun cancelTask( + request: Request, + ctx: ClientCallContext + ): Response = + request(A2AMethod.CancelTask, request, ctx) + + override fun resubscribeTask( + request: Request, + ctx: ClientCallContext + ): Flow> = + requestStreaming(A2AMethod.ResubscribeTask, request, ctx) + + override suspend fun setTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext + ): Response = + request(A2AMethod.SetTaskPushNotificationConfig, request, ctx) + + override suspend fun getTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext + ): Response = + request(A2AMethod.GetTaskPushNotificationConfig, request, ctx) + + override suspend fun listTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext + ): Response> = + request(A2AMethod.ListTaskPushNotificationConfig, request, ctx) + + override suspend fun deleteTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext + ): Response = + request(A2AMethod.DeleteTaskPushNotificationConfig, request, ctx) +} diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt new file mode 100644 index 0000000000..ae015b6026 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt @@ -0,0 +1,227 @@ +package ai.koog.a2a.transport.jsonrpc + +import ai.koog.a2a.annotations.InternalA2AApi +import ai.koog.a2a.exceptions.A2AException +import ai.koog.a2a.exceptions.A2AInternalErrorException +import ai.koog.a2a.exceptions.A2AInvalidParamsException +import ai.koog.a2a.exceptions.A2AInvalidRequestException +import ai.koog.a2a.exceptions.A2AMethodNotFoundException +import ai.koog.a2a.exceptions.A2AParseException +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.RequestId +import ai.koog.a2a.transport.Response +import ai.koog.a2a.transport.ServerCallContext +import ai.koog.a2a.transport.ServerTransport +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCError +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCErrorResponse +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCJson +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCRequest +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCResponse +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCSuccessResponse +import ai.koog.a2a.transport.jsonrpc.model.JSONRPC_VERSION +import ai.koog.a2a.utils.runCatchingCancellable +import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.catch +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.map +import kotlinx.serialization.SerializationException +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.decodeFromJsonElement +import kotlinx.serialization.json.encodeToJsonElement +import kotlinx.serialization.json.jsonPrimitive + +/** + * Abstract transport implementation for JSON-RPC-based server communication. + * Handles receiving JSON-RPC requests, processing them, and sending responses. + */ +public abstract class JSONRPCServerTransport : ServerTransport { + private companion object { + private val logger = KotlinLogging.logger {} + } + + /** + * Manually parse [raw] string to build a [JSONRPCRequest] while throwing exceptions that A2A TCK excepts, according + * to A2A specification. + */ + protected fun parseJSONRPCRequest(raw: String): Pair { + val jsonBody = try { + JSONRPCJson.decodeFromString(raw) + } catch (e: SerializationException) { + throw A2AParseException("Cannot parse request body to JSON:\n${e.message}") + } + + // According to A2A TCK, need to parse id early to reply with provided id in error messages + val id = jsonBody["id"]?.let { + try { + JSONRPCJson.decodeFromJsonElement(it) + } catch (e: SerializationException) { + throw A2AInvalidRequestException("Cannot parse request id to JSON-RPC id:\n${e.message}") + } + } + + val a2aMethod = (jsonBody["method"] as? JsonPrimitive) + ?.content + ?.let { + A2AMethod.entries.find { m -> m.value == it } + ?: throw A2AMethodNotFoundException("Method not found: $it", id) + } + ?: throw A2AInvalidRequestException("No method parameter", id) + + val params = jsonBody["params"] + ?.let { + try { + JSONRPCJson + .decodeFromJsonElement(it) + .also { + // According to A2A TCK, empty parameter names are not allowed + if (it.keys.any { it.isEmpty() }) { + throw A2AInvalidParamsException("Empty parameter names are not allowed", id) + } + } + } catch (e: SerializationException) { + throw A2AInvalidParamsException("Cannot parse request params to JSON:\n${e.message}", id) + } + } + + val jsonrpc = jsonBody["jsonrpc"] + ?.jsonPrimitive?.content + ?.takeIf { it == JSONRPC_VERSION } + ?: throw A2AInvalidRequestException("Unsupported JSON-RPC version", id) + + val jsonrpcBody = JSONRPCRequest( + id = id ?: throw A2AInvalidRequestException("No id parameter"), + method = a2aMethod.value, + params = params ?: JsonNull, + jsonrpc = jsonrpc, + ) + + return jsonrpcBody to a2aMethod + } + + /** + * Handles a JSON-RPC request and returns the corresponding response + * Handles exceptions, mapping all non [A2AException]s to [A2AInternalErrorException], and then converting them to [JSONRPCErrorResponse]. + */ + @OptIn(InternalA2AApi::class) + protected suspend fun onRequest( + request: JSONRPCRequest, + ctx: ServerCallContext, + ): JSONRPCResponse { + return runCatchingCancellable { + when (request.method) { + A2AMethod.GetAuthenticatedExtendedAgentCard.value -> + requestHandler.onGetAuthenticatedExtendedAgentCard(request.toRequest(), ctx) + .toJSONRPCSuccessResponse() + + A2AMethod.SendMessage.value -> + requestHandler.onSendMessage(request.toRequest(), ctx).toJSONRPCSuccessResponse() + + A2AMethod.GetTask.value -> + requestHandler.onGetTask(request.toRequest(), ctx).toJSONRPCSuccessResponse() + + A2AMethod.CancelTask.value -> + requestHandler.onCancelTask(request.toRequest(), ctx).toJSONRPCSuccessResponse() + + A2AMethod.SetTaskPushNotificationConfig.value -> + requestHandler.onSetTaskPushNotificationConfig(request.toRequest(), ctx).toJSONRPCSuccessResponse() + + A2AMethod.GetTaskPushNotificationConfig.value -> + requestHandler.onGetTaskPushNotificationConfig(request.toRequest(), ctx).toJSONRPCSuccessResponse() + + A2AMethod.ListTaskPushNotificationConfig.value -> + requestHandler.onListTaskPushNotificationConfig(request.toRequest(), ctx).toJSONRPCSuccessResponse() + + A2AMethod.DeleteTaskPushNotificationConfig.value -> + requestHandler.onDeleteTaskPushNotificationConfig(request.toRequest(), ctx) + .toJSONRPCSuccessResponse() + + else -> + throw A2AMethodNotFoundException("Non-streaming method not found: ${request.method}") + } + }.getOrElse { it.toJSONRPCErrorResponse(request.id) } + } + + /** + * Handles a JSON-RPC request and returns the corresponding response stream. + * Handles exceptions, mapping all non [A2AException]s to [A2AInternalErrorException], and then converting them to [JSONRPCErrorResponse]. + * Terminates the flow after the first exception. + */ + protected fun onRequestStreaming( + request: JSONRPCRequest, + ctx: ServerCallContext, + ): Flow { + return when (request.method) { + A2AMethod.SendMessageStreaming.value -> + requestHandler.onSendMessageStreaming(request.toRequest(), ctx) + + A2AMethod.ResubscribeTask.value -> + requestHandler.onResubscribeTask(request.toRequest(), ctx) + + else -> + flow { throw A2AMethodNotFoundException("Streaming method not found: ${request.method}") } + }.map { it.toJSONRPCSuccessResponse() as JSONRPCResponse } + .catch { emit(it.toJSONRPCErrorResponse(request.id)) } + } + + /** + * Convert generic [JSONRPCRequest] to [Request]. + * + * @throws A2AInvalidParamsException if request params cannot be parsed to [T]. + */ + protected inline fun JSONRPCRequest.toRequest(): Request { + val data = try { + JSONRPCJson.decodeFromJsonElement(params) + } catch (e: SerializationException) { + throw A2AInvalidParamsException("Cannot parse request params:\n${e.message}") + } + + return Request( + id = id, + data = data + ) + } + + /** + * Convert generic [Response] to [JSONRPCSuccessResponse]. + */ + protected inline fun Response.toJSONRPCSuccessResponse(): JSONRPCSuccessResponse { + return JSONRPCSuccessResponse( + id = id, + result = JSONRPCJson.encodeToJsonElement(data), + jsonrpc = JSONRPC_VERSION, + ) + } + + /** + * Handles exceptions, mapping all non [A2AException]s to [A2AInternalErrorException], and then converting them to [JSONRPCErrorResponse]. + */ + protected fun Throwable.toJSONRPCErrorResponse(requestId: RequestId? = null): JSONRPCErrorResponse { + val a2aException: A2AException = when (this) { + is A2AException -> this + is CancellationException -> throw this + is Exception -> { + logger.warn(this) { "Non-A2A exception was detected when responding to request [requestId=$requestId]" } + A2AInternalErrorException("Internal error: ${this.message}") + } + + else -> throw this // Non-exception throwable shouldn't be handled, rethrowing it + } + + return JSONRPCErrorResponse( + id = requestId ?: a2aException.requestId, // if there's no requestId, use the one from the exception + error = a2aException.toJSONRPCError(), + jsonrpc = JSONRPC_VERSION, + ) + } + + protected fun A2AException.toJSONRPCError(): JSONRPCError { + return JSONRPCError( + code = errorCode, + message = message + ) + } +} diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt new file mode 100644 index 0000000000..0a87cfbc8a --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt @@ -0,0 +1,57 @@ +@file:Suppress("MissingKDocForPublicAPI") + +package ai.koog.a2a.transport.jsonrpc.model + +import ai.koog.a2a.transport.RequestId +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull + +/** + * Default JSON-RPC version. + */ +public const val JSONRPC_VERSION: String = "2.0" + +@Serializable(with = JSONRPCMessageSerializer::class) +public sealed interface JSONRPCMessage { + public val jsonrpc: String +} + +@Serializable(with = JSONRPCResponseSerializer::class) +public sealed interface JSONRPCResponse : JSONRPCMessage + +@Serializable +public data class JSONRPCRequest( + public val id: RequestId, + val method: String, + val params: JsonElement = JsonNull, + override val jsonrpc: String, +) : JSONRPCMessage + +@Serializable +public data class JSONRPCNotification( + val method: String, + val params: JsonElement = JsonNull, + override val jsonrpc: String, +) : JSONRPCMessage + +@Serializable +public data class JSONRPCSuccessResponse( + public val id: RequestId, + public val result: JsonElement = JsonNull, + override val jsonrpc: String, +) : JSONRPCResponse + +@Serializable +public data class JSONRPCError( + val code: Int, + val message: String, + val data: JsonElement = JsonNull, +) + +@Serializable +public data class JSONRPCErrorResponse( + public val id: RequestId?, + public val error: JSONRPCError, + override val jsonrpc: String, +) : JSONRPCResponse diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Serialization.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Serialization.kt new file mode 100644 index 0000000000..9763919e93 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Serialization.kt @@ -0,0 +1,42 @@ +@file:Suppress("MissingKDocForPublicAPI") + +package ai.koog.a2a.transport.jsonrpc.model + +import kotlinx.serialization.DeserializationStrategy +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonContentPolymorphicSerializer +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.jsonObject +import kotlin.collections.contains + +public val JSONRPCJson: Json = Json { + explicitNulls = false + encodeDefaults = false + ignoreUnknownKeys = true +} + +internal object JSONRPCMessageSerializer : JsonContentPolymorphicSerializer(JSONRPCMessage::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + val jsonObject = element.jsonObject + + return when { + "method" in jsonObject -> when { + "id" in jsonObject -> JSONRPCRequest.serializer() + else -> JSONRPCNotification.serializer() + } + + else -> JSONRPCResponseSerializer + } + } +} + +internal object JSONRPCResponseSerializer : JsonContentPolymorphicSerializer(JSONRPCResponse::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + val jsonObject = element.jsonObject + + return when { + "error" in jsonObject -> JSONRPCErrorResponse.serializer() + else -> JSONRPCSuccessResponse.serializer() + } + } +} diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonTest/kotlin/ai/koog/a2a/transport/jsonrpc/model/JsonRpcSerializationTest.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonTest/kotlin/ai/koog/a2a/transport/jsonrpc/model/JsonRpcSerializationTest.kt new file mode 100644 index 0000000000..87b084e48d --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonTest/kotlin/ai/koog/a2a/transport/jsonrpc/model/JsonRpcSerializationTest.kt @@ -0,0 +1,128 @@ +package ai.koog.a2a.transport.jsonrpc.model + +import ai.koog.a2a.transport.RequestId +import kotlinx.serialization.json.JsonPrimitive +import kotlin.test.Test +import kotlin.test.assertEquals + +class JsonRpcSerializationTest { + @Test + fun testJSONRPCError() { + val error = JSONRPCError(code = -32700, message = "Parse error", data = JsonPrimitive("some data")) + //language=JSON + val errorJson = """{"code":-32700,"message":"Parse error","data":"some data"}""" + + val serialized = JSONRPCJson.encodeToString(error) + assertEquals(errorJson, serialized) + + val deserialized = JSONRPCJson.decodeFromString(errorJson) + assertEquals(error, deserialized) + } + + @Test + fun testJSONRPCRequest() { + val request: JSONRPCMessage = JSONRPCRequest( + id = RequestId.NumberId(42), + method = "add", + jsonrpc = JSONRPC_VERSION, + ) + + //language=JSON + val requestJson = """{"id":42,"method":"add","jsonrpc":"2.0"}""" + + val serialized = JSONRPCJson.encodeToString(request) + assertEquals(requestJson, serialized) + + val deserialized = JSONRPCJson.decodeFromString(requestJson) + assertEquals(request, deserialized) + } + + @Test + fun testJSONRPCNotification() { + val request: JSONRPCMessage = JSONRPCNotification( + method = "update", + params = JsonPrimitive("notification-params"), + jsonrpc = JSONRPC_VERSION, + ) + + //language=JSON + val notificationJson = """{"method":"update","params":"notification-params","jsonrpc":"2.0"}""" + + val serialized = JSONRPCJson.encodeToString(request) + assertEquals(notificationJson, serialized) + + val deserialized = JSONRPCJson.decodeFromString(notificationJson) + assertEquals(request, deserialized) + } + + @Test + fun testJSONRPCNotificationWithoutParams() { + val request: JSONRPCMessage = JSONRPCNotification( + method = "notify", + jsonrpc = JSONRPC_VERSION, + ) + + //language=JSON + val notificationWithoutParamsJson = """{"method":"notify","jsonrpc":"2.0"}""" + + val serialized = JSONRPCJson.encodeToString(request) + assertEquals(notificationWithoutParamsJson, serialized) + + val deserialized = JSONRPCJson.decodeFromString(notificationWithoutParamsJson) + assertEquals(request, deserialized) + } + + @Test + fun testJSONRPCSuccessResponse() { + val response: JSONRPCMessage = JSONRPCSuccessResponse( + id = RequestId.NumberId(99), + result = JsonPrimitive(100), + jsonrpc = JSONRPC_VERSION, + ) + + //language=JSON + val successResponseJson = """{"id":99,"result":100,"jsonrpc":"2.0"}""" + + val serialized = JSONRPCJson.encodeToString(response) + assertEquals(successResponseJson, serialized) + + val deserialized = JSONRPCJson.decodeFromString(successResponseJson) + assertEquals(response, deserialized) + } + + @Test + fun testJSONRPCErrorResponse() { + val response: JSONRPCMessage = JSONRPCErrorResponse( + id = RequestId.NumberId(123), + error = JSONRPCError(code = -32602, message = "Invalid params"), + jsonrpc = JSONRPC_VERSION, + ) + + //language=JSON + val errorResponseJson = """{"id":123,"error":{"code":-32602,"message":"Invalid params"},"jsonrpc":"2.0"}""" + + val serialized = JSONRPCJson.encodeToString(response) + assertEquals(errorResponseJson, serialized) + + val deserialized = JSONRPCJson.decodeFromString(errorResponseJson) + assertEquals(response, deserialized) + } + + @Test + fun testJSONRPCErrorResponseWithoutId() { + val response: JSONRPCMessage = JSONRPCErrorResponse( + id = null, + error = JSONRPCError(code = -32700, message = "Parse error"), + jsonrpc = JSONRPC_VERSION, + ) + + //language=JSON + val errorResponseWithoutIdJson = """{"error":{"code":-32700,"message":"Parse error"},"jsonrpc":"2.0"}""" + + val serialized = JSONRPCJson.encodeToString(response) + assertEquals(errorResponseWithoutIdJson, serialized) + + val deserialized = JSONRPCJson.decodeFromString(errorResponseWithoutIdJson) + assertEquals(response, deserialized) + } +} diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/Module.md b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/Module.md new file mode 100644 index 0000000000..39c8ed9260 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/Module.md @@ -0,0 +1,3 @@ +# Module a2a-transport-server-jsonrpc-http + +HTTP-based JSON-RPC server transport implementation for A2A protocol. Implements the ServerTransport interface using Ktor HTTP server to expose JSON-RPC endpoints for agent communication. diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts new file mode 100644 index 0000000000..4033523c28 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts @@ -0,0 +1,49 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + sourceSets { + commonMain { + dependencies { + api(project(":a2a:a2a-transport:a2a-transport-core-jsonrpc")) + api(libs.kotlinx.coroutines.core) + api(libs.kotlinx.serialization.json) + api(libs.ktor.server.core) + implementation(libs.ktor.serialization.kotlinx.json) + implementation(libs.ktor.server.content.negotiation) + implementation(libs.ktor.server.sse) + implementation(libs.ktor.server.cors) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + implementation(libs.ktor.server.test.host) + } + } + + jsTest { + dependencies { + implementation(kotlin("test-js")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt new file mode 100644 index 0000000000..c32bf736c9 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt @@ -0,0 +1,273 @@ +package ai.koog.a2a.transport.server.jsonrpc.http + +import ai.koog.a2a.annotations.InternalA2AApi +import ai.koog.a2a.consts.A2AConsts +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.transport.RequestHandler +import ai.koog.a2a.transport.ServerCallContext +import ai.koog.a2a.transport.jsonrpc.JSONRPCServerTransport +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCJson +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCRequest +import ai.koog.a2a.utils.runCatchingCancellable +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.serialization.kotlinx.json.json +import io.ktor.server.application.ApplicationCall +import io.ktor.server.application.install +import io.ktor.server.engine.ApplicationEngine +import io.ktor.server.engine.ApplicationEngineFactory +import io.ktor.server.engine.EmbeddedServer +import io.ktor.server.engine.embeddedServer +import io.ktor.server.plugins.contentnegotiation.ContentNegotiation +import io.ktor.server.plugins.cors.routing.CORS +import io.ktor.server.request.receiveText +import io.ktor.server.response.header +import io.ktor.server.response.respond +import io.ktor.server.routing.Route +import io.ktor.server.routing.RoutingContext +import io.ktor.server.routing.get +import io.ktor.server.routing.post +import io.ktor.server.routing.route +import io.ktor.server.routing.routing +import io.ktor.server.sse.SSE +import io.ktor.server.sse.SSEServerContent +import io.ktor.server.sse.ServerSSESession +import io.ktor.sse.ServerSentEvent +import io.ktor.util.toMap +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock + +/** + * Implements A2A JSON-RPC server transport over HTTP using Ktor server + * It ensures compliance with the A2A specification for error handling and JSON-RPC request processing. + * This transport can be used either as a standalone server or integrated into an existing Ktor application. + * + * Example usage as a standalone server: + * ```kotlin + * val transport = HttpJSONRPCServerTransport( + * requestHandler = A2AServer(...) + * ) + * + * transport.start(Netty, 8080, "/my-agent", agentCard = AgentCard(...), agentCardPath = "/my-agent-card.json") + * transport.stop() + * ``` + * + * Example usage as an integration into an existing Ktor server. + * Can also be used to integrate multiple A2A server transports on the same server, to serve multiple A2A agents: + * ```kotlin + * val agentOneTransport = HttpJSONRPCServerTransport( + * requestHandler = A2AServer(...) + * ) + * val agentTwoTransport = HttpJSONRPCServerTransport( + * requestHandler = A2AServer(...) + * ) + * + * embeddedServer(Netty, port = 8080) { + * install(SSE) + * + * // Other configurations... + * + * routing { + * // Other routes... + * + * route("/a2a") { + * agentOneTransport.transportRoutes(this, "/agent-1") + * agentTwoTransport.transportRoutes(this, "/agent-2") + * } + * } + * }.startSuspend(wait = true) + * ``` + * + * @property requestHandler The handler responsible for processing A2A requests received by the transport. + */ +@OptIn(InternalA2AApi::class) +public class HttpJSONRPCServerTransport( + override val requestHandler: RequestHandler, +) : JSONRPCServerTransport() { + + /** + * Current running server instance if this transport is used as a standalone server. + */ + private var server: EmbeddedServer<*, *>? = null + private var serverMutex = Mutex() + + /** + * Starts Ktor embedded server with Netty engine to handle A2A JSON-RPC requests, using the specified port and endpoint path. + * Can be used to start a standalone server for quick prototyping or when no integration into the existing server is required. + * The routing consists only of [transportRoutes]. + * + * Can also optionally serve [AgentCard] at the specified [agentCardPath]. + * + * If you need to integrate A2A request handling logic into existing Ktor application, + * use [transportRoutes] method to mount the transport routes into existing [Route] configuration block. + * + * @param engineFactory An application engine factory. + * @param port A port on which the server will listen. + * @param path A JSON-RPC endpoint path to handle incoming requests. + * @param wait If true, the server will block until it is stopped. Defaults to false. + * @param agentCard An optional [AgentCard] that will be served at the specified [agentCardPath]. + * @param agentCardPath The path at which the [agentCard] will be served, if specified. + * Defaults to [A2AConsts.AGENT_CARD_WELL_KNOWN_PATH]. + * + * @throws IllegalStateException if the server is already running. + * + * @see [transportRoutes] + */ + public suspend fun start( + engineFactory: ApplicationEngineFactory, + port: Int, + path: String, + wait: Boolean = false, + agentCard: AgentCard? = null, + agentCardPath: String = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH, + ): Unit = serverMutex.withLock { + check(server == null) { "Server is already configured and running. Stop it before starting a new one." } + + server = embeddedServer(engineFactory, port) { + install(SSE) + + routing { + install(ContentNegotiation) { + json(JSONRPCJson) + } + + install(CORS) { + anyHost() + allowNonSimpleContentTypes = true + } + + transportRoutes(this, path) + + if (agentCard != null) { + get(agentCardPath) { + call.respond(agentCard) + } + } + } + }.startSuspend(wait = wait) + } + + /** + * Stops the server gracefully within the specified time limits. + * + * @param gracePeriodMillis The time in milliseconds to allow ongoing requests to finish gracefully before shutting down. + * @param timeoutMillis The maximum time in milliseconds to wait for the server to stop. + * + * @throws IllegalStateException if the server is not configured or running. + */ + public suspend fun stop(gracePeriodMillis: Long = 1000, timeoutMillis: Long = 2000): Unit = serverMutex.withLock { + check(server != null) { "Server is not configured or running." } + + server?.stopSuspend(gracePeriodMillis, timeoutMillis) + server = null + } + + /** + * Routes for handling JSON-RPC HTTP requests. + * Follows A2A specification in error handling. + * Allows mounting A2A requests handling into an existing Ktor server application. + * This can also be used to mount multiple A2A server transports on the same server, to serve multiple A2A agents. + * + * Example usage: + * ```kotlin + * embeddedServer(Netty, port = 8080) { + * install(SSE) + * + * // Other configurations... + * + * routing { + * // Other routes... + * + * route("/a2a") { + * agentOneTransport.transportRoutes(this, "/agent-1") + * agentTwoTransport.transportRoutes(this, "/agent-2") + * } + * } + * }.startSuspend(wait = true) + * ``` + * + * @param route The base route to which the transport routes should be mounted. + * @param path JSON-RPC endpoint path that will be mounted under the base [route]. + */ + public fun transportRoutes(route: Route, path: String): Route = route.route(path) { + plugin(SSE) + + install(ContentNegotiation) { + json(JSONRPCJson) + } + + // Handle incoming JSON-RPC requests, both regular and streaming + post { + runCatchingCancellable { + val (request, a2aMethod) = parseJSONRPCRequest(call.receiveText()) + val ctx = call.toServerCallContext() + + runCatchingCancellable { + if (a2aMethod.streaming) { + handleRequestStreaming(request, ctx) + } else { + handleRequest(request, ctx) + } + }.getOrElse { + call.respond(it.toJSONRPCErrorResponse(request.id)) + } + }.getOrElse { + call.respond(it.toJSONRPCErrorResponse()) + } + } + } + + /** + * Handling A2A requests to regular methods. + */ + private suspend fun RoutingContext.handleRequest(request: JSONRPCRequest, ctx: ServerCallContext) { + val response = runCatchingCancellable { + onRequest( + request = request, + ctx = ctx + ) + }.getOrElse { it.toJSONRPCErrorResponse() } + + call.respond(response) + } + + /** + * Handling A2A requests to streaming methods. + */ + private suspend fun RoutingContext.handleRequestStreaming(request: JSONRPCRequest, ctx: ServerCallContext) { + val handle: suspend ServerSSESession.() -> Unit = { + runCatchingCancellable { + onRequestStreaming( + request = request, + ctx = ctx + ).collect { response -> + send( + ServerSentEvent(JSONRPCJson.encodeToString(response)) + ) + } + }.getOrElse { + send( + ServerSentEvent( + JSONRPCJson.encodeToString(it.toJSONRPCErrorResponse()) + ) + ) + } + } + + // Reply with SSE (implementation copied from SSE plugin code) + call.response.apply { + header(HttpHeaders.ContentType, ContentType.Text.EventStream.toString()) + header(HttpHeaders.CacheControl, "no-store") + header(HttpHeaders.Connection, "keep-alive") + header("X-Accel-Buffering", "no") + } + + call.respond(SSEServerContent(call, handle)) + } + + private fun ApplicationCall.toServerCallContext(): ServerCallContext { + return ServerCallContext( + headers = request.headers.toMap() + ) + } +} diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt new file mode 100644 index 0000000000..b41739fbff --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt @@ -0,0 +1,642 @@ +package ai.koog.a2a.transport.server.jsonrpc.http + +import ai.koog.a2a.exceptions.A2AErrorCodes +import ai.koog.a2a.model.AgentCapabilities +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.AgentSkill +import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.Event +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.PushNotificationConfig +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskPushNotificationConfig +import ai.koog.a2a.model.TaskPushNotificationConfigParams +import ai.koog.a2a.model.TaskQueryParams +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatus +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.RequestHandler +import ai.koog.a2a.transport.RequestId +import ai.koog.a2a.transport.Response +import ai.koog.a2a.transport.ServerCallContext +import ai.koog.a2a.transport.jsonrpc.A2AMethod +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCErrorResponse +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCJson +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCRequest +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCSuccessResponse +import ai.koog.a2a.transport.jsonrpc.model.JSONRPC_VERSION +import io.ktor.client.plugins.sse.sse +import io.ktor.client.request.post +import io.ktor.client.request.setBody +import io.ktor.client.statement.bodyAsText +import io.ktor.http.ContentType +import io.ktor.http.HttpMethod +import io.ktor.http.HttpStatusCode +import io.ktor.http.contentType +import io.ktor.server.sse.SSE +import io.ktor.server.testing.testApplication +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.asFlow +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.decodeFromJsonElement +import kotlinx.serialization.json.encodeToJsonElement +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNull +import io.ktor.client.plugins.sse.SSE as SSEClient + +class HttpJSONRPCServerTransportTest { + private object MockRequestHandler : RequestHandler { + val agentCard = AgentCard( + name = "Test Agent", + description = "A test agent", + url = "https://api.example.com/a2a", + version = "1.0.0", + capabilities = AgentCapabilities(), + defaultInputModes = listOf("text/plain"), + defaultOutputModes = listOf("text/plain"), + security = listOf( + mapOf("oauth" to listOf("read")), + mapOf("api-key" to listOf("mtls")), + ), + skills = listOf( + AgentSkill( + id = "test-skill", + name = "Test Skill", + description = "A test skill", + tags = listOf("test") + ) + ) + ) + + val communicationEvent = Message( + messageId = "message-1", + role = Role.Agent, + parts = listOf(TextPart("Response message.")), + taskId = "task-1" + ) + + val updateEvents = listOf( + Message( + messageId = "message-stream-1", + role = Role.Agent, + parts = listOf(TextPart("Streaming response part 1")), + taskId = "task-1" + ), + Message( + messageId = "message-stream-2", + role = Role.Agent, + parts = listOf(TextPart("Streaming response part 2")), + taskId = "task-1" + ) + ) + + val taskGet = Task( + id = "task-1", + contextId = "test-context-1", + status = TaskStatus( + state = TaskState.Working + ) + ) + + val taskCancel = Task( + id = "task-1", + contextId = "test-context-1", + status = TaskStatus( + state = TaskState.Canceled + ) + ) + + val taskPushNotificationConfig = TaskPushNotificationConfig( + taskId = "task-1", + pushNotificationConfig = PushNotificationConfig( + id = "notification-config-1", + url = "https://webhook.example.com", + token = "webhook-token-123" + ) + ) + + val taskPushNotificationConfigList = listOf(taskPushNotificationConfig) + + override suspend fun onGetAuthenticatedExtendedAgentCard( + request: Request, + ctx: ServerCallContext + ): Response { + return Response( + id = request.id, + data = agentCard + ) + } + + override suspend fun onSendMessage( + request: Request, + ctx: ServerCallContext + ): Response { + return Response( + id = request.id, + data = communicationEvent + ) + } + + override fun onSendMessageStreaming( + request: Request, + ctx: ServerCallContext + ): Flow> { + return updateEvents + .asFlow() + .map { + Response( + id = request.id, + data = it + ) + } + } + + override suspend fun onGetTask( + request: Request, + ctx: ServerCallContext + ): Response { + return Response( + id = request.id, + data = taskGet + ) + } + + override suspend fun onCancelTask( + request: Request, + ctx: ServerCallContext + ): Response { + return Response( + id = request.id, + data = taskCancel + ) + } + + override fun onResubscribeTask( + request: Request, + ctx: ServerCallContext + ): Flow> { + return updateEvents + .asFlow() + .map { + Response( + id = request.id, + data = it + ) + } + } + + override suspend fun onSetTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response { + return Response( + id = request.id, + data = request.data + ) + } + + override suspend fun onGetTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response { + return Response( + id = request.id, + data = taskPushNotificationConfig + ) + } + + override suspend fun onListTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response> { + return Response( + id = request.id, + data = taskPushNotificationConfigList + ) + } + + override suspend fun onDeleteTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response { + return Response( + id = request.id, + data = null + ) + } + } + + private val json = JSONRPCJson + + private inline fun testServerMethod( + method: A2AMethod, + request: Request, + expectedResponse: Response, + ) { + testApplication { + install(SSE) + + val transport = HttpJSONRPCServerTransport(MockRequestHandler) + + routing { + transport.transportRoutes(this, "/a2a") + } + + val jsonRpcRequest = JSONRPCRequest( + id = request.id, + method = method.value, + params = json.encodeToJsonElement(request.data), + jsonrpc = JSONRPC_VERSION, + ) + + val response = client.post("/a2a") { + contentType(ContentType.Application.Json) + setBody(json.encodeToString(jsonRpcRequest)) + } + + assertEquals(HttpStatusCode.OK, response.status) + + val jsonRpcResponse = json.decodeFromString(response.bodyAsText()) + val actualResponse = Response( + id = jsonRpcResponse.id, + data = json.decodeFromJsonElement(jsonRpcResponse.result) + ) + + assertEquals(expectedResponse.id, actualResponse.id) + assertEquals(expectedResponse.data, actualResponse.data) + } + } + + private inline fun testServerMethodStreaming( + method: A2AMethod, + request: Request, + expectedResponses: List>, + ) { + testApplication { + install(SSE) + + val client = createClient { + install(SSEClient) + } + + val transport = HttpJSONRPCServerTransport(MockRequestHandler) + + routing { + transport.transportRoutes(this, "/a2a") + } + + val jsonRpcRequest = JSONRPCRequest( + id = request.id, + method = method.value, + params = json.encodeToJsonElement(request.data), + jsonrpc = JSONRPC_VERSION, + ) + + val jsonrpcResponses = buildList { + client.sse( + urlString = "/a2a", + request = { + this.method = HttpMethod.Post + + contentType(ContentType.Application.Json) + setBody(json.encodeToString(jsonRpcRequest)) + }, + ) { + assertEquals(HttpStatusCode.OK, call.response.status) + + incoming + .map { event -> JSONRPCJson.decodeFromString(event.data!!) } + .collect { add(it) } + } + } + + val actualResponses = jsonrpcResponses.map { + Response( + id = it.id, + data = json.decodeFromJsonElement(it.result) + ) + } + + assertEquals(expectedResponses.map { it.id }, actualResponses.map { it.id }) + assertEquals(expectedResponses.map { it.data }, actualResponses.map { it.data }) + } + } + + @Test + fun testGetAuthenticatedExtendedAgentCard() = runTest { + val requestId = RequestId.StringId("test-1") + + val request = Request( + id = requestId, + data = null, + ) + + val expectedResponse = Response( + id = requestId, + data = MockRequestHandler.agentCard, + ) + + testServerMethod( + method = A2AMethod.GetAuthenticatedExtendedAgentCard, + request = request, + expectedResponse = expectedResponse, + ) + } + + @Test + fun testSendMessage() = runTest { + val requestId = RequestId.StringId("test-2") + + val messageSendParams = MessageSendParams( + message = Message( + messageId = "msg-1", + role = Role.User, + parts = listOf(TextPart("Hello, agent!")), + taskId = "task-1" + ) + ) + + val request = Request( + id = requestId, + data = messageSendParams, + ) + + val expectedResponse = Response( + id = requestId, + data = MockRequestHandler.communicationEvent, + ) + + testServerMethod( + method = A2AMethod.SendMessage, + request = request, + expectedResponse = expectedResponse, + ) + } + + @Test + fun testSendMessageStreaming() = runTest { + val requestId = RequestId.StringId("test-2") + + val messageSendParams = MessageSendParams( + message = Message( + messageId = "msg-1", + role = Role.User, + parts = listOf(TextPart("Hello, agent!")), + taskId = "task-1" + ) + ) + + val request = Request( + id = requestId, + data = messageSendParams, + ) + + val expectedResponses = MockRequestHandler.updateEvents.map { + Response( + id = requestId, + data = it, + ) + } + + testServerMethodStreaming( + method = A2AMethod.SendMessageStreaming, + request = request, + expectedResponses = expectedResponses, + ) + } + + @Test + fun testGetTask() = runTest { + val requestId = RequestId.StringId("test-3") + val taskQueryParams = TaskQueryParams(id = "task-1") + + val request = Request( + id = requestId, + data = taskQueryParams, + ) + + val expectedResponse = Response( + id = requestId, + data = MockRequestHandler.taskGet, + ) + + testServerMethod( + method = A2AMethod.GetTask, + request = request, + expectedResponse = expectedResponse, + ) + } + + @Test + fun testCancelTask() = runTest { + val requestId = RequestId.StringId("test-4") + val taskIdParams = TaskIdParams(id = "task-1") + + val request = Request( + id = requestId, + data = taskIdParams, + ) + + val expectedResponse = Response( + id = requestId, + data = MockRequestHandler.taskCancel, + ) + + testServerMethod( + method = A2AMethod.CancelTask, + request = request, + expectedResponse = expectedResponse, + ) + } + + @Test + fun testResubscribeTask() = runTest { + val requestId = RequestId.StringId("test-7") + val taskIdParams = TaskIdParams(id = "task-1") + + val request = Request( + id = requestId, + data = taskIdParams, + ) + + val expectedResponses = MockRequestHandler.updateEvents.map { + Response( + id = requestId, + data = it, + ) + } + + testServerMethodStreaming( + method = A2AMethod.ResubscribeTask, + request = request, + expectedResponses = expectedResponses, + ) + } + + @Test + fun testSetTaskPushNotificationConfig() = runTest { + val requestId = RequestId.StringId("test-5") + + val pushNotificationConfig = TaskPushNotificationConfig( + taskId = "task-123", + pushNotificationConfig = PushNotificationConfig( + id = "notification-config-1", + url = "https://webhook.example.com/notifications", + token = "webhook-token-123" + ) + ) + + val request = Request( + id = requestId, + data = pushNotificationConfig, + ) + + val expectedResponse = Response( + id = requestId, + data = pushNotificationConfig, + ) + + testServerMethod( + method = A2AMethod.SetTaskPushNotificationConfig, + request = request, + expectedResponse = expectedResponse, + ) + } + + @Test + fun testGetTaskPushNotificationConfig() = runTest { + val requestId = RequestId.StringId("test-6") + + val configParams = TaskPushNotificationConfigParams( + id = "task-123", + pushNotificationConfigId = "notification-config-1" + ) + + val request = Request( + id = requestId, + data = configParams, + ) + + val expectedResponse = Response( + id = requestId, + data = MockRequestHandler.taskPushNotificationConfig, + ) + + testServerMethod( + method = A2AMethod.GetTaskPushNotificationConfig, + request = request, + expectedResponse = expectedResponse, + ) + } + + @Test + fun testListTaskPushNotificationConfig() = runTest { + val requestId = RequestId.StringId("test-7") + val taskIdParams = TaskIdParams(id = "task-1") + + val request = Request( + id = requestId, + data = taskIdParams, + ) + + val expectedResponse = Response( + id = requestId, + data = MockRequestHandler.taskPushNotificationConfigList, + ) + + testServerMethod( + method = A2AMethod.ListTaskPushNotificationConfig, + request = request, + expectedResponse = expectedResponse, + ) + } + + @Test + fun testDeleteTaskPushNotificationConfig() = runTest { + val requestId = RequestId.StringId("test-8") + + val configParams = TaskPushNotificationConfigParams( + id = "task-123", + pushNotificationConfigId = "notification-config-1" + ) + + val request = Request( + id = requestId, + data = configParams, + ) + + val expectedResponse = Response( + id = requestId, + data = null, + ) + + testServerMethod( + method = A2AMethod.DeleteTaskPushNotificationConfig, + request = request, + expectedResponse = expectedResponse, + ) + } + + @Test + fun testMethodNotFound() = runTest { + testApplication { + install(SSE) + + val transport = HttpJSONRPCServerTransport(MockRequestHandler) + + routing { + transport.transportRoutes(this, "/a2a") + } + + val requestId = RequestId.StringId("test-9") + val jsonRpcRequest = JSONRPCRequest( + id = requestId, + method = "unknown.method", + params = JsonNull, + jsonrpc = JSONRPC_VERSION, + ) + + val response = client.post("/a2a") { + contentType(ContentType.Application.Json) + setBody(json.encodeToString(jsonRpcRequest)) + } + + assertEquals(HttpStatusCode.OK, response.status) + + val jsonRpcResponse = json.decodeFromString(response.bodyAsText()) + assertEquals(requestId, jsonRpcResponse.id) + assertEquals(A2AErrorCodes.METHOD_NOT_FOUND, jsonRpcResponse.error.code) + } + } + + @Test + fun testInvalidJsonRequest() = runTest { + testApplication { + install(SSE) + + val transport = HttpJSONRPCServerTransport(MockRequestHandler) + + routing { + transport.transportRoutes(this, "/a2a") + } + + val response = client.post("/a2a") { + contentType(ContentType.Application.Json) + setBody("invalid json") + } + + assertEquals(HttpStatusCode.OK, response.status) + + val jsonRpcResponse = json.decodeFromString(response.bodyAsText()) + assertNull(jsonRpcResponse.id) + assertEquals(A2AErrorCodes.PARSE_ERROR, jsonRpcResponse.error.code) + } + } +} diff --git a/a2a/test-python-a2a-server/.gitignore b/a2a/test-python-a2a-server/.gitignore new file mode 100644 index 0000000000..605d1308d1 --- /dev/null +++ b/a2a/test-python-a2a-server/.gitignore @@ -0,0 +1,2 @@ +.venv +*.iml diff --git a/a2a/test-python-a2a-server/.python-version b/a2a/test-python-a2a-server/.python-version new file mode 100644 index 0000000000..e4fba21835 --- /dev/null +++ b/a2a/test-python-a2a-server/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/a2a/test-python-a2a-server/Dockerfile b/a2a/test-python-a2a-server/Dockerfile new file mode 100644 index 0000000000..87f268e7e2 --- /dev/null +++ b/a2a/test-python-a2a-server/Dockerfile @@ -0,0 +1,7 @@ +FROM ghcr.io/astral-sh/uv:python3.12-alpine +WORKDIR /app +COPY pyproject.toml uv.lock ./ +RUN uv sync --frozen +COPY src/ ./src/ +EXPOSE 9999 +CMD ["uv", "run", "--no-sync", "src/main.py"] diff --git a/a2a/test-python-a2a-server/pyproject.toml b/a2a/test-python-a2a-server/pyproject.toml new file mode 100644 index 0000000000..2ce64bc7b0 --- /dev/null +++ b/a2a/test-python-a2a-server/pyproject.toml @@ -0,0 +1,9 @@ +[project] +name = "test-python-a2a-server" +version = "0.1.0" +description = "Python A2A server for integration tests." +requires-python = ">=3.12" +dependencies = [ + "a2a-sdk[http-server]==0.3.5", + "uvicorn==0.35.0", +] diff --git a/a2a/test-python-a2a-server/src/agent_executor.py b/a2a/test-python-a2a-server/src/agent_executor.py new file mode 100644 index 0000000000..0a3ad4d5d1 --- /dev/null +++ b/a2a/test-python-a2a-server/src/agent_executor.py @@ -0,0 +1,201 @@ +import asyncio + +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.events import EventQueue +from a2a.types import ( + Message, + TaskStatusUpdateEvent, + TaskStatus, + TaskState, + Task, +) +from a2a.utils import ( + new_agent_text_message, + new_task +) +from datetime import datetime, timezone + + +def get_current_timestamp(): + """Get current timestamp in ISO 8601 format (UTC)""" + return datetime.now(timezone.utc).isoformat() + + +async def say_hello( + event_queue: EventQueue, + context: RequestContext, +) -> None: + message = context.message + + await event_queue.enqueue_event( + new_agent_text_message( + text="Hello World", + context_id=message.context_id, + task_id=message.task_id + ) + ) + + +async def do_task( + event_queue: EventQueue, + context: RequestContext, +) -> None: + message = context.message + + # noinspection PyTypeChecker + task = Task( + id=message.task_id, + context_id=message.context_id, + status=TaskStatus( + state=TaskState.submitted, + timestamp=get_current_timestamp() + ), + history=[message] + ) + + # noinspection PyTypeChecker + events = [ + task, + + TaskStatusUpdateEvent( + context_id=task.context_id, + task_id=task.id, + status=TaskStatus( + state=TaskState.working, + message=new_agent_text_message( + text="Working on task", + context_id=task.context_id, + task_id=task.id + ) + ), + final=False + ), + + TaskStatusUpdateEvent( + context_id=task.context_id, + task_id=task.id, + status=TaskStatus( + state=TaskState.completed, + message=new_agent_text_message( + text="Task completed", + context_id=task.context_id, + task_id=task.id + ) + ), + final=True + ) + ] + + for event in events: + await event_queue.enqueue_event(event) + + +async def do_cancelable_task( + event_queue: EventQueue, + context: RequestContext, +): + message = context.message + + # noinspection PyTypeChecker + task = Task( + id=message.task_id, + context_id=message.context_id, + status=TaskStatus( + state=TaskState.submitted, + timestamp=get_current_timestamp() + ), + history=[message] + ) + await event_queue.enqueue_event(task) + + +async def do_long_running_task( + event_queue: EventQueue, + context: RequestContext, +): + message = context.message + + # noinspection PyTypeChecker + task = Task( + id=message.task_id, + context_id=message.context_id, + status=TaskStatus( + state=TaskState.submitted, + timestamp=get_current_timestamp() + ), + history=[message] + ) + + await event_queue.enqueue_event(task) + + # Simulate long-running task + for i in range(4): + await asyncio.sleep(0.2) + + # noinspection PyTypeChecker + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=task.id, + context_id=task.context_id, + status=TaskStatus( + state=TaskState.working, + message=new_agent_text_message( + text=f"Still working {i}", + context_id=task.context_id, + task_id=task.id + ) + ), + final=False + ) + ) + + +class HelloWorldAgentExecutor(AgentExecutor): + """Test AgentProxy Implementation.""" + + async def execute( + self, + context: RequestContext, + event_queue: EventQueue, + ) -> None: + user_input = context.get_user_input() + + # Test scenarios to test various aspects of A2A + if user_input == "hello world": + await say_hello(event_queue, context) + + elif user_input == "do task": + await do_task(event_queue, context) + + elif user_input == "do cancelable task": + await do_cancelable_task(event_queue, context) + + elif user_input == "do long-running task": + await do_long_running_task(event_queue, context) + + else: + await event_queue.enqueue_event( + new_agent_text_message("Sorry, I don't understand you") + ) + + async def cancel( + self, + context: RequestContext, + event_queue: EventQueue + ) -> None: + # noinspection PyTypeChecker + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + context_id=context.context_id, + task_id=context.task_id, + status=TaskStatus( + state=TaskState.canceled, + message=new_agent_text_message( + text="Task canceled", + context_id=context.context_id, + task_id=context.task_id + ) + ), + final=True, + ) + ) diff --git a/a2a/test-python-a2a-server/src/main.py b/a2a/test-python-a2a-server/src/main.py new file mode 100644 index 0000000000..516ce0fa34 --- /dev/null +++ b/a2a/test-python-a2a-server/src/main.py @@ -0,0 +1,79 @@ +import uvicorn + +from a2a.server.apps import A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import ( + InMemoryTaskStore, + InMemoryPushNotificationConfigStore +) +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentSkill, +) +from agent_executor import ( + HelloWorldAgentExecutor, +) + + +if __name__ == '__main__': + skill = AgentSkill( + id='hello_world', + name='Returns hello world', + description='just returns hello world', + tags=['hello world'], + examples=['hi', 'hello world'], + ) + + extended_skill = AgentSkill( + id='super_hello_world', + name='Returns a SUPER Hello World', + description='A more enthusiastic greeting, only for authenticated users.', + tags=['hello world', 'super', 'extended'], + examples=['super hi', 'give me a super hello'], + ) + + public_agent_card = AgentCard( + name='Hello World Agent', + description='Just a hello world agent', + url='http://localhost:9999/', + version='1.0.0', + default_input_modes=['text'], + default_output_modes=['text'], + capabilities=AgentCapabilities( + streaming=True, + push_notifications=True, + ), + skills=[skill], # Only the basic skill for the public card + supports_authenticated_extended_card=True, + ) + + # This will be the authenticated extended agent card + # It includes the additional 'extended_skill' + specific_extended_agent_card = public_agent_card.model_copy( + update={ + 'name': 'Hello World Agent - Extended Edition', # Different name for clarity + 'description': 'The full-featured hello world agent for authenticated users.', + 'version': '1.0.1', # Could even be a different version + # Capabilities and other fields like url, default_input_modes, default_output_modes, + # supports_authenticated_extended_card are inherited from public_agent_card unless specified here. + 'skills': [ + skill, + extended_skill, + ], # Both skills for the extended card + } + ) + + request_handler = DefaultRequestHandler( + agent_executor=HelloWorldAgentExecutor(), + task_store=InMemoryTaskStore(), + push_config_store=InMemoryPushNotificationConfigStore() + ) + + server = A2AStarletteApplication( + agent_card=public_agent_card, + http_handler=request_handler, + extended_agent_card=specific_extended_agent_card, + ) + + uvicorn.run(server.build(), host='0.0.0.0', port=9999) diff --git a/a2a/test-python-a2a-server/uv.lock b/a2a/test-python-a2a-server/uv.lock new file mode 100644 index 0000000000..6bbaa6a180 --- /dev/null +++ b/a2a/test-python-a2a-server/uv.lock @@ -0,0 +1,468 @@ +version = 1 +revision = 2 +requires-python = ">=3.12" +resolution-markers = [ + "python_full_version >= '3.13'", + "python_full_version < '3.13'", +] + +[[package]] +name = "a2a-sdk" +version = "0.3.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core" }, + { name = "httpx" }, + { name = "httpx-sse" }, + { name = "protobuf" }, + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cf/0d/12ebef081b096ca5fafd1ec8cc589739abba07b46ae7899c7420e599f2a6/a2a_sdk-0.3.5.tar.gz", hash = "sha256:48cf37dedeb63cf0a072512221a12ed4b3950df695c9d65eadb839a99392c3e5", size = 222064, upload-time = "2025-09-08T17:30:35.826Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4c/96/c33802d929b0f884cb6e509195d69914632536256d273bd7127e900d79ea/a2a_sdk-0.3.5-py3-none-any.whl", hash = "sha256:fd85b1e4e7be18a89b5d723e4013171510150a235275876f98de9e1ba869457e", size = 136911, upload-time = "2025-09-08T17:30:34.091Z" }, +] + +[package.optional-dependencies] +http-server = [ + { name = "fastapi" }, + { name = "sse-starlette" }, + { name = "starlette" }, +] + +[[package]] +name = "annotated-types" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081, upload-time = "2024-05-20T21:33:25.928Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, +] + +[[package]] +name = "anyio" +version = "4.10.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna" }, + { name = "sniffio" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f1/b4/636b3b65173d3ce9a38ef5f0522789614e590dab6a8d505340a4efe4c567/anyio-4.10.0.tar.gz", hash = "sha256:3f3fae35c96039744587aa5b8371e7e8e603c0702999535961dd336026973ba6", size = 213252, upload-time = "2025-08-04T08:54:26.451Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6f/12/e5e0282d673bb9746bacfb6e2dba8719989d3660cdb2ea79aee9a9651afb/anyio-4.10.0-py3-none-any.whl", hash = "sha256:60e474ac86736bbfd6f210f7a61218939c318f43f9972497381f1c5e930ed3d1", size = 107213, upload-time = "2025-08-04T08:54:24.882Z" }, +] + +[[package]] +name = "cachetools" +version = "5.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/81/3747dad6b14fa2cf53fcf10548cf5aea6913e96fab41a3c198676f8948a5/cachetools-5.5.2.tar.gz", hash = "sha256:1a661caa9175d26759571b2e19580f9d6393969e5dfca11fdb1f947a23e640d4", size = 28380, upload-time = "2025-02-20T21:01:19.524Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/76/20fa66124dbe6be5cafeb312ece67de6b61dd91a0247d1ea13db4ebb33c2/cachetools-5.5.2-py3-none-any.whl", hash = "sha256:d26a22bcc62eb95c3beabd9f1ee5e820d3d2704fe2967cbe350e20c8ffcd3f0a", size = 10080, upload-time = "2025-02-20T21:01:16.647Z" }, +] + +[[package]] +name = "certifi" +version = "2025.8.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/dc/67/960ebe6bf230a96cda2e0abcf73af550ec4f090005363542f0765df162e0/certifi-2025.8.3.tar.gz", hash = "sha256:e564105f78ded564e3ae7c923924435e1daa7463faeab5bb932bc53ffae63407", size = 162386, upload-time = "2025-08-03T03:07:47.08Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/48/1549795ba7742c948d2ad169c1c8cdbae65bc450d6cd753d124b17c8cd32/certifi-2025.8.3-py3-none-any.whl", hash = "sha256:f6c12493cfb1b06ba2ff328595af9350c65d6644968e5d3a2ffd78699af217a5", size = 161216, upload-time = "2025-08-03T03:07:45.777Z" }, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/83/2d/5fd176ceb9b2fc619e63405525573493ca23441330fcdaee6bef9460e924/charset_normalizer-3.4.3.tar.gz", hash = "sha256:6fce4b8500244f6fcb71465d4a4930d132ba9ab8e71a7859e6a5d59851068d14", size = 122371, upload-time = "2025-08-09T07:57:28.46Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/5e/14c94999e418d9b87682734589404a25854d5f5d0408df68bc15b6ff54bb/charset_normalizer-3.4.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:e28e334d3ff134e88989d90ba04b47d84382a828c061d0d1027b1b12a62b39b1", size = 205655, upload-time = "2025-08-09T07:56:08.475Z" }, + { url = "https://files.pythonhosted.org/packages/7d/a8/c6ec5d389672521f644505a257f50544c074cf5fc292d5390331cd6fc9c3/charset_normalizer-3.4.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0cacf8f7297b0c4fcb74227692ca46b4a5852f8f4f24b3c766dd94a1075c4884", size = 146223, upload-time = "2025-08-09T07:56:09.708Z" }, + { url = "https://files.pythonhosted.org/packages/fc/eb/a2ffb08547f4e1e5415fb69eb7db25932c52a52bed371429648db4d84fb1/charset_normalizer-3.4.3-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c6fd51128a41297f5409deab284fecbe5305ebd7e5a1f959bee1c054622b7018", size = 159366, upload-time = "2025-08-09T07:56:11.326Z" }, + { url = "https://files.pythonhosted.org/packages/82/10/0fd19f20c624b278dddaf83b8464dcddc2456cb4b02bb902a6da126b87a1/charset_normalizer-3.4.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:3cfb2aad70f2c6debfbcb717f23b7eb55febc0bb23dcffc0f076009da10c6392", size = 157104, upload-time = "2025-08-09T07:56:13.014Z" }, + { url = "https://files.pythonhosted.org/packages/16/ab/0233c3231af734f5dfcf0844aa9582d5a1466c985bbed6cedab85af9bfe3/charset_normalizer-3.4.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1606f4a55c0fd363d754049cdf400175ee96c992b1f8018b993941f221221c5f", size = 151830, upload-time = "2025-08-09T07:56:14.428Z" }, + { url = "https://files.pythonhosted.org/packages/ae/02/e29e22b4e02839a0e4a06557b1999d0a47db3567e82989b5bb21f3fbbd9f/charset_normalizer-3.4.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:027b776c26d38b7f15b26a5da1044f376455fb3766df8fc38563b4efbc515154", size = 148854, upload-time = "2025-08-09T07:56:16.051Z" }, + { url = "https://files.pythonhosted.org/packages/05/6b/e2539a0a4be302b481e8cafb5af8792da8093b486885a1ae4d15d452bcec/charset_normalizer-3.4.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:42e5088973e56e31e4fa58eb6bd709e42fc03799c11c42929592889a2e54c491", size = 160670, upload-time = "2025-08-09T07:56:17.314Z" }, + { url = "https://files.pythonhosted.org/packages/31/e7/883ee5676a2ef217a40ce0bffcc3d0dfbf9e64cbcfbdf822c52981c3304b/charset_normalizer-3.4.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:cc34f233c9e71701040d772aa7490318673aa7164a0efe3172b2981218c26d93", size = 158501, upload-time = "2025-08-09T07:56:18.641Z" }, + { url = "https://files.pythonhosted.org/packages/c1/35/6525b21aa0db614cf8b5792d232021dca3df7f90a1944db934efa5d20bb1/charset_normalizer-3.4.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:320e8e66157cc4e247d9ddca8e21f427efc7a04bbd0ac8a9faf56583fa543f9f", size = 153173, upload-time = "2025-08-09T07:56:20.289Z" }, + { url = "https://files.pythonhosted.org/packages/50/ee/f4704bad8201de513fdc8aac1cabc87e38c5818c93857140e06e772b5892/charset_normalizer-3.4.3-cp312-cp312-win32.whl", hash = "sha256:fb6fecfd65564f208cbf0fba07f107fb661bcd1a7c389edbced3f7a493f70e37", size = 99822, upload-time = "2025-08-09T07:56:21.551Z" }, + { url = "https://files.pythonhosted.org/packages/39/f5/3b3836ca6064d0992c58c7561c6b6eee1b3892e9665d650c803bd5614522/charset_normalizer-3.4.3-cp312-cp312-win_amd64.whl", hash = "sha256:86df271bf921c2ee3818f0522e9a5b8092ca2ad8b065ece5d7d9d0e9f4849bcc", size = 107543, upload-time = "2025-08-09T07:56:23.115Z" }, + { url = "https://files.pythonhosted.org/packages/65/ca/2135ac97709b400c7654b4b764daf5c5567c2da45a30cdd20f9eefe2d658/charset_normalizer-3.4.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:14c2a87c65b351109f6abfc424cab3927b3bdece6f706e4d12faaf3d52ee5efe", size = 205326, upload-time = "2025-08-09T07:56:24.721Z" }, + { url = "https://files.pythonhosted.org/packages/71/11/98a04c3c97dd34e49c7d247083af03645ca3730809a5509443f3c37f7c99/charset_normalizer-3.4.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:41d1fc408ff5fdfb910200ec0e74abc40387bccb3252f3f27c0676731df2b2c8", size = 146008, upload-time = "2025-08-09T07:56:26.004Z" }, + { url = "https://files.pythonhosted.org/packages/60/f5/4659a4cb3c4ec146bec80c32d8bb16033752574c20b1252ee842a95d1a1e/charset_normalizer-3.4.3-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:1bb60174149316da1c35fa5233681f7c0f9f514509b8e399ab70fea5f17e45c9", size = 159196, upload-time = "2025-08-09T07:56:27.25Z" }, + { url = "https://files.pythonhosted.org/packages/86/9e/f552f7a00611f168b9a5865a1414179b2c6de8235a4fa40189f6f79a1753/charset_normalizer-3.4.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:30d006f98569de3459c2fc1f2acde170b7b2bd265dc1943e87e1a4efe1b67c31", size = 156819, upload-time = "2025-08-09T07:56:28.515Z" }, + { url = "https://files.pythonhosted.org/packages/7e/95/42aa2156235cbc8fa61208aded06ef46111c4d3f0de233107b3f38631803/charset_normalizer-3.4.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:416175faf02e4b0810f1f38bcb54682878a4af94059a1cd63b8747244420801f", size = 151350, upload-time = "2025-08-09T07:56:29.716Z" }, + { url = "https://files.pythonhosted.org/packages/c2/a9/3865b02c56f300a6f94fc631ef54f0a8a29da74fb45a773dfd3dcd380af7/charset_normalizer-3.4.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6aab0f181c486f973bc7262a97f5aca3ee7e1437011ef0c2ec04b5a11d16c927", size = 148644, upload-time = "2025-08-09T07:56:30.984Z" }, + { url = "https://files.pythonhosted.org/packages/77/d9/cbcf1a2a5c7d7856f11e7ac2d782aec12bdfea60d104e60e0aa1c97849dc/charset_normalizer-3.4.3-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:fdabf8315679312cfa71302f9bd509ded4f2f263fb5b765cf1433b39106c3cc9", size = 160468, upload-time = "2025-08-09T07:56:32.252Z" }, + { url = "https://files.pythonhosted.org/packages/f6/42/6f45efee8697b89fda4d50580f292b8f7f9306cb2971d4b53f8914e4d890/charset_normalizer-3.4.3-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:bd28b817ea8c70215401f657edef3a8aa83c29d447fb0b622c35403780ba11d5", size = 158187, upload-time = "2025-08-09T07:56:33.481Z" }, + { url = "https://files.pythonhosted.org/packages/70/99/f1c3bdcfaa9c45b3ce96f70b14f070411366fa19549c1d4832c935d8e2c3/charset_normalizer-3.4.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:18343b2d246dc6761a249ba1fb13f9ee9a2bcd95decc767319506056ea4ad4dc", size = 152699, upload-time = "2025-08-09T07:56:34.739Z" }, + { url = "https://files.pythonhosted.org/packages/a3/ad/b0081f2f99a4b194bcbb1934ef3b12aa4d9702ced80a37026b7607c72e58/charset_normalizer-3.4.3-cp313-cp313-win32.whl", hash = "sha256:6fb70de56f1859a3f71261cbe41005f56a7842cc348d3aeb26237560bfa5e0ce", size = 99580, upload-time = "2025-08-09T07:56:35.981Z" }, + { url = "https://files.pythonhosted.org/packages/9a/8f/ae790790c7b64f925e5c953b924aaa42a243fb778fed9e41f147b2a5715a/charset_normalizer-3.4.3-cp313-cp313-win_amd64.whl", hash = "sha256:cf1ebb7d78e1ad8ec2a8c4732c7be2e736f6e5123a4146c5b89c9d1f585f8cef", size = 107366, upload-time = "2025-08-09T07:56:37.339Z" }, + { url = "https://files.pythonhosted.org/packages/8e/91/b5a06ad970ddc7a0e513112d40113e834638f4ca1120eb727a249fb2715e/charset_normalizer-3.4.3-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:3cd35b7e8aedeb9e34c41385fda4f73ba609e561faedfae0a9e75e44ac558a15", size = 204342, upload-time = "2025-08-09T07:56:38.687Z" }, + { url = "https://files.pythonhosted.org/packages/ce/ec/1edc30a377f0a02689342f214455c3f6c2fbedd896a1d2f856c002fc3062/charset_normalizer-3.4.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b89bc04de1d83006373429975f8ef9e7932534b8cc9ca582e4db7d20d91816db", size = 145995, upload-time = "2025-08-09T07:56:40.048Z" }, + { url = "https://files.pythonhosted.org/packages/17/e5/5e67ab85e6d22b04641acb5399c8684f4d37caf7558a53859f0283a650e9/charset_normalizer-3.4.3-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2001a39612b241dae17b4687898843f254f8748b796a2e16f1051a17078d991d", size = 158640, upload-time = "2025-08-09T07:56:41.311Z" }, + { url = "https://files.pythonhosted.org/packages/f1/e5/38421987f6c697ee3722981289d554957c4be652f963d71c5e46a262e135/charset_normalizer-3.4.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8dcfc373f888e4fb39a7bc57e93e3b845e7f462dacc008d9749568b1c4ece096", size = 156636, upload-time = "2025-08-09T07:56:43.195Z" }, + { url = "https://files.pythonhosted.org/packages/a0/e4/5a075de8daa3ec0745a9a3b54467e0c2967daaaf2cec04c845f73493e9a1/charset_normalizer-3.4.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:18b97b8404387b96cdbd30ad660f6407799126d26a39ca65729162fd810a99aa", size = 150939, upload-time = "2025-08-09T07:56:44.819Z" }, + { url = "https://files.pythonhosted.org/packages/02/f7/3611b32318b30974131db62b4043f335861d4d9b49adc6d57c1149cc49d4/charset_normalizer-3.4.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:ccf600859c183d70eb47e05a44cd80a4ce77394d1ac0f79dbd2dd90a69a3a049", size = 148580, upload-time = "2025-08-09T07:56:46.684Z" }, + { url = "https://files.pythonhosted.org/packages/7e/61/19b36f4bd67f2793ab6a99b979b4e4f3d8fc754cbdffb805335df4337126/charset_normalizer-3.4.3-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:53cd68b185d98dde4ad8990e56a58dea83a4162161b1ea9272e5c9182ce415e0", size = 159870, upload-time = "2025-08-09T07:56:47.941Z" }, + { url = "https://files.pythonhosted.org/packages/06/57/84722eefdd338c04cf3030ada66889298eaedf3e7a30a624201e0cbe424a/charset_normalizer-3.4.3-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:30a96e1e1f865f78b030d65241c1ee850cdf422d869e9028e2fc1d5e4db73b92", size = 157797, upload-time = "2025-08-09T07:56:49.756Z" }, + { url = "https://files.pythonhosted.org/packages/72/2a/aff5dd112b2f14bcc3462c312dce5445806bfc8ab3a7328555da95330e4b/charset_normalizer-3.4.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d716a916938e03231e86e43782ca7878fb602a125a91e7acb8b5112e2e96ac16", size = 152224, upload-time = "2025-08-09T07:56:51.369Z" }, + { url = "https://files.pythonhosted.org/packages/b7/8c/9839225320046ed279c6e839d51f028342eb77c91c89b8ef2549f951f3ec/charset_normalizer-3.4.3-cp314-cp314-win32.whl", hash = "sha256:c6dbd0ccdda3a2ba7c2ecd9d77b37f3b5831687d8dc1b6ca5f56a4880cc7b7ce", size = 100086, upload-time = "2025-08-09T07:56:52.722Z" }, + { url = "https://files.pythonhosted.org/packages/ee/7a/36fbcf646e41f710ce0a563c1c9a343c6edf9be80786edeb15b6f62e17db/charset_normalizer-3.4.3-cp314-cp314-win_amd64.whl", hash = "sha256:73dc19b562516fc9bcf6e5d6e596df0b4eb98d87e4f79f3ae71840e6ed21361c", size = 107400, upload-time = "2025-08-09T07:56:55.172Z" }, + { url = "https://files.pythonhosted.org/packages/8a/1f/f041989e93b001bc4e44bb1669ccdcf54d3f00e628229a85b08d330615c5/charset_normalizer-3.4.3-py3-none-any.whl", hash = "sha256:ce571ab16d890d23b5c278547ba694193a45011ff86a9162a71307ed9f86759a", size = 53175, upload-time = "2025-08-09T07:57:26.864Z" }, +] + +[[package]] +name = "click" +version = "8.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/60/6c/8ca2efa64cf75a977a0d7fac081354553ebe483345c734fb6b6515d96bbc/click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202", size = 286342, upload-time = "2025-05-20T23:19:49.832Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215, upload-time = "2025-05-20T23:19:47.796Z" }, +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, +] + +[[package]] +name = "fastapi" +version = "0.116.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/78/d7/6c8b3bfe33eeffa208183ec037fee0cce9f7f024089ab1c5d12ef04bd27c/fastapi-0.116.1.tar.gz", hash = "sha256:ed52cbf946abfd70c5a0dccb24673f0670deeb517a88b3544d03c2a6bf283143", size = 296485, upload-time = "2025-07-11T16:22:32.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/47/d63c60f59a59467fda0f93f46335c9d18526d7071f025cb5b89d5353ea42/fastapi-0.116.1-py3-none-any.whl", hash = "sha256:c46ac7c312df840f0c9e220f7964bada936781bc4e2e6eb71f1c4d7553786565", size = 95631, upload-time = "2025-07-11T16:22:30.485Z" }, +] + +[[package]] +name = "google-api-core" +version = "2.25.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth" }, + { name = "googleapis-common-protos" }, + { name = "proto-plus" }, + { name = "protobuf" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/dc/21/e9d043e88222317afdbdb567165fdbc3b0aad90064c7e0c9eb0ad9955ad8/google_api_core-2.25.1.tar.gz", hash = "sha256:d2aaa0b13c78c61cb3f4282c464c046e45fbd75755683c9c525e6e8f7ed0a5e8", size = 165443, upload-time = "2025-06-12T20:52:20.439Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/4b/ead00905132820b623732b175d66354e9d3e69fcf2a5dcdab780664e7896/google_api_core-2.25.1-py3-none-any.whl", hash = "sha256:8a2a56c1fef82987a524371f99f3bd0143702fecc670c72e600c1cda6bf8dbb7", size = 160807, upload-time = "2025-06-12T20:52:19.334Z" }, +] + +[[package]] +name = "google-auth" +version = "2.40.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cachetools" }, + { name = "pyasn1-modules" }, + { name = "rsa" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9e/9b/e92ef23b84fa10a64ce4831390b7a4c2e53c0132568d99d4ae61d04c8855/google_auth-2.40.3.tar.gz", hash = "sha256:500c3a29adedeb36ea9cf24b8d10858e152f2412e3ca37829b3fa18e33d63b77", size = 281029, upload-time = "2025-06-04T18:04:57.577Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/63/b19553b658a1692443c62bd07e5868adaa0ad746a0751ba62c59568cd45b/google_auth-2.40.3-py2.py3-none-any.whl", hash = "sha256:1370d4593e86213563547f97a92752fc658456fe4514c809544f330fed45a7ca", size = 216137, upload-time = "2025-06-04T18:04:55.573Z" }, +] + +[[package]] +name = "googleapis-common-protos" +version = "1.70.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/39/24/33db22342cf4a2ea27c9955e6713140fedd51e8b141b5ce5260897020f1a/googleapis_common_protos-1.70.0.tar.gz", hash = "sha256:0e1b44e0ea153e6594f9f394fef15193a68aaaea2d843f83e2742717ca753257", size = 145903, upload-time = "2025-04-14T10:17:02.924Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/86/f1/62a193f0227cf15a920390abe675f386dec35f7ae3ffe6da582d3ade42c7/googleapis_common_protos-1.70.0-py3-none-any.whl", hash = "sha256:b8bfcca8c25a2bb253e0e0b0adaf8c00773e5e6af6fd92397576680b807e0fd8", size = 294530, upload-time = "2025-04-14T10:17:01.271Z" }, +] + +[[package]] +name = "h11" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, +] + +[[package]] +name = "httpcore" +version = "1.0.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" }, +] + +[[package]] +name = "httpx" +version = "0.28.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "certifi" }, + { name = "httpcore" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, +] + +[[package]] +name = "httpx-sse" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6e/fa/66bd985dd0b7c109a3bcb89272ee0bfb7e2b4d06309ad7b38ff866734b2a/httpx_sse-0.4.1.tar.gz", hash = "sha256:8f44d34414bc7b21bf3602713005c5df4917884f76072479b21f68befa4ea26e", size = 12998, upload-time = "2025-06-24T13:21:05.71Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/25/0a/6269e3473b09aed2dab8aa1a600c70f31f00ae1349bee30658f7e358a159/httpx_sse-0.4.1-py3-none-any.whl", hash = "sha256:cba42174344c3a5b06f255ce65b350880f962d99ead85e776f23c6618a377a37", size = 8054, upload-time = "2025-06-24T13:21:04.772Z" }, +] + +[[package]] +name = "idna" +version = "3.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490, upload-time = "2024-09-15T18:07:39.745Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, +] + +[[package]] +name = "proto-plus" +version = "1.26.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f4/ac/87285f15f7cce6d4a008f33f1757fb5a13611ea8914eb58c3d0d26243468/proto_plus-1.26.1.tar.gz", hash = "sha256:21a515a4c4c0088a773899e23c7bbade3d18f9c66c73edd4c7ee3816bc96a012", size = 56142, upload-time = "2025-03-10T15:54:38.843Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/6d/280c4c2ce28b1593a19ad5239c8b826871fc6ec275c21afc8e1820108039/proto_plus-1.26.1-py3-none-any.whl", hash = "sha256:13285478c2dcf2abb829db158e1047e2f1e8d63a077d94263c2b88b043c75a66", size = 50163, upload-time = "2025-03-10T15:54:37.335Z" }, +] + +[[package]] +name = "protobuf" +version = "6.32.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c0/df/fb4a8eeea482eca989b51cffd274aac2ee24e825f0bf3cbce5281fa1567b/protobuf-6.32.0.tar.gz", hash = "sha256:a81439049127067fc49ec1d36e25c6ee1d1a2b7be930675f919258d03c04e7d2", size = 440614, upload-time = "2025-08-14T21:21:25.015Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/18/df8c87da2e47f4f1dcc5153a81cd6bca4e429803f4069a299e236e4dd510/protobuf-6.32.0-cp310-abi3-win32.whl", hash = "sha256:84f9e3c1ff6fb0308dbacb0950d8aa90694b0d0ee68e75719cb044b7078fe741", size = 424409, upload-time = "2025-08-14T21:21:12.366Z" }, + { url = "https://files.pythonhosted.org/packages/e1/59/0a820b7310f8139bd8d5a9388e6a38e1786d179d6f33998448609296c229/protobuf-6.32.0-cp310-abi3-win_amd64.whl", hash = "sha256:a8bdbb2f009cfc22a36d031f22a625a38b615b5e19e558a7b756b3279723e68e", size = 435735, upload-time = "2025-08-14T21:21:15.046Z" }, + { url = "https://files.pythonhosted.org/packages/cc/5b/0d421533c59c789e9c9894683efac582c06246bf24bb26b753b149bd88e4/protobuf-6.32.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:d52691e5bee6c860fff9a1c86ad26a13afbeb4b168cd4445c922b7e2cf85aaf0", size = 426449, upload-time = "2025-08-14T21:21:16.687Z" }, + { url = "https://files.pythonhosted.org/packages/ec/7b/607764ebe6c7a23dcee06e054fd1de3d5841b7648a90fd6def9a3bb58c5e/protobuf-6.32.0-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:501fe6372fd1c8ea2a30b4d9be8f87955a64d6be9c88a973996cef5ef6f0abf1", size = 322869, upload-time = "2025-08-14T21:21:18.282Z" }, + { url = "https://files.pythonhosted.org/packages/40/01/2e730bd1c25392fc32e3268e02446f0d77cb51a2c3a8486b1798e34d5805/protobuf-6.32.0-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:75a2aab2bd1aeb1f5dc7c5f33bcb11d82ea8c055c9becbb41c26a8c43fd7092c", size = 322009, upload-time = "2025-08-14T21:21:19.893Z" }, + { url = "https://files.pythonhosted.org/packages/9c/f2/80ffc4677aac1bc3519b26bc7f7f5de7fce0ee2f7e36e59e27d8beb32dd1/protobuf-6.32.0-py3-none-any.whl", hash = "sha256:ba377e5b67b908c8f3072a57b63e2c6a4cbd18aea4ed98d2584350dbf46f2783", size = 169287, upload-time = "2025-08-14T21:21:23.515Z" }, +] + +[[package]] +name = "pyasn1" +version = "0.6.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/e9/01f1a64245b89f039897cb0130016d79f77d52669aae6ee7b159a6c4c018/pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034", size = 145322, upload-time = "2024-09-10T22:41:42.55Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629", size = 83135, upload-time = "2024-09-11T16:00:36.122Z" }, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/e6/78ebbb10a8c8e4b61a59249394a4a594c1a7af95593dc933a349c8d00964/pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6", size = 307892, upload-time = "2025-03-28T02:41:22.17Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259, upload-time = "2025-03-28T02:41:19.028Z" }, +] + +[[package]] +name = "pydantic" +version = "2.11.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-types" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/00/dd/4325abf92c39ba8623b5af936ddb36ffcfe0beae70405d456ab1fb2f5b8c/pydantic-2.11.7.tar.gz", hash = "sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db", size = 788350, upload-time = "2025-06-14T08:33:17.137Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/c0/ec2b1c8712ca690e5d61979dee872603e92b8a32f94cc1b72d53beab008a/pydantic-2.11.7-py3-none-any.whl", hash = "sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b", size = 444782, upload-time = "2025-06-14T08:33:14.905Z" }, +] + +[[package]] +name = "pydantic-core" +version = "2.33.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ad/88/5f2260bdfae97aabf98f1778d43f69574390ad787afb646292a638c923d4/pydantic_core-2.33.2.tar.gz", hash = "sha256:7cb8bc3605c29176e1b105350d2e6474142d7c1bd1d9327c4a9bdb46bf827acc", size = 435195, upload-time = "2025-04-23T18:33:52.104Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/8a/2b41c97f554ec8c71f2a8a5f85cb56a8b0956addfe8b0efb5b3d77e8bdc3/pydantic_core-2.33.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a7ec89dc587667f22b6a0b6579c249fca9026ce7c333fc142ba42411fa243cdc", size = 2009000, upload-time = "2025-04-23T18:31:25.863Z" }, + { url = "https://files.pythonhosted.org/packages/a1/02/6224312aacb3c8ecbaa959897af57181fb6cf3a3d7917fd44d0f2917e6f2/pydantic_core-2.33.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3c6db6e52c6d70aa0d00d45cdb9b40f0433b96380071ea80b09277dba021ddf7", size = 1847996, upload-time = "2025-04-23T18:31:27.341Z" }, + { url = "https://files.pythonhosted.org/packages/d6/46/6dcdf084a523dbe0a0be59d054734b86a981726f221f4562aed313dbcb49/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e61206137cbc65e6d5256e1166f88331d3b6238e082d9f74613b9b765fb9025", size = 1880957, upload-time = "2025-04-23T18:31:28.956Z" }, + { url = "https://files.pythonhosted.org/packages/ec/6b/1ec2c03837ac00886ba8160ce041ce4e325b41d06a034adbef11339ae422/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb8c529b2819c37140eb51b914153063d27ed88e3bdc31b71198a198e921e011", size = 1964199, upload-time = "2025-04-23T18:31:31.025Z" }, + { url = "https://files.pythonhosted.org/packages/2d/1d/6bf34d6adb9debd9136bd197ca72642203ce9aaaa85cfcbfcf20f9696e83/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c52b02ad8b4e2cf14ca7b3d918f3eb0ee91e63b3167c32591e57c4317e134f8f", size = 2120296, upload-time = "2025-04-23T18:31:32.514Z" }, + { url = "https://files.pythonhosted.org/packages/e0/94/2bd0aaf5a591e974b32a9f7123f16637776c304471a0ab33cf263cf5591a/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:96081f1605125ba0855dfda83f6f3df5ec90c61195421ba72223de35ccfb2f88", size = 2676109, upload-time = "2025-04-23T18:31:33.958Z" }, + { url = "https://files.pythonhosted.org/packages/f9/41/4b043778cf9c4285d59742281a769eac371b9e47e35f98ad321349cc5d61/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f57a69461af2a5fa6e6bbd7a5f60d3b7e6cebb687f55106933188e79ad155c1", size = 2002028, upload-time = "2025-04-23T18:31:39.095Z" }, + { url = "https://files.pythonhosted.org/packages/cb/d5/7bb781bf2748ce3d03af04d5c969fa1308880e1dca35a9bd94e1a96a922e/pydantic_core-2.33.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:572c7e6c8bb4774d2ac88929e3d1f12bc45714ae5ee6d9a788a9fb35e60bb04b", size = 2100044, upload-time = "2025-04-23T18:31:41.034Z" }, + { url = "https://files.pythonhosted.org/packages/fe/36/def5e53e1eb0ad896785702a5bbfd25eed546cdcf4087ad285021a90ed53/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:db4b41f9bd95fbe5acd76d89920336ba96f03e149097365afe1cb092fceb89a1", size = 2058881, upload-time = "2025-04-23T18:31:42.757Z" }, + { url = "https://files.pythonhosted.org/packages/01/6c/57f8d70b2ee57fc3dc8b9610315949837fa8c11d86927b9bb044f8705419/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:fa854f5cf7e33842a892e5c73f45327760bc7bc516339fda888c75ae60edaeb6", size = 2227034, upload-time = "2025-04-23T18:31:44.304Z" }, + { url = "https://files.pythonhosted.org/packages/27/b9/9c17f0396a82b3d5cbea4c24d742083422639e7bb1d5bf600e12cb176a13/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5f483cfb75ff703095c59e365360cb73e00185e01aaea067cd19acffd2ab20ea", size = 2234187, upload-time = "2025-04-23T18:31:45.891Z" }, + { url = "https://files.pythonhosted.org/packages/b0/6a/adf5734ffd52bf86d865093ad70b2ce543415e0e356f6cacabbc0d9ad910/pydantic_core-2.33.2-cp312-cp312-win32.whl", hash = "sha256:9cb1da0f5a471435a7bc7e439b8a728e8b61e59784b2af70d7c169f8dd8ae290", size = 1892628, upload-time = "2025-04-23T18:31:47.819Z" }, + { url = "https://files.pythonhosted.org/packages/43/e4/5479fecb3606c1368d496a825d8411e126133c41224c1e7238be58b87d7e/pydantic_core-2.33.2-cp312-cp312-win_amd64.whl", hash = "sha256:f941635f2a3d96b2973e867144fde513665c87f13fe0e193c158ac51bfaaa7b2", size = 1955866, upload-time = "2025-04-23T18:31:49.635Z" }, + { url = "https://files.pythonhosted.org/packages/0d/24/8b11e8b3e2be9dd82df4b11408a67c61bb4dc4f8e11b5b0fc888b38118b5/pydantic_core-2.33.2-cp312-cp312-win_arm64.whl", hash = "sha256:cca3868ddfaccfbc4bfb1d608e2ccaaebe0ae628e1416aeb9c4d88c001bb45ab", size = 1888894, upload-time = "2025-04-23T18:31:51.609Z" }, + { url = "https://files.pythonhosted.org/packages/46/8c/99040727b41f56616573a28771b1bfa08a3d3fe74d3d513f01251f79f172/pydantic_core-2.33.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:1082dd3e2d7109ad8b7da48e1d4710c8d06c253cbc4a27c1cff4fbcaa97a9e3f", size = 2015688, upload-time = "2025-04-23T18:31:53.175Z" }, + { url = "https://files.pythonhosted.org/packages/3a/cc/5999d1eb705a6cefc31f0b4a90e9f7fc400539b1a1030529700cc1b51838/pydantic_core-2.33.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f517ca031dfc037a9c07e748cefd8d96235088b83b4f4ba8939105d20fa1dcd6", size = 1844808, upload-time = "2025-04-23T18:31:54.79Z" }, + { url = "https://files.pythonhosted.org/packages/6f/5e/a0a7b8885c98889a18b6e376f344da1ef323d270b44edf8174d6bce4d622/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a9f2c9dd19656823cb8250b0724ee9c60a82f3cdf68a080979d13092a3b0fef", size = 1885580, upload-time = "2025-04-23T18:31:57.393Z" }, + { url = "https://files.pythonhosted.org/packages/3b/2a/953581f343c7d11a304581156618c3f592435523dd9d79865903272c256a/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2b0a451c263b01acebe51895bfb0e1cc842a5c666efe06cdf13846c7418caa9a", size = 1973859, upload-time = "2025-04-23T18:31:59.065Z" }, + { url = "https://files.pythonhosted.org/packages/e6/55/f1a813904771c03a3f97f676c62cca0c0a4138654107c1b61f19c644868b/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ea40a64d23faa25e62a70ad163571c0b342b8bf66d5fa612ac0dec4f069d916", size = 2120810, upload-time = "2025-04-23T18:32:00.78Z" }, + { url = "https://files.pythonhosted.org/packages/aa/c3/053389835a996e18853ba107a63caae0b9deb4a276c6b472931ea9ae6e48/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fb2d542b4d66f9470e8065c5469ec676978d625a8b7a363f07d9a501a9cb36a", size = 2676498, upload-time = "2025-04-23T18:32:02.418Z" }, + { url = "https://files.pythonhosted.org/packages/eb/3c/f4abd740877a35abade05e437245b192f9d0ffb48bbbbd708df33d3cda37/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fdac5d6ffa1b5a83bca06ffe7583f5576555e6c8b3a91fbd25ea7780f825f7d", size = 2000611, upload-time = "2025-04-23T18:32:04.152Z" }, + { url = "https://files.pythonhosted.org/packages/59/a7/63ef2fed1837d1121a894d0ce88439fe3e3b3e48c7543b2a4479eb99c2bd/pydantic_core-2.33.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04a1a413977ab517154eebb2d326da71638271477d6ad87a769102f7c2488c56", size = 2107924, upload-time = "2025-04-23T18:32:06.129Z" }, + { url = "https://files.pythonhosted.org/packages/04/8f/2551964ef045669801675f1cfc3b0d74147f4901c3ffa42be2ddb1f0efc4/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:c8e7af2f4e0194c22b5b37205bfb293d166a7344a5b0d0eaccebc376546d77d5", size = 2063196, upload-time = "2025-04-23T18:32:08.178Z" }, + { url = "https://files.pythonhosted.org/packages/26/bd/d9602777e77fc6dbb0c7db9ad356e9a985825547dce5ad1d30ee04903918/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:5c92edd15cd58b3c2d34873597a1e20f13094f59cf88068adb18947df5455b4e", size = 2236389, upload-time = "2025-04-23T18:32:10.242Z" }, + { url = "https://files.pythonhosted.org/packages/42/db/0e950daa7e2230423ab342ae918a794964b053bec24ba8af013fc7c94846/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:65132b7b4a1c0beded5e057324b7e16e10910c106d43675d9bd87d4f38dde162", size = 2239223, upload-time = "2025-04-23T18:32:12.382Z" }, + { url = "https://files.pythonhosted.org/packages/58/4d/4f937099c545a8a17eb52cb67fe0447fd9a373b348ccfa9a87f141eeb00f/pydantic_core-2.33.2-cp313-cp313-win32.whl", hash = "sha256:52fb90784e0a242bb96ec53f42196a17278855b0f31ac7c3cc6f5c1ec4811849", size = 1900473, upload-time = "2025-04-23T18:32:14.034Z" }, + { url = "https://files.pythonhosted.org/packages/a0/75/4a0a9bac998d78d889def5e4ef2b065acba8cae8c93696906c3a91f310ca/pydantic_core-2.33.2-cp313-cp313-win_amd64.whl", hash = "sha256:c083a3bdd5a93dfe480f1125926afcdbf2917ae714bdb80b36d34318b2bec5d9", size = 1955269, upload-time = "2025-04-23T18:32:15.783Z" }, + { url = "https://files.pythonhosted.org/packages/f9/86/1beda0576969592f1497b4ce8e7bc8cbdf614c352426271b1b10d5f0aa64/pydantic_core-2.33.2-cp313-cp313-win_arm64.whl", hash = "sha256:e80b087132752f6b3d714f041ccf74403799d3b23a72722ea2e6ba2e892555b9", size = 1893921, upload-time = "2025-04-23T18:32:18.473Z" }, + { url = "https://files.pythonhosted.org/packages/a4/7d/e09391c2eebeab681df2b74bfe6c43422fffede8dc74187b2b0bf6fd7571/pydantic_core-2.33.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:61c18fba8e5e9db3ab908620af374db0ac1baa69f0f32df4f61ae23f15e586ac", size = 1806162, upload-time = "2025-04-23T18:32:20.188Z" }, + { url = "https://files.pythonhosted.org/packages/f1/3d/847b6b1fed9f8ed3bb95a9ad04fbd0b212e832d4f0f50ff4d9ee5a9f15cf/pydantic_core-2.33.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95237e53bb015f67b63c91af7518a62a8660376a6a0db19b89acc77a4d6199f5", size = 1981560, upload-time = "2025-04-23T18:32:22.354Z" }, + { url = "https://files.pythonhosted.org/packages/6f/9a/e73262f6c6656262b5fdd723ad90f518f579b7bc8622e43a942eec53c938/pydantic_core-2.33.2-cp313-cp313t-win_amd64.whl", hash = "sha256:c2fc0a768ef76c15ab9238afa6da7f69895bb5d1ee83aeea2e3509af4472d0b9", size = 1935777, upload-time = "2025-04-23T18:32:25.088Z" }, +] + +[[package]] +name = "requests" +version = "2.32.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "idna" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, +] + +[[package]] +name = "rsa" +version = "4.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/da/8a/22b7beea3ee0d44b1916c0c1cb0ee3af23b700b6da9f04991899d0c555d4/rsa-4.9.1.tar.gz", hash = "sha256:e7bdbfdb5497da4c07dfd35530e1a902659db6ff241e39d9953cad06ebd0ae75", size = 29034, upload-time = "2025-04-16T09:51:18.218Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/64/8d/0133e4eb4beed9e425d9a98ed6e081a55d195481b7632472be1af08d2f6b/rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762", size = 34696, upload-time = "2025-04-16T09:51:17.142Z" }, +] + +[[package]] +name = "sniffio" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372, upload-time = "2024-02-25T23:20:04.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, +] + +[[package]] +name = "sse-starlette" +version = "3.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/6f/22ed6e33f8a9e76ca0a412405f31abb844b779d52c5f96660766edcd737c/sse_starlette-3.0.2.tar.gz", hash = "sha256:ccd60b5765ebb3584d0de2d7a6e4f745672581de4f5005ab31c3a25d10b52b3a", size = 20985, upload-time = "2025-07-27T09:07:44.565Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/10/c78f463b4ef22eef8491f218f692be838282cd65480f6e423d7730dfd1fb/sse_starlette-3.0.2-py3-none-any.whl", hash = "sha256:16b7cbfddbcd4eaca11f7b586f3b8a080f1afe952c15813455b162edea619e5a", size = 11297, upload-time = "2025-07-27T09:07:43.268Z" }, +] + +[[package]] +name = "starlette" +version = "0.47.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/15/b9/cc3017f9a9c9b6e27c5106cc10cc7904653c3eec0729793aec10479dd669/starlette-0.47.3.tar.gz", hash = "sha256:6bc94f839cc176c4858894f1f8908f0ab79dfec1a6b8402f6da9be26ebea52e9", size = 2584144, upload-time = "2025-08-24T13:36:42.122Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ce/fd/901cfa59aaa5b30a99e16876f11abe38b59a1a2c51ffb3d7142bb6089069/starlette-0.47.3-py3-none-any.whl", hash = "sha256:89c0778ca62a76b826101e7c709e70680a1699ca7da6b44d38eb0a7e61fe4b51", size = 72991, upload-time = "2025-08-24T13:36:40.887Z" }, +] + +[[package]] +name = "test-python-a2a-server" +version = "0.1.0" +source = { virtual = "." } +dependencies = [ + { name = "a2a-sdk", extra = ["http-server"] }, + { name = "uvicorn" }, +] + +[package.metadata] +requires-dist = [ + { name = "a2a-sdk", extras = ["http-server"], specifier = "==0.3.5" }, + { name = "uvicorn", specifier = "==0.35.0" }, +] + +[[package]] +name = "typing-extensions" +version = "4.15.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, +] + +[[package]] +name = "typing-inspection" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f8/b1/0c11f5058406b3af7609f121aaa6b609744687f1d158b3c3a5bf4cc94238/typing_inspection-0.4.1.tar.gz", hash = "sha256:6ae134cc0203c33377d43188d4064e9b357dba58cff3185f22924610e70a9d28", size = 75726, upload-time = "2025-05-21T18:55:23.885Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/69/cd203477f944c353c31bade965f880aa1061fd6bf05ded0726ca845b6ff7/typing_inspection-0.4.1-py3-none-any.whl", hash = "sha256:389055682238f53b04f7badcb49b989835495a96700ced5dab2d8feae4b26f51", size = 14552, upload-time = "2025-05-21T18:55:22.152Z" }, +] + +[[package]] +name = "urllib3" +version = "2.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/15/22/9ee70a2574a4f4599c47dd506532914ce044817c7752a79b6a51286319bc/urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760", size = 393185, upload-time = "2025-06-18T14:07:41.644Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, +] + +[[package]] +name = "uvicorn" +version = "0.35.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5e/42/e0e305207bb88c6b8d3061399c6a961ffe5fbb7e2aa63c9234df7259e9cd/uvicorn-0.35.0.tar.gz", hash = "sha256:bc662f087f7cf2ce11a1d7fd70b90c9f98ef2e2831556dd078d131b96cc94a01", size = 78473, upload-time = "2025-06-28T16:15:46.058Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/e2/dc81b1bd1dcfe91735810265e9d26bc8ec5da45b4c0f6237e286819194c3/uvicorn-0.35.0-py3-none-any.whl", hash = "sha256:197535216b25ff9b785e29a0b79199f55222193d47f820816e7da751e9bc8d4a", size = 66406, upload-time = "2025-06-28T16:15:44.816Z" }, +] diff --git a/a2a/test-tck/.gitignore b/a2a/test-tck/.gitignore new file mode 100644 index 0000000000..b44f8ddb8c --- /dev/null +++ b/a2a/test-tck/.gitignore @@ -0,0 +1,2 @@ +# A2A Testing Kit, should be cloned locally by using setup_tck.sh +/a2a-tck diff --git a/a2a/test-tck/README.md b/a2a/test-tck/README.md new file mode 100644 index 0000000000..b0b863c4e8 --- /dev/null +++ b/a2a/test-tck/README.md @@ -0,0 +1,33 @@ +# A2A Testing Kit Integration + +This directory contains tooling to validate the A2A Kotlin SDK against the +official [A2A protocol specification](https://a2a-protocol.org/latest/specification/) using the A2A Testing Kit (TCK). + +## Contents + +- **`a2a-test-server-tck/`**: Sample A2A server implementation built with Koog SDK for TCK validation +- **`a2a-tck/`**: Official A2A Testing Kit repository (gitignored, should be cloned by `setup_tck.sh`) +- **`setup_tck.sh`**: Clone and setup the A2A Testing Kit +- **`run_sut.sh`**: Run the Kotlin test server (System Under Test) +- **`run_tck.sh`**: Execute TCK tests against the running server + +## Quick start + +1. **Setup the Testing Kit:** + ```bash + ./setup_tck.sh + ``` + +2. **Run the test server:** + ```bash + ./run_sut.sh + ``` + +3. **In another terminal, run the TCK tests:** + ```bash + ./run_tck.sh --sut-url http://localhost:9999/a2a --category all --report + ``` + +## More information + +For more information, see the [A2A Testing Kit repo](https://a2a-protocol.org/latest/tck/). diff --git a/a2a/test-tck/a2a-test-server-tck/Module.md b/a2a/test-tck/a2a-test-server-tck/Module.md new file mode 100644 index 0000000000..dc73dc300d --- /dev/null +++ b/a2a/test-tck/a2a-test-server-tck/Module.md @@ -0,0 +1,3 @@ +# Module a2a-test-server-tck + +Technology Compatibility Kit (TCK) for A2A server implementations. Provides a comprehensive test suite to verify compliance with the A2A protocol specification, ensuring interoperability across different implementations. diff --git a/a2a/test-tck/a2a-test-server-tck/build.gradle.kts b/a2a/test-tck/a2a-test-server-tck/build.gradle.kts new file mode 100644 index 0000000000..bd974a53e9 --- /dev/null +++ b/a2a/test-tck/a2a-test-server-tck/build.gradle.kts @@ -0,0 +1,22 @@ +plugins { + id("ai.kotlin.jvm") + alias(libs.plugins.kotlin.serialization) + application +} + +application { + mainClass = "ai.koog.a2a.test.tck.MainKt" +} + +dependencies { + implementation(project(":a2a:a2a-server")) + implementation(project(":a2a:a2a-client")) + implementation(project(":a2a:a2a-transport:a2a-transport-client-jsonrpc-http")) + implementation(project(":a2a:a2a-transport:a2a-transport-server-jsonrpc-http")) + + implementation(libs.ktor.client.cio) + implementation(libs.ktor.client.logging) + implementation(libs.ktor.server.netty) + implementation(libs.oshai.kotlin.logging) + runtimeOnly(libs.logback.classic) +} diff --git a/a2a/test-tck/a2a-test-server-tck/src/main/kotlin/ai/koog/a2a/test/tck/Main.kt b/a2a/test-tck/a2a-test-server-tck/src/main/kotlin/ai/koog/a2a/test/tck/Main.kt new file mode 100644 index 0000000000..df4cbb42d5 --- /dev/null +++ b/a2a/test-tck/a2a-test-server-tck/src/main/kotlin/ai/koog/a2a/test/tck/Main.kt @@ -0,0 +1,114 @@ +package ai.koog.a2a.test.tck + +import ai.koog.a2a.consts.A2AConsts +import ai.koog.a2a.model.AgentCapabilities +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.AgentInterface +import ai.koog.a2a.model.AgentSkill +import ai.koog.a2a.model.AuthorizationCodeOAuthFlow +import ai.koog.a2a.model.HTTPAuthSecurityScheme +import ai.koog.a2a.model.OAuth2SecurityScheme +import ai.koog.a2a.model.OAuthFlows +import ai.koog.a2a.model.TransportProtocol +import ai.koog.a2a.server.A2AServer +import ai.koog.a2a.server.notifications.InMemoryPushNotificationConfigStorage +import ai.koog.a2a.transport.server.jsonrpc.http.HttpJSONRPCServerTransport +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.server.netty.Netty + +private val logger = KotlinLogging.logger {} + +suspend fun main() { + logger.info { "Starting TCK A2A Agent on http://localhost:9999" } + + // Define security schemes + val httpBearerScheme = HTTPAuthSecurityScheme( + scheme = "bearer", + description = "HTTP Bearer token authentication" + ) + + val oauth2Scheme = OAuth2SecurityScheme( + flows = OAuthFlows( + authorizationCode = AuthorizationCodeOAuthFlow( + authorizationUrl = "https://auth.example.com/oauth/authorize", + tokenUrl = "https://auth.example.com/oauth/token", + scopes = mapOf( + "read" to "Read access", + "write" to "Write access" + ) + ) + ), + description = "OAuth 2.0 authentication" + ) + + val securitySchemes = mapOf( + "bearerAuth" to httpBearerScheme, + "oauth2" to oauth2Scheme + ) + + // Create agent card with capabilities and security + val agentCard = AgentCard( + protocolVersion = "0.3.0", + name = "TCK A2A Agent", + description = "A complete A2A agent implementation designed specifically for testing with the A2A Technology Compatibility Kit (TCK)", + version = "1.0.0", + url = "http://localhost:9999/a2a", + preferredTransport = TransportProtocol.JSONRPC, + additionalInterfaces = listOf( + AgentInterface( + url = "http://localhost:9999/a2a", + transport = TransportProtocol.JSONRPC, + ) + ), + capabilities = AgentCapabilities( + streaming = true, + ), + defaultInputModes = listOf("text"), + defaultOutputModes = listOf("text"), + skills = listOf( + AgentSkill( + id = "tck_agent", + name = "TCK Agent", + description = "A complete A2A agent implementation designed for TCK testing", + examples = listOf("hi", "hello world", "how are you", "goodbye"), + tags = listOf("tck", "testing", "core", "complete") + ) + ), + securitySchemes = securitySchemes, + security = listOf( + mapOf("bearerAuth" to emptyList()), + mapOf("oauth2" to listOf("read", "write")) + ), + supportsAuthenticatedExtendedCard = false + ) + + // Create extended agent card (same as basic for testing purposes) + val agentCardExtended = agentCard.copy( + name = "TCK A2A Agent - Extended Edition", + description = "The full-featured A2A agent for authenticated users." + ) + + // Create agent executor + val agentExecutor = TckAgentExecutor() + + // Create A2A server + val a2aServer = A2AServer( + agentExecutor = agentExecutor, + agentCard = agentCard, + agentCardExtended = agentCardExtended, + pushConfigStorage = InMemoryPushNotificationConfigStorage() + ) + + // Create and start server transport + val serverTransport = HttpJSONRPCServerTransport(a2aServer) + + logger.info { "Authentication tests will document SDK gaps with expected failures" } + serverTransport.start( + engineFactory = Netty, + port = 9999, + path = "/a2a", + wait = true, // Block until server stops + agentCard = agentCard, + agentCardPath = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH + ) +} diff --git a/a2a/test-tck/a2a-test-server-tck/src/main/kotlin/ai/koog/a2a/test/tck/TckAgentExecutor.kt b/a2a/test-tck/a2a-test-server-tck/src/main/kotlin/ai/koog/a2a/test/tck/TckAgentExecutor.kt new file mode 100644 index 0000000000..9dd6398680 --- /dev/null +++ b/a2a/test-tck/a2a-test-server-tck/src/main/kotlin/ai/koog/a2a/test/tck/TckAgentExecutor.kt @@ -0,0 +1,229 @@ +package ai.koog.a2a.test.tck + +import ai.koog.a2a.model.Artifact +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskArtifactUpdateEvent +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatus +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.server.agent.AgentExecutor +import ai.koog.a2a.server.session.RequestContext +import ai.koog.a2a.server.session.SessionEventProcessor +import io.ktor.utils.io.CancellationException +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.delay +import kotlinx.datetime.Clock +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +@OptIn(ExperimentalUuidApi::class) +class TckAgentExecutor : AgentExecutor { + override suspend fun execute( + context: RequestContext, + eventProcessor: SessionEventProcessor + ) { + val userMessage = context.params.message + val userInput = userMessage.parts.filterIsInstance() + .joinToString(" ") { it.text } + .lowercase() + + if (userInput.isBlank()) { + eventProcessor.sendMessage( + Message( + messageId = Uuid.random().toString(), + role = Role.Agent, + parts = listOf(TextPart("Hello! Please provide a message for me to respond to.")), + contextId = context.contextId + ) + ) + return + } + + val taskId = context.taskId + + try { + if (context.task == null) { + processNewTask(context, eventProcessor, userMessage, userInput) + } else { + processExistingTask(context, eventProcessor, userMessage, userInput) + } + } catch (e: CancellationException) { + // Propagate cancellation exception + throw e + } catch (e: Exception) { + // Handle errors by marking task as failed + val errorMessage = "Error processing request: ${e.message}" + val failedStatus = TaskStatusUpdateEvent( + contextId = context.contextId, + taskId = taskId, + status = TaskStatus( + state = TaskState.Failed, + message = Message( + messageId = Uuid.random().toString(), + role = Role.Agent, + parts = listOf(TextPart(errorMessage)), + contextId = context.contextId, + taskId = taskId + ), + timestamp = Clock.System.now() + ), + final = true + ) + eventProcessor.sendTaskEvent(failedStatus) + } + } + + override suspend fun cancel( + context: RequestContext, + eventProcessor: SessionEventProcessor, + agentJob: Deferred? + ) { + val taskId = context.taskId + // Cancel the coroutine job if provided + agentJob?.cancelAndJoin() + + // Send cancellation event + val canceledStatus = TaskStatusUpdateEvent( + contextId = context.contextId, + taskId = taskId, + status = TaskStatus( + state = TaskState.Canceled, + timestamp = Clock.System.now() + ), + final = true + ) + eventProcessor.sendTaskEvent(canceledStatus) + } + + private suspend fun processNewTask( + context: RequestContext, + eventProcessor: SessionEventProcessor, + userMessage: Message, + userInput: String, + ) { + val task = Task( + id = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = TaskState.Submitted, + timestamp = Clock.System.now() + ), + history = listOf(userMessage) + ) + + // Send initial task event + eventProcessor.sendTaskEvent(task) + + // Short delay to allow tests to see submitted state + delay(200) + + // Update to working state + val workingStatus = TaskStatusUpdateEvent( + contextId = context.contextId, + taskId = task.id, + status = TaskStatus( + state = TaskState.Working, + timestamp = Clock.System.now() + ), + final = false + ) + eventProcessor.sendTaskEvent(workingStatus) + + // Short delay for working state + delay(200) + + // Process the request + val result = processUserInput(userInput) + delay(200) // Brief processing delay + + // Create artifact with result + val artifact = Artifact( + artifactId = "response", + parts = listOf(TextPart(result)), + description = "Agent response to user message." + ) + + val artifactEvent = TaskArtifactUpdateEvent( + taskId = task.id, + contextId = context.contextId, + artifact = artifact, + append = false + ) + eventProcessor.sendTaskEvent(artifactEvent) + + // Mark task as completed + val completedStatus = TaskStatusUpdateEvent( + contextId = context.contextId, + taskId = task.id, + status = TaskStatus( + state = TaskState.InputRequired, + message = Message( + messageId = Uuid.random().toString(), + role = Role.Agent, + parts = listOf(TextPart(result)), + contextId = context.contextId, + taskId = task.id + ), + timestamp = Clock.System.now() + ), + final = true + ) + eventProcessor.sendTaskEvent(completedStatus) + } + + private suspend fun processExistingTask( + context: RequestContext, + eventProcessor: SessionEventProcessor, + userMessage: Message, + userInput: String, + ) { + eventProcessor.sendTaskEvent( + TaskStatusUpdateEvent( + taskId = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = TaskState.Working, + timestamp = Clock.System.now(), + message = userMessage + ), + final = false + ) + ) + + delay(100) + + eventProcessor.sendTaskEvent( + TaskStatusUpdateEvent( + taskId = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = TaskState.Completed, + timestamp = Clock.System.now(), + message = Message( + messageId = Uuid.random().toString(), + role = Role.Agent, + parts = listOf(TextPart("Task completed successfully!")), + contextId = context.contextId, + taskId = context.taskId + ) + ), + final = true + ) + ) + } + + private fun processUserInput(input: String): String { + return when { + "hello" in input || "hi" in input -> "Hello World! Nice to meet you!" + "how are you" in input -> "I'm doing great! Thanks for asking. How can I help you today?" + "goodbye" in input || "bye" in input -> "Goodbye! Have a wonderful day!" + else -> "Hello World! You said: '$input'. Thanks for your message!" + } + } +} diff --git a/a2a/test-tck/a2a-test-server-tck/src/main/resources/logback.xml b/a2a/test-tck/a2a-test-server-tck/src/main/resources/logback.xml new file mode 100644 index 0000000000..24a99c370b --- /dev/null +++ b/a2a/test-tck/a2a-test-server-tck/src/main/resources/logback.xml @@ -0,0 +1,11 @@ + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + diff --git a/a2a/test-tck/run_sut.sh b/a2a/test-tck/run_sut.sh new file mode 100755 index 0000000000..86701ec414 --- /dev/null +++ b/a2a/test-tck/run_sut.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Always operate relative to this script's directory +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd "${SCRIPT_DIR}/../.." && pwd)" + +cd "${ROOT_DIR}" + +# Run the SUT via Gradle; delegate any extra CLI args to Gradle +exec ./gradlew :a2a:test-tck:a2a-test-server-tck:run "$@" diff --git a/a2a/test-tck/run_tck.sh b/a2a/test-tck/run_tck.sh new file mode 100755 index 0000000000..e860921c9f --- /dev/null +++ b/a2a/test-tck/run_tck.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Always operate relative to this script's directory +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_DIR="${SCRIPT_DIR}/a2a-tck" + +if [ ! -d "${REPO_DIR}" ]; then + echo "[run_tck] Expected directory not found: ${REPO_DIR}" + echo "[run_tck] Please run setup_tck.sh first to clone and prepare A2A testing kit project." + exit 1 +fi + +cd "${REPO_DIR}" + +# Delegate all CLI parameters to the underlying script +exec uv run ./run_tck.py "$@" diff --git a/a2a/test-tck/setup_tck.sh b/a2a/test-tck/setup_tck.sh new file mode 100755 index 0000000000..611b1c30e2 --- /dev/null +++ b/a2a/test-tck/setup_tck.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash + +set -euo pipefail + +# Resolve the directory where this script resides to always operate relative to it +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_DIR="${SCRIPT_DIR}/a2a-tck" +REPO_URL="https://github.com/a2aproject/a2a-tck.git" + +# 1) Check if a2a-tck already exists +if [ -d "${REPO_DIR}" ]; then + echo "[setup_tck] 'a2a-tck' directory already exists at: ${REPO_DIR}" + echo "[setup_tck] Will skip cloning but still run 'uv sync'." +else + echo "[setup_tck] 'a2a-tck' directory not found in ${SCRIPT_DIR}." + echo "[setup_tck] Cloning repository into: ${REPO_DIR}" + git clone "${REPO_URL}" "${REPO_DIR}" --depth=1 +fi + +# 2) Always run uv sync in the repo directory +echo "[setup_tck] Running 'uv sync' in ${REPO_DIR}..." +( + cd "${REPO_DIR}" + uv sync --all-packages --all-extras +) +echo "[setup_tck] Done." diff --git a/agents/agents-core/Module.md b/agents/agents-core/Module.md index e62d76d1e8..57441a7593 100644 --- a/agents/agents-core/Module.md +++ b/agents/agents-core/Module.md @@ -78,3 +78,42 @@ val agent = AIAgent( // Run the agent val result = agent.execute("Calculate the square root of 16") ``` + + +### Standard Feature Events + +Features in the Koog ecosystem consume standardized Feature Events emitted by agents-core during agent execution. These events are defined in this module under the package `ai.koog.agents.core.feature.model.events`. + +- Agent events: + - AgentStartingEvent + - AgentCompletedEvent + - AgentExecutionFailedEvent + - AgentClosingEvent + +- Strategy events: + - GraphStrategyStartingEvent + - FunctionalStrategyStartingEvent + - StrategyCompletedEvent + +- Node execution events: + - NodeExecutionStartingEvent + - NodeExecutionCompletedEvent + - NodeExecutionFailedEvent + +- LLM call events: + - LLMCallStartingEvent + - LLMCallCompletedEvent + +- LLM streaming events: + - LLMStreamingStartingEvent + - LLMStreamingFrameReceivedEvent + - LLMStreamingFailedEvent + - LLMStreamingCompletedEvent + +- Tool execution events: + - ToolExecutionStartingEvent + - ToolValidationFailedEvent + - ToolExecutionFailedEvent + - ToolExecutionCompletedEvent + +These events are emitted by the agents-core runtime and consumed by features such as Tracing, Debugger, and EventHandler to enable logging, tracing, monitoring, and remote inspection. diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/context/AIAgentLLMContext.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/context/AIAgentLLMContext.kt index 3215e57931..07f3501c56 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/context/AIAgentLLMContext.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/context/AIAgentLLMContext.kt @@ -50,7 +50,7 @@ public annotation class DetachedPromptExecutorAPI */ public class AIAgentLLMContext( tools: List, - public val toolRegistry: ToolRegistry = ToolRegistry.Companion.EMPTY, + public val toolRegistry: ToolRegistry = ToolRegistry.EMPTY, prompt: Prompt, model: LLModel, @property:DetachedPromptExecutorAPI diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/entity/AIAgentNode.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/entity/AIAgentNode.kt index abd36e7802..311f4f1e85 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/entity/AIAgentNode.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/entity/AIAgentNode.kt @@ -119,7 +119,6 @@ public abstract class AIAgentNodeBase internal constructor() { /** * Executes the node operation using the provided execution context and input, bypassing type safety checks. * This method internally calls the type-safe `execute` method after casting the input. - * The lifecycle hooks `onBeforeNode` and `onAfterNode` are invoked before and after the execution respectively. * * @param context The execution context that provides runtime information and functionality. * @param input The input data to be processed by the node, which may be of any type. diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMSession.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMSession.kt index 296b2df3e6..c86535b79e 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMSession.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMSession.kt @@ -15,6 +15,7 @@ import ai.koog.prompt.structure.StructureFixingParser import ai.koog.prompt.structure.StructuredOutputConfig import ai.koog.prompt.structure.StructuredResponse import ai.koog.prompt.structure.executeStructured +import ai.koog.prompt.structure.parseResponseToStructuredResponse import kotlinx.serialization.KSerializer import kotlinx.serialization.serializer @@ -312,6 +313,24 @@ public sealed class AIAgentLLMSession( fixingParser = fixingParser, ) + /** + * Parses a structured response from the language model using the specified configuration. + * + * This function takes a response message and a structured output configuration, + * parses the response content based on the defined structure, and returns + * a structured response containing the parsed data and the original message. + * + * @param response The response message from the language model that contains the content to be parsed. + * The message is expected to match the defined structured output. + * @param config The configuration defining the expected structure and additional parsing behavior. + * It includes options such as structure definitions and optional parsers for error handling. + * @return A structured response containing the parsed data of type `T` along with the original message. + */ + public suspend fun parseResponseToStructuredResponse( + response: Message.Assistant, + config: StructuredOutputConfig + ): StructuredResponse = executor.parseResponseToStructuredResponse(response, config, model) + /** * Sends a request to the language model, potentially receiving multiple choices, * and returns a list of choices from the model. diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodes.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodes.kt index 89ff0437f4..4bc9bf6561 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodes.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodes.kt @@ -507,3 +507,26 @@ public inline fun AIAgentSubgraphBuilderBase< toolResult } } + +/** + * Creates a node that sets up a structured output for an AI agent subgraph. + * + * The method defines a new node with a configurable structured output schema + * that will be applied during the AI agent's message processing. The schema + * is determined by the given configuration. + * + * @param name An optional name for the node. If null, a default name will be assigned. + * @param config The configuration that defines the structured output format and schema. + * @return An instance of [AIAgentNodeDelegate] representing the constructed node. + */ +@AIAgentBuilderDslMarker +public inline fun AIAgentSubgraphBuilderBase<*, *>.nodeSetStructuredOutput( + name: String? = null, + config: StructuredOutputConfig +): AIAgentNodeDelegate = + node(name) { message -> + llm.writeSession { + prompt = config.updatePrompt(model, prompt) + message + } + } diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/GenericAgentEnvironment.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/GenericAgentEnvironment.kt index 2e33a990b9..35c70d1317 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/GenericAgentEnvironment.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/GenericAgentEnvironment.kt @@ -117,7 +117,7 @@ internal class GenericAgentEnvironment( ) } - pipeline.onToolExecutionStarting(content.runId, content.toolCallId, tool, toolArgs) + pipeline.onToolCallStarting(content.runId, content.toolCallId, tool, toolArgs) val toolResult = try { @Suppress("UNCHECKED_CAST") @@ -135,7 +135,7 @@ internal class GenericAgentEnvironment( } catch (e: Exception) { logger.error(e) { "Tool \"${tool.name}\" failed to execute with arguments: ${content.toolArgs}" } - pipeline.onToolExecutionFailed(content.runId, content.toolCallId, tool, toolArgs, e) + pipeline.onToolCallFailed(content.runId, content.toolCallId, tool, toolArgs, e) return toolResult( message = "Tool \"${tool.name}\" failed to execute because of ${e.message}!", @@ -146,7 +146,7 @@ internal class GenericAgentEnvironment( ) } - pipeline.onToolExecutionCompleted(content.runId, content.toolCallId, tool, toolArgs, toolResult) + pipeline.onToolCallCompleted(content.runId, content.toolCallId, tool, toolArgs, toolResult) logger.trace { "Completed execution of ${content.toolName} with result: $toolResult" } diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/AIAgentPipeline.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/AIAgentPipeline.kt index 2592af660e..e12770e18c 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/AIAgentPipeline.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/AIAgentPipeline.kt @@ -41,13 +41,13 @@ import ai.koog.agents.core.feature.handler.streaming.LLMStreamingFrameReceivedCo import ai.koog.agents.core.feature.handler.streaming.LLMStreamingFrameReceivedHandler import ai.koog.agents.core.feature.handler.streaming.LLMStreamingStartingContext import ai.koog.agents.core.feature.handler.streaming.LLMStreamingStartingHandler +import ai.koog.agents.core.feature.handler.tool.ToolCallCompletedContext +import ai.koog.agents.core.feature.handler.tool.ToolCallEventHandler +import ai.koog.agents.core.feature.handler.tool.ToolCallFailedContext import ai.koog.agents.core.feature.handler.tool.ToolCallFailureHandler import ai.koog.agents.core.feature.handler.tool.ToolCallHandler import ai.koog.agents.core.feature.handler.tool.ToolCallResultHandler -import ai.koog.agents.core.feature.handler.tool.ToolExecutionCompletedContext -import ai.koog.agents.core.feature.handler.tool.ToolExecutionEventHandler -import ai.koog.agents.core.feature.handler.tool.ToolExecutionFailedContext -import ai.koog.agents.core.feature.handler.tool.ToolExecutionStartingContext +import ai.koog.agents.core.feature.handler.tool.ToolCallStartingContext import ai.koog.agents.core.feature.handler.tool.ToolValidationErrorHandler import ai.koog.agents.core.feature.handler.tool.ToolValidationFailedContext import ai.koog.agents.core.tools.Tool @@ -123,7 +123,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { * Map of tool execution handlers registered for different features. * Keys are feature storage keys, values are tool execution handlers. */ - protected val toolExecutionEventHandlers: MutableMap, ToolExecutionEventHandler> = mutableMapOf() + protected val toolCallEventHandlers: MutableMap, ToolCallEventHandler> = mutableMapOf() /** * Map of LLM execution handlers registered for different features. @@ -313,7 +313,8 @@ public abstract class AIAgentPipeline(public val clock: Clock) { strategy = strategy, feature = handler.feature, result = result, - resultType = resultType + resultType = resultType, + agentId = context.agentId ) handler.handleStrategyCompletedUnsafe(eventContext) } @@ -367,9 +368,9 @@ public abstract class AIAgentPipeline(public val clock: Clock) { * @param tool The tool that is being called * @param toolArgs The arguments provided to the tool */ - public suspend fun onToolExecutionStarting(runId: String, toolCallId: String?, tool: Tool<*, *>, toolArgs: Any?) { - val eventContext = ToolExecutionStartingContext(runId, toolCallId, tool, toolArgs) - toolExecutionEventHandlers.values.forEach { handler -> handler.toolCallHandler.handle(eventContext) } + public suspend fun onToolCallStarting(runId: String, toolCallId: String?, tool: Tool<*, *>, toolArgs: Any?) { + val eventContext = ToolCallStartingContext(runId, toolCallId, tool, toolArgs) + toolCallEventHandlers.values.forEach { handler -> handler.toolCallHandler.handle(eventContext) } } /** @@ -389,7 +390,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { ) { val eventContext = ToolValidationFailedContext(runId, toolCallId, tool, toolArgs, error) - toolExecutionEventHandlers.values.forEach { handler -> handler.toolValidationErrorHandler.handle(eventContext) } + toolCallEventHandlers.values.forEach { handler -> handler.toolValidationErrorHandler.handle(eventContext) } } /** @@ -400,15 +401,15 @@ public abstract class AIAgentPipeline(public val clock: Clock) { * @param toolArgs The arguments provided to the tool * @param throwable The exception that caused the failure */ - public suspend fun onToolExecutionFailed( + public suspend fun onToolCallFailed( runId: String, toolCallId: String?, tool: Tool<*, *>, toolArgs: Any?, throwable: Throwable ) { - val eventContext = ToolExecutionFailedContext(runId, toolCallId, tool, toolArgs, throwable) - toolExecutionEventHandlers.values.forEach { handler -> handler.toolCallFailureHandler.handle(eventContext) } + val eventContext = ToolCallFailedContext(runId, toolCallId, tool, toolArgs, throwable) + toolCallEventHandlers.values.forEach { handler -> handler.toolCallFailureHandler.handle(eventContext) } } /** @@ -419,15 +420,15 @@ public abstract class AIAgentPipeline(public val clock: Clock) { * @param toolArgs The arguments that were provided to the tool * @param result The result produced by the tool, or null if no result was produced */ - public suspend fun onToolExecutionCompleted( + public suspend fun onToolCallCompleted( runId: String, toolCallId: String?, tool: Tool<*, *>, toolArgs: Any?, result: Any? ) { - val eventContext = ToolExecutionCompletedContext(runId, toolCallId, tool, toolArgs, result) - toolExecutionEventHandlers.values.forEach { handler -> handler.toolCallResultHandler.handle(eventContext) } + val eventContext = ToolCallCompletedContext(runId, toolCallId, tool, toolArgs, result) + toolCallEventHandlers.values.forEach { handler -> handler.toolCallResultHandler.handle(eventContext) } } //endregion Trigger Tool Call Handlers @@ -867,16 +868,16 @@ public abstract class AIAgentPipeline(public val clock: Clock) { * * Example: * ``` - * pipeline.interceptToolExecutionStarting(interceptContext) { eventContext -> + * pipeline.interceptToolCallStarting(interceptContext) { eventContext -> * // Process or log the tool call * } * ``` */ - public fun interceptToolExecutionStarting( + public fun interceptToolCallStarting( interceptContext: InterceptContext, - handle: suspend TFeature.(eventContext: ToolExecutionStartingContext) -> Unit + handle: suspend TFeature.(eventContext: ToolCallStartingContext) -> Unit ) { - val handler = toolExecutionEventHandlers.getOrPut(interceptContext.feature.key) { ToolExecutionEventHandler() } + val handler = toolCallEventHandlers.getOrPut(interceptContext.feature.key) { ToolCallEventHandler() } handler.toolCallHandler = ToolCallHandler( function = createConditionalHandler(interceptContext, handle) ) @@ -899,7 +900,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { interceptContext: InterceptContext, handle: suspend TFeature.(eventContext: ToolValidationFailedContext) -> Unit ) { - val handler = toolExecutionEventHandlers.getOrPut(interceptContext.feature.key) { ToolExecutionEventHandler() } + val handler = toolCallEventHandlers.getOrPut(interceptContext.feature.key) { ToolCallEventHandler() } handler.toolValidationErrorHandler = ToolValidationErrorHandler( function = createConditionalHandler(interceptContext, handle) ) @@ -913,16 +914,16 @@ public abstract class AIAgentPipeline(public val clock: Clock) { * * Example: * ``` - * pipeline.interceptToolExecutionFailed(interceptContext) { eventContext -> + * pipeline.interceptToolCallFailed(interceptContext) { eventContext -> * // Handle the tool call failure here * } * ``` */ - public fun interceptToolExecutionFailed( + public fun interceptToolCallFailed( interceptContext: InterceptContext, - handle: suspend TFeature.(eventContext: ToolExecutionFailedContext) -> Unit + handle: suspend TFeature.(eventContext: ToolCallFailedContext) -> Unit ) { - val handler = toolExecutionEventHandlers.getOrPut(interceptContext.feature.key) { ToolExecutionEventHandler() } + val handler = toolCallEventHandlers.getOrPut(interceptContext.feature.key) { ToolCallEventHandler() } handler.toolCallFailureHandler = ToolCallFailureHandler( function = createConditionalHandler(interceptContext, handle) ) @@ -942,11 +943,11 @@ public abstract class AIAgentPipeline(public val clock: Clock) { * } * ``` */ - public fun interceptToolExecutionCompleted( + public fun interceptToolCallCompleted( interceptContext: InterceptContext, - handle: suspend TFeature.(eventContext: ToolExecutionCompletedContext) -> Unit + handle: suspend TFeature.(eventContext: ToolCallCompletedContext) -> Unit ) { - val handler = toolExecutionEventHandlers.getOrPut(interceptContext.feature.key) { ToolExecutionEventHandler() } + val handler = toolCallEventHandlers.getOrPut(interceptContext.feature.key) { ToolCallEventHandler() } handler.toolCallResultHandler = ToolCallResultHandler( function = createConditionalHandler(interceptContext, handle) ) @@ -968,7 +969,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { ) public fun interceptBeforeAgentStarted( interceptContext: InterceptContext, - handle: suspend (ai.koog.agents.core.feature.handler.AgentStartContext) -> Unit + handle: suspend (AgentStartingContext) -> Unit ) { interceptAgentStarting(interceptContext, handle) } @@ -987,7 +988,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { ) public fun interceptAgentFinished( interceptContext: InterceptContext, - handle: suspend TFeature.(eventContext: ai.koog.agents.core.feature.handler.AgentFinishedContext) -> Unit + handle: suspend TFeature.(eventContext: AgentCompletedContext) -> Unit ) { interceptAgentCompleted(interceptContext, handle) } @@ -1006,7 +1007,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { ) public fun interceptAgentRunError( interceptContext: InterceptContext, - handle: suspend TFeature.(ai.koog.agents.core.feature.handler.AgentRunErrorContext) -> Unit + handle: suspend TFeature.(AgentExecutionFailedContext) -> Unit ) { interceptAgentExecutionFailed(interceptContext, handle) } @@ -1025,7 +1026,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { ) public fun interceptAgentBeforeClose( interceptContext: InterceptContext, - handle: suspend TFeature.(ai.koog.agents.core.feature.handler.AgentBeforeCloseContext) -> Unit + handle: suspend TFeature.(AgentClosingContext) -> Unit ) { interceptAgentClosing(interceptContext, handle) } @@ -1044,7 +1045,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { ) public fun interceptStrategyStart( interceptContext: InterceptContext, - handle: suspend (ai.koog.agents.core.feature.handler.StrategyStartContext) -> Unit + handle: suspend (StrategyStartingContext) -> Unit ) { interceptStrategyStarting(interceptContext, handle) } @@ -1063,7 +1064,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { ) public fun interceptStrategyFinished( interceptContext: InterceptContext, - handle: suspend (ai.koog.agents.core.feature.handler.StrategyFinishedContext) -> Unit + handle: suspend (StrategyCompletedContext) -> Unit ) { interceptStrategyCompleted(interceptContext, handle) } @@ -1082,7 +1083,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { ) public fun interceptBeforeLLMCall( interceptContext: InterceptContext, - handle: suspend TFeature.(eventContext: ai.koog.agents.core.feature.handler.BeforeLLMCallContext) -> Unit + handle: suspend TFeature.(eventContext: LLMCallStartingContext) -> Unit ) { interceptLLMCallStarting(interceptContext, handle) } @@ -1101,7 +1102,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { ) public fun interceptAfterLLMCall( interceptContext: InterceptContext, - handle: suspend TFeature.(eventContext: ai.koog.agents.core.feature.handler.AfterLLMCallContext) -> Unit + handle: suspend TFeature.(eventContext: LLMCallCompletedContext) -> Unit ) { interceptLLMCallCompleted(interceptContext, handle) } @@ -1111,57 +1112,57 @@ public abstract class AIAgentPipeline(public val clock: Clock) { * Updates the tool call handler for the given feature key with a custom handler. */ @Deprecated( - message = "Please use interceptToolExecutionStarting instead. This method is deprecated and will be removed in the next release.", + message = "Please use interceptToolCallStarting instead. This method is deprecated and will be removed in the next release.", replaceWith = ReplaceWith( - expression = "interceptToolExecutionStarting(interceptContext, handle)", + expression = "interceptToolCallStarting(interceptContext, handle)", imports = arrayOf( - "ai.koog.agents.core.feature.handler.tool.ToolExecutionStartingContext" + "ai.koog.agents.core.feature.handler.tool.ToolCallStartingContext" ) ) ) public fun interceptToolCall( interceptContext: InterceptContext, - handle: suspend TFeature.(eventContext: ai.koog.agents.core.feature.handler.ToolCallContext) -> Unit + handle: suspend TFeature.(eventContext: ToolCallStartingContext) -> Unit ) { - interceptToolExecutionStarting(interceptContext, handle) + interceptToolCallStarting(interceptContext, handle) } /** * Intercepts the result of a tool call with a custom handler for a specific feature. */ @Deprecated( - message = "Please use interceptToolExecutionCompleted instead. This method is deprecated and will be removed in the next release.", + message = "Please use interceptToolCallCompleted instead. This method is deprecated and will be removed in the next release.", replaceWith = ReplaceWith( - expression = "interceptToolExecutionCompleted(interceptContext, handle)", + expression = "interceptToolCallCompleted(interceptContext, handle)", imports = arrayOf( - "ai.koog.agents.core.feature.handler.tool.ToolExecutionCompletedContext" + "ai.koog.agents.core.feature.handler.tool.ToolCallCompletedContext" ) ) ) public fun interceptToolCallResult( interceptContext: InterceptContext, - handle: suspend TFeature.(eventContext: ToolExecutionCompletedContext) -> Unit + handle: suspend TFeature.(eventContext: ToolCallCompletedContext) -> Unit ) { - interceptToolExecutionCompleted(interceptContext, handle) + interceptToolCallCompleted(interceptContext, handle) } /** * Sets up an interception mechanism to handle tool call failures for a specific feature. */ @Deprecated( - message = "Please use interceptToolExecutionFailed instead. This method is deprecated and will be removed in the next release.", + message = "Please use interceptToolCallFailed instead. This method is deprecated and will be removed in the next release.", replaceWith = ReplaceWith( - expression = "interceptToolExecutionFailed(interceptContext, handle)", + expression = "interceptToolCallFailed(interceptContext, handle)", imports = arrayOf( - "ai.koog.agents.core.feature.handler.tool.ToolExecutionFailedContext" + "ai.koog.agents.core.feature.handler.tool.ToolCallFailedContext" ) ) ) public fun interceptToolCallFailure( interceptContext: InterceptContext, - handle: suspend TFeature.(eventContext: ToolExecutionFailedContext) -> Unit + handle: suspend TFeature.(eventContext: ToolCallFailedContext) -> Unit ) { - interceptToolExecutionFailed(interceptContext, handle) + interceptToolCallFailed(interceptContext, handle) } /** diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/AgentLifecycleEventType.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/AgentLifecycleEventType.kt index 444f57d9b6..6cb3ee0b36 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/AgentLifecycleEventType.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/AgentLifecycleEventType.kt @@ -89,7 +89,7 @@ public sealed interface AgentLifecycleEventType { /** * Represents an event triggered when a tool is called. */ - public object ToolExecutionStarting : AgentLifecycleEventType + public object ToolCallStarting : AgentLifecycleEventType /** * Represents an event triggered when a tool call fails validation. @@ -99,12 +99,12 @@ public sealed interface AgentLifecycleEventType { /** * Represents an event triggered when a tool call fails. */ - public object ToolExecutionFailed : AgentLifecycleEventType + public object ToolCallFailed : AgentLifecycleEventType /** * Represents an event triggered when a tool call succeeds. */ - public object ToolExecutionCompleted : AgentLifecycleEventType + public object ToolCallCompleted : AgentLifecycleEventType //endregion Tool diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/DeprecatedExecuteLLMEventHandlerContext.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/DeprecatedExecuteLLMEventHandlerContext.kt index 488d42389c..7b02f7dde6 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/DeprecatedExecuteLLMEventHandlerContext.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/DeprecatedExecuteLLMEventHandlerContext.kt @@ -3,7 +3,7 @@ package ai.koog.agents.core.feature.handler /** * Represents the context for handling LLM-specific events within the framework. */ -public interface LLMEventHandlerContext : EventHandlerContext +public interface LLMEventHandlerContext : AgentLifecycleEventContext /** * Represents the context for handling a before LLM call event. diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/DeprecatedExecuteToolEventHandlerContext.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/DeprecatedExecuteToolEventHandlerContext.kt index 2607446f7a..79e90449be 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/DeprecatedExecuteToolEventHandlerContext.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/DeprecatedExecuteToolEventHandlerContext.kt @@ -4,25 +4,25 @@ package ai.koog.agents.core.feature.handler * Represents the context for handling tool-specific events within the framework. */ @Deprecated( - message = "Use ToolExecutionEventContext instead", + message = "Use ToolCallEventContext instead", replaceWith = ReplaceWith( - expression = "ToolExecutionEventContext", - imports = arrayOf("ai.koog.agents.core.feature.handler.tool.ToolExecutionEventContext") + expression = "ToolCallEventContext", + imports = arrayOf("ai.koog.agents.core.feature.handler.tool.ToolCallEventContext") ) ) -public typealias ToolEventHandlerContext = ai.koog.agents.core.feature.handler.tool.ToolExecutionEventContext +public typealias ToolEventHandlerContext = ai.koog.agents.core.feature.handler.tool.ToolCallEventContext /** * Represents the context for handling a tool call event. */ @Deprecated( - message = "Use ToolExecutionStartingContext instead", + message = "Use ToolCallStartingContext instead", replaceWith = ReplaceWith( - expression = "ToolExecutionStartingContext", - imports = arrayOf("ai.koog.agents.core.feature.handler.tool.ToolExecutionStartingContext") + expression = "ToolCallStartingContext", + imports = arrayOf("ai.koog.agents.core.feature.handler.tool.ToolCallStartingContext") ) ) -public typealias ToolCallContext = ai.koog.agents.core.feature.handler.tool.ToolExecutionStartingContext +public typealias ToolCallContext = ai.koog.agents.core.feature.handler.tool.ToolCallStartingContext /** * Represents the context for handling validation errors that occur during the execution of a tool. @@ -40,22 +40,22 @@ public typealias ToolValidationErrorContext = ai.koog.agents.core.feature.handle * Represents the context provided to handle a failure during the execution of a tool. */ @Deprecated( - message = "Use ToolExecutionFailedContext instead", + message = "Use ToolCallFailedContext instead", replaceWith = ReplaceWith( - expression = "ToolExecutionFailedContext", - imports = arrayOf("ai.koog.agents.core.feature.handler.tool.ToolExecutionFailedContext") + expression = "ToolCallFailedContext", + imports = arrayOf("ai.koog.agents.core.feature.handler.tool.ToolCallFailedContext") ) ) -public typealias ToolCallFailureContext = ai.koog.agents.core.feature.handler.tool.ToolExecutionFailedContext +public typealias ToolCallFailureContext = ai.koog.agents.core.feature.handler.tool.ToolCallFailedContext /** * Represents the context used when handling the result of a tool call. */ @Deprecated( - message = "Use ToolExecutionCompletedContext instead", + message = "Use ToolCallCompletedContext instead", replaceWith = ReplaceWith( - expression = "ToolExecutionCompletedContext", - imports = arrayOf("ai.koog.agents.core.feature.handler.tool.ToolExecutionCompletedContext") + expression = "ToolCallCompletedContext", + imports = arrayOf("ai.koog.agents.core.feature.handler.tool.ToolCallCompletedContext") ) ) -public typealias ToolCallResultContext = ai.koog.agents.core.feature.handler.tool.ToolExecutionCompletedContext +public typealias ToolCallResultContext = ai.koog.agents.core.feature.handler.tool.ToolCallCompletedContext diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/strategy/StrategyEventContext.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/strategy/StrategyEventContext.kt index 92dbeb19b8..a9b557925f 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/strategy/StrategyEventContext.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/strategy/StrategyEventContext.kt @@ -46,6 +46,7 @@ public class StrategyCompletedContext( public val feature: TFeature, public val result: Any?, public val resultType: KType, + public val agentId: String ) : StrategyEventContext { override val eventType: AgentLifecycleEventType = AgentLifecycleEventType.StrategyCompleted } diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolExecutionEventContext.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolCallEventContext.kt similarity index 84% rename from agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolExecutionEventContext.kt rename to agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolCallEventContext.kt index a6550b80fb..385661f77d 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolExecutionEventContext.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolCallEventContext.kt @@ -7,7 +7,7 @@ import ai.koog.agents.core.tools.Tool /** * Represents the context for handling tool-specific events within the framework. */ -public interface ToolExecutionEventContext : AgentLifecycleEventContext +public interface ToolCallEventContext : AgentLifecycleEventContext /** * Represents the context for handling a tool call event. @@ -15,13 +15,13 @@ public interface ToolExecutionEventContext : AgentLifecycleEventContext * @property tool The tool instance that is being executed. It encapsulates the logic and metadata for the operation. * @property toolArgs The arguments provided for the tool execution, adhering to the tool's expected input structure. */ -public data class ToolExecutionStartingContext( +public data class ToolCallStartingContext( val runId: String, val toolCallId: String?, val tool: Tool<*, *>, val toolArgs: Any? -) : ToolExecutionEventContext { - override val eventType: AgentLifecycleEventType = AgentLifecycleEventType.ToolExecutionStarting +) : ToolCallEventContext { + override val eventType: AgentLifecycleEventType = AgentLifecycleEventType.ToolCallStarting } /** @@ -37,7 +37,7 @@ public data class ToolValidationFailedContext( val tool: Tool<*, *>, val toolArgs: Any?, val error: String -) : ToolExecutionEventContext { +) : ToolCallEventContext { override val eventType: AgentLifecycleEventType = AgentLifecycleEventType.ToolValidationFailed } @@ -48,14 +48,14 @@ public data class ToolValidationFailedContext( * @param toolArgs The arguments that were passed to the tool during execution. * @param throwable The exception or error that caused the failure. */ -public data class ToolExecutionFailedContext( +public data class ToolCallFailedContext( val runId: String, val toolCallId: String?, val tool: Tool<*, *>, val toolArgs: Any?, val throwable: Throwable -) : ToolExecutionEventContext { - override val eventType: AgentLifecycleEventType = AgentLifecycleEventType.ToolExecutionFailed +) : ToolCallEventContext { + override val eventType: AgentLifecycleEventType = AgentLifecycleEventType.ToolCallFailed } /** @@ -65,12 +65,12 @@ public data class ToolExecutionFailedContext( * @param toolArgs The arguments required by the tool for execution. * @param result An optional result produced by the tool after execution can be null if not applicable. */ -public data class ToolExecutionCompletedContext( +public data class ToolCallCompletedContext( val runId: String, val toolCallId: String?, val tool: Tool<*, *>, val toolArgs: Any?, val result: Any? -) : ToolExecutionEventContext { - override val eventType: AgentLifecycleEventType = AgentLifecycleEventType.ToolExecutionCompleted +) : ToolCallEventContext { + override val eventType: AgentLifecycleEventType = AgentLifecycleEventType.ToolCallCompleted } diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolExecutionEventHandler.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolCallEventHandler.kt similarity index 94% rename from agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolExecutionEventHandler.kt rename to agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolCallEventHandler.kt index 55708d4509..082fd1a26c 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolExecutionEventHandler.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolCallEventHandler.kt @@ -7,7 +7,7 @@ package ai.koog.agents.core.feature.handler.tool * This class provides properties that allow defining specific behavior * during different stages of a tool's execution process. */ -public class ToolExecutionEventHandler { +public class ToolCallEventHandler { /** * A variable of type [ToolCallHandler] used to handle tool call operations. * It provides a mechanism for executing specific logic when a tool is called @@ -61,7 +61,7 @@ public fun interface ToolCallHandler { /** * Handles the execution of a given tool using the provided arguments. */ - public suspend fun handle(eventContext: ToolExecutionStartingContext) + public suspend fun handle(eventContext: ToolCallStartingContext) } /** @@ -84,7 +84,7 @@ public fun interface ToolCallFailureHandler { /** * Handles a failure that occurs during the execution of a tool call. */ - public suspend fun handle(eventContext: ToolExecutionFailedContext) + public suspend fun handle(eventContext: ToolCallFailedContext) } /** @@ -96,5 +96,5 @@ public fun interface ToolCallResultHandler { /** * Handles the execution of a specific tool by processing its arguments and optionally handling its result. */ - public suspend fun handle(eventContext: ToolExecutionCompletedContext) + public suspend fun handle(eventContext: ToolCallCompletedContext) } diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/message/FeatureEvent.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/message/FeatureEvent.kt index 902e19cadb..27e871c26b 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/message/FeatureEvent.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/message/FeatureEvent.kt @@ -8,13 +8,4 @@ package ai.koog.agents.core.feature.message * Implementations of this interface are intended to detail specific events in the feature * processing workflow. */ -public interface FeatureEvent : FeatureMessage { - - /** - * Represents a unique identifier for a feature-related event. - * - * This identifier is used to distinguish and track individual events in the system, - * enabling a clear correlation between logged events or processed messages. - */ - public val eventId: String -} +public interface FeatureEvent : FeatureMessage diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/FeatureEventMessage.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/FeatureEventMessage.kt index 951b8d0b8d..02daab5243 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/FeatureEventMessage.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/FeatureEventMessage.kt @@ -14,8 +14,6 @@ import kotlinx.serialization.Serializable * The primary purpose of this class is to represent feature event data with an associated unique event identifier, * a timestamp marking its creation, and the message type indicating it is an event-specific message. * - * @property eventId A unique identifier associated with this event. - * This property implements the [ai.koog.agents.core.feature.message.FeatureEvent.eventId] from the parent interface. * @property timestamp The time at which this event message was created has represented in milliseconds since the epoch. * This property implements the [ai.koog.agents.core.feature.message.FeatureMessage.timestamp] from the parent interface. * @property messageType The type of the message, which in this case is fixed as [Type.Event]. @@ -23,7 +21,6 @@ import kotlinx.serialization.Serializable */ @Serializable public data class FeatureEventMessage( - override val eventId: String, override val timestamp: Long = Clock.System.now().toEpochMilliseconds() ) : FeatureEvent { diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/agentEvents.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/agentEvents.kt index ec4e2c2e1a..1174e6748b 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/agentEvents.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/agentEvents.kt @@ -12,14 +12,12 @@ import kotlinx.serialization.Serializable * * @property agentId The unique identifier of the AI agent; * @property runId The unique identifier of the AI agen run; - * @property eventId A string representing the event type; * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @Serializable public data class AgentStartingEvent( val agentId: String, val runId: String, - override val eventId: String = AgentStartingEvent::class.simpleName!!, override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), ) : DefinedFeatureEvent() @@ -33,7 +31,6 @@ public data class AgentStartingEvent( * @property agentId The unique identifier of the AI agent; * @property runId The unique identifier of the AI agen run; * @property result The result of the strategy execution, or null if unavailable; - * @property eventId A string representing the event type; * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @Serializable @@ -41,7 +38,6 @@ public data class AgentCompletedEvent( val agentId: String, val runId: String, val result: String?, - override val eventId: String = AgentCompletedEvent::class.simpleName!!, override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), ) : DefinedFeatureEvent() @@ -56,7 +52,6 @@ public data class AgentCompletedEvent( * @property runId The unique identifier of the AI agen run; * @property error The [AIAgentError] instance encapsulating details about the encountered error, * such as its message, stack trace, and cause; - * @property eventId A string representing the event type; * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @Serializable @@ -64,7 +59,6 @@ public data class AgentExecutionFailedEvent( val agentId: String, val runId: String, val error: AIAgentError, - override val eventId: String = AgentExecutionFailedEvent::class.simpleName!!, override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), ) : DefinedFeatureEvent() @@ -73,13 +67,11 @@ public data class AgentExecutionFailedEvent( * by a unique `agentId`. * * @property agentId The unique identifier of the AI agent; - * @property eventId A string representing the event type; * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @Serializable public data class AgentClosingEvent( val agentId: String, - override val eventId: String = AgentClosingEvent::class.simpleName!!, override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), ) : DefinedFeatureEvent() diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/llmCallEvents.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/llmCallEvents.kt index 3e749b20c8..e7890e1df2 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/llmCallEvents.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/llmCallEvents.kt @@ -13,11 +13,11 @@ import kotlinx.serialization.Serializable * input prompt and any tools that will be used during the call. It extends the `DefinedFeatureEvent` class * and serves as a specific type of event in a feature-driven framework. * + * @property runId A unique identifier associated with the specific run of the LLM call. * @property prompt The input prompt encapsulated as a [Prompt] object. This represents the structured set of * messages and configuration parameters sent to the LLM. - * @property tools The list of tools, represented by their string identifiers, being used within the scope - * of the LLM call. These tools may extend or enhance the core functionality of the LLM. - * @property eventId A string representing the event type; + * @property model The description of the LLM model used during the call. Use the format: 'llm_provider:model_id'; + * @property tools A list of tools used or invoked during the LLM call. * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @Serializable @@ -26,7 +26,6 @@ public data class LLMCallStartingEvent( val prompt: Prompt, val model: String, val tools: List, - override val eventId: String = LLMCallStartingEvent::class.simpleName!!, override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), ) : DefinedFeatureEvent() @@ -38,9 +37,14 @@ public data class LLMCallStartingEvent( * The event is used within the system to capture relevant output data and ensure proper tracking * and logging of LLM-related interactions. * + * @property runId The unique identifier of the LLM run. + * @property prompt The input prompt encapsulated as a [Prompt] object. This represents the structured set of + * messages and configuration parameters sent to the LLM. + * @property model The description of the LLM model used during the call. Use the format: 'llm_provider:model_id'; * @property responses A list of responses generated by the LLM, represented as instances of [Message.Response]. * Each response contains content, metadata, and additional context about the interaction. - * @property eventId A string representing the event type; + * @property moderationResponse The moderation response, if any, returned by the LLM. + * This is typically used to capture and track content moderation results. * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @Serializable @@ -50,7 +54,6 @@ public data class LLMCallCompletedEvent( val model: String, val responses: List, val moderationResponse: ModerationResult? = null, - override val eventId: String = LLMCallCompletedEvent::class.simpleName!!, override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), ) : DefinedFeatureEvent() diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/llmStreamingEvents.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/llmStreamingEvents.kt new file mode 100644 index 0000000000..a18107977b --- /dev/null +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/llmStreamingEvents.kt @@ -0,0 +1,84 @@ +package ai.koog.agents.core.feature.model.events + +import ai.koog.agents.core.feature.model.AIAgentError +import ai.koog.prompt.dsl.Prompt +import ai.koog.prompt.streaming.StreamFrame +import kotlinx.datetime.Clock +import kotlinx.serialization.Serializable + +/** + * Represents an event triggered when a language model (LLM) streaming operation is starting. + * + * This event holds metadata related to the initiation of the LLM streaming process, including + * details about the run, the input prompt, the model used, and the tools involved. + * + * @property runId Unique identifier for the LLM run; + * @property prompt The input prompt provided for the LLM operation; + * @property model The description of the LLM model used during the call. Use the format: 'llm_provider:model_id'; + * @property tools A list of associated tools or resources that are part of the operation; + * @property timestamp The time when the event occurred, represented in epoch milliseconds. + */ +@Serializable +public data class LLMStreamingStartingEvent( + val runId: String, + val prompt: Prompt, + val model: String, + val tools: List, + override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), +) : DefinedFeatureEvent() + +/** + * Event representing the receipt of a streaming frame from a Language Learning Model (LLM). + * + * This event occurs as part of the streaming interaction with the LLM, where individual + * frames of data are sent incrementally. The event contains details about the specific + * frame received, as well as metadata related to the event's timing and identity. + * + * @property runId The unique identifier for the LLM run or session associated with this event; + * @property frame The frame data received as part of the streaming response. This can include textual + * content, tool invocations, or signaling the end of the stream; + * @property timestamp The timestamp of when the event was created, represented in milliseconds since the Unix epoch. + */ +@Serializable +public data class LLMStreamingFrameReceivedEvent( + val runId: String, + val frame: StreamFrame, + override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), +) : DefinedFeatureEvent() + +/** + * Represents an event indicating a failure in the streaming process of a Language Learning Model (LLM). + * + * This event captures details of the failure encountered during the streaming operation. + * It includes information such as the unique identifier of the operation run, a detailed + * error description, and inherits common properties such as event ID and timestamp. + * + * @property runId A unique identifier representing the specific operation or run in which the failure occurred; + * @property error An instance of [AIAgentError], containing information about the error encountered, including its + * message, stack trace, and cause, if available; + * @property timestamp A timestamp indicating when the event occurred, represented in milliseconds since the Unix epoch. + */ +@Serializable +public data class LLMStreamingFailedEvent( + val runId: String, + val error: AIAgentError, + override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), +) : DefinedFeatureEvent() + +/** + * Represents an event that occurs when the streaming process of a Large Language Model (LLM) call is completed. + * + * @property runId The unique identifier of the LLM run; + * @property prompt The prompt associated with the LLM call; + * @property model The description of the LLM model used during the call. Use the format: 'llm_provider:model_id'; + * @property tools A list of tools used or invoked during the LLM call; + * @property timestamp The timestamp indicating when the event occurred, represented in milliseconds since the epoch, defaulting to the current system time. + */ +@Serializable +public data class LLMStreamingCompletedEvent( + val runId: String, + val prompt: Prompt, + val model: String, + val tools: List, + override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), +) : DefinedFeatureEvent() diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/nodeExecutionEvents.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/nodeExecutionEvents.kt index 6455b7850f..67daf8ffeb 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/nodeExecutionEvents.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/nodeExecutionEvents.kt @@ -17,7 +17,6 @@ import kotlinx.serialization.Serializable * * @property nodeName The name of the node whose execution is starting; * @property input The input data being processed by the node during the execution; - * @property eventId A string representing the event type; * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @Serializable @@ -25,7 +24,6 @@ public data class NodeExecutionStartingEvent( val runId: String, val nodeName: String, val input: String, - override val eventId: String = NodeExecutionStartingEvent::class.simpleName!!, override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), ) : DefinedFeatureEvent() @@ -39,7 +37,6 @@ public data class NodeExecutionStartingEvent( * @property nodeName The name of the node that finished execution; * @property input The input data provided to the node; * @property output The output generated by the node; - * @property eventId A unique identifier for the event type; * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @Serializable @@ -48,7 +45,6 @@ public data class NodeExecutionCompletedEvent( val nodeName: String, val input: String, val output: String, - override val eventId: String = NodeExecutionCompletedEvent::class.simpleName!!, override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), ) : DefinedFeatureEvent() @@ -67,7 +63,6 @@ public data class NodeExecutionCompletedEvent( * @property runId A unique identifier associated with the specific run of the AI agent; * @property nodeName The name of the node where the error occurred; * @property error An instance of `AIAgentError` containing the error details, such as a message, stack trace, and optional cause; - * @property eventId A string representing the event type; * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @Serializable @@ -75,7 +70,6 @@ public data class NodeExecutionFailedEvent( val runId: String, val nodeName: String, val error: AIAgentError, - override val eventId: String = NodeExecutionFailedEvent::class.simpleName!!, override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), ) : DefinedFeatureEvent() diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/strategyEvents.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/strategyEvents.kt index a75f593ba2..427975dd23 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/strategyEvents.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/strategyEvents.kt @@ -14,7 +14,6 @@ import kotlinx.serialization.Serializable * shared properties from the [DefinedFeatureEvent] superclass. * * @property strategyName The name of the strategy being started. - * @property eventId A string representing the event type. */ public abstract class StrategyStartingEvent : DefinedFeatureEvent() { @@ -53,7 +52,6 @@ public abstract class StrategyStartingEvent : DefinedFeatureEvent() { * @property strategyName The name of the graph-based strategy being executed. * @property graph The graph structure representing the strategy's execution workflow, encompassing nodes * and their directed relationships; - * @property eventId A string representing the event type; * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @Serializable @@ -61,7 +59,6 @@ public data class GraphStrategyStartingEvent( override val runId: String, override val strategyName: String, val graph: StrategyEventGraph, - override val eventId: String = GraphStrategyStartingEvent::class.simpleName!!, override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), ) : StrategyStartingEvent() @@ -76,14 +73,12 @@ public data class GraphStrategyStartingEvent( * * @property runId A unique identifier representing the specific run or instance of the strategy execution; * @property strategyName The name of the functional-based strategy being executed; - * @property eventId A string representing the event type; * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @Serializable public data class FunctionalStrategyStartingEvent( override val runId: String, override val strategyName: String, - override val eventId: String = FunctionalStrategyStartingEvent::class.simpleName!!, override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), ) : StrategyStartingEvent() @@ -96,7 +91,6 @@ public data class FunctionalStrategyStartingEvent( * @property strategyName The name of the strategy that was executed; * @property result The result of the strategy execution, providing details such as success, failure, * or other status descriptions; - * @property eventId A string representing the event type; * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @Serializable @@ -104,7 +98,6 @@ public data class StrategyCompletedEvent( val runId: String, val strategyName: String, val result: String?, - override val eventId: String = StrategyCompletedEvent::class.simpleName!!, override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), ) : DefinedFeatureEvent() diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/toolExecutionEvents.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/toolExecutionEvents.kt index e8933972e3..9b7e01c77f 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/toolExecutionEvents.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/toolExecutionEvents.kt @@ -15,16 +15,14 @@ import kotlinx.serialization.json.JsonObject * * @property toolName The unique name of the tool being called; * @property toolArgs The arguments provided for the tool execution; - * @property eventId A string representing the event type; * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @Serializable -public data class ToolExecutionStartingEvent( +public data class ToolCallStartingEvent( val runId: String, val toolCallId: String?, val toolName: String, val toolArgs: JsonObject, - override val eventId: String = ToolExecutionStartingEvent::class.simpleName!!, override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), ) : DefinedFeatureEvent() @@ -37,7 +35,6 @@ public data class ToolExecutionStartingEvent( * @property toolName The name of the tool that encountered the validation error; * @property toolArgs The arguments associated with the tool at the time of validation failure; * @property error A message describing the validation error encountered; - * @property eventId A string representing the event type; * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @Serializable @@ -47,7 +44,6 @@ public data class ToolValidationFailedEvent( val toolName: String, val toolArgs: JsonObject, val error: String, - override val eventId: String = ToolValidationFailedEvent::class.simpleName!!, override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), ) : DefinedFeatureEvent() @@ -61,17 +57,15 @@ public data class ToolValidationFailedEvent( * @property toolName The name of the tool that failed; * @property toolArgs The arguments passed to the tool during the failed execution; * @property error The error encountered during the tool's execution; - * @property eventId A string representing the event type; * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @Serializable -public data class ToolExecutionFailedEvent( +public data class ToolCallFailedEvent( val runId: String, val toolCallId: String?, val toolName: String, val toolArgs: JsonObject, val error: AIAgentError, - override val eventId: String = ToolExecutionFailedEvent::class.simpleName!!, override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), ) : DefinedFeatureEvent() @@ -85,27 +79,25 @@ public data class ToolExecutionFailedEvent( * @property toolName The name of the tool that was executed; * @property toolArgs The arguments used for executing the tool; * @property result The result of the tool execution, which may be null if no result was produced or an error occurred; - * @property eventId A string representing the event type; * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @Serializable -public data class ToolExecutionCompletedEvent( +public data class ToolCallCompletedEvent( val runId: String, val toolCallId: String?, val toolName: String, val toolArgs: JsonObject, val result: String?, - override val eventId: String = ToolExecutionCompletedEvent::class.simpleName!!, override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), ) : DefinedFeatureEvent() //region Deprecated @Deprecated( - message = "Use ToolExecutionStartingEvent instead", - replaceWith = ReplaceWith("ToolExecutionStartingEvent") + message = "Use ToolCallStartingEvent instead", + replaceWith = ReplaceWith("ToolCallStartingEvent") ) -public typealias ToolCallEvent = ToolExecutionStartingEvent +public typealias ToolCallEvent = ToolCallStartingEvent @Deprecated( message = "Use ToolValidationFailedEvent instead", @@ -114,15 +106,15 @@ public typealias ToolCallEvent = ToolExecutionStartingEvent public typealias ToolValidationErrorEvent = ToolValidationFailedEvent @Deprecated( - message = "Use ToolExecutionFailedEvent instead", - replaceWith = ReplaceWith("ToolExecutionFailedEvent") + message = "Use ToolCallFailedEvent instead", + replaceWith = ReplaceWith("ToolCallFailedEvent") ) -public typealias ToolCallFailureEvent = ToolExecutionFailedEvent +public typealias ToolCallFailureEvent = ToolCallFailedEvent @Deprecated( - message = "Use ToolExecutionCompletedEvent instead", - replaceWith = ReplaceWith("ToolExecutionCompletedEvent") + message = "Use ToolCallCompletedEvent instead", + replaceWith = ReplaceWith("ToolCallCompletedEvent") ) -public typealias ToolCallResultEvent = ToolExecutionCompletedEvent +public typealias ToolCallResultEvent = ToolCallCompletedEvent //endregion Deprecated diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/remote/jsonConfig.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/remote/jsonConfig.kt index 21504e9055..fbd636007e 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/remote/jsonConfig.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/remote/jsonConfig.kt @@ -13,14 +13,18 @@ import ai.koog.agents.core.feature.model.events.FunctionalStrategyStartingEvent import ai.koog.agents.core.feature.model.events.GraphStrategyStartingEvent import ai.koog.agents.core.feature.model.events.LLMCallCompletedEvent import ai.koog.agents.core.feature.model.events.LLMCallStartingEvent +import ai.koog.agents.core.feature.model.events.LLMStreamingCompletedEvent +import ai.koog.agents.core.feature.model.events.LLMStreamingFailedEvent +import ai.koog.agents.core.feature.model.events.LLMStreamingFrameReceivedEvent +import ai.koog.agents.core.feature.model.events.LLMStreamingStartingEvent import ai.koog.agents.core.feature.model.events.NodeExecutionCompletedEvent import ai.koog.agents.core.feature.model.events.NodeExecutionFailedEvent import ai.koog.agents.core.feature.model.events.NodeExecutionStartingEvent import ai.koog.agents.core.feature.model.events.StrategyCompletedEvent import ai.koog.agents.core.feature.model.events.StrategyStartingEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionFailedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionStartingEvent +import ai.koog.agents.core.feature.model.events.ToolCallCompletedEvent +import ai.koog.agents.core.feature.model.events.ToolCallFailedEvent +import ai.koog.agents.core.feature.model.events.ToolCallStartingEvent import ai.koog.agents.core.feature.model.events.ToolValidationFailedEvent import io.ktor.utils.io.InternalAPI import kotlinx.serialization.DeserializationStrategy @@ -75,10 +79,10 @@ public val defaultFeatureMessageJsonConfig: Json * - [StrategyCompletedEvent] - Fired when an AI agent strategy completes * - [NodeExecutionStartingEvent] - Fired when a node execution starts * - [NodeExecutionCompletedEvent] - Fired when a node execution ends - * - [ToolExecutionStartingEvent] - Fired when a tool is called + * - [ToolCallStartingEvent] - Fired when a tool is called * - [ToolValidationFailedEvent] - Fired when tool validation fails - * - [ToolExecutionFailedEvent] - Fired when a tool call fails - * - [ToolExecutionCompletedEvent] - Fired when a tool call returns a result + * - [ToolCallFailedEvent] - Fired when a tool call fails + * - [ToolCallCompletedEvent] - Fired when a tool call returns a result * - [LLMCallStartingEvent] - Fired before making an LLM call * - [LLMCallCompletedEvent] - Fired after completing an LLM call * @@ -101,12 +105,16 @@ public val defaultFeatureMessageSerializersModule: SerializersModule subclass(NodeExecutionStartingEvent::class, NodeExecutionStartingEvent.serializer()) subclass(NodeExecutionCompletedEvent::class, NodeExecutionCompletedEvent.serializer()) subclass(NodeExecutionFailedEvent::class, NodeExecutionFailedEvent.serializer()) - subclass(ToolExecutionStartingEvent::class, ToolExecutionStartingEvent.serializer()) + subclass(ToolCallStartingEvent::class, ToolCallStartingEvent.serializer()) subclass(ToolValidationFailedEvent::class, ToolValidationFailedEvent.serializer()) - subclass(ToolExecutionFailedEvent::class, ToolExecutionFailedEvent.serializer()) - subclass(ToolExecutionCompletedEvent::class, ToolExecutionCompletedEvent.serializer()) + subclass(ToolCallFailedEvent::class, ToolCallFailedEvent.serializer()) + subclass(ToolCallCompletedEvent::class, ToolCallCompletedEvent.serializer()) subclass(LLMCallStartingEvent::class, LLMCallStartingEvent.serializer()) subclass(LLMCallCompletedEvent::class, LLMCallCompletedEvent.serializer()) + subclass(LLMStreamingStartingEvent::class, LLMStreamingStartingEvent.serializer()) + subclass(LLMStreamingFrameReceivedEvent::class, LLMStreamingFrameReceivedEvent.serializer()) + subclass(LLMStreamingFailedEvent::class, LLMStreamingFailedEvent.serializer()) + subclass(LLMStreamingCompletedEvent::class, LLMStreamingCompletedEvent.serializer()) } polymorphic(FeatureEvent::class) { @@ -121,12 +129,16 @@ public val defaultFeatureMessageSerializersModule: SerializersModule subclass(NodeExecutionStartingEvent::class, NodeExecutionStartingEvent.serializer()) subclass(NodeExecutionCompletedEvent::class, NodeExecutionCompletedEvent.serializer()) subclass(NodeExecutionFailedEvent::class, NodeExecutionFailedEvent.serializer()) - subclass(ToolExecutionStartingEvent::class, ToolExecutionStartingEvent.serializer()) + subclass(ToolCallStartingEvent::class, ToolCallStartingEvent.serializer()) subclass(ToolValidationFailedEvent::class, ToolValidationFailedEvent.serializer()) - subclass(ToolExecutionFailedEvent::class, ToolExecutionFailedEvent.serializer()) - subclass(ToolExecutionCompletedEvent::class, ToolExecutionCompletedEvent.serializer()) + subclass(ToolCallFailedEvent::class, ToolCallFailedEvent.serializer()) + subclass(ToolCallCompletedEvent::class, ToolCallCompletedEvent.serializer()) subclass(LLMCallStartingEvent::class, LLMCallStartingEvent.serializer()) subclass(LLMCallCompletedEvent::class, LLMCallCompletedEvent.serializer()) + subclass(LLMStreamingStartingEvent::class, LLMStreamingStartingEvent.serializer()) + subclass(LLMStreamingFrameReceivedEvent::class, LLMStreamingFrameReceivedEvent.serializer()) + subclass(LLMStreamingFailedEvent::class, LLMStreamingFailedEvent.serializer()) + subclass(LLMStreamingCompletedEvent::class, LLMStreamingCompletedEvent.serializer()) } polymorphic(DefinedFeatureEvent::class) { @@ -140,12 +152,16 @@ public val defaultFeatureMessageSerializersModule: SerializersModule subclass(NodeExecutionStartingEvent::class, NodeExecutionStartingEvent.serializer()) subclass(NodeExecutionCompletedEvent::class, NodeExecutionCompletedEvent.serializer()) subclass(NodeExecutionFailedEvent::class, NodeExecutionFailedEvent.serializer()) - subclass(ToolExecutionStartingEvent::class, ToolExecutionStartingEvent.serializer()) + subclass(ToolCallStartingEvent::class, ToolCallStartingEvent.serializer()) subclass(ToolValidationFailedEvent::class, ToolValidationFailedEvent.serializer()) - subclass(ToolExecutionFailedEvent::class, ToolExecutionFailedEvent.serializer()) - subclass(ToolExecutionCompletedEvent::class, ToolExecutionCompletedEvent.serializer()) + subclass(ToolCallFailedEvent::class, ToolCallFailedEvent.serializer()) + subclass(ToolCallCompletedEvent::class, ToolCallCompletedEvent.serializer()) subclass(LLMCallStartingEvent::class, LLMCallStartingEvent.serializer()) subclass(LLMCallCompletedEvent::class, LLMCallCompletedEvent.serializer()) + subclass(LLMStreamingStartingEvent::class, LLMStreamingStartingEvent.serializer()) + subclass(LLMStreamingFrameReceivedEvent::class, LLMStreamingFrameReceivedEvent.serializer()) + subclass(LLMStreamingFailedEvent::class, LLMStreamingFailedEvent.serializer()) + subclass(LLMStreamingCompletedEvent::class, LLMStreamingCompletedEvent.serializer()) } polymorphic(StrategyStartingEvent::class) { diff --git a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/FunctionalAIAgentTest.kt b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/FunctionalAIAgentTest.kt index 29324e9767..9105f3bbe3 100644 --- a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/FunctionalAIAgentTest.kt +++ b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/FunctionalAIAgentTest.kt @@ -1,6 +1,5 @@ package ai.koog.agents.core.agent -import ai.koog.agents.core.agent.functionalStrategy import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.features.eventHandler.feature.EventHandler import ai.koog.agents.testing.tools.getMockExecutor @@ -52,7 +51,7 @@ class FunctionalAIAgentTest { toolRegistry = testToolRegistry ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } @@ -92,7 +91,7 @@ class FunctionalAIAgentTest { toolRegistry = testToolRegistry, ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } @@ -134,7 +133,7 @@ class FunctionalAIAgentTest { } ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } diff --git a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/SingleRunStrategyTests.kt b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/SingleRunStrategyTests.kt index ef525ae3cb..d4de2bd1a0 100644 --- a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/SingleRunStrategyTests.kt +++ b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/SingleRunStrategyTests.kt @@ -32,7 +32,7 @@ class SingleRunStrategyTests { toolRegistry = testToolRegistry ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } @@ -63,7 +63,7 @@ class SingleRunStrategyTests { toolRegistry = testToolRegistry ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } @@ -93,7 +93,7 @@ class SingleRunStrategyTests { toolRegistry = testToolRegistry ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } @@ -123,7 +123,7 @@ class SingleRunStrategyTests { toolRegistry = testToolRegistry ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } @@ -162,7 +162,7 @@ class SingleRunStrategyTests { toolRegistry = testToolRegistry ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } @@ -201,7 +201,7 @@ class SingleRunStrategyTests { toolRegistry = testToolRegistry ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } @@ -241,7 +241,7 @@ class SingleRunStrategyTests { toolRegistry = testToolRegistry ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } @@ -281,7 +281,7 @@ class SingleRunStrategyTests { toolRegistry = testToolRegistry ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } diff --git a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMSessionStructuredOutputTest.kt b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMSessionStructuredOutputTest.kt new file mode 100644 index 0000000000..dd24d3f6a3 --- /dev/null +++ b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMSessionStructuredOutputTest.kt @@ -0,0 +1,222 @@ +package ai.koog.agents.core.agent.session + +import ai.koog.agents.core.CalculatorChatExecutor.testClock +import ai.koog.agents.core.agent.context.AIAgentLLMContext +import ai.koog.agents.core.agent.context.AgentTestBase +import ai.koog.agents.core.tools.annotations.LLMDescription +import ai.koog.agents.testing.tools.getMockExecutor +import ai.koog.agents.testing.tools.mockLLMAnswer +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.executor.clients.openai.OpenAIModels +import ai.koog.prompt.message.Message +import ai.koog.prompt.message.ResponseMetaInfo +import ai.koog.prompt.structure.StructuredOutput +import ai.koog.prompt.structure.StructuredOutputConfig +import ai.koog.prompt.structure.json.JsonStructuredData +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.Serializable +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull + +class AIAgentLLMSessionStructuredOutputTest : AgentTestBase() { + + @Serializable + data class TestStructure( + @property:LLMDescription("The name field") + val name: String, + @property:LLMDescription("The age field") + val age: Int, + @property:LLMDescription("Optional description field") + val description: String? = null + ) + + @Test + fun testParseResponseToStructuredResponse() = runTest { + val structure = JsonStructuredData.createJsonStructure() + val config = StructuredOutputConfig( + default = StructuredOutput.Manual(structure) + ) + + val mockExecutor = getMockExecutor { + mockLLMAnswer("Test response").asDefaultResponse + } + + val prompt = prompt("test") { + system("Test system message") + } + + val llmContext = AIAgentLLMContext( + tools = emptyList(), + prompt = prompt, + model = OpenAIModels.CostOptimized.GPT4oMini, + promptExecutor = mockExecutor, + environment = createTestEnvironment(), + config = createTestConfig(), + clock = testClock + ) + + val context = createTestContext( + llmContext = llmContext + ) + + val jsonResponse = """ + { + "name": "John Doe", + "age": 30, + "description": "A test person" + } + """.trimIndent() + + val assistantMessage = Message.Assistant( + content = jsonResponse, + metaInfo = ResponseMetaInfo.create(testClock) + ) + + val result = context.llm.writeSession { + parseResponseToStructuredResponse(assistantMessage, config) + } + + assertNotNull(result) + assertEquals("John Doe", result.structure.name) + assertEquals(30, result.structure.age) + assertEquals("A test person", result.structure.description) + assertEquals(assistantMessage, result.message) + } + + @Test + fun testParseResponseToStructuredResponseWithNullableField() = runTest { + val structure = JsonStructuredData.createJsonStructure() + val config = StructuredOutputConfig( + default = StructuredOutput.Manual(structure) + ) + + val mockExecutor = getMockExecutor { + mockLLMAnswer("Test response").asDefaultResponse + } + + val prompt = prompt("test") { + system("Test system message") + } + + val llmContext = AIAgentLLMContext( + tools = emptyList(), + prompt = prompt, + model = OpenAIModels.CostOptimized.GPT4oMini, + promptExecutor = mockExecutor, + environment = createTestEnvironment(), + config = createTestConfig(), + clock = testClock + ) + + val context = createTestContext( + llmContext = llmContext + ) + + val jsonResponse = """ + { + "name": "Jane Doe", + "age": 25 + } + """.trimIndent() + + val assistantMessage = Message.Assistant( + content = jsonResponse, + metaInfo = ResponseMetaInfo.create(testClock) + ) + + val result = context.llm.writeSession { + parseResponseToStructuredResponse(assistantMessage, config) + } + + assertNotNull(result) + assertEquals("Jane Doe", result.structure.name) + assertEquals(25, result.structure.age) + assertEquals(null, result.structure.description) + assertEquals(assistantMessage, result.message) + } + + @Test + fun testParseResponseToStructuredResponseComplexStructure() = runTest { + @Serializable + data class Address( + @property:LLMDescription("Street name") + val street: String, + @property:LLMDescription("City name") + val city: String + ) + + @Serializable + data class ComplexStructure( + @property:LLMDescription("User identifier") + val id: Int, + @property:LLMDescription("List of addresses") + val addresses: List
, + @property:LLMDescription("Tags associated with the user") + val tags: Set + ) + + val structure = JsonStructuredData.createJsonStructure() + val config = StructuredOutputConfig( + default = StructuredOutput.Manual(structure) + ) + + val mockExecutor = getMockExecutor { + mockLLMAnswer("Test response").asDefaultResponse + } + + val prompt = prompt("test") { + system("Test system message") + } + + val llmContext = AIAgentLLMContext( + tools = emptyList(), + prompt = prompt, + model = OpenAIModels.CostOptimized.GPT4oMini, + promptExecutor = mockExecutor, + environment = createTestEnvironment(), + config = createTestConfig(), + clock = testClock + ) + + val context = createTestContext( + llmContext = llmContext + ) + + val jsonResponse = """ + { + "id": 123, + "addresses": [ + { + "street": "123 Main St", + "city": "New York" + }, + { + "street": "456 Oak Ave", + "city": "Los Angeles" + } + ], + "tags": ["vip", "premium", "verified"] + } + """.trimIndent() + + val assistantMessage = Message.Assistant( + content = jsonResponse, + metaInfo = ResponseMetaInfo.create(testClock) + ) + + val result = context.llm.writeSession { + parseResponseToStructuredResponse(assistantMessage, config) + } + + assertNotNull(result) + assertEquals(123, result.structure.id) + assertEquals(2, result.structure.addresses.size) + assertEquals("123 Main St", result.structure.addresses[0].street) + assertEquals("New York", result.structure.addresses[0].city) + assertEquals("456 Oak Ave", result.structure.addresses[1].street) + assertEquals("Los Angeles", result.structure.addresses[1].city) + assertEquals(setOf("vip", "premium", "verified"), result.structure.tags) + assertEquals(assistantMessage, result.message) + } +} diff --git a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodesHistoryCompressionTest.kt b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodesHistoryCompressionTest.kt index bdfa27e634..b32b8e63af 100644 --- a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodesHistoryCompressionTest.kt +++ b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodesHistoryCompressionTest.kt @@ -8,6 +8,7 @@ import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.features.eventHandler.feature.EventHandler import ai.koog.agents.features.eventHandler.feature.handleEvents import ai.koog.agents.testing.tools.DummyTool +import ai.koog.agents.utils.use import ai.koog.prompt.dsl.prompt import ai.koog.prompt.llm.OllamaModels import ai.koog.prompt.message.Message @@ -55,7 +56,7 @@ class AIAgentNodesHistoryCompressionTest { maxAgentIterations = 10 ) - val runner = AIAgent( + AIAgent( promptExecutor = testExecutor, strategy = agentStrategy, agentConfig = agentConfig, @@ -64,12 +65,12 @@ class AIAgentNodesHistoryCompressionTest { } ) { install(EventHandler) { - onAgentFinished { eventContext -> results += eventContext.result } + onAgentCompleted { eventContext -> results += eventContext.result } } + }.use { agent -> + agent.run("") } - runner.run("") - // After compression, we should have one result assertEquals(1, results.size) assertEquals("Done", results.first()) @@ -107,7 +108,7 @@ class AIAgentNodesHistoryCompressionTest { maxAgentIterations = 10 ) - val runner = AIAgent( + AIAgent( promptExecutor = testExecutor, strategy = agentStrategy, agentConfig = agentConfig, @@ -116,12 +117,12 @@ class AIAgentNodesHistoryCompressionTest { } ) { install(EventHandler) { - onAgentFinished { eventContext -> results += eventContext.result } + onAgentCompleted { eventContext -> results += eventContext.result } } + }.use { agent -> + agent.run("") } - runner.run("") - // After compression, we should have one result assertEquals(1, results.size) assertEquals("Done", results.first()) @@ -162,7 +163,7 @@ class AIAgentNodesHistoryCompressionTest { maxAgentIterations = 10 ) - val runner = AIAgent( + AIAgent( promptExecutor = testExecutor, strategy = agentStrategy, agentConfig = agentConfig, @@ -171,12 +172,12 @@ class AIAgentNodesHistoryCompressionTest { } ) { handleEvents { - onAgentFinished { eventContext -> results += eventContext.result } + onAgentCompleted { eventContext -> results += eventContext.result } } + }.use { agent -> + agent.run("") } - runner.run("") - // After compression, we should have one result assertEquals(1, results.size) assertEquals("Done", results.first()) diff --git a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodesTest.kt b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodesTest.kt index 461118a47e..76787d3a36 100644 --- a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodesTest.kt +++ b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodesTest.kt @@ -9,12 +9,19 @@ import ai.koog.agents.features.eventHandler.feature.EventHandler import ai.koog.agents.testing.tools.DummyTool import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.agents.testing.tools.mockLLMAnswer +import ai.koog.agents.utils.use +import ai.koog.prompt.dsl.Prompt import ai.koog.prompt.dsl.prompt import ai.koog.prompt.executor.clients.openai.OpenAIModels import ai.koog.prompt.llm.OllamaModels +import ai.koog.prompt.structure.StructuredOutput +import ai.koog.prompt.structure.StructuredOutputConfig +import ai.koog.prompt.structure.json.JsonStructuredData import kotlinx.coroutines.test.runTest +import kotlinx.serialization.Serializable import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertNotNull import kotlin.test.assertTrue class AIAgentNodesTest { @@ -43,7 +50,7 @@ class AIAgentNodesTest { mockLLMAnswer("Default test response").asDefaultResponse } - val runner = AIAgent( + AIAgent( promptExecutor = testExecutor, strategy = agentStrategy, agentConfig = agentConfig, @@ -52,12 +59,12 @@ class AIAgentNodesTest { } ) { install(EventHandler) { - onAgentFinished { eventContext -> results += eventContext.result } + onAgentCompleted { eventContext -> results += eventContext.result } } + }.use { agent -> + agent.run("") } - runner.run("") - // After compression, we should have one result assertEquals(1, results.size) assertEquals("Done", results.first()) @@ -104,7 +111,7 @@ class AIAgentNodesTest { maxAgentIterations = 10 ) - val runner = AIAgent( + AIAgent( promptExecutor = modelCapturingExecutor, strategy = agentStrategy, agentConfig = agentConfig, @@ -113,28 +120,121 @@ class AIAgentNodesTest { } ) { install(EventHandler) { - onAgentFinished { eventContext -> + onAgentCompleted { eventContext -> executionEvents += "Agent finished" results += eventContext.result } } + }.use { agent -> + + val executionResult = agent.run("Heeeey") + + assertEquals("Done", executionResult, "Agent execution should return 'Done'") + assertEquals(1, results.size, "Should have exactly one result") + + assertTrue(executionEvents.contains("nodeStart -> compress"), "Should transition from start to compress") + assertTrue(executionEvents.contains("compress -> nodeFinish"), "Should transition from compress to finish") + + assertTrue( + agentConfig.prompt.messages.any { it.content.contains("testing history compression") }, + "Prompt should contain test content for compression" + ) + assertTrue( + executionEvents.size >= 3, + "Should have at least 3 execution events (agent finished, node transitions)" + ) } + } - val executionResult = runner.run("Heeeey") + @Test + fun testNodeSetStructuredOutput() = runTest { + @Serializable + data class TestOutput( + val message: String, + val code: Int + ) - assertEquals("Done", executionResult, "Agent execution should return 'Done'") - assertEquals(1, results.size, "Should have exactly one result") + // Test Manual mode + val manualStructure = JsonStructuredData.createJsonStructure() + val manualConfig = StructuredOutputConfig( + default = StructuredOutput.Manual(manualStructure) + ) - assertTrue(executionEvents.contains("nodeStart -> compress"), "Should transition from start to compress") - assertTrue(executionEvents.contains("compress -> nodeFinish"), "Should transition from compress to finish") + var capturedPrompt: Prompt? = null - assertTrue( - agentConfig.prompt.messages.any { it.content.contains("testing history compression") }, - "Prompt should contain test content for compression" + val manualStrategy = strategy("test-manual") { + val setStructuredOutput by nodeSetStructuredOutput(config = manualConfig) + val checkPrompt by node { input -> + capturedPrompt = llm.prompt + input + } + + edge(nodeStart forwardTo setStructuredOutput) + edge(setStructuredOutput forwardTo checkPrompt) + edge(checkPrompt forwardTo nodeFinish) + } + + val testExecutor = getMockExecutor { + mockLLMAnswer("Test").asDefaultResponse + } + + val agentConfig = AIAgentConfig( + prompt = prompt("test") {}, + model = OpenAIModels.CostOptimized.GPT4oMini, + maxAgentIterations = 5 + ) + + val manualAgent = AIAgent( + promptExecutor = testExecutor, + strategy = manualStrategy, + agentConfig = agentConfig, + toolRegistry = ToolRegistry { } ) + + manualAgent.run("Test input") + + // Manual mode: schema should not be set, user message should be added + assertNotNull(capturedPrompt, "Prompt should be captured") + assertEquals(null, capturedPrompt!!.params.schema, "Schema should not be set for Manual config") assertTrue( - executionEvents.size >= 3, - "Should have at least 3 execution events (agent finished, node transitions)" + capturedPrompt!!.messages.any { it is ai.koog.prompt.message.Message.User }, + "Should have user message with instructions for Manual config" + ) + + // Test Native mode + val nativeStructure = JsonStructuredData.createJsonStructure() + val nativeConfig = StructuredOutputConfig( + default = StructuredOutput.Native(nativeStructure) ) + + val nativeStrategy = strategy("test-native") { + val setStructuredOutput by nodeSetStructuredOutput(config = nativeConfig) + val checkPrompt by node { input -> + capturedPrompt = llm.prompt + input + } + + edge(nodeStart forwardTo setStructuredOutput) + edge(setStructuredOutput forwardTo checkPrompt) + edge(checkPrompt forwardTo nodeFinish) + } + + val nativeAgent = AIAgent( + promptExecutor = testExecutor, + strategy = nativeStrategy, + agentConfig = AIAgentConfig( + prompt = prompt("test") {}, + model = OpenAIModels.CostOptimized.GPT4oMini, + maxAgentIterations = 5 + ), + toolRegistry = ToolRegistry { } + ) + + nativeAgent.run("Test input") + + // Native mode: schema should be set + assertNotNull(capturedPrompt, "Prompt should be captured") + assertNotNull(capturedPrompt!!.params.schema, "Schema should be set for Native config") + assertEquals(nativeStructure.schema, capturedPrompt!!.params.schema, "Schema should match structure's schema") } } diff --git a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/feature/TestFeature.kt b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/feature/TestFeature.kt index 8a9eba1e3e..173c75fb90 100644 --- a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/feature/TestFeature.kt +++ b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/feature/TestFeature.kt @@ -70,11 +70,11 @@ class TestFeature(val events: MutableList, val runIds: MutableList + pipeline.interceptToolCallStarting(context) { event -> feature.events += "Tool: call tool (tool: ${event.tool.name}, args: ${event.toolArgs})" } - pipeline.interceptToolExecutionCompleted(context) { event -> + pipeline.interceptToolCallCompleted(context) { event -> feature.events += "Tool: finish tool call with result (tool: ${event.tool.name}, result: ${event.result?.let(event.tool::encodeResultToStringUnsafe) ?: "null"})" } diff --git a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/feature/mock/TestFeatureEventMessage.kt b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/feature/mock/TestFeatureEventMessage.kt index 516780addd..d9050425b8 100644 --- a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/feature/mock/TestFeatureEventMessage.kt +++ b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/feature/mock/TestFeatureEventMessage.kt @@ -5,7 +5,6 @@ import ai.koog.agents.core.feature.message.FeatureMessage import kotlinx.datetime.Clock internal class TestFeatureEventMessage(id: String) : FeatureEvent { - override val eventId: String = id override val timestamp: Long get() = Clock.System.now().toEpochMilliseconds() override val messageType: FeatureMessage.Type = FeatureMessage.Type.Event } diff --git a/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/feature/writer/FeatureMessageFileWriterTest.kt b/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/feature/writer/FeatureMessageFileWriterTest.kt index 301cc029e2..968a4678ac 100644 --- a/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/feature/writer/FeatureMessageFileWriterTest.kt +++ b/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/feature/writer/FeatureMessageFileWriterTest.kt @@ -38,7 +38,7 @@ class FeatureMessageFileWriterTest { override fun FeatureMessage.toFileString(): String { return when (this) { - is TestFeatureEventMessage -> "[${this.messageType.value}] ${this.eventId}" + is TestFeatureEventMessage -> "[${this.messageType.value}] ${this.testMessage}" is FeatureStringMessage -> "[${this.messageType.value}] ${this.message}" else -> "UNDEFINED" } @@ -117,15 +117,15 @@ class FeatureMessageFileWriterTest { TestFeatureMessageFileWriter(tempDir).use { writer -> writer.initialize() - val stringMessage = FeatureStringMessage("Test message") - val eventMessage = TestFeatureEventMessage("Test event") + val stringMessage = FeatureStringMessage(message = "Test message") + val eventMessage = TestFeatureEventMessage(testMessage = "Test event") writer.onMessage(stringMessage) writer.onMessage(eventMessage) val expectedContent = listOf( "[${stringMessage.messageType.value}] ${stringMessage.message}", - "[${eventMessage.messageType.value}] ${eventMessage.eventId}" + "[${eventMessage.messageType.value}] ${eventMessage.testMessage}" ) val actualContent = writer.targetPath.readLines() diff --git a/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/feature/writer/FeatureMessageLogWriterTest.kt b/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/feature/writer/FeatureMessageLogWriterTest.kt index 3e7059e955..944fc0a0ef 100644 --- a/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/feature/writer/FeatureMessageLogWriterTest.kt +++ b/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/feature/writer/FeatureMessageLogWriterTest.kt @@ -30,7 +30,7 @@ class FeatureMessageLogWriterTest { "message: ${this.message}" } is FeatureEvent -> { - "event id: ${this.eventId}" + "feature events has no message provided" } else -> { "UNDEFINED" @@ -70,8 +70,8 @@ class FeatureMessageLogWriterTest { @Test fun `test logger stream feature provider for event message`() = runBlocking { val messages = listOf( - TestFeatureEventMessage("test-event-1"), - TestFeatureEventMessage("test-event-2"), + TestFeatureEventMessage(testMessage = "test-event-1"), + TestFeatureEventMessage(testMessage = "test-event-2"), ) TestFeatureMessageLogWriter(targetLogger).use { writer -> @@ -79,7 +79,7 @@ class FeatureMessageLogWriterTest { messages.forEach { message -> writer.onMessage(message) } val expectedLogMessages = messages.map { originalMessage -> - "[INFO] Received feature message [${originalMessage.messageType.value}]: event id: ${originalMessage.eventId}" + "[INFO] Received feature message [${originalMessage.messageType.value}]: feature events has no message provided" } assertEquals(expectedLogMessages.size, targetLogger.messages.size) @@ -101,7 +101,7 @@ class FeatureMessageLogWriterTest { val expectedLogMessages = listOf( "[INFO] Received feature message [${messages[0].messageType.value}]: message: ${(messages[0] as FeatureStringMessage).message}", - "[INFO] Received feature message [${messages[1].messageType.value}]: event id: ${(messages[1] as TestFeatureEventMessage).eventId}" + "[INFO] Received feature message [${messages[1].messageType.value}]: feature events has no message provided" ) assertEquals(expectedLogMessages.size, targetLogger.messages.size) @@ -141,7 +141,7 @@ class FeatureMessageLogWriterTest { messages.forEach { message -> writer.onMessage(message) } val expectedLogMessages = messages.map { originalMessage -> - "[DEBUG] Received feature message [${originalMessage.messageType.value}]: event id: ${originalMessage.eventId}" + "[DEBUG] Received feature message [${originalMessage.messageType.value}]: feature events has no message provided" } assertEquals(expectedLogMessages.size, targetLogger.messages.size) diff --git a/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/feature/writer/FeatureMessageRemoteWriterTest.kt b/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/feature/writer/FeatureMessageRemoteWriterTest.kt index 5f471aa879..4bb6c70d97 100644 --- a/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/feature/writer/FeatureMessageRemoteWriterTest.kt +++ b/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/feature/writer/FeatureMessageRemoteWriterTest.kt @@ -254,7 +254,7 @@ class FeatureMessageRemoteWriterTest { assertNotNull(actualEventMessage) { "Client received a server SSE message, but it is not a string message" } - assertEquals(testServerMessage.eventId, actualEventMessage.eventId) + assertEquals(testServerMessage.testMessage, actualEventMessage.testMessage) logger.info { "Client is finished successfully" } } diff --git a/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/feature/writer/TestFeatureEventMessage.kt b/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/feature/writer/TestFeatureEventMessage.kt index b9e68ab275..fd897ea346 100644 --- a/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/feature/writer/TestFeatureEventMessage.kt +++ b/agents/agents-core/src/jvmTest/kotlin/ai/koog/agents/core/feature/writer/TestFeatureEventMessage.kt @@ -8,7 +8,6 @@ import kotlinx.serialization.Serializable @Serializable data class TestFeatureEventMessage( val testMessage: String, - override val eventId: String = TestFeatureEventMessage::class.simpleName!!, override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), ) : FeatureEvent { diff --git a/agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentStrategies.kt b/agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentStrategies.kt index 6952929553..70bf252071 100644 --- a/agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentStrategies.kt +++ b/agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentStrategies.kt @@ -1,17 +1,25 @@ package ai.koog.agents.ext.agent +import ai.koog.agents.core.agent.context.AIAgentGraphContextBase import ai.koog.agents.core.agent.entity.AIAgentGraphStrategy import ai.koog.agents.core.agent.entity.createStorageKey import ai.koog.agents.core.dsl.builder.forwardTo import ai.koog.agents.core.dsl.builder.strategy +import ai.koog.agents.core.dsl.extension.nodeExecuteMultipleTools import ai.koog.agents.core.dsl.extension.nodeExecuteTool import ai.koog.agents.core.dsl.extension.nodeLLMRequest +import ai.koog.agents.core.dsl.extension.nodeLLMRequestMultiple +import ai.koog.agents.core.dsl.extension.nodeLLMSendMultipleToolResults import ai.koog.agents.core.dsl.extension.nodeLLMSendToolResult +import ai.koog.agents.core.dsl.extension.nodeSetStructuredOutput import ai.koog.agents.core.dsl.extension.onAssistantMessage +import ai.koog.agents.core.dsl.extension.onMultipleAssistantMessages +import ai.koog.agents.core.dsl.extension.onMultipleToolCalls import ai.koog.agents.core.dsl.extension.onToolCall import ai.koog.agents.core.environment.ReceivedToolResult import ai.koog.agents.core.environment.result import ai.koog.prompt.message.Message +import ai.koog.prompt.structure.StructuredOutputConfig // FIXME improve this strategy to use Message.Assistant to chat, it works better than tools @@ -164,3 +172,74 @@ public fun reActStrategy( edge(nodeExecuteTool forwardTo nodeCallLLMReason) edge(nodeCallLLMReason forwardTo nodeCallLLM) } + +/** + * Defines a strategy for handling structured output with tools integration using specified configuration and execution logic. + * + * This strategy facilitates a structured pipeline for generating outputs using tools and large language models (LLMs), + * enabling transformations between input, intermediate results, and structured output based on the provided configuration and execution behavior. + * + * @param Output The type of the structured output generated by the strategy. + * @param config The configuration for structured output processing, specifying schema, providers, and optional error handling mechanisms. + */ +public inline fun structuredOutputWithToolsStrategy( + config: StructuredOutputConfig, + parallelTools: Boolean = false +): AIAgentGraphStrategy = structuredOutputWithToolsStrategy( + config, + parallelTools +) { it } + +/** + * Defines a strategy for handling structured output with tools integration using specified configuration and execution logic. + * + * This strategy facilitates a structured pipeline for generating outputs using tools and large language models (LLMs), + * enabling transformations between input, intermediate results, and structured output based on the provided configuration and execution behavior. + * + * @param Input The type of the input to be processed by the strategy. + * @param Output The type of the structured output generated by the strategy. + * @param config The configuration for structured output processing, specifying schema, providers, and optional error handling mechanisms. + * @param transform A suspendable function that accepts the input of type `Input` and produces a string output + * that serves as the input for further processing in the structured output pipeline. + */ +public inline fun structuredOutputWithToolsStrategy( + config: StructuredOutputConfig, + parallelTools: Boolean = false, + noinline transform: suspend AIAgentGraphContextBase.(input: Input) -> String +): AIAgentGraphStrategy = strategy("structured_output_with_tools_strategy") { + val setStructuredOutput by nodeSetStructuredOutput(config = config) + val transformInput by node { transform(it) } + val callLLM by nodeLLMRequestMultiple() + val executeTools by nodeExecuteMultipleTools(parallelTools = parallelTools) + val sendToolResult by nodeLLMSendMultipleToolResults() + val transformToStructuredOutput by node { response -> + llm.writeSession { + parseResponseToStructuredResponse(response, config).structure + } + } + + // Set the structured output, get the input and then call the llm + nodeStart then setStructuredOutput then transformInput then callLLM + + // On tools + edge(callLLM forwardTo executeTools onMultipleToolCalls { true }) + edge(executeTools forwardTo sendToolResult) + + // On assistant messages + edge( + callLLM forwardTo transformToStructuredOutput + onMultipleAssistantMessages { true } + transformed { it.single() } + ) + + // Post tool result + edge(sendToolResult forwardTo executeTools onMultipleToolCalls { true }) + edge( + sendToolResult forwardTo transformToStructuredOutput + onMultipleAssistantMessages { true } + transformed { it.first() } + ) + + // Finish + transformToStructuredOutput then nodeFinish +} diff --git a/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/AIAgentStrategiesTest.kt b/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/AIAgentStrategiesTest.kt index fc2a0cce9d..6bbfd73103 100644 --- a/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/AIAgentStrategiesTest.kt +++ b/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/AIAgentStrategiesTest.kt @@ -1,9 +1,15 @@ package ai.koog.agents.ext.agent +import ai.koog.agents.core.tools.annotations.LLMDescription +import ai.koog.prompt.structure.StructuredOutput +import ai.koog.prompt.structure.StructuredOutputConfig +import ai.koog.prompt.structure.json.JsonStructuredData import kotlinx.coroutines.test.runTest +import kotlinx.serialization.Serializable import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFailsWith +import kotlin.test.assertNotNull class AIAgentStrategiesTest { private val defaultName = "re_act" @@ -43,4 +49,151 @@ class AIAgentStrategiesTest { reActStrategy(reasoningInterval = -1) } } + + @Test + fun testStructuredOutputWithToolsStrategyDefaultName() = runTest { + @Serializable + data class TestOutput( + @property:LLMDescription("Test field") + val field: String + ) + + val structure = JsonStructuredData.createJsonStructure() + val config = StructuredOutputConfig( + default = StructuredOutput.Manual(structure) + ) + + val strategy = structuredOutputWithToolsStrategy(config) { input -> + "Processed: $input" + } + + assertEquals("structured_output_with_tools_strategy", strategy.name) + } + + @Test + fun testStructuredOutputWithToolsStrategyWithParallelTools() = runTest { + @Serializable + data class TestResult( + @property:LLMDescription("Result message") + val message: String, + @property:LLMDescription("Success status") + val success: Boolean + ) + + val structure = JsonStructuredData.createJsonStructure() + val config = StructuredOutputConfig( + default = StructuredOutput.Manual(structure) + ) + + val strategyWithParallel = structuredOutputWithToolsStrategy( + config = config, + parallelTools = true + ) { input -> + "Processing with parallel tools: $input" + } + + assertNotNull(strategyWithParallel) + assertEquals("structured_output_with_tools_strategy", strategyWithParallel.name) + + val strategyWithoutParallel = structuredOutputWithToolsStrategy( + config = config, + parallelTools = false + ) { input -> + "Processing without parallel tools: $input" + } + + assertNotNull(strategyWithoutParallel) + assertEquals("structured_output_with_tools_strategy", strategyWithoutParallel.name) + } + + @Test + fun testStructuredOutputWithToolsStrategyComplexTypes() = runTest { + @Serializable + data class Address( + @property:LLMDescription("Street address") + val street: String, + @property:LLMDescription("City") + val city: String, + @property:LLMDescription("ZIP code") + val zipCode: String + ) + + @Serializable + data class ComplexOutput( + @property:LLMDescription("User ID") + val id: Int, + @property:LLMDescription("User name") + val name: String, + @property:LLMDescription("User addresses") + val addresses: List
, + @property:LLMDescription("User preferences") + val preferences: Map? = null + ) + + @Serializable + data class ComplexInput( + val userId: String, + val requestType: String + ) + + val structure = JsonStructuredData.createJsonStructure() + val config = StructuredOutputConfig( + default = StructuredOutput.Manual(structure) + ) + + val strategy = structuredOutputWithToolsStrategy(config) { input -> + "Fetch user data for ID: ${input.userId}, type: ${input.requestType}" + } + + assertNotNull(strategy) + assertEquals("structured_output_with_tools_strategy", strategy.name) + } + + @Test + fun testStructuredOutputWithToolsStrategyDifferentConfigs() = runTest { + @Serializable + data class SimpleOutput( + @property:LLMDescription("Value") + val value: String + ) + + val manualStructure = JsonStructuredData.createJsonStructure() + val nativeStructure = JsonStructuredData.createJsonStructure() + + // Test with manual mode + val manualConfig = StructuredOutputConfig( + default = StructuredOutput.Manual(manualStructure) + ) + + val manualStrategy = structuredOutputWithToolsStrategy(manualConfig) { input -> + "Manual mode: $input" + } + + assertNotNull(manualStrategy) + + // Test with native mode + val nativeConfig = StructuredOutputConfig( + default = StructuredOutput.Native(nativeStructure) + ) + + val nativeStrategy = structuredOutputWithToolsStrategy(nativeConfig) { input -> + "Native mode: $input" + } + + assertNotNull(nativeStrategy) + + // Test with both modes in config + val mixedConfig = StructuredOutputConfig( + default = StructuredOutput.Manual(manualStructure), + byProvider = mapOf( + ai.koog.prompt.llm.LLMProvider.OpenAI to StructuredOutput.Native(nativeStructure) + ) + ) + + val mixedStrategy = structuredOutputWithToolsStrategy(mixedConfig) { input -> + "Mixed mode: $input" + } + + assertNotNull(mixedStrategy) + } } diff --git a/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/StructuredOutputWithToolsIntegrationTest.kt b/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/StructuredOutputWithToolsIntegrationTest.kt new file mode 100644 index 0000000000..1423b3d1b9 --- /dev/null +++ b/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/StructuredOutputWithToolsIntegrationTest.kt @@ -0,0 +1,315 @@ +package ai.koog.agents.ext.agent + +import ai.koog.agents.core.agent.AIAgent +import ai.koog.agents.core.agent.config.AIAgentConfig +import ai.koog.agents.core.tools.SimpleTool +import ai.koog.agents.core.tools.ToolRegistry +import ai.koog.agents.core.tools.annotations.LLMDescription +import ai.koog.agents.features.eventHandler.feature.EventHandler +import ai.koog.agents.testing.tools.getMockExecutor +import ai.koog.agents.testing.tools.mockLLMAnswer +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.executor.clients.openai.OpenAIModels +import ai.koog.prompt.structure.StructuredOutput +import ai.koog.prompt.structure.StructuredOutputConfig +import ai.koog.prompt.structure.json.JsonStructuredData +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.KSerializer +import kotlinx.serialization.Serializable +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull + +class StructuredOutputWithToolsIntegrationTest { + + @Serializable + data class WeatherRequest( + val city: String, + val country: String + ) + + @Serializable + data class WeatherResponse( + @property:LLMDescription("Temperature in Celsius") + val temperature: Int, + @property:LLMDescription("Weather conditions") + val conditions: String, + @property:LLMDescription("Wind speed in km/h") + val windSpeed: Double, + @property:LLMDescription("Humidity percentage") + val humidity: Int + ) + + object GetTemperatureTool : SimpleTool() { + @Serializable + data class Args( + @property:LLMDescription("City name") + val city: String, + @property:LLMDescription("Country name") + val country: String + ) + + override val argsSerializer: KSerializer = Args.serializer() + + override val name: String = "get_temperature" + override val description: String = "Get current temperature for a city" + + override suspend fun doExecute(args: Args): String = + "Temperature in ${args.city}, ${args.country}: 22°C" + } + + object GetWeatherConditionsTool : SimpleTool() { + @Serializable + data class Args( + @property:LLMDescription("City name") + val city: String, + @property:LLMDescription("Country name") + val country: String + ) + + override val argsSerializer: KSerializer = Args.serializer() + + override val name: String = "get_weather_conditions" + override val description: String = "Get current weather conditions for a city" + + override suspend fun doExecute(args: Args): String = + "Weather conditions in ${args.city}, ${args.country}: Partly Cloudy" + } + + object GetWindSpeedTool : SimpleTool() { + @Serializable + data class Args( + @property:LLMDescription("City name") + val city: String, + @property:LLMDescription("Country name") + val country: String + ) + + override val argsSerializer: KSerializer = Args.serializer() + + override val name: String = "get_wind_speed" + override val description: String = "Get current wind speed for a city" + + override suspend fun doExecute(args: Args): String = + "Wind speed in ${args.city}, ${args.country}: 15.5 km/h" + } + + object GetHumidityTool : SimpleTool() { + @Serializable + data class Args( + @property:LLMDescription("City name") + val city: String, + @property:LLMDescription("Country name") + val country: String + ) + + override val argsSerializer: KSerializer = Args.serializer() + + override val name: String = "get_humidity" + override val description: String = "Get current humidity for a city" + + override suspend fun doExecute(args: Args): String = + "Humidity in ${args.city}, ${args.country}: 65%" + } + + @Test + fun testStructuredOutputWithToolsIntegration() = runTest { + val structure = JsonStructuredData.createJsonStructure() + val config = StructuredOutputConfig( + default = StructuredOutput.Manual(structure) + ) + + val strategy = structuredOutputWithToolsStrategy( + config = config, + parallelTools = false + ) { request -> + "Get complete weather data for ${request.city}, ${request.country}" + } + + val toolCallEvents = mutableListOf() + val results = mutableListOf() + + // For common tests, we need to use a simpler mock setup + val mockExecutor = getMockExecutor { + // Simply return the structured output directly + mockLLMAnswer( + """ + { + "temperature": 22, + "conditions": "Partly Cloudy", + "windSpeed": 15.5, + "humidity": 65 + } + """.trimIndent() + ).asDefaultResponse + } + + val agentConfig = AIAgentConfig( + prompt = prompt("weather-agent") { + system( + """ + You are a weather assistant. Use the available tools to gather weather data + and return a complete weather report in the specified JSON format. + """.trimIndent() + ) + }, + model = OpenAIModels.CostOptimized.GPT4oMini, + maxAgentIterations = 10 + ) + + val agent = AIAgent( + promptExecutor = mockExecutor, + strategy = strategy, + agentConfig = agentConfig, + toolRegistry = ToolRegistry { + tool(GetTemperatureTool) + tool(GetWeatherConditionsTool) + tool(GetWindSpeedTool) + tool(GetHumidityTool) + } + ) { + install(EventHandler) { + onToolCallStarting { eventContext -> + toolCallEvents.add(eventContext.tool.name) + } + onAgentCompleted { eventContext -> + eventContext.result?.let { results.add(it as WeatherResponse) } + } + } + } + + val request = WeatherRequest(city = "New York", country = "USA") + val result = agent.run(request) + + assertNotNull(result) + assertEquals(22, result.temperature) + assertEquals("Partly Cloudy", result.conditions) + assertEquals(15.5, result.windSpeed) + assertEquals(65, result.humidity) + } + + @Test + fun testStructuredOutputWithToolsParallelExecution() = runTest { + val structure = JsonStructuredData.createJsonStructure() + val config = StructuredOutputConfig( + default = StructuredOutput.Manual(structure) + ) + + val strategy = structuredOutputWithToolsStrategy( + config = config, + parallelTools = true // Enable parallel tool execution + ) { request -> + "Get all weather metrics simultaneously for ${request.city}, ${request.country}" + } + + val toolCallTimestamps = mutableMapOf() + val currentTime = kotlinx.datetime.Clock.System.now().toEpochMilliseconds() + + val mockExecutor = getMockExecutor { + // Return structured output + mockLLMAnswer( + """ + { + "temperature": 18, + "conditions": "Rainy", + "windSpeed": 20.0, + "humidity": 80 + } + """.trimIndent() + ).asDefaultResponse + } + + val agentConfig = AIAgentConfig( + prompt = prompt("weather-agent-parallel") { + system( + """ + You are a weather assistant. Gather all weather metrics in parallel + and return a complete weather report. + """.trimIndent() + ) + }, + model = OpenAIModels.CostOptimized.GPT4oMini, + maxAgentIterations = 10 + ) + + val agent = AIAgent( + promptExecutor = mockExecutor, + strategy = strategy, + agentConfig = agentConfig, + toolRegistry = ToolRegistry { + tool(GetTemperatureTool) + tool(GetWeatherConditionsTool) + tool(GetWindSpeedTool) + tool(GetHumidityTool) + } + ) { + install(EventHandler) { + onToolCallStarting { eventContext -> + toolCallTimestamps[eventContext.tool.name] = currentTime + } + } + } + + val request = WeatherRequest(city = "London", country = "UK") + val result = agent.run(request) + + assertNotNull(result) + assertEquals(18, result.temperature) + assertEquals("Rainy", result.conditions) + assertEquals(20.0, result.windSpeed) + assertEquals(80, result.humidity) + } + + @Test + fun testStructuredOutputWithNoTools() = runTest { + val structure = JsonStructuredData.createJsonStructure() + val config = StructuredOutputConfig( + default = StructuredOutput.Manual(structure) + ) + + val strategy = structuredOutputWithToolsStrategy( + config = config + ) { request -> + "Generate mock weather data for ${request.city}, ${request.country}" + } + + val mockExecutor = getMockExecutor { + // LLM directly returns structured output without calling tools + // Set as default response to match any request + mockLLMAnswer( + """ + { + "temperature": 25, + "conditions": "Sunny", + "windSpeed": 10.0, + "humidity": 50 + } + """.trimIndent() + ).asDefaultResponse + } + + val agentConfig = AIAgentConfig( + prompt = prompt("weather-agent-no-tools") { + system("Generate weather data without using tools.") + }, + model = OpenAIModels.CostOptimized.GPT4oMini, + maxAgentIterations = 10 + ) + + val agent = AIAgent( + promptExecutor = mockExecutor, + strategy = strategy, + agentConfig = agentConfig, + toolRegistry = ToolRegistry { } // No tools registered + ) + + val request = WeatherRequest(city = "Paris", country = "France") + val result = agent.run(request) + + assertNotNull(result) + assertEquals(25, result.temperature) + assertEquals("Sunny", result.conditions) + assertEquals(10.0, result.windSpeed) + assertEquals(50, result.humidity) + } +} diff --git a/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/SubgraphWithRetryTest.kt b/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/SubgraphWithRetryTest.kt index a82b26c50e..e13b41cc45 100644 --- a/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/SubgraphWithRetryTest.kt +++ b/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/SubgraphWithRetryTest.kt @@ -8,6 +8,7 @@ import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.features.eventHandler.feature.EventHandler import ai.koog.agents.testing.tools.DummyTool import ai.koog.agents.testing.tools.getMockExecutor +import ai.koog.agents.utils.use import ai.koog.prompt.dsl.prompt import ai.koog.prompt.llm.OllamaModels import ai.koog.prompt.message.Message @@ -119,7 +120,7 @@ class SubgraphWithRetryTest { maxAgentIterations = MAX_AGENT_ITERATIONS, ) - val agent = AIAgent( + AIAgent( promptExecutor = getMockExecutor {}, strategy = testStrategy, agentConfig = agentConfig, @@ -128,12 +129,12 @@ class SubgraphWithRetryTest { }, ) { install(EventHandler) { - onAgentFinished { eventContext -> results += eventContext.result } + onAgentCompleted { eventContext -> results += eventContext.result } } + }.use { agent -> + agent.run("test input") } - agent.run("test input") - assertEquals(1, results.size) val result = results.first() as RetrySubgraphResult<*> assertEquals(SUCCESS, result.output) @@ -176,7 +177,7 @@ class SubgraphWithRetryTest { maxAgentIterations = MAX_AGENT_ITERATIONS, ) - val agent = AIAgent( + AIAgent( promptExecutor = getMockExecutor {}, strategy = testStrategy, agentConfig = agentConfig, @@ -185,12 +186,12 @@ class SubgraphWithRetryTest { }, ) { install(EventHandler) { - onAgentFinished { eventContext -> results += eventContext.result } + onAgentCompleted { eventContext -> results += eventContext.result } } + }.use { agent -> + agent.run("test input") } - agent.run("test input") - assertEquals(1, results.size) val result = results.first() as RetrySubgraphResult<*> assertEquals(SUCCESS, result.output) @@ -229,7 +230,7 @@ class SubgraphWithRetryTest { maxAgentIterations = MAX_AGENT_ITERATIONS, ) - val agent = AIAgent( + AIAgent( promptExecutor = getMockExecutor {}, strategy = testStrategy, agentConfig = agentConfig, @@ -238,17 +239,19 @@ class SubgraphWithRetryTest { }, ) { install(EventHandler) { - onAgentFinished { eventContext -> results += eventContext.result } + onAgentCompleted { eventContext -> results += eventContext.result } } - } + }.use { agent -> - agent.run("test input") + agent.run("test input") - assertEquals(1, results.size) - val result = results.first() as RetrySubgraphResult<*> - assertEquals(maxAttempts, result.retryCount) - assertEquals(maxAttempts, attemptCount.size) - assertFalse(result.success) + assertEquals(1, results.size) + + val result = results.first() as RetrySubgraphResult<*> + assertEquals(maxAttempts, result.retryCount) + assertEquals(maxAttempts, attemptCount.size) + assertFalse(result.success) + } } @Test @@ -293,7 +296,7 @@ class SubgraphWithRetryTest { maxAgentIterations = MAX_AGENT_ITERATIONS, ) - val agent = AIAgent( + AIAgent( promptExecutor = getMockExecutor {}, strategy = testStrategy, agentConfig = agentConfig, @@ -302,12 +305,12 @@ class SubgraphWithRetryTest { }, ) { install(EventHandler) { - onAgentFinished { eventContext -> results += eventContext.result } + onAgentCompleted { eventContext -> results += eventContext.result } } + }.use { agent -> + agent.run("test input") } - agent.run("test input") - assertEquals(1, results.size) assertEquals(SUCCESS, results.first()) } @@ -352,7 +355,7 @@ class SubgraphWithRetryTest { }, ) { install(EventHandler) { - onAgentFinished { eventContext -> results += eventContext.result } + onAgentCompleted { eventContext -> results += eventContext.result } } } @@ -394,7 +397,7 @@ class SubgraphWithRetryTest { maxAgentIterations = MAX_AGENT_ITERATIONS, ) - val agent = AIAgent( + AIAgent( promptExecutor = getMockExecutor {}, strategy = testStrategy, agentConfig = agentConfig, @@ -403,15 +406,16 @@ class SubgraphWithRetryTest { }, ) { install(EventHandler) { - onAgentFinished { eventContext -> results += eventContext.result } + onAgentCompleted { eventContext -> results += eventContext.result } } - } + }.use { agent -> - agent.run("test input") + agent.run("test input") - assertEquals(1, results.size) - assertEquals("failure", results.first()) - assertEquals(maxAttempts, attemptCount.size) + assertEquals(1, results.size) + assertEquals("failure", results.first()) + assertEquals(maxAttempts, attemptCount.size) + } } @Test diff --git a/agents/agents-features/Module.md b/agents/agents-features/Module.md index 055f5de60e..04fcaf9593 100644 --- a/agents/agents-features/Module.md +++ b/agents/agents-features/Module.md @@ -1,65 +1,76 @@ # Module agents:agents-features -Provides implementations of useful features of AI agents, such as Tracing, Debugger, EventHandler, Memory, OpenTelemetry, Snapshot, and more. +A collection of plug-and-play features for Koog AI agents. These features hook into the agent execution pipeline to observe, enrich, and extend behavior. -### Overview +## What is inside -Features integrate with the agent pipeline via interceptor hooks and consume standardized Feature Events emitted during agent execution. After the 0afb32b refactor, event and interceptor names are unified across the system. +Each feature lives in its own submodule, you only depend on what you need. Commonly used features include: +- Tracing: End-to-end spans for agent runs and LLM, or Tool calls. Great for local dev and production observability; +- Debugger: Step-through style inspection of the agent pipeline; +- EventHandler: Subscribe to standardized agent events and react (log, metrics, custom side effects); +- Memory: Pluggable memory interfaces for storing and retrieving agent context; +- OpenTelemetry: OTel exporters and wiring for spans emitted by the agent pipeline; +- Snapshot: Persist and restore agent snapshots for reproducibility and time-travel debugging. -### Standard Feature Events +Check each feature’s own README/Module docs for details and advanced configuration. -- Agent events: - - AgentStartingEvent - - AgentCompletedEvent - - AgentExecutionFailedEvent - - AgentClosingEvent +## How features integrate -- Strategy events: - - GraphStrategyStartingEvent - - FunctionalStrategyStartingEvent - - StrategyCompletedEvent +Features integrate via interceptor hooks and consume standardized events emitted during an agent execution. These events are defined in the agents-core module under: +- ai.koog.agents.core.feature.model.events -- Node execution events: - - NodeExecutionStartingEvent - - NodeExecutionCompletedEvent - - NodeExecutionFailedEvent +Typical events include: +- AgentStarting/Completed +- LLMCallStarting/Completed +- ToolCallStarting/Completed -- LLM call events: - - LLMCallStartingEvent - - LLMCallCompletedEvent +Features can listen to these events, mutate context when appropriate, and publish additional events for downstream consumers. -- Tool execution events: - - ToolExecutionStartingEvent - - ToolValidationFailedEvent - - ToolExecutionFailedEvent - - ToolExecutionCompletedEvent +## Installing features -These events are produced by features such as Tracing and Debugger to enable logging, tracing, monitoring, and remote inspection. - -### Using in your project - -Add the desired feature dependency, for example: +Install features when constructing your agent. Multiple features can be installed together; they remain decoupled and communicate via events. ```kotlin -dependencies { - implementation("ai.koog.agents:agents-features-trace:$version") - implementation("ai.koog.agents:agents-features-debugger:$version") +val agent = createAgent(/* ... */) { + install(Tracing) { + // Tracing configuration + } + install(OpenTelemetry) { + // OTel configuration + } } ``` -Install a feature in the agent builder: +Consult each feature’s README for exact configuration options and defaults. + +## Using in unit tests + +Features are test-friendly. They honor testing configurations and can be directed to in-memory writers/ports. +- Install the same feature in tests to capture events deterministically. +- Point outputs to test stubs to assert behavior (e.g., assert a specific sequence of Feature Events). +- Prefer higher sampling in tests so important transitions are recorded. +Example (pseudo): ```kotlin -val agent = createAgent(/* ... */) { - install(Tracing) - install(Debugger) +@Test +fun testAgentEmitsExpectedEvents() { + val events = MutableListWriter() + createAgent { + install(EventHandler) { + writer = events + } + }.use { agent -> + agent.run("Hello") + assertTrue(events.any { it is LlmCallRequested }) + } } ``` -### Using in unit tests - -Most features can be installed in tests; they honor testing configuration and can be pointed to test writers/ports. - -### Example of usage - -See each feature's Module/README in its submodule for concrete examples (Tracing, Debugger, EventHandler, Memory, OpenTelemetry, Snapshot). +## Where to learn more +See each feature’s Module/README in its submodule for concrete examples and advanced setup: +- Tracing +- Debugger +- EventHandler +- Memory +- OpenTelemetry +- Snapshot diff --git a/agents/agents-features/agents-features-a2a-client/Module.md b/agents/agents-features/agents-features-a2a-client/Module.md new file mode 100644 index 0000000000..985dc62807 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-client/Module.md @@ -0,0 +1,12 @@ +# Module agents-features-a2a-client + +Agent feature for enabling A2A client capabilities within Koog agents. + +## Overview + +This module provides an agent feature that enables Koog agents to act as A2A clients, allowing them to communicate with remote A2A-enabled agents. When installed, agents gain access to a registry of A2A clients and pre-built nodes for sending messages, managing tasks, and subscribing to events from other A2A agents. + +## Key Components + +- **`A2AAgentClient`**: Feature providing access to registered A2A clients from agent nodes +- **Agent Nodes**: Pre-built nodes for common operations (messaging, task management, push notifications) diff --git a/agents/agents-features/agents-features-a2a-client/build.gradle.kts b/agents/agents-features/agents-features-a2a-client/build.gradle.kts new file mode 100644 index 0000000000..2af1909021 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-client/build.gradle.kts @@ -0,0 +1,40 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + sourceSets { + commonMain { + dependencies { + api(project(":agents:agents-core")) + api(project(":a2a:a2a-client")) + api(project(":agents:agents-features:agents-features-a2a-core")) + + api(libs.kotlinx.serialization.json) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + implementation(libs.kotlinx.coroutines.test) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/agents/agents-features/agents-features-a2a-client/src/commonMain/kotlin/ai/koog/agents/a2a/client/feature/A2AAgentClient.kt b/agents/agents-features/agents-features-a2a-client/src/commonMain/kotlin/ai/koog/agents/a2a/client/feature/A2AAgentClient.kt new file mode 100644 index 0000000000..d2af12c29c --- /dev/null +++ b/agents/agents-features/agents-features-a2a-client/src/commonMain/kotlin/ai/koog/agents/a2a/client/feature/A2AAgentClient.kt @@ -0,0 +1,111 @@ +package ai.koog.agents.a2a.client.feature + +import ai.koog.a2a.client.A2AClient +import ai.koog.agents.core.agent.context.AIAgentContext +import ai.koog.agents.core.agent.entity.AIAgentStorageKey +import ai.koog.agents.core.agent.entity.createStorageKey +import ai.koog.agents.core.feature.AIAgentGraphFeature +import ai.koog.agents.core.feature.AIAgentGraphPipeline +import ai.koog.agents.core.feature.AIAgentNonGraphFeature +import ai.koog.agents.core.feature.AIAgentNonGraphPipeline +import ai.koog.agents.core.feature.config.FeatureConfig +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract + +/** + * Agent feature that enables A2A client mode by providing access to registered A2A clients + * from within agent strategies. + * + * This feature allows agents to communicate with other A2A-enabled agents by making + * [A2AClient] instances available to agent nodes. When installed, it provides access to + * a registry of A2A clients, allowing agent nodes to: + * - Send messages to remote A2A agents + * - Retrieve agent cards and capabilities + * - Manage tasks on remote agents + * - Subscribe to task events and streaming responses + * - Configure push notifications + * + * The feature provides convenience nodes for common A2A client + * operations like sending messages, retrieving tasks, and managing subscriptions. + * + * @property a2aClients Map of A2A clients keyed by agent ID + * + * @see ai.koog.a2a.client.A2AClient + */ +public class A2AAgentClient( + public val a2aClients: Map +) { + /** + * Configuration for the [A2AAgentClient] feature. + */ + public class Config : FeatureConfig() { + /** + * Map of [A2AClient] instances keyed by agent ID for accessing remote A2A agents. + */ + public val a2aClients: Map = mapOf() + } + + public companion object Feature : + AIAgentGraphFeature, + AIAgentNonGraphFeature { + + override val key: AIAgentStorageKey = + createStorageKey("agents-features-a2a-client") + + override fun createInitialConfig(): Config = Config() + + override fun install( + config: Config, + pipeline: AIAgentGraphPipeline + ) { + pipeline.interceptContextAgentFeature(this) { _ -> + A2AAgentClient(config.a2aClients) + } + } + + override fun install( + config: Config, + pipeline: AIAgentNonGraphPipeline + ) { + pipeline.interceptContextAgentFeature(this) { + A2AAgentClient(config.a2aClients) + } + } + } +} + +/** + * Retrieves the [A2AAgentClient] feature from the agent context. + * + * @return The installed A2AAgentClient feature + * @throws IllegalStateException if the feature is not installed + */ +public fun AIAgentContext.a2aAgentClient(): A2AAgentClient = featureOrThrow(A2AAgentClient.Feature) + +/** + * Executes an action with the [A2AAgentClient] feature as the receiver. + * This is a convenience function that retrieves the feature and provides it as the receiver for the action block. + * + * @param action The action to execute with A2AAgentClient as receiver + * @return The result of the action + * @throws IllegalStateException if the feature is not installed + */ +@OptIn(ExperimentalContracts::class) +public inline fun AIAgentContext.withA2AAgentClient(action: A2AAgentClient.() -> T): T { + contract { + callsInPlace(action, InvocationKind.AT_MOST_ONCE) + } + + return a2aAgentClient().action() +} + +/** + * Retrieves an A2A client by agent ID or throws if not found. + * + * @param agentId The identifier of the A2A agent to retrieve + * @return The A2AClient instance for the specified agent ID + * @throws NoSuchElementException if no client is registered with the given agent ID + */ +public fun A2AAgentClient.a2aClientOrThrow(agentId: String): A2AClient = + a2aClients[agentId] ?: throw NoSuchElementException("A2A agent with id $agentId not found in the current agent context. Make sure to register it in the A2AAgentClient feature.") diff --git a/agents/agents-features/agents-features-a2a-client/src/commonMain/kotlin/ai/koog/agents/a2a/client/feature/A2AAgentClientNodes.kt b/agents/agents-features/agents-features-a2a-client/src/commonMain/kotlin/ai/koog/agents/a2a/client/feature/A2AAgentClientNodes.kt new file mode 100644 index 0000000000..c40a0bf065 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-client/src/commonMain/kotlin/ai/koog/agents/a2a/client/feature/A2AAgentClientNodes.kt @@ -0,0 +1,293 @@ +package ai.koog.agents.a2a.client.feature + +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.Event +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskPushNotificationConfig +import ai.koog.a2a.model.TaskPushNotificationConfigParams +import ai.koog.a2a.model.TaskQueryParams +import ai.koog.a2a.transport.ClientCallContext +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.Response +import ai.koog.agents.core.dsl.builder.AIAgentBuilderDslMarker +import ai.koog.agents.core.dsl.builder.AIAgentNodeDelegate +import ai.koog.agents.core.dsl.builder.AIAgentSubgraphBuilderBase +import kotlinx.coroutines.flow.Flow +import kotlinx.serialization.Serializable + +/** + * Request parameters for A2A client operations that require call context and typed parameters. + * + * @property agentId The identifier of the A2A agent to send the request to, with which it is registered in [A2AAgentClient]. + * @property callContext The [io.ktor.server.application.CallContext] + * @property params The typed parameters for the specific A2A operation + */ +@Serializable +public data class A2AClientRequest( + val agentId: String, + val callContext: ClientCallContext, + val params: T +) + +/** + * Information about a registered A2A agent client. + * + * @property agentId The identifier with which the agent is registered in [A2AAgentClient] + * @property agentCard The cached agent card for this agent + */ +public data class A2AClientAgentInfo( + val agentId: String, + val agentCard: AgentCard, +) + +/** + * Creates a node that retrieves information about all A2A agents registered in [A2AAgentClient]. + * + * @param name Optional node name for debugging and tracing + * @return A node that returns a list of all registered agents with their cached agent cards + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientGetAllAgents( + name: String? = null, +): AIAgentNodeDelegate> = + node(name) { + withA2AAgentClient { + a2aClients.map { (agentId, a2aClient) -> + A2AClientAgentInfo( + agentId = agentId, + agentCard = a2aClient.cachedAgentCard() + ) + } + } + } + +/** + * Creates a node that retrieves an agent card from an A2A server. + * Input is an agent id with which [ai.koog.a2a.client.A2AClient] is registered in [A2AAgentClient]. + * + * @see ai.koog.a2a.client.A2AClient.getAgentCard + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientGetAgentCard( + name: String? = null, +): AIAgentNodeDelegate = + node(name) { agentId -> + withA2AAgentClient { + a2aClientOrThrow(agentId).getAgentCard() + } + } + +/** + * Creates a node that retrieves the cached agent card without making a network call. + * Input is an agent id with which [ai.koog.a2a.client.A2AClient] is registered in [A2AAgentClient]. + * + * @see ai.koog.a2a.client.A2AClient.cachedAgentCard + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientCachedAgentCard( + name: String? = null, +): AIAgentNodeDelegate = + node(name) { agentId -> + withA2AAgentClient { + a2aClientOrThrow(agentId).cachedAgentCard() + } + } + +/** + * Creates a node that retrieves an authenticated extended agent card. + * + * @see ai.koog.a2a.client.A2AClient.getAuthenticatedExtendedAgentCard + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientGetAuthenticatedExtendedAgentCard( + name: String? = null, +): AIAgentNodeDelegate, Response> = + node(name) { request -> + withA2AAgentClient { + a2aClientOrThrow(request.agentId) + .getAuthenticatedExtendedAgentCard( + request = Request(data = request.params), + ctx = request.callContext, + ) + } + } + +/** + * Creates a node that sends a message to an A2A agent. + * + * @see ai.koog.a2a.client.A2AClient.sendMessage + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientSendMessage( + name: String? = null, +): AIAgentNodeDelegate, CommunicationEvent> = + node(name) { request -> + withA2AAgentClient { + a2aClientOrThrow(request.agentId) + .sendMessage( + request = Request(data = request.params), + ctx = request.callContext, + ) + .data + } + } + +/** + * Creates a node that sends a message to an A2A agent with streaming response. + * + * @see ai.koog.a2a.client.A2AClient.sendMessageStreaming + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientSendMessageStreaming( + name: String? = null, +): AIAgentNodeDelegate, Flow>> = + node(name) { request -> + withA2AAgentClient { + a2aClientOrThrow(request.agentId) + .sendMessageStreaming( + request = Request(data = request.params), + ctx = request.callContext, + ) + } + } + +/** + * Creates a node that retrieves a task by ID from an A2A agent. + * + * @see ai.koog.a2a.client.A2AClient.getTask + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientGetTask( + name: String? = null, +): AIAgentNodeDelegate, Task> = + node(name) { request -> + withA2AAgentClient { + a2aClientOrThrow(request.agentId) + .getTask( + request = Request(data = request.params), + ctx = request.callContext, + ) + .data + } + } + +/** + * Creates a node that cancels a task on an A2A agent. + * + * @see ai.koog.a2a.client.A2AClient.cancelTask + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientCancelTask( + name: String? = null, +): AIAgentNodeDelegate, Task> = + node(name) { request -> + withA2AAgentClient { + a2aClientOrThrow(request.agentId) + .cancelTask( + request = Request(data = request.params), + ctx = request.callContext, + ) + .data + } + } + +/** + * Creates a node that resubscribes to task events from an A2A agent. + * + * @see ai.koog.a2a.client.A2AClient.resubscribeTask + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientResubscribeTask( + name: String? = null, +): AIAgentNodeDelegate, Flow>> = + node(name) { request -> + withA2AAgentClient { + a2aClientOrThrow(request.agentId) + .resubscribeTask( + request = Request(data = request.params), + ctx = request.callContext, + ) + } + } + +/** + * Creates a node that sets push notification configuration for a task. + * + * @see ai.koog.a2a.client.A2AClient.setTaskPushNotificationConfig + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientSetTaskPushNotificationConfig( + name: String? = null, +): AIAgentNodeDelegate, TaskPushNotificationConfig> = + node(name) { request -> + withA2AAgentClient { + a2aClientOrThrow(request.agentId) + .setTaskPushNotificationConfig( + request = Request(data = request.params), + ctx = request.callContext, + ) + .data + } + } + +/** + * Creates a node that retrieves push notification configuration for a task. + * + * @see ai.koog.a2a.client.A2AClient.getTaskPushNotificationConfig + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientGetTaskPushNotificationConfig( + name: String? = null, +): AIAgentNodeDelegate, TaskPushNotificationConfig> = + node(name) { request -> + withA2AAgentClient { + a2aClientOrThrow(request.agentId) + .getTaskPushNotificationConfig( + request = Request(data = request.params), + ctx = request.callContext, + ) + .data + } + } + +/** + * Creates a node that lists all push notification configurations for a task. + * + * @see ai.koog.a2a.client.A2AClient.listTaskPushNotificationConfig + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientListTaskPushNotificationConfig( + name: String? = null, +): AIAgentNodeDelegate, List> = + node(name) { request -> + withA2AAgentClient { + a2aClientOrThrow(request.agentId) + .listTaskPushNotificationConfig( + request = Request(data = request.params), + ctx = request.callContext, + ) + .data + } + } + +/** + * Creates a node that deletes push notification configuration for a task. + * + * @see ai.koog.a2a.client.A2AClient.deleteTaskPushNotificationConfig + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientDeleteTaskPushNotificationConfig( + name: String? = null, +): AIAgentNodeDelegate, Unit> = + node(name) { request -> + withA2AAgentClient { + a2aClientOrThrow(request.agentId) + .deleteTaskPushNotificationConfig( + request = Request(data = request.params), + ctx = request.callContext, + ) + } + } diff --git a/agents/agents-features/agents-features-a2a-core/Module.md b/agents/agents-features/agents-features-a2a-core/Module.md new file mode 100644 index 0000000000..b40aaa5b5b --- /dev/null +++ b/agents/agents-features/agents-features-a2a-core/Module.md @@ -0,0 +1,7 @@ +# Module agents-features-a2a-core + +Core utilities and type converters for A2A (Agent-to-Agent) integration with Koog agents. + +## Overview + +This module provides type converters that bridge between A2A's message format and Koog's internal message representation, enabling seamless communication between A2A-enabled agents and Koog's agent system. diff --git a/agents/agents-features/agents-features-a2a-core/build.gradle.kts b/agents/agents-features/agents-features-a2a-core/build.gradle.kts new file mode 100644 index 0000000000..a2edcb5c55 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-core/build.gradle.kts @@ -0,0 +1,39 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + sourceSets { + commonMain { + dependencies { + api(project(":prompt:prompt-model")) + api(project(":a2a:a2a-core")) + + api(libs.kotlinx.serialization.json) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + implementation(libs.kotlinx.coroutines.test) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageA2AMetadata.kt b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageA2AMetadata.kt new file mode 100644 index 0000000000..877e6ee2d4 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageA2AMetadata.kt @@ -0,0 +1,48 @@ +package ai.koog.agents.a2a.core + +import ai.koog.prompt.message.MessageMetaInfo +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.decodeFromJsonElement +import kotlinx.serialization.json.encodeToJsonElement + +/** + * Key used to store [MessageA2AMetadata] in [ai.koog.prompt.message.MessageMetaInfo.metadata] + */ +public const val MESSAGE_A2A_METADATA_KEY: String = "a2a_metadata" + +/** + * Represents additional A2A-related message metadata stored in [ai.koog.prompt.message.MessageMetaInfo.metadata] + * For more info on each field, check [ai.koog.a2a.model.Message] + */ +@Serializable +public data class MessageA2AMetadata( + val messageId: String, + val contextId: String? = null, + val taskId: String? = null, + val referenceTaskIds: List? = null, + val metadata: JsonObject? = null, + val extensions: List? = null, +) + +/** + * Retrieves [MessageA2AMetadata] from [MessageMetaInfo.metadata], if [MESSAGE_A2A_METADATA_KEY] is present. + */ +public fun MessageMetaInfo.getA2AMetadata(): MessageA2AMetadata? { + return metadata + ?.get(MESSAGE_A2A_METADATA_KEY) + ?.let { a2aMetadata -> + A2AFeatureJson.decodeFromJsonElement(a2aMetadata) + } +} + +/** + * Updates provided [JsonObject], overwriting [MESSAGE_A2A_METADATA_KEY] with [a2aMetadata]. + */ +public fun JsonObject.withA2AMetadata(a2aMetadata: MessageA2AMetadata): JsonObject { + return toMutableMap() + .apply { + put(MESSAGE_A2A_METADATA_KEY, A2AFeatureJson.encodeToJsonElement(a2aMetadata)) + } + .let { JsonObject(it) } +} diff --git a/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageConverters.kt b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageConverters.kt new file mode 100644 index 0000000000..acbf239d28 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageConverters.kt @@ -0,0 +1,98 @@ +package ai.koog.agents.a2a.core + +import ai.koog.a2a.model.Role +import ai.koog.prompt.message.Message +import ai.koog.prompt.message.RequestMetaInfo +import ai.koog.prompt.message.ResponseMetaInfo +import kotlinx.datetime.Clock +import kotlinx.serialization.json.JsonObject +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +/** + * Alias to A2A message type, to avoid clashing with Koog's message type. + * @see [ai.koog.a2a.model.Message] + */ +public typealias A2AMessage = ai.koog.a2a.model.Message + +/** + * Converts [A2AMessage] to Koog's [Message]. + * Returned message will contain [MessageA2AMetadata] at [MESSAGE_A2A_METADATA_KEY] in [ai.koog.prompt.message.MessageMetaInfo.metadata], + * which can be retrieved with helper method [getA2AMetadata]. + * + * @param clock The clock to use for the timestamp. Defaults to [Clock.System]. + */ +public fun A2AMessage.toKoogMessage( + clock: Clock = Clock.System, +): Message { + // Convert to the actual message content and attachments. + val (content, attachments) = parts.map { it.toKoogPart() }.toContentWithAttachments() + + // Create metadata + val metadata = JsonObject(emptyMap()).withA2AMetadata( + MessageA2AMetadata( + messageId = messageId, + contextId = contextId, + taskId = taskId, + referenceTaskIds = referenceTaskIds, + metadata = metadata, + extensions = extensions, + ) + ) + + return when (role) { + Role.User -> Message.User( + content = content, + metaInfo = RequestMetaInfo( + timestamp = clock.now(), + metadata = metadata, + ), + attachments = attachments.toList(), + ) + + Role.Agent -> Message.Assistant( + content = content, + metaInfo = ResponseMetaInfo( + timestamp = clock.now(), + metadata = metadata, + ), + attachments = attachments, + ) + } +} + +/** + * Converts Koog's [Message] to [A2AMessage]. + * To fill A2A-specific fields, it will attempt to read [MessageA2AMetadata] from [ai.koog.prompt.message.MessageMetaInfo.metadata], + * but it also can be overridden with [a2aMetadata] + * + * @param a2aMetadata The A2A-specific metadata to override exiting in this [Message]. + * @see ai.koog.a2a.model.Message + */ +@OptIn(ExperimentalUuidApi::class) +public fun Message.toA2AMessage( + a2aMetadata: MessageA2AMetadata? = null, +): A2AMessage { + val actualMetadata = a2aMetadata ?: metaInfo.getA2AMetadata() + + val role = when (this) { + is Message.User -> Role.User + is Message.Assistant -> Role.Agent + else -> throw IllegalArgumentException("A2A can't handle this Koog message type: $this") + } + + // Add parts + val parts = (listOf(KoogContentPart(content)) + attachments.map { KoogAttachmentPart(it) }) + .map { it.toA2APart() } + + return A2AMessage( + messageId = actualMetadata?.messageId ?: Uuid.random().toString(), + role = role, + parts = parts, + extensions = actualMetadata?.extensions, + taskId = actualMetadata?.taskId, + referenceTaskIds = actualMetadata?.referenceTaskIds, + contextId = actualMetadata?.contextId, + metadata = actualMetadata?.metadata + ) +} diff --git a/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/PartConverters.kt b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/PartConverters.kt new file mode 100644 index 0000000000..47e932bd5c --- /dev/null +++ b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/PartConverters.kt @@ -0,0 +1,109 @@ +package ai.koog.agents.a2a.core + +import ai.koog.a2a.model.DataPart +import ai.koog.a2a.model.FilePart +import ai.koog.a2a.model.FileWithBytes +import ai.koog.a2a.model.FileWithUri +import ai.koog.a2a.model.Part +import ai.koog.a2a.model.TextPart +import ai.koog.prompt.message.Attachment +import ai.koog.prompt.message.AttachmentContent +import kotlinx.serialization.Serializable + +/** + * Koog doesn't have proper support for message parts yet, but A2A operates with parts. + * This is a helper mapping structure, to map A2A parts either to part of the textual content or to the attachment. + */ +@Serializable +public sealed interface KoogPart + +/** + * Text content part representing part of the [ai.koog.prompt.message.Message.content] + */ +@Serializable +public data class KoogContentPart(val content: String) : KoogPart + +/** + * Attachment part representing part of the [ai.koog.prompt.message.Message.Assistant.attachments] or + * [ai.koog.prompt.message.Message.User.attachments] + */ +@Serializable +public data class KoogAttachmentPart(val attachment: Attachment) : KoogPart + +/** + * Converts A2A [Part] to Koog [KoogPart]. + */ +public fun Part.toKoogPart(): KoogPart = when (this) { + is TextPart -> KoogContentPart(this.text) + // Koog doesn't support structured data as a separate type, treat it as a content part. + + is DataPart -> KoogContentPart(A2AFeatureJson.encodeToString(this.data)) + + is FilePart -> { + val file = this.file // to enable smart cast + + val attachment = Attachment.File( + // do not have that information separately in A2A + format = "", + // if no mime type is provided, assume it's arbitrary binary data + mimeType = file.mimeType ?: "application/octet-stream", + fileName = file.name, + content = when (file) { + is FileWithBytes -> AttachmentContent.Binary.Base64(file.bytes) + is FileWithUri -> AttachmentContent.URL(file.uri) + } + ) + + KoogAttachmentPart(attachment) + } +} + +/** + * Converts Koog [KoogPart] to A2A [Part]. + */ +public fun KoogPart.toA2APart(): Part = when (this) { + is KoogContentPart -> TextPart(this.content) + + is KoogAttachmentPart -> { + val file = when (val content = attachment.content) { + // Plain text files are not supported, convert them to binary files. + is AttachmentContent.PlainText -> FileWithBytes( + bytes = AttachmentContent.Binary.Bytes(content.text.encodeToByteArray()) + .asBase64(), + name = attachment.fileName, + mimeType = attachment.mimeType, + ) + + is AttachmentContent.Binary -> FileWithBytes( + bytes = content.asBase64(), + name = attachment.fileName, + mimeType = attachment.mimeType, + ) + + is AttachmentContent.URL -> FileWithUri( + uri = content.url, + name = attachment.fileName, + mimeType = attachment.mimeType, + ) + } + + FilePart(file) + } +} + +/** + * Helper method to convert an iterable of [KoogPart] to the pair of the text content and attachments. + */ +public fun Iterable.toContentWithAttachments(): Pair> { + val content = StringBuilder() + val attachments = mutableListOf() + + forEach { part -> + when (part) { + is KoogContentPart -> content.appendLine(part.content) + is KoogAttachmentPart -> attachments.add(part.attachment) + } + } + + return content.toString().trim() to attachments +} diff --git a/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/Serialization.kt b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/Serialization.kt new file mode 100644 index 0000000000..96084a22e1 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/Serialization.kt @@ -0,0 +1,7 @@ +package ai.koog.agents.a2a.core + +import kotlinx.serialization.json.Json + +internal val A2AFeatureJson = Json { + prettyPrint = true +} diff --git a/agents/agents-features/agents-features-a2a-core/src/commonTest/kotlin/ai/koog/agents/a2a/core/MessageConvertersTest.kt b/agents/agents-features/agents-features-a2a-core/src/commonTest/kotlin/ai/koog/agents/a2a/core/MessageConvertersTest.kt new file mode 100644 index 0000000000..682325a881 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-core/src/commonTest/kotlin/ai/koog/agents/a2a/core/MessageConvertersTest.kt @@ -0,0 +1,231 @@ +package ai.koog.agents.a2a.core + +import ai.koog.a2a.model.DataPart +import ai.koog.a2a.model.FilePart +import ai.koog.a2a.model.FileWithBytes +import ai.koog.a2a.model.FileWithUri +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.TextPart +import ai.koog.prompt.message.Attachment +import ai.koog.prompt.message.AttachmentContent +import ai.koog.prompt.message.Message +import ai.koog.prompt.message.RequestMetaInfo +import ai.koog.prompt.message.ResponseMetaInfo +import kotlinx.datetime.Clock +import kotlinx.datetime.Instant +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.encodeToJsonElement +import kotlinx.serialization.json.put +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +class MessageConvertersTest { + private val fixedInstant: Instant = Instant.parse("2024-01-01T00:00:00Z") + private val fixedClock: Clock = object : Clock { + override fun now(): Instant = fixedInstant + } + + private val prettyJson = Json { prettyPrint = true } + + @Test + fun testA2AtoKoog_User_withTextDataAndFiles() { + val json = buildJsonObject { put("k", "v") } + val bytesBase64 = "YmFzZTY0" // arbitrary base64 string + + val a2a = A2AMessage( + messageId = "m1", + role = Role.User, + parts = listOf( + TextPart("Hello"), + DataPart(json), + FilePart(FileWithBytes(bytes = bytesBase64, name = "file.bin", mimeType = null)), + FilePart(FileWithUri(uri = "https://example.com/doc.txt", name = "doc.txt", mimeType = "text/plain")), + ), + contextId = "ctx-123", + taskId = "task-1", + referenceTaskIds = listOf("ref-1", "ref-2"), + extensions = listOf("ext:a"), + ) + + val actual: Message = a2a.toKoogMessage(clock = fixedClock) + + val expectedContent = "Hello\n" + prettyJson.encodeToString(json) + val expectedAttachments = listOf( + Attachment.File( + format = "", + mimeType = "application/octet-stream", + fileName = "file.bin", + content = AttachmentContent.Binary.Base64(bytesBase64) + ), + Attachment.File( + format = "", + mimeType = "text/plain", + fileName = "doc.txt", + content = AttachmentContent.URL("https://example.com/doc.txt") + ) + ) + val expectedMetadata = JsonObject( + mapOf( + MESSAGE_A2A_METADATA_KEY to Json.encodeToJsonElement( + MessageA2AMetadata( + messageId = "m1", + contextId = "ctx-123", + taskId = "task-1", + referenceTaskIds = listOf("ref-1", "ref-2"), + metadata = null, + extensions = listOf("ext:a"), + ) + ) + ) + ) + val expected: Message = Message.User( + content = expectedContent, + metaInfo = RequestMetaInfo(timestamp = fixedInstant, metadata = expectedMetadata), + attachments = expectedAttachments + ) + + assertEquals(expected, actual) + } + + @Test + fun testA2AtoKoog_Agent() { + val a2a = A2AMessage( + messageId = "m2", + role = Role.Agent, + parts = listOf(TextPart("Agent says hi")), + ) + + val actual = a2a.toKoogMessage(clock = fixedClock) + + val expectedMetadata = JsonObject( + mapOf( + MESSAGE_A2A_METADATA_KEY to Json.encodeToJsonElement( + MessageA2AMetadata( + messageId = "m2", + contextId = null, + taskId = null, + referenceTaskIds = null, + metadata = null, + extensions = null, + ) + ) + ) + ) + val expected = Message.Assistant( + content = "Agent says hi", + metaInfo = ResponseMetaInfo(timestamp = fixedInstant, metadata = expectedMetadata), + attachments = emptyList() + ) + + assertEquals(expected, actual) + } + + @Test + fun testKoogToA2A_User_withPlainTextBinaryAndUrlAttachments() { + val plain = Attachment.File( + content = AttachmentContent.PlainText("abc"), + format = "txt", + mimeType = "text/plain", + fileName = "note.txt", + ) + val bytes = byteArrayOf(1, 2, 3) + val bin = Attachment.File( + content = AttachmentContent.Binary.Bytes(bytes), + format = "bin", + mimeType = "application/octet-stream", + fileName = "bytes.bin", + ) + val url = Attachment.File( + content = AttachmentContent.URL("https://example.com/a.png"), + format = "png", + mimeType = "image/png", + fileName = "a.png", + ) + + val koog: Message = Message.User( + content = "Hi", + metaInfo = RequestMetaInfo(timestamp = fixedInstant), + attachments = listOf(plain, bin, url) + ) + + val actual = koog.toA2AMessage( + a2aMetadata = MessageA2AMetadata( + messageId = "mid", + contextId = "ctx", + taskId = "task", + referenceTaskIds = listOf("r1"), + metadata = null, + extensions = null, + ) + ) + + val expectedPlainBase64 = AttachmentContent.Binary.Bytes("abc".encodeToByteArray()).asBase64() + val expectedBinBase64 = AttachmentContent.Binary.Bytes(bytes).asBase64() + val expected = A2AMessage( + messageId = "mid", + role = Role.User, + parts = listOf( + TextPart("Hi"), + FilePart(FileWithBytes(bytes = expectedPlainBase64, name = "note.txt", mimeType = "text/plain")), + FilePart( + FileWithBytes( + bytes = expectedBinBase64, + name = "bytes.bin", + mimeType = "application/octet-stream" + ) + ), + FilePart(FileWithUri(uri = "https://example.com/a.png", name = "a.png", mimeType = "image/png")), + ), + extensions = null, + taskId = "task", + referenceTaskIds = listOf("r1"), + contextId = "ctx", + metadata = null, + ) + + assertEquals(expected, actual) + } + + @Test + fun testKoogToA2A_Assistant() { + val koog: Message = Message.Assistant( + content = "Answer", + metaInfo = ResponseMetaInfo(timestamp = fixedInstant), + ) + val actual = koog.toA2AMessage( + a2aMetadata = MessageA2AMetadata( + messageId = "m3", + contextId = null, + taskId = null, + referenceTaskIds = null, + metadata = null, + extensions = null, + ) + ) + val expected = A2AMessage( + messageId = "m3", + role = Role.Agent, + parts = listOf(TextPart("Answer")), + extensions = null, + taskId = null, + referenceTaskIds = null, + contextId = null, + metadata = null, + ) + assertEquals(expected, actual) + } + + @Test + fun testKoogToA2A_unsupportedKoogMessageThrows() { + val sys: Message = Message.System( + content = "system", + metaInfo = RequestMetaInfo(timestamp = fixedInstant) + ) + assertFailsWith { + sys.toA2AMessage() + } + } +} diff --git a/agents/agents-features/agents-features-a2a-server/Module.md b/agents/agents-features/agents-features-a2a-server/Module.md new file mode 100644 index 0000000000..5e97a68301 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-server/Module.md @@ -0,0 +1,17 @@ +# Module agents-features-a2a-server + +Agent feature for enabling A2A server capabilities within Koog agents. + +## Overview + +This module provides an agent feature that enables Koog agents to act as A2A servers, allowing them to receive and process requests from A2A clients. When installed, agents can access the A2A request context and event processor, enabling them to handle incoming messages, manage tasks, interact with storage, and send events back to clients. + +## Key Components + +- **`A2AAgentServer`**: Feature providing access to A2A request context and event processor +- **Agent Nodes**: Pre-built nodes for messaging, task events, and storage operations + +## Related Modules + +- `agents-features-a2a-client`: Client-side A2A agent feature +- `agents-features-a2a-core`: Core A2A utilities and converters diff --git a/agents/agents-features/agents-features-a2a-server/build.gradle.kts b/agents/agents-features/agents-features-a2a-server/build.gradle.kts new file mode 100644 index 0000000000..28fa55136e --- /dev/null +++ b/agents/agents-features/agents-features-a2a-server/build.gradle.kts @@ -0,0 +1,40 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + sourceSets { + commonMain { + dependencies { + api(project(":agents:agents-core")) + api(project(":a2a:a2a-server")) + api(project(":agents:agents-features:agents-features-a2a-core")) + + api(libs.kotlinx.serialization.json) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + implementation(libs.kotlinx.coroutines.test) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/agents/agents-features/agents-features-a2a-server/src/commonMain/kotlin/ai/koog/agents/a2a/server/feature/A2AAgentServer.kt b/agents/agents-features/agents-features-a2a-server/src/commonMain/kotlin/ai/koog/agents/a2a/server/feature/A2AAgentServer.kt new file mode 100644 index 0000000000..ad1a757887 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-server/src/commonMain/kotlin/ai/koog/agents/a2a/server/feature/A2AAgentServer.kt @@ -0,0 +1,112 @@ +package ai.koog.agents.a2a.server.feature + +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.server.session.RequestContext +import ai.koog.a2a.server.session.SessionEventProcessor +import ai.koog.agents.core.agent.context.AIAgentContext +import ai.koog.agents.core.agent.entity.AIAgentStorageKey +import ai.koog.agents.core.agent.entity.createStorageKey +import ai.koog.agents.core.feature.AIAgentGraphFeature +import ai.koog.agents.core.feature.AIAgentGraphPipeline +import ai.koog.agents.core.feature.AIAgentNonGraphFeature +import ai.koog.agents.core.feature.AIAgentNonGraphPipeline +import ai.koog.agents.core.feature.config.FeatureConfig +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract + +/** + * Agent feature that enables A2A server mode by providing access to the request context and event processor + * from within agent strategies. + * + * This feature is designed to be used within agents that are hosted as A2A servers via [ai.koog.a2a.server.agent.AgentExecutor]. + * When installed, it makes the [context] and [eventProcessor] available to agent nodes, allowing them to: + * - Access incoming A2A messages and task information + * - Send outgoing messages and task events back to the client + * - Interact with message and task storage + * - Manage the A2A session lifecycle + * + * The feature provides convenience nodes for common A2A operations like + * sending messages, updating task status, and accessing storage. + * + * @property context The A2A [RequestContext] from [ai.koog.a2a.server.agent.AgentExecutor.execute] + * @property eventProcessor The A2A [SessionEventProcessor] from [ai.koog.a2a.server.agent.AgentExecutor.execute] + * + * @see ai.koog.a2a.server.agent.AgentExecutor + * @see ai.koog.a2a.server.session.RequestContext + * @see ai.koog.a2a.server.session.SessionEventProcessor + */ +public class A2AAgentServer( + public val context: RequestContext, + public val eventProcessor: SessionEventProcessor +) { + /** + * Configuration for the [A2AAgentServer] feature. + */ + public class Config : FeatureConfig() { + /** + * The A2A [RequestContext] from [ai.koog.a2a.server.agent.AgentExecutor.execute] + * @see RequestContext + */ + public lateinit var context: RequestContext + + /** + * The A2A [SessionEventProcessor] from [ai.koog.a2a.server.agent.AgentExecutor.execute] + * @see SessionEventProcessor + */ + public lateinit var eventProcessor: SessionEventProcessor + } + + public companion object Feature : + AIAgentGraphFeature, + AIAgentNonGraphFeature { + + override val key: AIAgentStorageKey = + createStorageKey("agents-features-a2a-server") + + override fun createInitialConfig(): Config = Config() + + override fun install( + config: Config, + pipeline: AIAgentGraphPipeline + ) { + pipeline.interceptContextAgentFeature(this) { _ -> + A2AAgentServer(config.context, config.eventProcessor) + } + } + + override fun install( + config: Config, + pipeline: AIAgentNonGraphPipeline + ) { + pipeline.interceptContextAgentFeature(this) { + A2AAgentServer(config.context, config.eventProcessor) + } + } + } +} + +/** + * Retrieves the [A2AAgentServer] feature from the agent context. + * + * @return The installed A2AAgentExecutor feature + * @throws IllegalStateException if the feature is not installed + */ +public fun AIAgentContext.a2aAgentServer(): A2AAgentServer = featureOrThrow(A2AAgentServer.Feature) + +/** + * Executes an action with the [A2AAgentServer] feature as the receiver. + * This is a convenience function that retrieves the feature and provides it as the receiver for the action block. + * + * @param action The action to execute with A2AAgentExecutor as receiver + * @return The result of the action + * @throws IllegalStateException if the feature is not installed + */ +@OptIn(ExperimentalContracts::class) +public inline fun AIAgentContext.withA2AAgentServer(action: A2AAgentServer.() -> T): T { + contract { + callsInPlace(action, InvocationKind.AT_MOST_ONCE) + } + + return a2aAgentServer().action() +} diff --git a/agents/agents-features/agents-features-a2a-server/src/commonMain/kotlin/ai/koog/agents/a2a/server/feature/A2AAgentServerNodes.kt b/agents/agents-features/agents-features-a2a-server/src/commonMain/kotlin/ai/koog/agents/a2a/server/feature/A2AAgentServerNodes.kt new file mode 100644 index 0000000000..4dff75c52c --- /dev/null +++ b/agents/agents-features/agents-features-a2a-server/src/commonMain/kotlin/ai/koog/agents/a2a/server/feature/A2AAgentServerNodes.kt @@ -0,0 +1,176 @@ +package ai.koog.agents.a2a.server.feature + +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskEvent +import ai.koog.agents.a2a.core.A2AMessage +import ai.koog.agents.core.dsl.builder.AIAgentBuilderDslMarker +import ai.koog.agents.core.dsl.builder.AIAgentNodeDelegate +import ai.koog.agents.core.dsl.builder.AIAgentSubgraphBuilderBase +import kotlinx.serialization.Serializable + +/** + * Creates a node that sends an A2A message back to the client. + * + * @param name Optional node name for debugging and tracing + * @param saveToStorage If true, also saves the message to storage before sending + * @return A node that sends the message and passes it through unchanged + * @see ai.koog.a2a.server.session.SessionEventProcessor.sendMessage + * @see ai.koog.a2a.server.messages.MessageStorage.save + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2ARespondMessage( + name: String? = null, + saveToStorage: Boolean = false, +): AIAgentNodeDelegate = + node(name) { message -> + withA2AAgentServer { + eventProcessor.sendMessage(message) + if (saveToStorage) { + context.messageStorage.save(message) + } + } + + message + } + +/** + * Creates a node that sends a task event (status update, creation, etc.) to the client. + * + * @param name Optional node name for debugging and tracing + * @return A node that sends the task event and passes it through unchanged + * @see ai.koog.a2a.server.session.SessionEventProcessor.sendTaskEvent + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2ARespondTaskEvent( + name: String? = null, +): AIAgentNodeDelegate = + node(name) { event -> + withA2AAgentServer { + eventProcessor.sendTaskEvent(event) + } + + event + } + +/** + * Creates a node that saves a message to storage without sending it to the client. + * + * @param name Optional node name for debugging and tracing + * @return A node that saves the message and passes it through unchanged + * @see ai.koog.a2a.server.messages.MessageStorage.save + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AMessageStorageSave( + name: String? = null, +): AIAgentNodeDelegate = + node(name) { event -> + withA2AAgentServer { + context.messageStorage.save(event) + } + + event + } + +/** + * Creates a node that loads all messages from storage for the current context. + * + * @param name Optional node name for debugging and tracing + * @return A node that returns the list of all stored messages + * @see ai.koog.a2a.server.messages.MessageStorage.getByContext + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AMessageStorageLoad( + name: String? = null, +): AIAgentNodeDelegate> = + node(name) { + withA2AAgentServer { + context.messageStorage.getAll() + } + } + +/** + * Creates a node that replaces all messages in storage for the current context. + * + * @param name Optional node name for debugging and tracing + * @return A node that replaces all stored messages with the provided list + * @see ai.koog.a2a.server.messages.MessageStorage.replaceByContext + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AMessageStorageReplace( + name: String? = null, +): AIAgentNodeDelegate, Unit> = + node(name) { + withA2AAgentServer { + context.messageStorage.replaceAll(it) + } + } + +/** + * Parameters for retrieving a single task from storage. + * + * @property taskId The unique task identifier + * @property historyLength Maximum number of messages to include in conversation history. Set to `null` for all messages + * @property includeArtifacts Whether to include task artifacts in the response + */ +@Serializable +public data class A2ATaskGetParams( + val taskId: String, + val historyLength: Int? = 0, + val includeArtifacts: Boolean = false, +) + +/** + * Creates a node that retrieves a single task by ID from storage. + * + * @param name Optional node name for debugging and tracing + * @return A node that returns the task or null if not found + * @see ai.koog.a2a.server.tasks.TaskStorage.get + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2ATaskGet( + name: String? = null, +): AIAgentNodeDelegate = + node(name) { params -> + withA2AAgentServer { + context.taskStorage.get( + taskId = params.taskId, + historyLength = params.historyLength, + includeArtifacts = params.includeArtifacts + ) + } + } + +/** + * Parameters for retrieving multiple tasks from storage. + * + * @property taskIds List of task identifiers to retrieve + * @property historyLength Maximum number of messages to include in conversation history. Set to `null` for all messages + * @property includeArtifacts Whether to include task artifacts in the response + */ +@Serializable +public data class A2ATaskGetAllParams( + val taskIds: List, + val historyLength: Int? = 0, + val includeArtifacts: Boolean = false, +) + +/** + * Creates a node that retrieves multiple tasks by their IDs from storage. + * + * @param name Optional node name for debugging and tracing + * @return A node that returns the list of found tasks (may be fewer than requested) + * @see ai.koog.a2a.server.tasks.TaskStorage.getAll + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2ATaskGetAll( + name: String? = null, +): AIAgentNodeDelegate> = + node(name) { params -> + withA2AAgentServer { + context.taskStorage.getAll( + taskIds = params.taskIds, + historyLength = params.historyLength, + includeArtifacts = params.includeArtifacts + ) + } + } diff --git a/agents/agents-features/agents-features-debugger/src/commonMain/kotlin/ai/koog/agents/features/debugger/feature/Debugger.kt b/agents/agents-features/agents-features-debugger/src/commonMain/kotlin/ai/koog/agents/features/debugger/feature/Debugger.kt index 05219a8a11..3c56770d19 100644 --- a/agents/agents-features/agents-features-debugger/src/commonMain/kotlin/ai/koog/agents/features/debugger/feature/Debugger.kt +++ b/agents/agents-features/agents-features-debugger/src/commonMain/kotlin/ai/koog/agents/features/debugger/feature/Debugger.kt @@ -13,12 +13,16 @@ import ai.koog.agents.core.feature.model.events.AgentStartingEvent import ai.koog.agents.core.feature.model.events.GraphStrategyStartingEvent import ai.koog.agents.core.feature.model.events.LLMCallCompletedEvent import ai.koog.agents.core.feature.model.events.LLMCallStartingEvent +import ai.koog.agents.core.feature.model.events.LLMStreamingCompletedEvent +import ai.koog.agents.core.feature.model.events.LLMStreamingFailedEvent +import ai.koog.agents.core.feature.model.events.LLMStreamingFrameReceivedEvent +import ai.koog.agents.core.feature.model.events.LLMStreamingStartingEvent import ai.koog.agents.core.feature.model.events.NodeExecutionCompletedEvent import ai.koog.agents.core.feature.model.events.NodeExecutionStartingEvent import ai.koog.agents.core.feature.model.events.StrategyCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionFailedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionStartingEvent +import ai.koog.agents.core.feature.model.events.ToolCallCompletedEvent +import ai.koog.agents.core.feature.model.events.ToolCallFailedEvent +import ai.koog.agents.core.feature.model.events.ToolCallStartingEvent import ai.koog.agents.core.feature.model.events.ToolValidationFailedEvent import ai.koog.agents.core.feature.model.events.startNodeToGraph import ai.koog.agents.core.feature.model.toAgentError @@ -207,13 +211,57 @@ public class Debugger { //endregion Intercept LLM Call Events + //region Intercept LLM Streaming Events + + pipeline.interceptLLMStreamingStarting(interceptContext) intercept@{ eventContext -> + val event = LLMStreamingStartingEvent( + runId = eventContext.runId, + prompt = eventContext.prompt, + model = eventContext.model.eventString, + tools = eventContext.tools.map { it.name }, + timestamp = pipeline.clock.now().toEpochMilliseconds() + ) + writer.onMessage(event) + } + + pipeline.interceptLLMStreamingFrameReceived(interceptContext) intercept@{ eventContext -> + val event = LLMStreamingFrameReceivedEvent( + runId = eventContext.runId, + frame = eventContext.streamFrame, + timestamp = pipeline.clock.now().toEpochMilliseconds() + ) + writer.onMessage(event) + } + + pipeline.interceptLLMStreamingFailed(interceptContext) intercept@{ eventContext -> + val event = LLMStreamingFailedEvent( + runId = eventContext.runId, + error = eventContext.error.toAgentError(), + timestamp = pipeline.clock.now().toEpochMilliseconds() + ) + writer.onMessage(event) + } + + pipeline.interceptLLMStreamingCompleted(interceptContext) intercept@{ eventContext -> + val event = LLMStreamingCompletedEvent( + runId = eventContext.runId, + prompt = eventContext.prompt, + model = eventContext.model.eventString, + tools = eventContext.tools.map { it.name }, + timestamp = pipeline.clock.now().toEpochMilliseconds() + ) + writer.onMessage(event) + } + + //endregion Intercept LLM Streaming Events + //region Intercept Tool Call Events - pipeline.interceptToolExecutionStarting(interceptContext) intercept@{ eventContext -> + pipeline.interceptToolCallStarting(interceptContext) intercept@{ eventContext -> @Suppress("UNCHECKED_CAST") val tool = eventContext.tool as Tool - val event = ToolExecutionStartingEvent( + val event = ToolCallStartingEvent( runId = eventContext.runId, toolCallId = eventContext.toolCallId, toolName = eventContext.tool.name, @@ -238,11 +286,11 @@ public class Debugger { writer.onMessage(event) } - pipeline.interceptToolExecutionFailed(interceptContext) intercept@{ eventContext -> + pipeline.interceptToolCallFailed(interceptContext) intercept@{ eventContext -> @Suppress("UNCHECKED_CAST") val tool = eventContext.tool as Tool - val event = ToolExecutionFailedEvent( + val event = ToolCallFailedEvent( runId = eventContext.runId, toolCallId = eventContext.toolCallId, toolName = tool.name, @@ -253,11 +301,11 @@ public class Debugger { writer.onMessage(event) } - pipeline.interceptToolExecutionCompleted(interceptContext) intercept@{ eventContext -> + pipeline.interceptToolCallCompleted(interceptContext) intercept@{ eventContext -> @Suppress("UNCHECKED_CAST") val tool = eventContext.tool as Tool - val event = ToolExecutionCompletedEvent( + val event = ToolCallCompletedEvent( runId = eventContext.runId, toolCallId = eventContext.toolCallId, toolName = eventContext.tool.name, diff --git a/agents/agents-features/agents-features-debugger/src/jvmTest/kotlin/ai/koog/agents/features/debugger/feature/DebuggerTest.kt b/agents/agents-features/agents-features-debugger/src/jvmTest/kotlin/ai/koog/agents/features/debugger/feature/DebuggerTest.kt index 4aea8eaa30..9c2dd48d75 100644 --- a/agents/agents-features/agents-features-debugger/src/jvmTest/kotlin/ai/koog/agents/features/debugger/feature/DebuggerTest.kt +++ b/agents/agents-features/agents-features-debugger/src/jvmTest/kotlin/ai/koog/agents/features/debugger/feature/DebuggerTest.kt @@ -4,22 +4,28 @@ import ai.koog.agents.core.dsl.builder.forwardTo import ai.koog.agents.core.dsl.builder.strategy import ai.koog.agents.core.dsl.extension.nodeExecuteTool import ai.koog.agents.core.dsl.extension.nodeLLMRequest +import ai.koog.agents.core.dsl.extension.nodeLLMRequestStreamingAndSendResults import ai.koog.agents.core.dsl.extension.nodeLLMSendToolResult import ai.koog.agents.core.dsl.extension.onAssistantMessage import ai.koog.agents.core.dsl.extension.onToolCall +import ai.koog.agents.core.feature.model.AIAgentError import ai.koog.agents.core.feature.model.events.AgentCompletedEvent import ai.koog.agents.core.feature.model.events.AgentStartingEvent import ai.koog.agents.core.feature.model.events.GraphStrategyStartingEvent import ai.koog.agents.core.feature.model.events.LLMCallCompletedEvent import ai.koog.agents.core.feature.model.events.LLMCallStartingEvent +import ai.koog.agents.core.feature.model.events.LLMStreamingCompletedEvent +import ai.koog.agents.core.feature.model.events.LLMStreamingFailedEvent +import ai.koog.agents.core.feature.model.events.LLMStreamingFrameReceivedEvent +import ai.koog.agents.core.feature.model.events.LLMStreamingStartingEvent import ai.koog.agents.core.feature.model.events.NodeExecutionCompletedEvent import ai.koog.agents.core.feature.model.events.NodeExecutionStartingEvent import ai.koog.agents.core.feature.model.events.StrategyCompletedEvent import ai.koog.agents.core.feature.model.events.StrategyEventGraph import ai.koog.agents.core.feature.model.events.StrategyEventGraphEdge import ai.koog.agents.core.feature.model.events.StrategyEventGraphNode -import ai.koog.agents.core.feature.model.events.ToolExecutionCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionStartingEvent +import ai.koog.agents.core.feature.model.events.ToolCallCompletedEvent +import ai.koog.agents.core.feature.model.events.ToolCallStartingEvent import ai.koog.agents.core.feature.remote.client.FeatureMessageRemoteClient import ai.koog.agents.core.feature.remote.client.config.DefaultClientConnectionConfig import ai.koog.agents.core.feature.remote.server.config.DefaultServerConnectionConfig @@ -43,10 +49,15 @@ import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.agents.testing.tools.mockLLMAnswer import ai.koog.agents.utils.use import ai.koog.prompt.dsl.Prompt +import ai.koog.prompt.executor.model.PromptExecutor import ai.koog.prompt.llm.LLModel +import ai.koog.prompt.message.Message +import ai.koog.prompt.streaming.StreamFrame import io.ktor.http.URLProtocol import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.flow import kotlinx.coroutines.joinAll import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking @@ -55,6 +66,7 @@ import org.junit.jupiter.api.Disabled import kotlin.test.Test import kotlin.test.assertContentEquals import kotlin.test.assertEquals +import kotlin.test.assertFails import kotlin.test.assertNotNull import kotlin.test.assertNull import kotlin.test.assertTrue @@ -288,14 +300,14 @@ class DebuggerTest { input = toolCallMessage(dummyTool.name, content = """{"dummy":"test"}""").toString(), timestamp = testClock.now().toEpochMilliseconds() ), - ToolExecutionStartingEvent( + ToolCallStartingEvent( runId = clientEventsCollector.runId, toolCallId = "0", toolName = dummyTool.name, toolArgs = dummyTool.encodeArgs(DummyTool.Args("test")), timestamp = testClock.now().toEpochMilliseconds() ), - ToolExecutionCompletedEvent( + ToolCallCompletedEvent( runId = clientEventsCollector.runId, toolCallId = "0", toolName = dummyTool.name, @@ -382,6 +394,334 @@ class DebuggerTest { assertNotNull(isFinishedOrNull, "Client or server did not finish in time") } + @Test + fun `test feature message remote writer collect streaming success events on agent run`() = runBlocking { + // Agent Config + val agentId = "test-agent-id" + + val userPrompt = "Call the dummy tool with argument: test" + val systemPrompt = "Test system prompt" + val assistantPrompt = "Test assistant prompt" + val promptId = "Test prompt id" + + // Tools + val dummyTool = DummyTool() + + val toolRegistry = ToolRegistry { + tool(dummyTool) + } + + // Model + val testModel = LLModel( + provider = MockLLMProvider(), + id = "test-llm-id", + capabilities = emptyList(), + contextLength = 1_000, + ) + + // Prompt + val expectedPrompt = Prompt( + messages = listOf( + systemMessage(systemPrompt), + userMessage(userPrompt), + assistantMessage(assistantPrompt) + ), + id = promptId + ) + + val expectedLLMCallPrompt = expectedPrompt.copy( + messages = expectedPrompt.messages + ) + + // Executor + val testLLMResponse = "Default test response" + + val mockExecutor = getMockExecutor { + mockLLMAnswer(testLLMResponse).asDefaultResponse onUserRequestEquals userPrompt + } + + // Test Data + val port = findAvailablePort() + val clientConfig = DefaultClientConnectionConfig(host = HOST, port = port, protocol = URLProtocol.HTTP) + + val isClientFinished = CompletableDeferred() + val isServerStarted = CompletableDeferred() + + // Server + val serverJob = launch { + val strategy = strategy("tracing-streaming-success") { + val streamAndCollect by nodeLLMRequestStreamingAndSendResults("stream-and-collect") + + edge(nodeStart forwardTo streamAndCollect) + edge(streamAndCollect forwardTo nodeFinish transformed { messages -> messages.firstOrNull()?.content ?: "" }) + } + + createAgent( + agentId = agentId, + strategy = strategy, + promptId = promptId, + userPrompt = userPrompt, + systemPrompt = systemPrompt, + assistantPrompt = assistantPrompt, + toolRegistry = toolRegistry, + promptExecutor = mockExecutor, + model = testModel, + ) { + install(Debugger) { + setPort(port) + + launch { + val messageProcessor = messageProcessors.single() as FeatureMessageRemoteWriter + val isServerStartedCheck = withTimeoutOrNull(defaultClientServerTimeout) { + messageProcessor.isOpen.first { it } + } != null + + assertTrue(isServerStartedCheck, "Server did not start in time") + isServerStarted.complete(true) + } + } + }.use { agent -> + agent.run(userPrompt) + isClientFinished.await() + } + } + + // Client + val clientJob = launch { + FeatureMessageRemoteClient(connectionConfig = clientConfig, scope = this).use { client -> + + val clientEventsCollector = + ClientEventsCollector(client = client, expectedEventsCount = 13) + + val collectEventsJob = + clientEventsCollector.startCollectEvents(coroutineScope = this@launch) + + isServerStarted.await() + client.connect() + collectEventsJob.join() + + // Correct run id will be set after the 'collect events job' is finished. + val expectedEvents = listOf( + LLMStreamingStartingEvent( + runId = clientEventsCollector.runId, + prompt = expectedLLMCallPrompt, + model = testModel.eventString, + tools = listOf(dummyTool.name), + timestamp = testClock.now().toEpochMilliseconds(), + ), + LLMStreamingFrameReceivedEvent( + runId = clientEventsCollector.runId, + frame = StreamFrame.Append(testLLMResponse), + timestamp = testClock.now().toEpochMilliseconds(), + ), + LLMStreamingCompletedEvent( + runId = clientEventsCollector.runId, + prompt = expectedLLMCallPrompt, + model = testModel.eventString, + tools = listOf(dummyTool.name), + timestamp = testClock.now().toEpochMilliseconds(), + ) + ) + + val actualEvents = clientEventsCollector.collectedEvents.filter { event -> + event is LLMStreamingStartingEvent || + event is LLMStreamingFrameReceivedEvent || + event is LLMStreamingFailedEvent || + event is LLMStreamingCompletedEvent + } + + assertEquals(expectedEvents.size, actualEvents.size) + assertContentEquals(expectedEvents, actualEvents) + + isClientFinished.complete(true) + } + } + + val isFinishedOrNull = withTimeoutOrNull(defaultClientServerTimeout) { + listOf(clientJob, serverJob).joinAll() + } + + assertNotNull(isFinishedOrNull, "Client or server did not finish in time") + } + + @Test + fun `test feature message remote writer collect streaming failed events on agent run`() = runBlocking { + // Agent Config + val agentId = "test-agent-id" + + val userPrompt = "Call the dummy tool with argument: test" + val systemPrompt = "Test system prompt" + val assistantPrompt = "Test assistant prompt" + val promptId = "Test prompt id" + + // Tools + val dummyTool = DummyTool() + + val toolRegistry = ToolRegistry { + tool(dummyTool) + } + + // Model + val testModel = LLModel( + provider = MockLLMProvider(), + id = "test-llm-id", + capabilities = emptyList(), + contextLength = 1_000, + ) + + // Prompt + val expectedPrompt = Prompt( + messages = listOf( + systemMessage(systemPrompt), + userMessage(userPrompt), + assistantMessage(assistantPrompt) + ), + id = promptId + ) + + val expectedLLMCallPrompt = expectedPrompt.copy( + messages = expectedPrompt.messages + ) + + // Executor + val testStreamingErrorMessage = "Test streaming error" + var testStreamingStackTrace = "" + + val testStreamingExecutor = object : PromptExecutor { + override suspend fun execute( + prompt: Prompt, + model: LLModel, + tools: List + ): List = emptyList() + + override fun executeStreaming( + prompt: Prompt, + model: LLModel, + tools: List + ): Flow = flow { + val testException = IllegalStateException(testStreamingErrorMessage) + testStreamingStackTrace = testException.stackTraceToString() + throw testException + } + + override suspend fun moderate( + prompt: Prompt, + model: LLModel + ): ai.koog.prompt.dsl.ModerationResult { + throw UnsupportedOperationException("Not used in test") + } + } + + // Test Data + val port = findAvailablePort() + val clientConfig = DefaultClientConnectionConfig(host = HOST, port = port, protocol = URLProtocol.HTTP) + + val isClientFinished = CompletableDeferred() + val isServerStarted = CompletableDeferred() + + // Server + val serverJob = launch { + val strategy = strategy("tracing-streaming-success") { + val streamAndCollect by nodeLLMRequestStreamingAndSendResults("stream-and-collect") + + edge(nodeStart forwardTo streamAndCollect) + edge(streamAndCollect forwardTo nodeFinish transformed { messages -> messages.firstOrNull()?.content ?: "" }) + } + + createAgent( + agentId = agentId, + strategy = strategy, + promptId = promptId, + userPrompt = userPrompt, + systemPrompt = systemPrompt, + assistantPrompt = assistantPrompt, + toolRegistry = toolRegistry, + promptExecutor = testStreamingExecutor, + model = testModel, + ) { + install(Debugger) { + setPort(port) + + launch { + val messageProcessor = messageProcessors.single() as FeatureMessageRemoteWriter + val isServerStartedCheck = withTimeoutOrNull(defaultClientServerTimeout) { + messageProcessor.isOpen.first { it } + } != null + + assertTrue(isServerStartedCheck, "Server did not start in time") + isServerStarted.complete(true) + } + } + }.use { agent -> + val throwable = assertFails { + agent.run(userPrompt) + } + + isClientFinished.await() + + assertTrue(throwable is IllegalStateException) + assertEquals(testStreamingErrorMessage, throwable.message) + } + } + + // Client + val clientJob = launch { + FeatureMessageRemoteClient(connectionConfig = clientConfig, scope = this).use { client -> + + val clientEventsCollector = + ClientEventsCollector(client = client, expectedEventsCount = 9) + + val collectEventsJob = + clientEventsCollector.startCollectEvents(coroutineScope = this@launch) + + isServerStarted.await() + client.connect() + collectEventsJob.join() + + // Correct run id will be set after the 'collect events job' is finished. + val expectedEvents = listOf( + LLMStreamingStartingEvent( + runId = clientEventsCollector.runId, + prompt = expectedLLMCallPrompt, + model = testModel.eventString, + tools = listOf(dummyTool.name), + timestamp = testClock.now().toEpochMilliseconds(), + ), + LLMStreamingFailedEvent( + runId = clientEventsCollector.runId, + error = AIAgentError(testStreamingErrorMessage, testStreamingStackTrace), + timestamp = testClock.now().toEpochMilliseconds() + ), + LLMStreamingCompletedEvent( + runId = clientEventsCollector.runId, + prompt = expectedLLMCallPrompt, + model = testModel.eventString, + tools = listOf(dummyTool.name), + timestamp = testClock.now().toEpochMilliseconds(), + ) + ) + + val actualEvents = clientEventsCollector.collectedEvents.filter { event -> + event is LLMStreamingStartingEvent || + event is LLMStreamingFrameReceivedEvent || + event is LLMStreamingFailedEvent || + event is LLMStreamingCompletedEvent + } + + assertEquals(expectedEvents.size, actualEvents.size) + assertContentEquals(expectedEvents, actualEvents) + + isClientFinished.complete(true) + } + } + + val isFinishedOrNull = withTimeoutOrNull(defaultClientServerTimeout) { + listOf(clientJob, serverJob).joinAll() + } + + assertNotNull(isFinishedOrNull, "Client or server did not finish in time") + } + @Test @Disabled( """ diff --git a/agents/agents-features/agents-features-event-handler/Module.md b/agents/agents-features/agents-features-event-handler/Module.md index f8f5389fec..3b7561e9b9 100644 --- a/agents/agents-features/agents-features-event-handler/Module.md +++ b/agents/agents-features/agents-features-event-handler/Module.md @@ -56,7 +56,7 @@ val testAgent = AIAgent( var agentFinished = false handleEvents { - onToolExecutionStarting { eventContext -> + onToolCallStarting { eventContext -> toolCalled = true println("[DEBUG_LOG] Tool called: ${eventContext.tool.name}") } @@ -95,15 +95,15 @@ val agent = AIAgent( } // Monitor tool usage - onToolExecutionStarting { eventContext -> + onToolCallStarting { eventContext -> println("Tool called: ${eventContext.tool.name} with args: ${eventContext.toolArgs}") } - onToolExecutionCompleted { eventContext -> + onToolCallCompleted { eventContext -> println("Tool result: ${eventContext.result}") } - onToolExecutionFailed { eventContext -> + onToolCallFailed { eventContext -> println("Tool failed: ${eventContext.throwable.message}") } diff --git a/agents/agents-features/agents-features-event-handler/src/commonMain/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandler.kt b/agents/agents-features/agents-features-event-handler/src/commonMain/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandler.kt index 2706f13159..c25cb528d9 100644 --- a/agents/agents-features/agents-features-event-handler/src/commonMain/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandler.kt +++ b/agents/agents-features/agents-features-event-handler/src/commonMain/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandler.kt @@ -17,9 +17,9 @@ import ai.koog.agents.core.feature.handler.node.NodeExecutionStartingContext import ai.koog.agents.core.feature.handler.streaming.LLMStreamingCompletedContext import ai.koog.agents.core.feature.handler.streaming.LLMStreamingFrameReceivedContext import ai.koog.agents.core.feature.handler.streaming.LLMStreamingStartingContext -import ai.koog.agents.core.feature.handler.tool.ToolExecutionCompletedContext -import ai.koog.agents.core.feature.handler.tool.ToolExecutionFailedContext -import ai.koog.agents.core.feature.handler.tool.ToolExecutionStartingContext +import ai.koog.agents.core.feature.handler.tool.ToolCallCompletedContext +import ai.koog.agents.core.feature.handler.tool.ToolCallFailedContext +import ai.koog.agents.core.feature.handler.tool.ToolCallStartingContext import ai.koog.agents.core.feature.handler.tool.ToolValidationFailedContext import io.github.oshai.kotlinlogging.KotlinLogging @@ -33,7 +33,7 @@ import io.github.oshai.kotlinlogging.KotlinLogging * Example usage: * ``` * handleEvents { - * onToolExecutionStarting { eventContext -> + * onToolCallStarting { eventContext -> * println("Tool called: ${eventContext.tool.name} with args ${eventContext.toolArgs}") * } * @@ -58,7 +58,7 @@ public class EventHandler { * Example usage: * ``` * handleEvents { - * onToolExecutionStarting { eventContext -> + * onToolCallStarting { eventContext -> * println("Tool called: ${eventContext.tool.name} with args: ${eventContext.toolArgs}") * } * @@ -145,8 +145,8 @@ public class EventHandler { config.invokeOnLLMCallCompleted(eventContext) } - pipeline.interceptToolExecutionStarting(interceptContext) intercept@{ eventContext: ToolExecutionStartingContext -> - config.invokeOnToolExecutionStarting(eventContext) + pipeline.interceptToolCallStarting(interceptContext) intercept@{ eventContext: ToolCallStartingContext -> + config.invokeOnToolCallStarting(eventContext) } pipeline.interceptToolValidationFailed( @@ -155,16 +155,16 @@ public class EventHandler { config.invokeOnToolValidationFailed(eventContext) } - pipeline.interceptToolExecutionFailed(interceptContext) intercept@{ eventContext: ToolExecutionFailedContext -> - config.invokeOnToolExecutionFailed(eventContext) + pipeline.interceptToolCallFailed(interceptContext) intercept@{ eventContext: ToolCallFailedContext -> + config.invokeOnToolCallFailed(eventContext) } - pipeline.interceptToolExecutionCompleted(interceptContext) intercept@{ eventContext: ToolExecutionCompletedContext -> - config.invokeOnToolExecutionCompleted(eventContext) + pipeline.interceptToolCallCompleted(interceptContext) intercept@{ eventContext: ToolCallCompletedContext -> + config.invokeOnToolCallCompleted(eventContext) } pipeline.interceptLLMStreamingStarting(interceptContext) intercept@{ eventContext: LLMStreamingStartingContext -> - config.invokeOnLLMStreammingStarting(eventContext) + config.invokeOnLLMStreamingStarting(eventContext) } pipeline.interceptLLMStreamingFrameReceived(interceptContext) intercept@{ eventContext: LLMStreamingFrameReceivedContext -> @@ -205,7 +205,7 @@ public class EventHandler { * ``` * handleEvents { * // Log when tools are called - * onToolExecutionStarting { eventContext -> + * onToolCallStarting { eventContext -> * println("Tool called: ${eventContext.tool.name} with args: ${eventContext.toolArgs}") * } * diff --git a/agents/agents-features/agents-features-event-handler/src/commonMain/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerConfig.kt b/agents/agents-features/agents-features-event-handler/src/commonMain/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerConfig.kt index 1102b72b82..d6274ea509 100644 --- a/agents/agents-features/agents-features-event-handler/src/commonMain/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerConfig.kt +++ b/agents/agents-features/agents-features-event-handler/src/commonMain/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerConfig.kt @@ -31,9 +31,9 @@ import ai.koog.agents.core.feature.handler.streaming.LLMStreamingCompletedContex import ai.koog.agents.core.feature.handler.streaming.LLMStreamingFailedContext import ai.koog.agents.core.feature.handler.streaming.LLMStreamingFrameReceivedContext import ai.koog.agents.core.feature.handler.streaming.LLMStreamingStartingContext -import ai.koog.agents.core.feature.handler.tool.ToolExecutionCompletedContext -import ai.koog.agents.core.feature.handler.tool.ToolExecutionFailedContext -import ai.koog.agents.core.feature.handler.tool.ToolExecutionStartingContext +import ai.koog.agents.core.feature.handler.tool.ToolCallCompletedContext +import ai.koog.agents.core.feature.handler.tool.ToolCallFailedContext +import ai.koog.agents.core.feature.handler.tool.ToolCallStartingContext import ai.koog.agents.core.feature.handler.tool.ToolValidationFailedContext /** @@ -49,7 +49,7 @@ import ai.koog.agents.core.feature.handler.tool.ToolValidationFailedContext * Example usage: * ``` * handleEvents { - * onToolExecutionStarting { eventContext -> + * onToolCallStarting { eventContext -> * println("Tool called: ${eventContext.tool.name} with args ${eventContext.toolArgs}") * } * @@ -101,13 +101,13 @@ public class EventHandlerConfig : FeatureConfig() { //region Private Tool Call Handlers - private var _onToolExecutionStarting: suspend (eventHandler: ToolExecutionStartingContext) -> Unit = { _ -> } + private var _onToolCallStarting: suspend (eventHandler: ToolCallStartingContext) -> Unit = { _ -> } private var _onToolValidationFailed: suspend (eventHandler: ToolValidationFailedContext) -> Unit = { _ -> } - private var _onToolExecutionFailed: suspend (eventHandler: ToolExecutionFailedContext) -> Unit = { _ -> } + private var _onToolCallFailed: suspend (eventHandler: ToolCallFailedContext) -> Unit = { _ -> } - private var _onToolExecutionCompleted: suspend (eventHandler: ToolExecutionCompletedContext) -> Unit = { _ -> } + private var _onToolCallCompleted: suspend (eventHandler: ToolCallCompletedContext) -> Unit = { _ -> } //endregion Private Tool Call Handlers @@ -266,9 +266,9 @@ public class EventHandlerConfig : FeatureConfig() { /** * Append handler called when a tool is about to be called. */ - public fun onToolExecutionStarting(handler: suspend (eventContext: ToolExecutionStartingContext) -> Unit) { - val originalHandler = this._onToolExecutionStarting - this._onToolExecutionStarting = { eventContext -> + public fun onToolCallStarting(handler: suspend (eventContext: ToolCallStartingContext) -> Unit) { + val originalHandler = this._onToolCallStarting + this._onToolCallStarting = { eventContext -> originalHandler(eventContext) handler.invoke(eventContext) } @@ -288,9 +288,9 @@ public class EventHandlerConfig : FeatureConfig() { /** * Append handler called when a tool call fails with an exception. */ - public fun onToolExecutionFailed(handler: suspend (eventContext: ToolExecutionFailedContext) -> Unit) { - val originalHandler = this._onToolExecutionFailed - this._onToolExecutionFailed = { eventContext -> + public fun onToolCallFailed(handler: suspend (eventContext: ToolCallFailedContext) -> Unit) { + val originalHandler = this._onToolCallFailed + this._onToolCallFailed = { eventContext -> originalHandler(eventContext) handler.invoke(eventContext) } @@ -299,9 +299,9 @@ public class EventHandlerConfig : FeatureConfig() { /** * Append handler called when a tool call completes successfully. */ - public fun onToolExecutionCompleted(handler: suspend (eventContext: ToolExecutionCompletedContext) -> Unit) { - val originalHandler = this._onToolExecutionCompleted - this._onToolExecutionCompleted = { eventContext -> + public fun onToolCallCompleted(handler: suspend (eventContext: ToolCallCompletedContext) -> Unit) { + val originalHandler = this._onToolCallCompleted + this._onToolCallCompleted = { eventContext -> originalHandler(eventContext) handler.invoke(eventContext) } @@ -542,11 +542,11 @@ public class EventHandlerConfig : FeatureConfig() { * Append handler called when a tool is about to be called. */ @Deprecated( - message = "Use onToolExecutionStarting instead", - ReplaceWith("onToolExecutionStarting(handler)", "ai.koog.agents.core.feature.handler.ToolExecutionStartingContext") + message = "Use onToolCallStarting instead", + ReplaceWith("onToolCallStarting(handler)", "ai.koog.agents.core.feature.handler.ToolCallStartingContext") ) public fun onToolCall(handler: suspend (eventContext: ToolCallContext) -> Unit) { - onToolExecutionStarting(handler) + onToolCallStarting(handler) } /** @@ -564,22 +564,22 @@ public class EventHandlerConfig : FeatureConfig() { * Append handler called when a tool call fails with an exception. */ @Deprecated( - message = "Use onToolExecutionFailed instead", - ReplaceWith("onToolExecutionFailed(handler)", "ai.koog.agents.core.feature.handler.ToolExecutionFailedContext") + message = "Use onToolCallFailed instead", + ReplaceWith("onToolCallFailed(handler)", "ai.koog.agents.core.feature.handler.ToolCallFailedContext") ) public fun onToolCallFailure(handler: suspend (eventContext: ToolCallFailureContext) -> Unit) { - onToolExecutionFailed(handler) + onToolCallFailed(handler) } /** * Append handler called when a tool call completes successfully. */ @Deprecated( - message = "Use onToolExecutionCompleted instead", - ReplaceWith("onToolExecutionCompleted(handler)", "ai.koog.agents.core.feature.handler.ToolExecutionCompletedContext") + message = "Use onToolCallCompleted instead", + ReplaceWith("onToolCallCompleted(handler)", "ai.koog.agents.core.feature.handler.ToolCallCompletedContext") ) public fun onToolCallResult(handler: suspend (eventContext: ToolCallResultContext) -> Unit) { - onToolExecutionCompleted(handler) + onToolCallCompleted(handler) } //endregion Deprecated Handlers @@ -682,8 +682,8 @@ public class EventHandlerConfig : FeatureConfig() { /** * Invoke handlers for the tool call event. */ - internal suspend fun invokeOnToolExecutionStarting(eventContext: ToolExecutionStartingContext) { - _onToolExecutionStarting.invoke(eventContext) + internal suspend fun invokeOnToolCallStarting(eventContext: ToolCallStartingContext) { + _onToolCallStarting.invoke(eventContext) } /** @@ -696,15 +696,15 @@ public class EventHandlerConfig : FeatureConfig() { /** * Invoke handlers for a tool call failure with an exception event. */ - internal suspend fun invokeOnToolExecutionFailed(eventContext: ToolExecutionFailedContext) { - _onToolExecutionFailed.invoke(eventContext) + internal suspend fun invokeOnToolCallFailed(eventContext: ToolCallFailedContext) { + _onToolCallFailed.invoke(eventContext) } /** * Invoke handlers for an event when a tool call is completed successfully. */ - internal suspend fun invokeOnToolExecutionCompleted(eventContext: ToolExecutionCompletedContext) { - _onToolExecutionCompleted.invoke(eventContext) + internal suspend fun invokeOnToolCallCompleted(eventContext: ToolCallCompletedContext) { + _onToolCallCompleted.invoke(eventContext) } //endregion Invoke Tool Call Handlers @@ -716,7 +716,7 @@ public class EventHandlerConfig : FeatureConfig() { * * @param eventContext The context containing information about the streaming session about to begin */ - internal suspend fun invokeOnLLMStreammingStarting(eventContext: LLMStreamingStartingContext) { + internal suspend fun invokeOnLLMStreamingStarting(eventContext: LLMStreamingStartingContext) { _onLLMStreamingStarting.invoke(eventContext) } diff --git a/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerTest.kt b/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerTest.kt index 894532b3d7..d4443ea9ef 100644 --- a/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerTest.kt +++ b/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerTest.kt @@ -6,16 +6,23 @@ import ai.koog.agents.core.dsl.builder.forwardTo import ai.koog.agents.core.dsl.builder.strategy import ai.koog.agents.core.dsl.extension.nodeExecuteTool import ai.koog.agents.core.dsl.extension.nodeLLMRequest +import ai.koog.agents.core.dsl.extension.nodeLLMRequestStreamingAndSendResults import ai.koog.agents.core.dsl.extension.nodeLLMSendToolResult import ai.koog.agents.core.dsl.extension.onAssistantMessage import ai.koog.agents.core.dsl.extension.onToolCall import ai.koog.agents.core.tools.ToolRegistry +import ai.koog.agents.features.eventHandler.eventString import ai.koog.agents.testing.tools.DummyTool import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.agents.testing.tools.mockLLMAnswer import ai.koog.agents.utils.use +import ai.koog.prompt.dsl.Prompt import ai.koog.prompt.executor.clients.openai.OpenAIModels +import ai.koog.prompt.executor.model.PromptExecutor import ai.koog.prompt.message.Message +import ai.koog.prompt.streaming.StreamFrame +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow import kotlinx.coroutines.runBlocking import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.assertThrows @@ -49,15 +56,15 @@ class EventHandlerTest { val runId = eventsCollector.runId val expectedEvents = listOf( - "OnBeforeAgentStarted (agent id: test-agent-id, run id: $runId)", - "OnStrategyStarted (run id: $runId, strategy: $strategyName)", - "OnBeforeNode (run id: $runId, node: __start__, input: $agentInput)", - "OnAfterNode (run id: $runId, node: __start__, input: $agentInput, output: $agentInput)", - "OnBeforeNode (run id: $runId, node: __finish__, input: $agentResult)", - "OnAfterNode (run id: $runId, node: __finish__, input: $agentResult, output: $agentResult)", - "OnStrategyFinished (run id: $runId, strategy: $strategyName, result: $agentResult)", - "OnAgentFinished (agent id: test-agent-id, run id: $runId, result: $agentResult)", - "OnAgentBeforeClose (agent id: test-agent-id)", + "OnAgentStarting (agent id: test-agent-id, run id: $runId)", + "OnStrategyStarting (run id: $runId, strategy: $strategyName)", + "OnNodeExecutionStarting (run id: $runId, node: __start__, input: $agentInput)", + "OnNodeExecutionCompleted (run id: $runId, node: __start__, input: $agentInput, output: $agentInput)", + "OnNodeExecutionStarting (run id: $runId, node: __finish__, input: $agentResult)", + "OnNodeExecutionCompleted (run id: $runId, node: __finish__, input: $agentResult, output: $agentResult)", + "OnStrategyCompleted (run id: $runId, strategy: $strategyName, result: $agentResult)", + "OnAgentCompleted (agent id: test-agent-id, run id: $runId, result: $agentResult)", + "OnAgentClosing (agent id: test-agent-id)", ) assertEquals(expectedEvents.size, eventsCollector.size) @@ -66,21 +73,37 @@ class EventHandlerTest { @Test fun `test event handler single node without tools`() = runBlocking { - val agentId = "test-agent-id" val eventsCollector = TestEventsCollector() - val strategyName = "tracing-test-strategy" + val agentId = "test-agent-id" + + val promptId = "Test prompt Id" + val systemPrompt = "Test system message" + val userPrompt = "Test user message" + val assistantPrompt = "Test assistant response" + val temperature = 1.0 + val model = OpenAIModels.Chat.GPT4o + val agentResult = "Done" + val testLLMResponse = "Test LLM call prompt" + val strategyName = "tracing-test-strategy" val strategy = strategy(strategyName) { val llmCallNode by nodeLLMRequest("test LLM call") - edge(nodeStart forwardTo llmCallNode transformed { "Test LLM call prompt" }) + edge(nodeStart forwardTo llmCallNode transformed { testLLMResponse }) edge(llmCallNode forwardTo nodeFinish transformed { agentResult }) } val agent = createAgent( agentId = agentId, strategy = strategy, + promptId = promptId, + systemPrompt = systemPrompt, + userPrompt = userPrompt, + assistantPrompt = assistantPrompt, + temperature = temperature, + model = model, + toolRegistry = ToolRegistry { }, installFeatures = { install(EventHandler, eventsCollector.eventHandlerFeatureConfig) } @@ -93,19 +116,30 @@ class EventHandlerTest { val runId = eventsCollector.runId val expectedEvents = listOf( - "OnBeforeAgentStarted (agent id: test-agent-id, run id: $runId)", - "OnStrategyStarted (run id: $runId, strategy: $strategyName)", - "OnBeforeNode (run id: $runId, node: __start__, input: $agentInput)", - "OnAfterNode (run id: $runId, node: __start__, input: $agentInput, output: $agentInput)", - "OnBeforeNode (run id: $runId, node: test LLM call, input: Test LLM call prompt)", - "OnBeforeLLMCall (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: Test LLM call prompt}], temperature: null, tools: [])", - "OnAfterLLMCall (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: Test LLM call prompt}], temperature: null, model: openai:gpt-4o, tools: [], responses: [role: Assistant, message: Default test response])", - "OnAfterNode (run id: $runId, node: test LLM call, input: Test LLM call prompt, output: Assistant(content=Default test response, metaInfo=ResponseMetaInfo(timestamp=$ts, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}), attachments=[], finishReason=null))", - "OnBeforeNode (run id: $runId, node: __finish__, input: $agentResult)", - "OnAfterNode (run id: $runId, node: __finish__, input: $agentResult, output: $agentResult)", - "OnStrategyFinished (run id: $runId, strategy: $strategyName, result: $agentResult)", - "OnAgentFinished (agent id: test-agent-id, run id: $runId, result: $agentResult)", - "OnAgentBeforeClose (agent id: $agentId)", + "OnAgentStarting (agent id: $agentId, run id: $runId)", + "OnStrategyStarting (run id: $runId, strategy: $strategyName)", + "OnNodeExecutionStarting (run id: $runId, node: __start__, input: $agentInput)", + "OnNodeExecutionCompleted (run id: $runId, node: __start__, input: $agentInput, output: $agentInput)", + "OnNodeExecutionStarting (run id: $runId, node: test LLM call, input: $testLLMResponse)", + "OnLLMCallStarting (run id: $runId, prompt: id: $promptId, messages: [{" + + "role: ${Message.Role.System}, message: $systemPrompt, " + + "role: ${Message.Role.User}, message: $userPrompt, " + + "role: ${Message.Role.Assistant}, message: $assistantPrompt, " + + "role: ${Message.Role.User}, message: $testLLMResponse" + + "}], temperature: $temperature, tools: [])", + "OnLLMCallCompleted (run id: $runId, prompt: id: $promptId, messages: [{" + + "role: ${Message.Role.System}, message: $systemPrompt, " + + "role: ${Message.Role.User}, message: $userPrompt, " + + "role: ${Message.Role.Assistant}, message: $assistantPrompt, " + + "role: ${Message.Role.User}, message: $testLLMResponse" + + "}], temperature: $temperature, model: ${model.eventString}, tools: [], responses: [role: ${Message.Role.Assistant}, message: Default test response])", + "OnNodeExecutionCompleted (run id: $runId, node: test LLM call, input: $testLLMResponse, output: " + + "Assistant(content=Default test response, metaInfo=ResponseMetaInfo(timestamp=$ts, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}, metadata=null), attachments=[], finishReason=null))", + "OnNodeExecutionStarting (run id: $runId, node: __finish__, input: $agentResult)", + "OnNodeExecutionCompleted (run id: $runId, node: __finish__, input: $agentResult, output: $agentResult)", + "OnStrategyCompleted (run id: $runId, strategy: $strategyName, result: $agentResult)", + "OnAgentCompleted (agent id: test-agent-id, run id: $runId, result: $agentResult)", + "OnAgentClosing (agent id: $agentId)", ) assertEquals(expectedEvents.size, eventsCollector.size) @@ -115,9 +149,14 @@ class EventHandlerTest { @Test fun `test event handler single node with tools`() = runBlocking { val eventsCollector = TestEventsCollector() - val strategyName = "test-strategy" + val promptId = "Test prompt Id" + val systemPrompt = "Test system message" val userPrompt = "Call the dummy tool with argument: test" + val assistantPrompt = "Test assistant response" + val temperature = 1.0 + val strategyName = "test-strategy" + val mockResponse = "Return test result" val agentId = "test-agent-id" @@ -150,6 +189,11 @@ class EventHandlerTest { createAgent( agentId = agentId, strategy = strategy, + promptId = promptId, + systemPrompt = systemPrompt, + userPrompt = userPrompt, + assistantPrompt = assistantPrompt, + temperature = temperature, toolRegistry = toolRegistry, promptExecutor = mockExecutor, model = model, @@ -162,27 +206,55 @@ class EventHandlerTest { val runId = eventsCollector.runId val expectedEvents = listOf( - "OnBeforeAgentStarted (agent id: $agentId, run id: $runId)", - "OnStrategyStarted (run id: $runId, strategy: $strategyName)", - "OnBeforeNode (run id: $runId, node: __start__, input: $userPrompt)", - "OnAfterNode (run id: $runId, node: __start__, input: $userPrompt, output: $userPrompt)", - "OnBeforeNode (run id: $runId, node: test-llm-call, input: $userPrompt)", - "OnBeforeLLMCall (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: $userPrompt}], temperature: null, tools: [dummy])", - "OnAfterLLMCall (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: $userPrompt}], temperature: null, model: openai:gpt-4o, tools: [${dummyTool.name}], responses: [role: Tool, message: {\"dummy\":\"test\"}])", - "OnAfterNode (run id: $runId, node: test-llm-call, input: $userPrompt, output: Call(id=null, tool=${dummyTool.name}, content={\"dummy\":\"test\"}, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={})))", - "OnBeforeNode (run id: $runId, node: test-tool-call, input: Call(id=null, tool=${dummyTool.name}, content={\"dummy\":\"test\"}, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={})))", - "OnToolCall (run id: $runId, tool: ${dummyTool.name}, args: Args(dummy=test))", - "OnToolCallResult (run id: $runId, tool: ${dummyTool.name}, args: Args(dummy=test), result: ${dummyTool.result})", - "OnAfterNode (run id: $runId, node: test-tool-call, input: Call(id=null, tool=${dummyTool.name}, content={\"dummy\":\"test\"}, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={})), output: ReceivedToolResult(id=null, tool=${dummyTool.name}, content=${dummyTool.result}, result=${dummyTool.result}))", - "OnBeforeNode (run id: $runId, node: test-node-llm-send-tool-result, input: ReceivedToolResult(id=null, tool=${dummyTool.name}, content=${dummyTool.result}, result=${dummyTool.result}))", - "OnBeforeLLMCall (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: $userPrompt, role: Tool, message: {\"dummy\":\"test\"}, role: Tool, message: ${dummyTool.result}}], temperature: null, tools: [${dummyTool.name}])", - "OnAfterLLMCall (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: $userPrompt, role: Tool, message: {\"dummy\":\"test\"}, role: Tool, message: ${dummyTool.result}}], temperature: null, model: openai:gpt-4o, tools: [${dummyTool.name}], responses: [role: Assistant, message: Return test result])", - "OnAfterNode (run id: $runId, node: test-node-llm-send-tool-result, input: ReceivedToolResult(id=null, tool=${dummyTool.name}, content=${dummyTool.result}, result=${dummyTool.result}), output: Assistant(content=Return test result, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}), attachments=[], finishReason=null))", - "OnBeforeNode (run id: $runId, node: __finish__, input: $mockResponse)", - "OnAfterNode (run id: $runId, node: __finish__, input: $mockResponse, output: $mockResponse)", - "OnStrategyFinished (run id: $runId, strategy: $strategyName, result: $mockResponse)", - "OnAgentFinished (agent id: $agentId, run id: $runId, result: $mockResponse)", - "OnAgentBeforeClose (agent id: $agentId)", + "OnAgentStarting (agent id: $agentId, run id: $runId)", + "OnStrategyStarting (run id: $runId, strategy: $strategyName)", + "OnNodeExecutionStarting (run id: $runId, node: __start__, input: $userPrompt)", + "OnNodeExecutionCompleted (run id: $runId, node: __start__, input: $userPrompt, output: $userPrompt)", + "OnNodeExecutionStarting (run id: $runId, node: test-llm-call, input: $userPrompt)", + "OnLLMCallStarting (run id: $runId, prompt: id: $promptId, messages: [{" + + "role: ${Message.Role.System}, message: $systemPrompt, " + + "role: ${Message.Role.User}, message: $userPrompt, " + + "role: ${Message.Role.Assistant}, message: $assistantPrompt, " + + "role: ${Message.Role.User}, message: $userPrompt" + + "}], temperature: $temperature, tools: [${toolRegistry.tools.joinToString { it.name }}])", + "OnLLMCallCompleted (run id: $runId, prompt: id: $promptId, messages: [{" + + "role: ${Message.Role.System}, message: $systemPrompt, " + + "role: ${Message.Role.User}, message: $userPrompt, " + + "role: ${Message.Role.Assistant}, message: $assistantPrompt, " + + "role: ${Message.Role.User}, message: $userPrompt" + + "}], temperature: $temperature, model: ${model.eventString}, tools: [${dummyTool.name}], responses: [role: ${Message.Role.Tool}, message: {\"dummy\":\"test\"}])", + "OnNodeExecutionCompleted (run id: $runId, node: test-llm-call, input: $userPrompt, output: " + + "Call(id=null, tool=${dummyTool.name}, content={\"dummy\":\"test\"}, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}, metadata=null)))", + "OnNodeExecutionStarting (run id: $runId, node: test-tool-call, input: " + + "Call(id=null, tool=${dummyTool.name}, content={\"dummy\":\"test\"}, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}, metadata=null)))", + "OnToolCallStarting (run id: $runId, tool: ${dummyTool.name}, args: Args(dummy=test))", + "OnToolCallCompleted (run id: $runId, tool: ${dummyTool.name}, args: Args(dummy=test), result: ${dummyTool.result})", + "OnNodeExecutionCompleted (run id: $runId, node: test-tool-call, input: " + + "Call(id=null, tool=${dummyTool.name}, content={\"dummy\":\"test\"}, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}, metadata=null)), output: ReceivedToolResult(id=null, tool=${dummyTool.name}, content=${dummyTool.result}, result=${dummyTool.result}))", + "OnNodeExecutionStarting (run id: $runId, node: test-node-llm-send-tool-result, input: ReceivedToolResult(id=null, tool=${dummyTool.name}, content=${dummyTool.result}, result=${dummyTool.result}))", + "OnLLMCallStarting (run id: $runId, prompt: id: $promptId, messages: [{" + + "role: ${Message.Role.System}, message: $systemPrompt, " + + "role: ${Message.Role.User}, message: $userPrompt, " + + "role: ${Message.Role.Assistant}, message: $assistantPrompt, " + + "role: ${Message.Role.User}, message: $userPrompt, " + + "role: ${Message.Role.Tool}, message: {\"dummy\":\"test\"}, " + + "role: ${Message.Role.Tool}, message: ${dummyTool.result}" + + "}], temperature: $temperature, tools: [${dummyTool.name}])", + "OnLLMCallCompleted (run id: $runId, prompt: id: $promptId, messages: [{" + + "role: ${Message.Role.System}, message: $systemPrompt, " + + "role: ${Message.Role.User}, message: $userPrompt, " + + "role: ${Message.Role.Assistant}, message: $assistantPrompt, " + + "role: ${Message.Role.User}, message: $userPrompt, " + + "role: ${Message.Role.Tool}, message: {\"dummy\":\"test\"}, " + + "role: ${Message.Role.Tool}, message: ${dummyTool.result}" + + "}], temperature: $temperature, model: openai:gpt-4o, tools: [${dummyTool.name}], responses: [role: ${Message.Role.Assistant}, message: Return test result])", + "OnNodeExecutionCompleted (run id: $runId, node: test-node-llm-send-tool-result, input: " + + "ReceivedToolResult(id=null, tool=${dummyTool.name}, content=${dummyTool.result}, result=${dummyTool.result}), output: Assistant(content=Return test result, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}, metadata=null), attachments=[], finishReason=null))", + "OnNodeExecutionStarting (run id: $runId, node: __finish__, input: $mockResponse)", + "OnNodeExecutionCompleted (run id: $runId, node: __finish__, input: $mockResponse, output: $mockResponse)", + "OnStrategyCompleted (run id: $runId, strategy: $strategyName, result: $mockResponse)", + "OnAgentCompleted (agent id: $agentId, run id: $runId, result: $mockResponse)", + "OnAgentClosing (agent id: $agentId)", ) assertEquals(expectedEvents.size, eventsCollector.size) @@ -192,21 +264,39 @@ class EventHandlerTest { @Test fun `test event handler several nodes`() = runBlocking { val eventsCollector = TestEventsCollector() - val strategyName = "tracing-test-strategy" + + val promptId = "Test prompt Id" + val systemPrompt = "Test system message" + val userPrompt = "Test user message" + val assistantPrompt = "Test assistant response" + val temperature = 1.0 + val model = OpenAIModels.Chat.GPT4o + val agentResult = "Done" + val strategyName = "tracing-test-strategy" + val testLLMResponse = "Test LLM call prompt" + val llmCallWithToolsResponse = "Test LLM call with tools prompt" + val strategy = strategy(strategyName) { val llmCallNode by nodeLLMRequest("test LLM call") val llmCallWithToolsNode by nodeLLMRequest("test LLM call with tools") - edge(nodeStart forwardTo llmCallNode transformed { "Test LLM call prompt" }) - edge(llmCallNode forwardTo llmCallWithToolsNode transformed { "Test LLM call with tools prompt" }) + edge(nodeStart forwardTo llmCallNode transformed { testLLMResponse }) + edge(llmCallNode forwardTo llmCallWithToolsNode transformed { llmCallWithToolsResponse }) edge(llmCallWithToolsNode forwardTo nodeFinish transformed { agentResult }) } + val toolRegistry = ToolRegistry { tool(DummyTool()) } + val agent = createAgent( strategy = strategy, - toolRegistry = ToolRegistry { tool(DummyTool()) }, + promptId = promptId, + systemPrompt = systemPrompt, + userPrompt = userPrompt, + assistantPrompt = assistantPrompt, + temperature = temperature, + toolRegistry = toolRegistry, installFeatures = { install(EventHandler, eventsCollector.eventHandlerFeatureConfig) } @@ -219,23 +309,49 @@ class EventHandlerTest { val runId = eventsCollector.runId val expectedEvents = listOf( - "OnBeforeAgentStarted (agent id: test-agent-id, run id: $runId)", - "OnStrategyStarted (run id: $runId, strategy: $strategyName)", - "OnBeforeNode (run id: $runId, node: __start__, input: $agentInput)", - "OnAfterNode (run id: $runId, node: __start__, input: $agentInput, output: $agentInput)", - "OnBeforeNode (run id: $runId, node: test LLM call, input: Test LLM call prompt)", - "OnBeforeLLMCall (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: Test LLM call prompt}], temperature: null, tools: [dummy])", - "OnAfterLLMCall (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: Test LLM call prompt}], temperature: null, model: openai:gpt-4o, tools: [dummy], responses: [role: Assistant, message: Default test response])", - "OnAfterNode (run id: $runId, node: test LLM call, input: Test LLM call prompt, output: Assistant(content=Default test response, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}), attachments=[], finishReason=null))", - "OnBeforeNode (run id: $runId, node: test LLM call with tools, input: Test LLM call with tools prompt)", - "OnBeforeLLMCall (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: Test LLM call prompt, role: Assistant, message: Default test response, role: User, message: Test LLM call with tools prompt}], temperature: null, tools: [dummy])", - "OnAfterLLMCall (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: Test LLM call prompt, role: Assistant, message: Default test response, role: User, message: Test LLM call with tools prompt}], temperature: null, model: openai:gpt-4o, tools: [dummy], responses: [role: Assistant, message: Default test response])", - "OnAfterNode (run id: $runId, node: test LLM call with tools, input: Test LLM call with tools prompt, output: Assistant(content=Default test response, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}), attachments=[], finishReason=null))", - "OnBeforeNode (run id: $runId, node: __finish__, input: $agentResult)", - "OnAfterNode (run id: $runId, node: __finish__, input: $agentResult, output: $agentResult)", - "OnStrategyFinished (run id: $runId, strategy: $strategyName, result: $agentResult)", - "OnAgentFinished (agent id: test-agent-id, run id: $runId, result: $agentResult)", - "OnAgentBeforeClose (agent id: test-agent-id)", + "OnAgentStarting (agent id: test-agent-id, run id: $runId)", + "OnStrategyStarting (run id: $runId, strategy: $strategyName)", + "OnNodeExecutionStarting (run id: $runId, node: __start__, input: $agentInput)", + "OnNodeExecutionCompleted (run id: $runId, node: __start__, input: $agentInput, output: $agentInput)", + "OnNodeExecutionStarting (run id: $runId, node: test LLM call, input: $testLLMResponse)", + "OnLLMCallStarting (run id: $runId, prompt: id: $promptId, messages: [{" + + "role: ${Message.Role.System}, message: $systemPrompt, " + + "role: ${Message.Role.User}, message: $userPrompt, " + + "role: ${Message.Role.Assistant}, message: $assistantPrompt, " + + "role: ${Message.Role.User}, message: $testLLMResponse" + + "}], temperature: $temperature, tools: [${toolRegistry.tools.joinToString { it.name }}])", + "OnLLMCallCompleted (run id: $runId, prompt: id: $promptId, messages: [{" + + "role: ${Message.Role.System}, message: $systemPrompt, " + + "role: ${Message.Role.User}, message: $userPrompt, " + + "role: ${Message.Role.Assistant}, message: $assistantPrompt, " + + "role: ${Message.Role.User}, message: $testLLMResponse" + + "}], temperature: $temperature, model: ${model.eventString}, tools: [${toolRegistry.tools.joinToString { it.name }}], responses: [role: ${Message.Role.Assistant}, message: Default test response])", + "OnNodeExecutionCompleted (run id: $runId, node: test LLM call, input: $testLLMResponse, output: " + + "Assistant(content=Default test response, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}, metadata=null), attachments=[], finishReason=null))", + "OnNodeExecutionStarting (run id: $runId, node: test LLM call with tools, input: $llmCallWithToolsResponse)", + "OnLLMCallStarting (run id: $runId, prompt: id: $promptId, messages: [{" + + "role: ${Message.Role.System}, message: $systemPrompt, " + + "role: ${Message.Role.User}, message: $userPrompt, " + + "role: ${Message.Role.Assistant}, message: $assistantPrompt, " + + "role: ${Message.Role.User}, message: Test LLM call prompt, " + + "role: ${Message.Role.Assistant}, message: Default test response, " + + "role: ${Message.Role.User}, message: $llmCallWithToolsResponse" + + "}], temperature: $temperature, tools: [${toolRegistry.tools.joinToString { it.name }}])", + "OnLLMCallCompleted (run id: $runId, prompt: id: $promptId, messages: [{" + + "role: ${Message.Role.System}, message: $systemPrompt, " + + "role: ${Message.Role.User}, message: $userPrompt, " + + "role: ${Message.Role.Assistant}, message: $assistantPrompt, " + + "role: ${Message.Role.User}, message: Test LLM call prompt, " + + "role: ${Message.Role.Assistant}, message: Default test response, " + + "role: ${Message.Role.User}, message: $llmCallWithToolsResponse" + + "}], temperature: $temperature, model: openai:gpt-4o, tools: [${toolRegistry.tools.joinToString { it.name }}], responses: [role: ${Message.Role.Assistant}, message: Default test response])", + "OnNodeExecutionCompleted (run id: $runId, node: test LLM call with tools, input: $llmCallWithToolsResponse, output: " + + "Assistant(content=Default test response, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}, metadata=null), attachments=[], finishReason=null))", + "OnNodeExecutionStarting (run id: $runId, node: __finish__, input: $agentResult)", + "OnNodeExecutionCompleted (run id: $runId, node: __finish__, input: $agentResult, output: $agentResult)", + "OnStrategyCompleted (run id: $runId, strategy: $strategyName, result: $agentResult)", + "OnAgentCompleted (agent id: test-agent-id, run id: $runId, result: $agentResult)", + "OnAgentClosing (agent id: test-agent-id)", ) assertEquals(expectedEvents.size, eventsCollector.size) @@ -277,14 +393,14 @@ class EventHandlerTest { val runId = eventsCollector.runId val expectedEvents = listOf( - "OnBeforeAgentStarted (agent id: $agentId, run id: $runId)", - "OnStrategyStarted (run id: $runId, strategy: $strategyName)", - "OnBeforeNode (run id: $runId, node: __start__, input: $agentInput)", - "OnAfterNode (run id: $runId, node: __start__, input: $agentInput, output: $agentInput)", - "OnBeforeNode (run id: $runId, node: $errorNodeName, input: $agentInput)", - "OnNodeExecutionError (run id: $runId, node: $errorNodeName, error: $testErrorMessage)", - "OnAgentRunError (agent id: $agentId, run id: $runId, error: $testErrorMessage)", - "OnAgentBeforeClose (agent id: $agentId)", + "OnAgentStarting (agent id: $agentId, run id: $runId)", + "OnStrategyStarting (run id: $runId, strategy: $strategyName)", + "OnNodeExecutionStarting (run id: $runId, node: __start__, input: $agentInput)", + "OnNodeExecutionCompleted (run id: $runId, node: __start__, input: $agentInput, output: $agentInput)", + "OnNodeExecutionStarting (run id: $runId, node: $errorNodeName, input: $agentInput)", + "OnNodeExecutionFailed (run id: $runId, node: $errorNodeName, error: $testErrorMessage)", + "OnAgentExecutionFailed (agent id: $agentId, run id: $runId, error: $testErrorMessage)", + "OnAgentClosing (agent id: $agentId)", ) assertEquals(expectedEvents.size, eventsCollector.size) @@ -308,22 +424,22 @@ class EventHandlerTest { strategy = strategy, installFeatures = { install(EventHandler) { - onBeforeAgentStarted { eventContext -> + onAgentStarting { eventContext -> runId = eventContext.runId collectedEvents.add( - "OnBeforeAgentStarted first (agent id: ${eventContext.agent.id})" + "OnAgentStarting first (agent id: ${eventContext.agent.id})" ) } - onBeforeAgentStarted { eventContext -> + onAgentStarting { eventContext -> collectedEvents.add( - "OnBeforeAgentStarted second (agent id: ${eventContext.agent.id})" + "OnAgentStarting second (agent id: ${eventContext.agent.id})" ) } - onAgentFinished { eventContext -> + onAgentCompleted { eventContext -> collectedEvents.add( - "OnAgentFinished (agent id: ${eventContext.agentId}, run id: ${eventContext.runId}, result: $agentResult)" + "OnAgentCompleted (agent id: ${eventContext.agentId}, run id: ${eventContext.runId}, result: $agentResult)" ) } } @@ -334,9 +450,9 @@ class EventHandlerTest { agent.run(agentInput) val expectedEvents = listOf( - "OnBeforeAgentStarted first (agent id: ${agent.id})", - "OnBeforeAgentStarted second (agent id: ${agent.id})", - "OnAgentFinished (agent id: ${agent.id}, run id: $runId, result: $agentResult)", + "OnAgentStarting first (agent id: ${agent.id})", + "OnAgentStarting second (agent id: ${agent.id})", + "OnAgentCompleted (agent id: ${agent.id}, run id: $runId, result: $agentResult)", ) assertEquals(expectedEvents.size, collectedEvents.size) @@ -370,6 +486,153 @@ class EventHandlerTest { agent.close() } + @Test + fun `test llm streaming events success`() = runBlocking { + val eventsCollector = TestEventsCollector() + + val model = OpenAIModels.Chat.GPT4o + val promptId = "Test prompt Id" + val systemPrompt = "Test system message" + val userPrompt = "Test user message" + val assistantPrompt = "Test assistant response" + val temperature = 1.0 + + val strategyName = "event-handler-streaming-success" + val strategy = strategy(strategyName) { + val streamAndCollect by nodeLLMRequestStreamingAndSendResults("stream-and-collect") + + edge(nodeStart forwardTo streamAndCollect) + edge(streamAndCollect forwardTo nodeFinish transformed { messages -> messages.firstOrNull()?.content ?: "" }) + } + + val toolRegistry = ToolRegistry { tool(DummyTool()) } + + val testLLMResponse = "Default test response" + val executor = getMockExecutor { + mockLLMAnswer(testLLMResponse).asDefaultResponse onUserRequestEquals "Test user message" + } + + createAgent( + agentId = "test-agent-id", + strategy = strategy, + promptExecutor = executor, + promptId = promptId, + systemPrompt = systemPrompt, + userPrompt = userPrompt, + assistantPrompt = assistantPrompt, + temperature = temperature, + model = model, + toolRegistry = toolRegistry, + ) { + install(EventHandler, eventsCollector.eventHandlerFeatureConfig) + }.use { agent -> + agent.run("") + } + + val runId = eventsCollector.runId + + val actualEvents = eventsCollector.collectedEvents.filter { it.startsWith("OnLLMStreaming") } + + val expectedPromptString = "id: $promptId, messages: [{" + + "role: ${Message.Role.System}, message: $systemPrompt, " + + "role: ${Message.Role.User}, message: $userPrompt, " + + "role: ${Message.Role.Assistant}, message: $assistantPrompt" + + "}]" + + val expectedEvents = listOf( + "OnLLMStreamingStarting (run id: $runId, prompt: $expectedPromptString, temperature: $temperature, model: ${model.eventString}, tools: [${toolRegistry.tools.joinToString { it.name }}])", + "OnLLMStreamingFrameReceived (run id: $runId, frame: Append(text=$testLLMResponse))", + "OnLLMStreamingCompleted (run id: $runId, prompt: $expectedPromptString, temperature: $temperature, model: ${model.eventString}, tools: [${toolRegistry.tools.joinToString { it.name }}])", + ) + + assertEquals(expectedEvents.size, actualEvents.size) + assertContentEquals(expectedEvents, actualEvents) + } + + @Test + fun `test llm streaming events failure`() = runBlocking { + val eventsCollector = TestEventsCollector() + + val promptId = "Test prompt Id" + val systemPrompt = "Test system message" + val userPrompt = "Test user message" + val assistantPrompt = "Test assistant response" + val temperature = 1.0 + + val model = OpenAIModels.Chat.GPT4o + + val strategyName = "event-handler-streaming-failure" + val strategy = strategy(strategyName) { + val streamAndCollect by nodeLLMRequestStreamingAndSendResults("stream-and-collect") + + edge(nodeStart forwardTo streamAndCollect) + edge(streamAndCollect forwardTo nodeFinish transformed { messages -> messages.firstOrNull()?.content ?: "" }) + } + + val toolRegistry = ToolRegistry { tool(DummyTool()) } + + val testStreamingErrorMessage = "Test streaming error" + + val testStreamingExecutor = object : PromptExecutor { + override suspend fun execute( + prompt: Prompt, + model: ai.koog.prompt.llm.LLModel, + tools: List + ): List = emptyList() + + override fun executeStreaming( + prompt: Prompt, + model: ai.koog.prompt.llm.LLModel, + tools: List + ): Flow = flow { + throw IllegalStateException(testStreamingErrorMessage) + } + + override suspend fun moderate( + prompt: Prompt, + model: ai.koog.prompt.llm.LLModel + ): ai.koog.prompt.dsl.ModerationResult { + throw UnsupportedOperationException("Not used in test") + } + } + + createAgent( + strategy = strategy, + promptExecutor = testStreamingExecutor, + promptId = promptId, + systemPrompt = systemPrompt, + userPrompt = userPrompt, + assistantPrompt = assistantPrompt, + temperature = temperature, + model = model, + toolRegistry = toolRegistry, + ) { + install(EventHandler, eventsCollector.eventHandlerFeatureConfig) + }.use { agent -> + val throwable = assertThrows { agent.run("") } + assertEquals(testStreamingErrorMessage, throwable.message) + } + + val runId = eventsCollector.runId + + val actualEvents = eventsCollector.collectedEvents.filter { it.startsWith("OnLLMStreaming") } + + val expectedPromptString = "id: $promptId, messages: [{" + + "role: ${Message.Role.System}, message: $systemPrompt, " + + "role: ${Message.Role.User}, message: $userPrompt, " + + "role: ${Message.Role.Assistant}, message: $assistantPrompt" + + "}]" + + val expectedEvents = listOf( + "OnLLMStreamingStarting (run id: $runId, prompt: $expectedPromptString, temperature: $temperature, model: ${model.eventString}, tools: [${toolRegistry.tools.joinToString { it.name}}])", + "OnLLMStreamingFailed (run id: $runId, error: $testStreamingErrorMessage)", + "OnLLMStreamingCompleted (run id: $runId, prompt: $expectedPromptString, temperature: $temperature, model: ${model.eventString}, tools: [${toolRegistry.tools.joinToString { it.name}}])", + ) + + assertEquals(expectedEvents.size, actualEvents.size) + assertContentEquals(expectedEvents, actualEvents) + } + fun AIAgentSubgraphBuilderBase<*, *>.nodeException(name: String? = null): AIAgentNodeDelegate = node(name) { throw IllegalStateException("Test exception") } } diff --git a/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerTestAPI.kt b/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerTestAPI.kt index 39b4c997f6..aadb44b0fd 100644 --- a/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerTestAPI.kt +++ b/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerTestAPI.kt @@ -9,6 +9,7 @@ import ai.koog.prompt.dsl.prompt import ai.koog.prompt.executor.clients.openai.OpenAIModels import ai.koog.prompt.executor.model.PromptExecutor import ai.koog.prompt.llm.LLModel +import ai.koog.prompt.params.LLMParams import kotlinx.datetime.Clock import kotlinx.datetime.Instant @@ -22,18 +23,27 @@ fun createAgent( strategy: AIAgentGraphStrategy, agentId: String = "test-agent-id", promptExecutor: PromptExecutor? = null, + promptId: String? = null, + systemPrompt: String? = null, + userPrompt: String? = null, + assistantPrompt: String? = null, + temperature: Double? = null, toolRegistry: ToolRegistry? = null, model: LLModel? = null, installFeatures: GraphAIAgent.FeatureContext.() -> Unit = { } ): AIAgent { val agentConfig = AIAgentConfig( - prompt = prompt("test", clock = testClock) { - system("Test system message") - user("Test user message") - assistant("Test assistant response") + prompt = prompt( + id = promptId ?: "Test prompt", + clock = testClock, + params = LLMParams(temperature = temperature) + ) { + system(systemPrompt ?: "Test system message") + user(userPrompt ?: "Test user message") + assistant(assistantPrompt ?: "Test assistant response") }, model = model ?: OpenAIModels.Chat.GPT4o, - maxAgentIterations = 10 + maxAgentIterations = 10, ) return AIAgent( diff --git a/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/NodeLLMRequestStreamingAndSendResultsTest.kt b/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/NodeLLMRequestStreamingAndSendResultsTest.kt index c4bc665d76..793888ae18 100644 --- a/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/NodeLLMRequestStreamingAndSendResultsTest.kt +++ b/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/NodeLLMRequestStreamingAndSendResultsTest.kt @@ -89,7 +89,7 @@ class NodeLLMRequestStreamingAndSendResultsTest { // Verify streaming events were captured val streamingEvents = eventsCollector.collectedEvents.filter { - it.contains("OnBeforeStream") || it.contains("OnStreamFrame") || it.contains("OnAfterStream") + it.contains("OnLLMStreamingStarting") || it.contains("OnLLMStreamingFrameReceived") || it.contains("OnLLMStreamingCompleted") } assertTrue(streamingEvents.isNotEmpty(), "Should have captured streaming events") } @@ -132,7 +132,7 @@ class NodeLLMRequestStreamingAndSendResultsTest { // Verify streaming events occurred val streamingEvents = eventsCollector.collectedEvents.filter { - it.contains("OnBeforeStream") || it.contains("OnAfterStream") + it.contains("OnLLMStreamingStarting") || it.contains("OnLLMStreamingCompleted") } assertTrue(streamingEvents.isNotEmpty(), "Should have streaming events") } diff --git a/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/StreamingEventHandlerTest.kt b/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/StreamingEventHandlerTest.kt index cdd73939f4..d9ecb58157 100644 --- a/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/StreamingEventHandlerTest.kt +++ b/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/StreamingEventHandlerTest.kt @@ -15,7 +15,7 @@ import kotlin.test.assertTrue /** * Tests for streaming event handlers. - * These tests verify that the streaming handlers (onBeforeStream, onStreamFrame, onAfterStream) + * These tests verify that the streaming handlers (onLLMStreamingStarting, onLLMStreamingFrameReceived, onLLMStreamingCompleted) * are properly invoked during LLM streaming operations. */ class StreamingEventHandlerTest { @@ -36,13 +36,13 @@ class StreamingEventHandlerTest { assertEventsCollected(eventsCollector) // Verify streaming events are captured when using nodeLLMRequestsStreaming - val beforeStreamEvents = eventsCollector.collectedEvents.filter { it.contains("OnBeforeStream") } - val streamFrameEvents = eventsCollector.collectedEvents.filter { it.contains("OnStreamFrame") } - val afterStreamEvents = eventsCollector.collectedEvents.filter { it.contains("OnAfterStream") } + val beforeStreamEvents = eventsCollector.collectedEvents.filter { it.contains("OnLLMStreamingStarting") } + val streamFrameEvents = eventsCollector.collectedEvents.filter { it.contains("OnLLMStreamingFrameReceived") } + val afterStreamEvents = eventsCollector.collectedEvents.filter { it.contains("OnLLMStreamingCompleted") } - assertTrue(beforeStreamEvents.isNotEmpty(), "Should have OnBeforeStream events") - assertTrue(streamFrameEvents.isNotEmpty(), "Should have OnStreamFrame events") - assertTrue(afterStreamEvents.isNotEmpty(), "Should have OnAfterStream events") + assertTrue(beforeStreamEvents.isNotEmpty(), "Should have OnLLMStreamingStarting events") + assertTrue(streamFrameEvents.isNotEmpty(), "Should have OnLLMStreamingFrameReceived events") + assertTrue(afterStreamEvents.isNotEmpty(), "Should have OnLLMStreamingCompleted events") // Verify the stream frame contains the expected response val frameWithContent = streamFrameEvents.firstOrNull { it.contains(assistantResponse) } @@ -63,7 +63,7 @@ class StreamingEventHandlerTest { // Verify the overall event collection is working assertEventsCollected(eventsCollector) // Verify that streaming events were captured - val streamingEventTypes = listOf("OnBeforeStream", "OnStreamFrame", "OnAfterStream") + val streamingEventTypes = listOf("OnLLMStreamingStarting", "OnLLMStreamingFrameReceived", "OnLLMStreamingCompleted") assertTrue( actual = eventsCollector.collectedEvents.any { streamingEventTypes.any(it::contains) }, message = "Should have captured at least one streaming event (${streamingEventTypes.joinToString()})" diff --git a/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/TestEventsCollector.kt b/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/TestEventsCollector.kt index cd08a485a4..fe3fdfa09e 100644 --- a/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/TestEventsCollector.kt +++ b/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/TestEventsCollector.kt @@ -20,59 +20,61 @@ class TestEventsCollector { onAgentStarting { eventContext -> runId = eventContext.runId _collectedEvents.add( - "OnBeforeAgentStarted (agent id: ${eventContext.agent.id}, run id: ${eventContext.runId})" + "OnAgentStarting (agent id: ${eventContext.agent.id}, run id: ${eventContext.runId})" ) } onAgentCompleted { eventContext -> _collectedEvents.add( - "OnAgentFinished (agent id: ${eventContext.agentId}, run id: ${eventContext.runId}, result: ${eventContext.result})" + "OnAgentCompleted (agent id: ${eventContext.agentId}, run id: ${eventContext.runId}, result: ${eventContext.result})" ) } onAgentExecutionFailed { eventContext -> _collectedEvents.add( - "OnAgentRunError (agent id: ${eventContext.agentId}, run id: ${eventContext.runId}, error: ${eventContext.throwable.message})" + "OnAgentExecutionFailed (agent id: ${eventContext.agentId}, run id: ${eventContext.runId}, error: ${eventContext.throwable.message})" ) } onAgentClosing { eventContext -> - _collectedEvents.add("OnAgentBeforeClose (agent id: ${eventContext.agentId})") + _collectedEvents.add( + "OnAgentClosing (agent id: ${eventContext.agentId})" + ) } onStrategyStarting { eventContext -> _collectedEvents.add( - "OnStrategyStarted (run id: ${eventContext.runId}, strategy: ${eventContext.strategy.name})" + "OnStrategyStarting (run id: ${eventContext.runId}, strategy: ${eventContext.strategy.name})" ) } onStrategyCompleted { eventContext -> _collectedEvents.add( - "OnStrategyFinished (run id: ${eventContext.runId}, strategy: ${eventContext.strategy.name}, result: ${eventContext.result})" + "OnStrategyCompleted (run id: ${eventContext.runId}, strategy: ${eventContext.strategy.name}, result: ${eventContext.result})" ) } onNodeExecutionStarting { eventContext -> _collectedEvents.add( - "OnBeforeNode (run id: ${eventContext.context.runId}, node: ${eventContext.node.name}, input: ${eventContext.input})" + "OnNodeExecutionStarting (run id: ${eventContext.context.runId}, node: ${eventContext.node.name}, input: ${eventContext.input})" ) } onNodeExecutionCompleted { eventContext -> _collectedEvents.add( - "OnAfterNode (run id: ${eventContext.context.runId}, node: ${eventContext.node.name}, input: ${eventContext.input}, output: ${eventContext.output})" + "OnNodeExecutionCompleted (run id: ${eventContext.context.runId}, node: ${eventContext.node.name}, input: ${eventContext.input}, output: ${eventContext.output})" ) } onNodeExecutionFailed { eventContext -> _collectedEvents.add( - "OnNodeExecutionError (run id: ${eventContext.context.runId}, node: ${eventContext.node.name}, error: ${eventContext.throwable.message})" + "OnNodeExecutionFailed (run id: ${eventContext.context.runId}, node: ${eventContext.node.name}, error: ${eventContext.throwable.message})" ) } onLLMCallStarting { eventContext -> _collectedEvents.add( - "OnBeforeLLMCall (run id: ${eventContext.runId}, prompt: ${eventContext.prompt.traceString}, tools: [${ + "OnLLMCallStarting (run id: ${eventContext.runId}, prompt: ${eventContext.prompt.traceString}, tools: [${ eventContext.tools.joinToString { it.name } @@ -82,7 +84,7 @@ class TestEventsCollector { onLLMCallCompleted { eventContext -> _collectedEvents.add( - "OnAfterLLMCall (run id: ${eventContext.runId}, prompt: ${eventContext.prompt.traceString}, model: ${eventContext.model.eventString}, tools: [${ + "OnLLMCallCompleted (run id: ${eventContext.runId}, prompt: ${eventContext.prompt.traceString}, model: ${eventContext.model.eventString}, tools: [${ eventContext.tools.joinToString { it.name } @@ -90,33 +92,33 @@ class TestEventsCollector { ) } - onToolExecutionStarting { eventContext -> + onToolCallStarting { eventContext -> _collectedEvents.add( - "OnToolCall (run id: ${eventContext.runId}, tool: ${eventContext.tool.name}, args: ${eventContext.toolArgs})" + "OnToolCallStarting (run id: ${eventContext.runId}, tool: ${eventContext.tool.name}, args: ${eventContext.toolArgs})" ) } onToolValidationFailed { eventContext -> _collectedEvents.add( - "OnToolValidationError (run id: ${eventContext.runId}, tool: ${eventContext.tool.name}, args: ${eventContext.toolArgs}, value: ${eventContext.error})" + "OnToolValidationFailed (run id: ${eventContext.runId}, tool: ${eventContext.tool.name}, args: ${eventContext.toolArgs}, value: ${eventContext.error})" ) } - onToolExecutionFailed { eventContext -> + onToolCallFailed { eventContext -> _collectedEvents.add( - "OnToolCallFailure (run id: ${eventContext.runId}, tool: ${eventContext.tool.name}, args: ${eventContext.toolArgs}, throwable: ${eventContext.throwable.message})" + "OnToolCallFailed (run id: ${eventContext.runId}, tool: ${eventContext.tool.name}, args: ${eventContext.toolArgs}, throwable: ${eventContext.throwable.message})" ) } - onToolExecutionCompleted { eventContext -> + onToolCallCompleted { eventContext -> _collectedEvents.add( - "OnToolCallResult (run id: ${eventContext.runId}, tool: ${eventContext.tool.name}, args: ${eventContext.toolArgs}, result: ${eventContext.result})" + "OnToolCallCompleted (run id: ${eventContext.runId}, tool: ${eventContext.tool.name}, args: ${eventContext.toolArgs}, result: ${eventContext.result})" ) } onLLMStreamingStarting { eventContext -> _collectedEvents.add( - "OnBeforeStream (run id: ${eventContext.runId}, prompt: ${eventContext.prompt.traceString}, model: ${eventContext.model.eventString}, tools: [${ + "OnLLMStreamingStarting (run id: ${eventContext.runId}, prompt: ${eventContext.prompt.traceString}, model: ${eventContext.model.eventString}, tools: [${ eventContext.tools.joinToString { it.name } }])" ) @@ -124,19 +126,19 @@ class TestEventsCollector { onLLMStreamingFrameReceived { eventContext -> _collectedEvents.add( - "OnStreamFrame (run id: ${eventContext.runId}, frame: ${eventContext.streamFrame})" + "OnLLMStreamingFrameReceived (run id: ${eventContext.runId}, frame: ${eventContext.streamFrame})" ) } onLLMStreamingFailed { eventContext -> _collectedEvents.add( - "OnStreamError (run id: ${eventContext.runId}, error: ${eventContext.error.message})" + "OnLLMStreamingFailed (run id: ${eventContext.runId}, error: ${eventContext.error.message})" ) } onLLMStreamingCompleted { eventContext -> _collectedEvents.add( - "OnAfterStream (run id: ${eventContext.runId}, prompt: ${eventContext.prompt.traceString}, model: ${eventContext.model.eventString}, tools: [${ + "OnLLMStreamingCompleted (run id: ${eventContext.runId}, prompt: ${eventContext.prompt.traceString}, model: ${eventContext.model.eventString}, tools: [${ eventContext.tools.joinToString { it.name } }])" ) diff --git a/agents/agents-features/agents-features-opentelemetry/src/jvmMain/kotlin/ai/koog/agents/features/opentelemetry/feature/OpenTelemetry.kt b/agents/agents-features/agents-features-opentelemetry/src/jvmMain/kotlin/ai/koog/agents/features/opentelemetry/feature/OpenTelemetry.kt index 4e64330e41..bf63d08a1a 100644 --- a/agents/agents-features/agents-features-opentelemetry/src/jvmMain/kotlin/ai/koog/agents/features/opentelemetry/feature/OpenTelemetry.kt +++ b/agents/agents-features/agents-features-opentelemetry/src/jvmMain/kotlin/ai/koog/agents/features/opentelemetry/feature/OpenTelemetry.kt @@ -388,7 +388,7 @@ public class OpenTelemetry { //region Tool Call - pipeline.interceptToolExecutionStarting(interceptContext) { eventContext -> + pipeline.interceptToolCallStarting(interceptContext) { eventContext -> logger.debug { "Execute OpenTelemetry tool call handler" } // Get current NodeExecuteSpan @@ -416,7 +416,7 @@ public class OpenTelemetry { spanProcessor.startSpan(executeToolSpan) } - pipeline.interceptToolExecutionCompleted(interceptContext) { eventContext -> + pipeline.interceptToolCallCompleted(interceptContext) { eventContext -> logger.debug { "Execute OpenTelemetry tool result handler" } // Get current ExecuteToolSpan @@ -445,7 +445,7 @@ public class OpenTelemetry { spanProcessor.endSpan(span = executeToolSpan) } - pipeline.interceptToolExecutionFailed(interceptContext) { eventContext -> + pipeline.interceptToolCallFailed(interceptContext) { eventContext -> logger.debug { "Execute OpenTelemetry tool call failure handler" } // Get current ExecuteToolSpan diff --git a/agents/agents-features/agents-features-snapshot/Module.md b/agents/agents-features/agents-features-snapshot/Module.md index 296feabfe2..e8ed8d8637 100644 --- a/agents/agents-features/agents-features-snapshot/Module.md +++ b/agents/agents-features/agents-features-snapshot/Module.md @@ -27,18 +27,18 @@ dependencies { } ``` -Then, install the Persistency feature when creating your agent: +Then, install the Persistence feature when creating your agent: ```kotlin val agent = AIAgent( // other configuration parameters ) { - install(Persistency) { + install(Persistence) { // Configure the storage provider - storage = InMemoryPersistencyStorageProvider("agent-persistence-id") + storage = InMemoryPersistenceStorageProvider() // Optional: enable automatic checkpoint creation after each node - enableAutomaticPersistency = true + enableAutomaticPersistence = true // Use `RollbackStrategy.Default` if you want to checkpoint the whole state machine and continue from the same node in your strategy graph, // or `RollbackStrategy.MessageHistoryOnly` if you only want to checkpoint messages @@ -64,9 +64,9 @@ val agent = AIAgent( llmModel = OllamaModels.Meta.LLAMA_3_2, strategy = singleRunStrategy(ToolCalls.SEQUENTIAL), ) { - install(Persistency) { + install(Persistence) { storage = snapshotProvider - enableAutomaticPersistency = true + enableAutomaticPersistence = true } } diff --git a/agents/agents-features/agents-features-snapshot/README.md b/agents/agents-features/agents-features-snapshot/README.md index f06662292c..5a3db5d998 100644 --- a/agents/agents-features/agents-features-snapshot/README.md +++ b/agents/agents-features/agents-features-snapshot/README.md @@ -146,7 +146,7 @@ class MyCustomStorageProvider : AgentCheckpointStorageProvider { // Implementation } - override suspend fun saveCheckpoint(agentCheckpointData: AgentCheckpointData) { + override suspend fun saveCheckpoint(agentId: String, agentCheckpointData: AgentCheckpointData) { // Implementation } diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/AgentCheckpointData.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/AgentCheckpointData.kt index cadfb1ee78..95785109ab 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/AgentCheckpointData.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/AgentCheckpointData.kt @@ -6,7 +6,7 @@ import ai.koog.agents.core.agent.context.AIAgentContext import ai.koog.agents.core.agent.context.AgentContextData import ai.koog.agents.core.agent.context.RollbackStrategy import ai.koog.agents.core.annotation.InternalAgentsApi -import ai.koog.agents.snapshot.providers.PersistencyUtils +import ai.koog.agents.snapshot.providers.PersistenceUtils import ai.koog.prompt.message.Message import kotlinx.datetime.Instant import kotlinx.serialization.Serializable @@ -47,10 +47,10 @@ public fun tombstoneCheckpoint(time: Instant): AgentCheckpointData { return AgentCheckpointData( checkpointId = Uuid.random().toString(), createdAt = time, - nodeId = PersistencyUtils.TOMBSTONE_CHECKPOINT_NAME, + nodeId = PersistenceUtils.TOMBSTONE_CHECKPOINT_NAME, lastInput = JsonNull, messageHistory = emptyList(), - properties = mapOf(PersistencyUtils.TOMBSTONE_CHECKPOINT_NAME to JsonPrimitive(true)) + properties = mapOf(PersistenceUtils.TOMBSTONE_CHECKPOINT_NAME to JsonPrimitive(true)) ) } @@ -85,4 +85,4 @@ public fun AgentCheckpointData.toAgentContextData( * and the value is a JSON primitive set to `true`, otherwise `false`. */ public fun AgentCheckpointData.isTombstone(): Boolean = - properties?.get(PersistencyUtils.TOMBSTONE_CHECKPOINT_NAME) == JsonPrimitive(true) + properties?.get(PersistenceUtils.TOMBSTONE_CHECKPOINT_NAME) == JsonPrimitive(true) diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistency.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistency.kt index 062f329ac5..73b3b237bf 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistency.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistency.kt @@ -16,7 +16,7 @@ import ai.koog.agents.core.feature.AIAgentGraphPipeline import ai.koog.agents.core.feature.InterceptContext import ai.koog.agents.core.tools.DirectToolCallsEnabler import ai.koog.agents.core.tools.annotations.InternalAgentToolsApi -import ai.koog.agents.snapshot.providers.PersistencyStorageProvider +import ai.koog.agents.snapshot.providers.PersistenceStorageProvider import ai.koog.prompt.message.Message import io.github.oshai.kotlinlogging.KotlinLogging import kotlinx.datetime.Clock @@ -30,6 +30,15 @@ import kotlin.time.ExperimentalTime import kotlin.uuid.ExperimentalUuidApi import kotlin.uuid.Uuid +@Deprecated( + "`Persistency` has been renamed to `Persistence`", + replaceWith = ReplaceWith( + expression = "Persistence", + "ai.koog.agents.snapshot.feature.Persistence" + ) +) +public typealias Persistency = Persistence + /** * A feature that provides checkpoint functionality for AI agents. * @@ -40,14 +49,14 @@ import kotlin.uuid.Uuid * - Persisting agent state across sessions * * The feature can be configured to automatically create checkpoints after each node execution - * using the [PersistencyFeatureConfig.enableAutomaticPersistency] option. + * using the [PersistenceFeatureConfig.enableAutomaticPersistence] option. * - * @property persistencyStorageProvider The provider responsible for storing and retrieving checkpoints + * @property persistenceStorageProvider The provider responsible for storing and retrieving checkpoints * @property currentNodeId The ID of the node currently being executed */ @OptIn(ExperimentalUuidApi::class, ExperimentalTime::class, InternalAgentsApi::class) -public class Persistency( - private val persistencyStorageProvider: PersistencyStorageProvider, +public class Persistence( + private val persistenceStorageProvider: PersistenceStorageProvider, internal val clock: Clock = Clock.System, ) { /** @@ -68,7 +77,7 @@ public class Persistency( * A registry for managing rollback tools within the persistence system. * * The `rollbackToolRegistry` plays a key role in supporting the rollback mechanism in the - * persistency operations, allowing seamless state restoration for tools **with side-effects** to specified or latest + * persistence operations, allowing seamless state restoration for tools **with side-effects** to specified or latest * checkpoints as needed. * */ @@ -91,7 +100,7 @@ public class Persistency( /** * Feature companion object that implements [AIAgentFeature] for the checkpoint functionality. */ - public companion object Feature : AIAgentGraphFeature { + public companion object Feature : AIAgentGraphFeature { private val logger = KotlinLogging.logger { } private val json = Json { @@ -101,14 +110,14 @@ public class Persistency( /** * The storage key used to identify this feature in the agent's feature registry. */ - override val key: AIAgentStorageKey = AIAgentStorageKey("agents-features-snapshot") + override val key: AIAgentStorageKey = AIAgentStorageKey("agents-features-snapshot") /** * Creates the default configuration for this feature. * - * @return A new instance of [PersistencyFeatureConfig] with default settings + * @return A new instance of [PersistenceFeatureConfig] with default settings */ - override fun createInitialConfig(): PersistencyFeatureConfig = PersistencyFeatureConfig() + override fun createInitialConfig(): PersistenceFeatureConfig = PersistenceFeatureConfig() /** * Installs the checkpoint feature into the agent pipeline. @@ -122,10 +131,10 @@ public class Persistency( * @param pipeline The agent pipeline to install the feature into */ override fun install( - config: PersistencyFeatureConfig, + config: PersistenceFeatureConfig, pipeline: AIAgentGraphPipeline ) { - val featureImpl = Persistency(config.storage) + val featureImpl = Persistence(config.storage) featureImpl.rollbackStrategy = config.rollbackStrategy featureImpl.rollbackToolRegistry = config.rollbackToolRegistry val interceptContext = InterceptContext(this, featureImpl) @@ -155,7 +164,7 @@ public class Persistency( return@interceptNodeExecutionCompleted } - if (config.enableAutomaticPersistency) { + if (config.enableAutomaticPersistence) { createCheckpoint( agentContext = eventCtx.context, nodeId = eventCtx.node.id, @@ -170,8 +179,8 @@ public class Persistency( } pipeline.interceptStrategyCompleted(interceptContext) { ctx -> - if (config.enableAutomaticPersistency && config.rollbackStrategy == RollbackStrategy.Default) { - ctx.feature.createTombstoneCheckpoint(ctx.feature.clock.now()) + if (config.enableAutomaticPersistence && config.rollbackStrategy == RollbackStrategy.Default) { + ctx.feature.createTombstoneCheckpoint(ctx.agentId, ctx.feature.clock.now()) } } } @@ -219,7 +228,7 @@ public class Persistency( ) } - saveCheckpoint(checkpoint) + saveCheckpoint(agentContext.agentId, checkpoint) return checkpoint } @@ -233,9 +242,9 @@ public class Persistency( * @return The created tombstone checkpoint data. */ @InternalAgentsApi - public suspend fun createTombstoneCheckpoint(time: Instant): AgentCheckpointData { + public suspend fun createTombstoneCheckpoint(agentId: String, time: Instant): AgentCheckpointData { val checkpoint = tombstoneCheckpoint(time) - saveCheckpoint(checkpoint) + saveCheckpoint(agentId, checkpoint) return checkpoint } @@ -252,8 +261,8 @@ public class Persistency( * * @param checkpointData The checkpoint data to save */ - public suspend fun saveCheckpoint(checkpointData: AgentCheckpointData) { - persistencyStorageProvider.saveCheckpoint(checkpointData) + public suspend fun saveCheckpoint(agentId: String, checkpointData: AgentCheckpointData) { + persistenceStorageProvider.saveCheckpoint(agentId, checkpointData) } /** @@ -261,8 +270,8 @@ public class Persistency( * * @return The latest checkpoint data, or null if no checkpoint exists */ - public suspend fun getLatestCheckpoint(): AgentCheckpointData? = - persistencyStorageProvider.getLatestCheckpoint() + public suspend fun getLatestCheckpoint(agentId: String): AgentCheckpointData? = + persistenceStorageProvider.getLatestCheckpoint(agentId) /** * Retrieves a specific checkpoint by ID for the specified agent. @@ -270,8 +279,8 @@ public class Persistency( * @param checkpointId The ID of the checkpoint to retrieve * @return The checkpoint data with the specified ID, or null if not found */ - public suspend fun getCheckpointById(checkpointId: String): AgentCheckpointData? = - persistencyStorageProvider.getCheckpoints().firstOrNull { it.checkpointId == checkpointId } + public suspend fun getCheckpointById(agentId: String, checkpointId: String): AgentCheckpointData? = + persistenceStorageProvider.getCheckpoints(agentId).firstOrNull { it.checkpointId == checkpointId } /** * Sets the execution point of an agent to a specific state. @@ -311,7 +320,7 @@ public class Persistency( checkpointId: String, agentContext: AIAgentContext ): AgentCheckpointData? { - val checkpoint: AgentCheckpointData? = getCheckpointById(checkpointId) + val checkpoint: AgentCheckpointData? = getCheckpointById(agentContext.agentId, checkpointId) if (checkpoint != null) { agentContext.store( checkpoint.toAgentContextData(rollbackStrategy) { context -> @@ -368,7 +377,7 @@ public class Persistency( public suspend fun rollbackToLatestCheckpoint( agentContext: AIAgentContext ): AgentCheckpointData? { - val checkpoint: AgentCheckpointData? = getLatestCheckpoint() + val checkpoint: AgentCheckpointData? = getLatestCheckpoint(agentContext.agentId) if (checkpoint?.isTombstone() ?: true) { return null } @@ -381,50 +390,50 @@ public class Persistency( /** * Extension function to access the checkpoint feature from an agent context. * - * @return The [Persistency] feature instance for this agent + * @return The [Persistence] feature instance for this agent * @throws IllegalStateException if the checkpoint feature is not installed */ -public fun AIAgentContext.persistency(): Persistency = agent.persistency() +public fun AIAgentContext.persistence(): Persistence = agent.persistence() /** - * Retrieves the persistency feature for the AI agent. + * Retrieves the persistence feature for the AI agent. * - * @return The persistency feature associated with the AI agent. - * @throws IllegalStateException if the persistency feature is not available. + * @return The persistence feature associated with the AI agent. + * @throws IllegalStateException if the persistence feature is not available. */ -public fun AIAgent<*, *>.persistency(): Persistency = featureOrThrow(Persistency.Feature) +public fun AIAgent<*, *>.persistence(): Persistence = featureOrThrow(Persistence.Feature) /** - * Executes the provided action within the context of the AI agent's persistency layer. + * Executes the provided action within the context of the AI agent's persistence layer. * - * This function enhances agents with persistent state management capabilities by leveraging the `Persistency` component - * within the current `AIAgentContext`. The supplied action is executed with the persistency layer, enabling operations + * This function enhances agents with persistent state management capabilities by leveraging the `Persistence` component + * within the current `AIAgentContext`. The supplied action is executed with the persistence layer, enabling operations * that require consistent and reliable state management across the lifecycle of the agent. * - * @param action A suspendable lambda function that receives the `Persistency` instance and the current `AIAgentContext` - * as its parameters. This allows custom logic that interacts with the persistency layer to be executed. + * @param action A suspendable lambda function that receives the `Persistence` instance and the current `AIAgentContext` + * as its parameters. This allows custom logic that interacts with the persistence layer to be executed. * @return A result of type [T] produced by the execution of the provided action. */ -public suspend fun AIAgentContext.withPersistency( - action: suspend Persistency.(AIAgentContext) -> T -): T = this.persistency().action(this) +public suspend fun AIAgentContext.withPersistence( + action: suspend Persistence.(AIAgentContext) -> T +): T = this.persistence().action(this) /** - * Executes the provided action within the context of the agent's persistency layer if the agent is in a running state. + * Executes the provided action within the context of the agent's persistence layer if the agent is in a running state. * - * This function allows interaction with the persistency mechanism associated with the agent, ensuring that + * This function allows interaction with the persistence mechanism associated with the agent, ensuring that * the operation is carried out in the correct execution context. * - * @param action A suspending function defining operations to perform using the agent's persistency mechanism + * @param action A suspending function defining operations to perform using the agent's persistence mechanism * and the current agent context. * @return The result of the execution of the provided action. * @throws IllegalStateException If the agent is not in a running state when this function is called. */ @OptIn(InternalAgentsApi::class) -public suspend fun AIAgent<*, *>.withPersistency( - action: suspend Persistency.(AIAgentContext) -> T +public suspend fun AIAgent<*, *>.withPersistence( + action: suspend Persistence.(AIAgentContext) -> T ): T = when (val state = getState()) { - is Running<*> -> this.persistency().action(state.rootContext) + is Running<*> -> this.persistence().action(state.rootContext) else -> throw IllegalStateException("Agent is not running. Current agents's state: $state") } diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/PersistencyFeatureConfig.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/PersistencyFeatureConfig.kt index cd0d5ec068..e2a8d4f81e 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/PersistencyFeatureConfig.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/PersistencyFeatureConfig.kt @@ -3,35 +3,44 @@ package ai.koog.agents.snapshot.feature import ai.koog.agents.core.agent.context.RollbackStrategy import ai.koog.agents.core.feature.config.FeatureConfig import ai.koog.agents.snapshot.feature.RollbackToolRegistry -import ai.koog.agents.snapshot.providers.NoPersistencyStorageProvider -import ai.koog.agents.snapshot.providers.PersistencyStorageProvider +import ai.koog.agents.snapshot.providers.NoPersistenceStorageProvider +import ai.koog.agents.snapshot.providers.PersistenceStorageProvider + +@Deprecated( + "`PersistencyFeatureConfig` has been renamed to `PersistenceFeatureConfig`", + replaceWith = ReplaceWith( + expression = "PersistenceFeatureConfig", + "ai.koog.agents.snapshot.feature.PersistenceFeatureConfig" + ) +) +public typealias PersistencyFeatureConfig = PersistenceFeatureConfig /** * Configuration class for the Snapshot feature. */ -public class PersistencyFeatureConfig : FeatureConfig() { +public class PersistenceFeatureConfig : FeatureConfig() { /** * Defines the storage mechanism for persisting snapshots in the feature. - * This property accepts implementations of [PersistencyStorageProvider], + * This property accepts implementations of [PersistenceStorageProvider], * which manage how snapshots are stored and retrieved. * - * By default, the storage is set to [NoPersistencyStorageProvider], a no-op + * By default, the storage is set to [NoPersistenceStorageProvider], a no-op * implementation that does not persist any data. To enable actual state - * persistence, assign a custom implementation of [PersistencyStorageProvider] + * persistence, assign a custom implementation of [PersistenceStorageProvider] * to this property. */ - public var storage: PersistencyStorageProvider = NoPersistencyStorageProvider() + public var storage: PersistenceStorageProvider = NoPersistenceStorageProvider() /** * Controls whether the feature's state should be automatically persisted. * When enabled, changes to the checkpoint are saved after each node execution through the assigned - * [PersistencyStorageProvider], ensuring the state can be restored later. + * [PersistenceStorageProvider], ensuring the state can be restored later. * - * Set this property to `true` to turn on automatic state persistency, + * Set this property to `true` to turn on automatic state persistence, * or `false` to disable it. */ - public var enableAutomaticPersistency: Boolean = false + public var enableAutomaticPersistence: Boolean = false /** * Determines the strategy to be used for rolling back the agent's state to a previously saved checkpoint. @@ -46,7 +55,7 @@ public class PersistencyFeatureConfig : FeatureConfig() { /** * Registry for rollback tools used when rolling back to checkpoints. - * Configure it during Persistency installation. Do not mutate later in withPersistency. + * Configure it during Persistence installation. Do not mutate later in withPersistence. */ public var rollbackToolRegistry: RollbackToolRegistry = RollbackToolRegistry.EMPTY } diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/InMemoryPersistencyStorageProvider.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/InMemoryPersistencyStorageProvider.kt index 10dbd0b468..40095ad466 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/InMemoryPersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/InMemoryPersistencyStorageProvider.kt @@ -5,29 +5,28 @@ import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock /** - * In-memory implementation of [PersistencyStorageProvider]. + * In-memory implementation of [PersistenceStorageProvider]. * This provider stores snapshots in a mutable map. */ -public class InMemoryPersistencyStorageProvider(private val persistenceId: String) : PersistencyStorageProvider { +public class InMemoryPersistenceStorageProvider() : PersistenceStorageProvider { private val mutex = Mutex() private val snapshotMap = mutableMapOf>() - override suspend fun getCheckpoints(): List { + override suspend fun getCheckpoints(agentId: String): List { mutex.withLock { - return snapshotMap[persistenceId] ?: emptyList() + return snapshotMap[agentId] ?: emptyList() } } - override suspend fun saveCheckpoint(agentCheckpointData: AgentCheckpointData) { + override suspend fun saveCheckpoint(agentId: String, agentCheckpointData: AgentCheckpointData) { mutex.withLock { - val agentId = persistenceId snapshotMap[agentId] = (snapshotMap[agentId] ?: emptyList()) + agentCheckpointData } } - override suspend fun getLatestCheckpoint(): AgentCheckpointData? { + override suspend fun getLatestCheckpoint(agentId: String): AgentCheckpointData? { mutex.withLock { - return snapshotMap[persistenceId]?.maxBy { it.createdAt } + return snapshotMap[agentId]?.maxBy { it.createdAt } } } } diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/NoPersistencyStorageProvider.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/NoPersistencyStorageProvider.kt index 8187c302ac..898f350564 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/NoPersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/NoPersistencyStorageProvider.kt @@ -4,22 +4,23 @@ import ai.koog.agents.snapshot.feature.AgentCheckpointData import io.github.oshai.kotlinlogging.KotlinLogging /** - * No-op implementation of [PersistencyStorageProvider]. + * No-op implementation of [PersistenceStorageProvider]. */ -public class NoPersistencyStorageProvider : PersistencyStorageProvider { +public class NoPersistenceStorageProvider : PersistenceStorageProvider { private val logger = KotlinLogging.logger { } - override suspend fun getCheckpoints(): List { + override suspend fun getCheckpoints(agentId: String): List { return emptyList() } override suspend fun saveCheckpoint( + agentId: String, agentCheckpointData: AgentCheckpointData ) { logger.info { "Snapshot feature is not enabled in the agent. Snapshot will not be saved: $agentCheckpointData" } } - override suspend fun getLatestCheckpoint(): AgentCheckpointData? { + override suspend fun getLatestCheckpoint(agentId: String): AgentCheckpointData? { return null } } diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistencyUtils.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistenceUtils.kt similarity index 77% rename from agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistencyUtils.kt rename to agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistenceUtils.kt index ecb6c60bdc..087b4c1642 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistencyUtils.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistenceUtils.kt @@ -2,10 +2,19 @@ package ai.koog.agents.snapshot.providers import kotlinx.serialization.json.Json +@Deprecated( + "`PersistencyUtils` has been renamed to `PersistenceUtils`", + replaceWith = ReplaceWith( + expression = "PersistenceUtils", + "ai.koog.agents.snapshot.providers.PersistenceUtils" + ) +) +public typealias PersistencyUtils = PersistenceUtils + /** - * Utility object containing configurations and utilities for handling persistency-related operations. + * Utility object containing configurations and utilities for handling persistence-related operations. */ -public object PersistencyUtils { +public object PersistenceUtils { /** * A preconfigured JSON instance for handling serialization and deserialization of checkpoint data. * diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistencyStorageProvider.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistencyStorageProvider.kt index 3b0c873eee..64d8815c46 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistencyStorageProvider.kt @@ -4,8 +4,17 @@ package ai.koog.agents.snapshot.providers import ai.koog.agents.snapshot.feature.AgentCheckpointData -public interface PersistencyStorageProvider { - public suspend fun getCheckpoints(): List - public suspend fun saveCheckpoint(agentCheckpointData: AgentCheckpointData) - public suspend fun getLatestCheckpoint(): AgentCheckpointData? +@Deprecated( + "`PersistencyStorageProvider` has been renamed to `PersistenceStorageProvider`", + replaceWith = ReplaceWith( + expression = "PersistenceStorageProvider", + "ai.koog.agents.snapshot.feature.PersistenceStorageProvider" + ) +) +public typealias PersistencyStorageProvider = PersistenceStorageProvider + +public interface PersistenceStorageProvider { + public suspend fun getCheckpoints(agentId: String): List + public suspend fun saveCheckpoint(agentId: String, agentCheckpointData: AgentCheckpointData) + public suspend fun getLatestCheckpoint(agentId: String): AgentCheckpointData? } diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/file/FilePersistencyStorageProvider.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/file/FilePersistencyStorageProvider.kt index acdc2114d7..a8c0d099ea 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/file/FilePersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/file/FilePersistencyStorageProvider.kt @@ -1,16 +1,25 @@ package ai.koog.agents.snapshot.providers.file import ai.koog.agents.snapshot.feature.AgentCheckpointData -import ai.koog.agents.snapshot.providers.PersistencyStorageProvider -import ai.koog.agents.snapshot.providers.PersistencyUtils +import ai.koog.agents.snapshot.providers.PersistenceStorageProvider +import ai.koog.agents.snapshot.providers.PersistenceUtils import ai.koog.rag.base.files.FileSystemProvider import ai.koog.rag.base.files.createDirectory import ai.koog.rag.base.files.readText import ai.koog.rag.base.files.writeText import kotlinx.serialization.json.Json +@Deprecated( + "`FilePersistencyStorageProvider` has been renamed to `FilePersistenceStorageProvider`", + replaceWith = ReplaceWith( + expression = "FilePersistenceStorageProvider", + "ai.koog.agents.snapshot.providers.file.FilePersistenceStorageProvider" + ) +) +public typealias FilePersistencyStorageProvider = FilePersistenceStorageProvider + /** - * A file-based implementation of [PersistencyStorageProvider] that stores agent checkpoints in a file system. + * A file-based implementation of [PersistenceStorageProvider] that stores agent checkpoints in a file system. * * This implementation organizes checkpoints by agent ID and uses JSON serialization for storing and retrieving * checkpoint data. It relies on [FileSystemProvider.ReadWrite] for file system operations. @@ -19,12 +28,11 @@ import kotlinx.serialization.json.Json * @param fs A file system provider enabling read and write operations for file storage. * @param root Root file path where the checkpoint storage will organize data. */ -public open class FilePersistencyStorageProvider( - private val persistenceId: String, +public open class FilePersistenceStorageProvider( private val fs: FileSystemProvider.ReadWrite, private val root: Path, - private val json: Json = PersistencyUtils.defaultCheckpointJson -) : PersistencyStorageProvider { + private val json: Json = PersistenceUtils.defaultCheckpointJson +) : PersistenceStorageProvider { /** * Directory where agent checkpoints are stored @@ -40,9 +48,9 @@ public open class FilePersistencyStorageProvider( /** * Directory for a specific agent's checkpoints */ - private suspend fun agentCheckpointsDir(): Path { + private suspend fun agentCheckpointsDir(agentId: String): Path { val checkpointsDir = checkpointsDir() - val agentDir = fs.joinPath(checkpointsDir, persistenceId) + val agentDir = fs.joinPath(checkpointsDir, agentId) if (!fs.exists(agentDir)) { fs.createDirectory(agentDir) } @@ -52,13 +60,13 @@ public open class FilePersistencyStorageProvider( /** * Get the path to a specific checkpoint file */ - private suspend fun checkpointPath(checkpointId: String): Path { - val agentDir = agentCheckpointsDir() + private suspend fun checkpointPath(agentId: String, checkpointId: String): Path { + val agentDir = agentCheckpointsDir(agentId) return fs.joinPath(agentDir, checkpointId) } - override suspend fun getCheckpoints(): List { - val agentDir = agentCheckpointsDir() + override suspend fun getCheckpoints(agentId: String): List { + val agentDir = agentCheckpointsDir(agentId) if (!fs.exists(agentDir)) { return emptyList() @@ -74,14 +82,14 @@ public open class FilePersistencyStorageProvider( } } - override suspend fun saveCheckpoint(agentCheckpointData: AgentCheckpointData) { - val checkpointPath = checkpointPath(agentCheckpointData.checkpointId) + override suspend fun saveCheckpoint(agentId: String, agentCheckpointData: AgentCheckpointData) { + val checkpointPath = checkpointPath(agentId, agentCheckpointData.checkpointId) val serialized = json.encodeToString(AgentCheckpointData.serializer(), agentCheckpointData) fs.writeText(checkpointPath, serialized) } - override suspend fun getLatestCheckpoint(): AgentCheckpointData? { - return getCheckpoints() + override suspend fun getLatestCheckpoint(agentId: String): AgentCheckpointData? { + return getCheckpoints(agentId) .maxByOrNull { it.createdAt } } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmMain/kotlin/ai/koog/agents/snapshot/providers/file/JVMFilePersistencyStorageProvider.kt b/agents/agents-features/agents-features-snapshot/src/jvmMain/kotlin/ai/koog/agents/snapshot/providers/file/JVMFilePersistencyStorageProvider.kt index 2f8e96d365..bd00b4815c 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmMain/kotlin/ai/koog/agents/snapshot/providers/file/JVMFilePersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmMain/kotlin/ai/koog/agents/snapshot/providers/file/JVMFilePersistencyStorageProvider.kt @@ -1,12 +1,21 @@ package ai.koog.agents.snapshot.providers.file -import ai.koog.agents.snapshot.providers.PersistencyUtils +import ai.koog.agents.snapshot.providers.PersistenceUtils import ai.koog.rag.base.files.JVMFileSystemProvider import kotlinx.serialization.json.Json import java.nio.file.Path +@Deprecated( + "`JVMFilePersistencyStorageProvider` has been renamed to `JVMFilePersistenceStorageProvider`", + replaceWith = ReplaceWith( + expression = "JVMFilePersistenceStorageProvider", + "ai.koog.agents.snapshot.providers.file.JVMFilePersistenceStorageProvider" + ) +) +public typealias JVMFilePersistencyStorageProvider = JVMFilePersistenceStorageProvider + /** - * A JVM-specific implementation of [FilePersistencyStorageProvider] for managing agent checkpoints + * A JVM-specific implementation of [FilePersistenceStorageProvider] for managing agent checkpoints * in a file system. * * This class utilizes JVM's [Path] for file system operations and [JVMFileSystemProvider.ReadWrite] @@ -16,16 +25,14 @@ import java.nio.file.Path * Use this class to persistently store and retrieve agent checkpoints to and from a file-based system * in JVM environments. * - * @constructor Initializes the [JVMFilePersistencyStorageProvider] with a specified root directory [root]. + * @constructor Initializes the [JVMFilePersistenceStorageProvider] with a specified root directory [root]. * @param root The root directory where all agent checkpoints will be stored. */ -public class JVMFilePersistencyStorageProvider( +public class JVMFilePersistenceStorageProvider( root: Path, - persistenceId: String, - json: Json = PersistencyUtils.defaultCheckpointJson -) : FilePersistencyStorageProvider( + json: Json = PersistenceUtils.defaultCheckpointJson +) : FilePersistenceStorageProvider( fs = JVMFileSystemProvider.ReadWrite, root = root, - persistenceId = persistenceId, json = json ) diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/CheckpointSerializationTest.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/CheckpointSerializationTest.kt index dc1f478335..76c035fab9 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/CheckpointSerializationTest.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/CheckpointSerializationTest.kt @@ -1,6 +1,6 @@ import ai.koog.agents.snapshot.feature.AgentCheckpointData import ai.koog.agents.snapshot.feature.tombstoneCheckpoint -import ai.koog.agents.snapshot.providers.PersistencyUtils +import ai.koog.agents.snapshot.providers.PersistenceUtils import ai.koog.prompt.message.Message import ai.koog.prompt.message.RequestMetaInfo import ai.koog.prompt.message.ResponseMetaInfo @@ -35,7 +35,7 @@ class CheckpointSerializationTest { properties = null ) - val json = PersistencyUtils.defaultCheckpointJson + val json = PersistenceUtils.defaultCheckpointJson val serialized = json.encodeToString(AgentCheckpointData.serializer(), checkpoint) // properties should be omitted due to explicitNulls = false @@ -93,7 +93,7 @@ class CheckpointSerializationTest { properties = properties ) - val json = PersistencyUtils.defaultCheckpointJson + val json = PersistenceUtils.defaultCheckpointJson val serialized = json.encodeToString(AgentCheckpointData.serializer(), checkpoint) val restored = json.decodeFromString(AgentCheckpointData.serializer(), serialized) @@ -104,7 +104,7 @@ class CheckpointSerializationTest { @Test fun `serialize and deserialize tombstone checkpoint`() { val checkpoint = tombstoneCheckpoint(Clock.System.now()) - val json = PersistencyUtils.defaultCheckpointJson + val json = PersistenceUtils.defaultCheckpointJson val serialized = json.encodeToString(AgentCheckpointData.serializer(), checkpoint) val restored = json.decodeFromString(AgentCheckpointData.serializer(), serialized) diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/CheckpointsTests.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/CheckpointsTests.kt index 033fc9d49a..8c2c7a7e00 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/CheckpointsTests.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/CheckpointsTests.kt @@ -11,10 +11,10 @@ import ai.koog.agents.core.tools.Tool import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.ext.tool.SayToUser import ai.koog.agents.snapshot.feature.AgentCheckpointData -import ai.koog.agents.snapshot.feature.Persistency +import ai.koog.agents.snapshot.feature.Persistence import ai.koog.agents.snapshot.feature.RollbackToolRegistry -import ai.koog.agents.snapshot.feature.withPersistency -import ai.koog.agents.snapshot.providers.InMemoryPersistencyStorageProvider +import ai.koog.agents.snapshot.feature.withPersistence +import ai.koog.agents.snapshot.providers.InMemoryPersistenceStorageProvider import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.prompt.dsl.prompt import ai.koog.prompt.llm.OllamaModels @@ -69,7 +69,7 @@ class CheckpointsTests { } val checkpoint by node { input -> println("checkpoint save") - withPersistency { ctx -> + withPersistence { ctx -> createCheckpoint( ctx, currentNodeId ?: error("currentNodeId not set"), @@ -88,7 +88,7 @@ class CheckpointsTests { println("checkpoint load") if (!loaded) { loaded = true - withPersistency { ctx -> + withPersistence { ctx -> rollbackToCheckpoint("cpt-100500", ctx) } } @@ -104,8 +104,8 @@ class CheckpointsTests { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider() } } @@ -124,8 +124,8 @@ class CheckpointsTests { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider() } } @@ -209,7 +209,7 @@ class CheckpointsTests { // Node that creates a checkpoint val saveCheckpoint by node { input -> - withPersistency { ctx -> + withPersistence { ctx -> createCheckpoint( ctx, currentNodeId ?: error("currentNodeId not set"), @@ -261,8 +261,8 @@ class CheckpointsTests { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider() } } @@ -294,8 +294,8 @@ class CheckpointsTests { agentConfig = agentConfig, toolRegistry = localToolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("agent-tools-rollback-1") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider() rollbackToolRegistry = RollbackToolRegistry { registerRollback(WriteKVTool, DeleteKVTool) } @@ -322,7 +322,7 @@ class CheckpointsTests { assertContains(databaseMap, "user-2") assertContains(databaseMap, "user-3") - agent.withPersistency { ctx -> + agent.withPersistence { ctx -> println("ctx outside: $this") println("ctx outside [hash]: ${this.hashCode()}") rollbackToCheckpoint("ckpt-1", ctx) @@ -347,7 +347,7 @@ class CheckpointsTests { @Test fun testRestoreFromSingleCheckpoint() = runTest { - val checkpointStorageProvider = InMemoryPersistencyStorageProvider("testAgentId") + val checkpointStorageProvider = InMemoryPersistenceStorageProvider() val time = Clock.System.now() val agentId = "testAgentId" @@ -362,7 +362,7 @@ class CheckpointsTests { ) ) - checkpointStorageProvider.saveCheckpoint(testCheckpoint) + checkpointStorageProvider.saveCheckpoint(agentId, testCheckpoint) val agent = AIAgent( promptExecutor = getMockExecutor { }, @@ -371,7 +371,7 @@ class CheckpointsTests { toolRegistry = toolRegistry, id = agentId ) { - install(Persistency) { + install(Persistence) { storage = checkpointStorageProvider } } @@ -388,7 +388,7 @@ class CheckpointsTests { @Test fun testRestoreFromLatestCheckpoint() = runTest { - val checkpointStorageProvider = InMemoryPersistencyStorageProvider("testAgentId") + val checkpointStorageProvider = InMemoryPersistenceStorageProvider() val time = Clock.System.now() val agentId = "testAgentId" @@ -414,8 +414,8 @@ class CheckpointsTests { ) ) - checkpointStorageProvider.saveCheckpoint(testCheckpoint) - checkpointStorageProvider.saveCheckpoint(testCheckpoint2) + checkpointStorageProvider.saveCheckpoint(agentId, testCheckpoint) + checkpointStorageProvider.saveCheckpoint(agentId, testCheckpoint2) val agent = AIAgent( promptExecutor = getMockExecutor { }, @@ -424,7 +424,7 @@ class CheckpointsTests { toolRegistry = toolRegistry, id = agentId ) { - install(Persistency) { + install(Persistence) { storage = checkpointStorageProvider } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/NodeUniquenessCheckpointTest.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/NodeUniquenessCheckpointTest.kt index f0bbef7b18..8bd3803cd1 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/NodeUniquenessCheckpointTest.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/NodeUniquenessCheckpointTest.kt @@ -7,8 +7,8 @@ import ai.koog.agents.core.dsl.builder.forwardTo import ai.koog.agents.core.dsl.builder.strategy import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.ext.tool.SayToUser -import ai.koog.agents.snapshot.feature.Persistency -import ai.koog.agents.snapshot.providers.InMemoryPersistencyStorageProvider +import ai.koog.agents.snapshot.feature.Persistence +import ai.koog.agents.snapshot.providers.InMemoryPersistenceStorageProvider import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.prompt.dsl.prompt import ai.koog.prompt.executor.model.PromptExecutor @@ -104,8 +104,8 @@ class NodeUniquenessCheckpointTest { toolRegistry = toolRegistry ) { // Install the AgentCheckpoint feature - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider() } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRestoreStrategyTests.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRestoreStrategyTests.kt index 8b81618e2b..2dde64cdbe 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRestoreStrategyTests.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRestoreStrategyTests.kt @@ -3,8 +3,8 @@ import ai.koog.agents.core.agent.AIAgentService import ai.koog.agents.core.agent.config.AIAgentConfig import ai.koog.agents.core.agent.context.RollbackStrategy import ai.koog.agents.snapshot.feature.AgentCheckpointData -import ai.koog.agents.snapshot.feature.Persistency -import ai.koog.agents.snapshot.providers.InMemoryPersistencyStorageProvider +import ai.koog.agents.snapshot.feature.Persistence +import ai.koog.agents.snapshot.providers.InMemoryPersistenceStorageProvider import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.prompt.dsl.prompt import ai.koog.prompt.llm.OllamaModels @@ -16,10 +16,12 @@ import kotlinx.serialization.json.JsonPrimitive import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Test -class PersistencyRestoreStrategyTests { +class PersistenceRestoreStrategyTests { @Test fun `rollback Default resumes from checkpoint node`() = runTest { - val provider = InMemoryPersistencyStorageProvider("persistency-restore-default") + val provider = InMemoryPersistenceStorageProvider() + + val agentId = "persistency-restore-default" val checkpoint = AgentCheckpointData( checkpointId = "chk-1", @@ -29,7 +31,7 @@ class PersistencyRestoreStrategyTests { messageHistory = listOf(Message.Assistant("History Before", ResponseMetaInfo(Clock.System.now()))), ) - provider.saveCheckpoint(checkpoint) + provider.saveCheckpoint(agentId, checkpoint) val agent = AIAgent( promptExecutor = getMockExecutor { }, @@ -39,11 +41,12 @@ class PersistencyRestoreStrategyTests { model = OllamaModels.Meta.LLAMA_3_2, maxAgentIterations = 10 ), + id = agentId ) { - install(Persistency) { + install(Persistence) { storage = provider - // We only need restore on start; automatic persistency doesn't matter here - enableAutomaticPersistency = true + // We only need restore on start; automatic persistence doesn't matter here + enableAutomaticPersistence = true rollbackStrategy = RollbackStrategy.Default } } @@ -59,7 +62,7 @@ class PersistencyRestoreStrategyTests { @Test fun `rollback MessageHistoryOnly starts from beginning`() = runTest { - val provider = InMemoryPersistencyStorageProvider("persistency-restore-history-only") + val provider = InMemoryPersistenceStorageProvider() val agentService = AIAgentService( promptExecutor = getMockExecutor { }, @@ -70,9 +73,9 @@ class PersistencyRestoreStrategyTests { maxAgentIterations = 10 ), ) { - install(Persistency) { + install(Persistence) { storage = provider - enableAutomaticPersistency = true + enableAutomaticPersistence = true rollbackStrategy = RollbackStrategy.MessageHistoryOnly } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRunsTwiceTest.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRunsTwiceTest.kt index 2f791514f8..1c5a120399 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRunsTwiceTest.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRunsTwiceTest.kt @@ -1,8 +1,8 @@ import ai.koog.agents.core.agent.AIAgentService import ai.koog.agents.core.agent.config.AIAgentConfig -import ai.koog.agents.snapshot.feature.Persistency +import ai.koog.agents.snapshot.feature.Persistence import ai.koog.agents.snapshot.feature.isTombstone -import ai.koog.agents.snapshot.providers.InMemoryPersistencyStorageProvider +import ai.koog.agents.snapshot.providers.InMemoryPersistenceStorageProvider import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.prompt.dsl.prompt import ai.koog.prompt.llm.OllamaModels @@ -12,12 +12,12 @@ import kotlinx.coroutines.test.runTest import org.awaitility.kotlin.await import org.junit.jupiter.api.Test -class PersistencyRunsTwiceTest { +class PersistenceRunsTwiceTest { @Test fun `agent runs to end and on second run starts from beginning again`() = runTest { // Arrange - val provider = InMemoryPersistencyStorageProvider("persistency-test-agent") + val provider = InMemoryPersistenceStorageProvider() val testCollector = TestAgentLogsCollector() @@ -32,13 +32,14 @@ class PersistencyRunsTwiceTest { maxAgentIterations = 10 ), ) { - install(Persistency) { + install(Persistence) { storage = provider - enableAutomaticPersistency = true + enableAutomaticPersistence = true } } val firstAgent = agentService.createAgent(id = "SAME_ID") + val agentId1 = "SAME_ID" // Act: first run firstAgent.run("Start the test") @@ -53,11 +54,11 @@ class PersistencyRunsTwiceTest { await.until { runBlocking { - provider.getLatestCheckpoint()?.isTombstone() == true + provider.getLatestCheckpoint(agentId1)?.isTombstone() == true } } - val firstCheckpoint = provider.getLatestCheckpoint() + val firstCheckpoint = provider.getLatestCheckpoint(agentId1) val secondAgent = agentService.createAgent(id = "SAME_ID") @@ -67,7 +68,7 @@ class PersistencyRunsTwiceTest { // And still ends with a tombstone as the latest checkpoint await.until { runBlocking { - val latest2 = provider.getLatestCheckpoint() + val latest2 = provider.getLatestCheckpoint(agentId1) latest2?.isTombstone() == true latest2 != firstCheckpoint } @@ -76,7 +77,7 @@ class PersistencyRunsTwiceTest { @Test fun `agent fails on the first run and second run running successfully`() = runTest { - val provider = InMemoryPersistencyStorageProvider("persistency-test-agent") + val provider = InMemoryPersistenceStorageProvider() val testCollector = TestAgentLogsCollector() @@ -91,9 +92,9 @@ class PersistencyRunsTwiceTest { maxAgentIterations = 10 ), ) { - install(Persistency) { + install(Persistence) { storage = provider - enableAutomaticPersistency = true + enableAutomaticPersistence = true } } @@ -112,7 +113,7 @@ class PersistencyRunsTwiceTest { await.until { runBlocking { - provider.getCheckpoints().size == 2 + provider.getCheckpoints(agentId).size == 2 } } @@ -131,7 +132,7 @@ class PersistencyRunsTwiceTest { await.until { runBlocking { - provider.getCheckpoints().filter { !it.isTombstone() }.size == 4 + provider.getCheckpoints(agentId).filter { !it.isTombstone() }.size == 4 } } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SimpleGraphCheckpointTest.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SimpleGraphCheckpointTest.kt index 30287da353..70e09df945 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SimpleGraphCheckpointTest.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SimpleGraphCheckpointTest.kt @@ -2,8 +2,8 @@ import ai.koog.agents.core.agent.AIAgent import ai.koog.agents.core.agent.config.AIAgentConfig import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.ext.tool.SayToUser -import ai.koog.agents.snapshot.feature.Persistency -import ai.koog.agents.snapshot.providers.InMemoryPersistencyStorageProvider +import ai.koog.agents.snapshot.feature.Persistence +import ai.koog.agents.snapshot.providers.InMemoryPersistenceStorageProvider import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.prompt.dsl.prompt import ai.koog.prompt.executor.model.PromptExecutor @@ -51,8 +51,8 @@ class SimpleGraphCheckpointTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider() } } @@ -78,7 +78,7 @@ class SimpleGraphCheckpointTest { @Test fun `test agent creates and saves checkpoints`() = runTest { // Create a snapshot provider to store checkpoints - val checkpointStorageProvider = InMemoryPersistencyStorageProvider("testAgentId") + val checkpointStorageProvider = InMemoryPersistenceStorageProvider() // Create a mock executor for testing val mockExecutor: PromptExecutor = getMockExecutor { @@ -106,7 +106,7 @@ class SimpleGraphCheckpointTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { + install(Persistence) { storage = checkpointStorageProvider } } @@ -115,14 +115,14 @@ class SimpleGraphCheckpointTest { agent.run("Start the test") // Verify that a checkpoint was created and saved - val checkpoint = checkpointStorageProvider.getCheckpoints().firstOrNull() + val checkpoint = checkpointStorageProvider.getCheckpoints(agent.id).firstOrNull() assertNotNull(checkpoint, "No checkpoint was created") assertEquals("checkpointNode", checkpoint?.nodeId, "Checkpoint has incorrect node ID") } @Test fun test_checkpoint_persists_history() = runTest { - val checkpointStorageProvider = InMemoryPersistencyStorageProvider("testAgentId") + val checkpointStorageProvider = InMemoryPersistenceStorageProvider() val mockExecutor: PromptExecutor = getMockExecutor { // No specific mock responses needed for this test @@ -148,7 +148,7 @@ class SimpleGraphCheckpointTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { + install(Persistence) { storage = checkpointStorageProvider } } @@ -157,7 +157,7 @@ class SimpleGraphCheckpointTest { agent.run("Start the test") // Verify that a checkpoint was created and saved - val checkpoint = checkpointStorageProvider.getCheckpoints().firstOrNull() + val checkpoint = checkpointStorageProvider.getCheckpoints(agent.id).firstOrNull() if (checkpoint == null) { error("checkpoint is null") } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphCheckpointsTest.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphCheckpointsTest.kt index 0c3a5da999..173ac34671 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphCheckpointsTest.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphCheckpointsTest.kt @@ -2,8 +2,8 @@ import ai.koog.agents.core.agent.AIAgent import ai.koog.agents.core.agent.config.AIAgentConfig import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.ext.tool.SayToUser -import ai.koog.agents.snapshot.feature.Persistency -import ai.koog.agents.snapshot.providers.InMemoryPersistencyStorageProvider +import ai.koog.agents.snapshot.feature.Persistence +import ai.koog.agents.snapshot.providers.InMemoryPersistenceStorageProvider import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.prompt.dsl.prompt import ai.koog.prompt.llm.OllamaModels @@ -36,8 +36,8 @@ class SubgraphCheckpointsTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider() } } @@ -63,8 +63,8 @@ class SubgraphCheckpointsTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider() } } @@ -90,8 +90,8 @@ class SubgraphCheckpointsTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider() } } @@ -118,8 +118,8 @@ class SubgraphCheckpointsTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider() } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphSetExecutionPointTest.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphSetExecutionPointTest.kt index 0a723550ad..5ffdb66376 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphSetExecutionPointTest.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphSetExecutionPointTest.kt @@ -2,8 +2,8 @@ import ai.koog.agents.core.agent.AIAgent import ai.koog.agents.core.agent.config.AIAgentConfig import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.ext.tool.SayToUser -import ai.koog.agents.snapshot.feature.Persistency -import ai.koog.agents.snapshot.providers.InMemoryPersistencyStorageProvider +import ai.koog.agents.snapshot.feature.Persistence +import ai.koog.agents.snapshot.providers.InMemoryPersistenceStorageProvider import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.prompt.dsl.prompt import ai.koog.prompt.llm.OllamaModels @@ -32,8 +32,8 @@ class SubgraphSetExecutionPointTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider() } } @@ -56,8 +56,8 @@ class SubgraphSetExecutionPointTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider() } } @@ -84,8 +84,8 @@ class SubgraphSetExecutionPointTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider() } } @@ -109,8 +109,8 @@ class SubgraphSetExecutionPointTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider() } } @@ -136,8 +136,8 @@ class SubgraphSetExecutionPointTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider() } } @@ -163,8 +163,8 @@ class SubgraphSetExecutionPointTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider() } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/TestStrategies.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/TestStrategies.kt index 9393ce0401..c28aff1ebc 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/TestStrategies.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/TestStrategies.kt @@ -2,7 +2,7 @@ import ai.koog.agents.core.dsl.builder.AIAgentNodeDelegate import ai.koog.agents.core.dsl.builder.AIAgentSubgraphBuilderBase import ai.koog.agents.core.dsl.builder.forwardTo import ai.koog.agents.core.dsl.builder.strategy -import ai.koog.agents.snapshot.feature.withPersistency +import ai.koog.agents.snapshot.feature.withPersistence import kotlinx.serialization.json.JsonPrimitive import kotlin.reflect.typeOf @@ -103,10 +103,10 @@ private fun AIAgentSubgraphBuilderBase<*, *>.teleportOnceNode( ): AIAgentNodeDelegate = node(name) { if (!teleportState.teleported) { teleportState.teleported = true - withPersistency { ctx -> + withPersistence { ctx -> val history = llm.readSession { this.prompt.messages } setExecutionPoint(ctx, teleportToId, history, JsonPrimitive("$it\nTeleported")) - return@withPersistency "Teleported" + return@withPersistence "Teleported" } } else { // If we've already teleported, just return the input @@ -131,7 +131,7 @@ private fun AIAgentSubgraphBuilderBase<*, *>.nodeForSecondTry( private fun AIAgentSubgraphBuilderBase<*, *>.createCheckpointNode(name: String? = null, checkpointId: String) = node(name) { val input = it - withPersistency { ctx -> + withPersistence { ctx -> createCheckpoint(ctx, name!!, input, typeOf(), checkpointId) llm.writeSession { updatePrompt { @@ -157,7 +157,7 @@ private fun AIAgentSubgraphBuilderBase<*, *>.nodeRollbackToCheckpoint( return@node "Skipping rollback" } - withPersistency { + withPersistence { val checkpoint = rollbackToCheckpoint(checkpointId, it)!! teleportState.teleported = true llm.writeSession { @@ -174,7 +174,7 @@ private fun AIAgentSubgraphBuilderBase<*, *>.nodeCreateCheckpoint( name: String? = null, ): AIAgentNodeDelegate = node(name) { val input = it - withPersistency { ctx -> + withPersistence { ctx -> val checkpoint = createCheckpoint( ctx, currentNodeId ?: error("currentNodeId not set"), @@ -183,9 +183,9 @@ private fun AIAgentSubgraphBuilderBase<*, *>.nodeCreateCheckpoint( "snapshot-id" ) - saveCheckpoint(checkpoint ?: error("Checkpoint creation failed")) + saveCheckpoint(ctx.agentId, checkpoint ?: error("Checkpoint creation failed")) - return@withPersistency "$input\nSnapshot created" + return@withPersistence "$input\nSnapshot created" } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileAgentCheckpointStorageProviderTest.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileAgentCheckpointStorageProviderTest.kt index b9f5318e1b..9574974308 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileAgentCheckpointStorageProviderTest.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileAgentCheckpointStorageProviderTest.kt @@ -17,12 +17,12 @@ import kotlin.test.assertTrue class FileAgentCheckpointStorageProviderTest { private lateinit var tempDir: java.nio.file.Path - private lateinit var provider: JVMFilePersistencyStorageProvider + private lateinit var provider: JVMFilePersistenceStorageProvider @BeforeTest fun setup() { tempDir = Files.createTempDirectory("checkpoint-test") - provider = JVMFilePersistencyStorageProvider(tempDir, "testAgentId") + provider = JVMFilePersistenceStorageProvider(tempDir) } @AfterTest @@ -54,11 +54,12 @@ class FileAgentCheckpointStorageProviderTest { messageHistory = messageHistory ) + val agentId = "testAgentId" // Save the checkpoint - provider.saveCheckpoint(checkpoint) + provider.saveCheckpoint(agentId, checkpoint) // Retrieve all checkpoints for the agent - val checkpoints = provider.getCheckpoints() + val checkpoints = provider.getCheckpoints(agentId) assertEquals(1, checkpoints.size, "Should have one checkpoint") // Verify the retrieved checkpoint @@ -80,7 +81,7 @@ class FileAgentCheckpointStorageProviderTest { assertEquals(originalAssistantMsg.content, retrievedAssistantMsg.content) // Test getLatestCheckpoint - val latestCheckpoint = provider.getLatestCheckpoint() + val latestCheckpoint = provider.getLatestCheckpoint(agentId) assertNotNull(latestCheckpoint, "Latest checkpoint should not be null") assertEquals(checkpointId, latestCheckpoint.checkpointId) @@ -96,15 +97,15 @@ class FileAgentCheckpointStorageProviderTest { ) // Save the later checkpoint - provider.saveCheckpoint(laterCheckpoint) + provider.saveCheckpoint(agentId, laterCheckpoint) // Verify that getLatestCheckpoint returns the later checkpoint - val newLatestCheckpoint = provider.getLatestCheckpoint() + val newLatestCheckpoint = provider.getLatestCheckpoint(agentId) assertNotNull(newLatestCheckpoint, "New latest checkpoint should not be null") assertEquals(laterCheckpointId, newLatestCheckpoint.checkpointId) // Verify that getCheckpoints returns both checkpoints - val allCheckpoints = provider.getCheckpoints() + val allCheckpoints = provider.getCheckpoints(agentId) assertEquals(2, allCheckpoints.size, "Should have two checkpoints") assertTrue(allCheckpoints.any { it.checkpointId == checkpointId }) assertTrue(allCheckpoints.any { it.checkpointId == laterCheckpointId }) diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileCheckpointsTests.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileCheckpointsTests.kt index 99cfa8a1b5..fc531a529d 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileCheckpointsTests.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileCheckpointsTests.kt @@ -3,9 +3,9 @@ import ai.koog.agents.core.agent.config.AIAgentConfig import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.ext.tool.SayToUser import ai.koog.agents.snapshot.feature.AgentCheckpointData -import ai.koog.agents.snapshot.feature.Persistency +import ai.koog.agents.snapshot.feature.Persistence import ai.koog.agents.snapshot.feature.isTombstone -import ai.koog.agents.snapshot.providers.file.JVMFilePersistencyStorageProvider +import ai.koog.agents.snapshot.providers.file.JVMFilePersistenceStorageProvider import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.prompt.dsl.prompt import ai.koog.prompt.llm.OllamaModels @@ -32,7 +32,7 @@ import kotlin.time.Duration.Companion.seconds */ class FileCheckpointsTests { private lateinit var tempDir: Path - private lateinit var provider: JVMFilePersistencyStorageProvider + private lateinit var provider: JVMFilePersistenceStorageProvider val systemPrompt = "You are a test agent." val agentConfig = AIAgentConfig( @@ -49,7 +49,7 @@ class FileCheckpointsTests { @BeforeTest fun setup() { tempDir = Files.createTempDirectory("agent-checkpoint-test") - provider = JVMFilePersistencyStorageProvider(tempDir, "testAgentId") + provider = JVMFilePersistenceStorageProvider(tempDir) } @AfterTest @@ -70,7 +70,7 @@ class FileCheckpointsTests { toolRegistry = toolRegistry, id = agentId ) { - install(Persistency) { + install(Persistence) { storage = provider } } @@ -86,7 +86,7 @@ class FileCheckpointsTests { ) // Verify that the checkpoint was saved to the file system - val checkpoints = provider.getCheckpoints().filter { !it.isTombstone() } + val checkpoints = provider.getCheckpoints(agentId).filter { !it.isTombstone() } assertEquals(1, checkpoints.size, "Should have one checkpoint") assertEquals("checkpointId", checkpoints.first().checkpointId) } @@ -99,7 +99,7 @@ class FileCheckpointsTests { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { + install(Persistence) { storage = provider } } @@ -129,7 +129,7 @@ class FileCheckpointsTests { ) ) - provider.saveCheckpoint(testCheckpoint) + provider.saveCheckpoint(agentId, testCheckpoint) val agent = AIAgent( promptExecutor = getMockExecutor { }, @@ -138,7 +138,7 @@ class FileCheckpointsTests { toolRegistry = toolRegistry, id = agentId ) { - install(Persistency) { + install(Persistence) { storage = provider } } @@ -180,8 +180,8 @@ class FileCheckpointsTests { ) ) - provider.saveCheckpoint(testCheckpoint) - provider.saveCheckpoint(testCheckpoint2) + provider.saveCheckpoint(agentId, testCheckpoint) + provider.saveCheckpoint(agentId, testCheckpoint2) val agent = AIAgent( promptExecutor = getMockExecutor { }, @@ -190,7 +190,7 @@ class FileCheckpointsTests { toolRegistry = toolRegistry, id = agentId ) { - install(Persistency) { + install(Persistence) { storage = provider } } @@ -216,17 +216,17 @@ class FileCheckpointsTests { toolRegistry = toolRegistry, id = agentId ) { - install(Persistency) { + install(Persistence) { storage = provider - enableAutomaticPersistency = true + enableAutomaticPersistence = true } } agent.run("Start the test") // Verify that checkpoints were automatically created - val checkpoints = provider.getCheckpoints() + val checkpoints = provider.getCheckpoints(agentId) assertTrue(checkpoints.isNotEmpty(), "Should have automatically created checkpoints") } } diff --git a/agents/agents-features/agents-features-sql/Module.md b/agents/agents-features/agents-features-sql/Module.md index 350d7ab79e..e96acfa55b 100644 --- a/agents/agents-features/agents-features-sql/Module.md +++ b/agents/agents-features/agents-features-sql/Module.md @@ -19,19 +19,19 @@ Provides SQL-based persistence providers for agent checkpoints using JetBrains E ## Providers -### ExposedPersistencyStorageProvider +### ExposedPersistenceStorageProvider Base provider using Exposed ORM with configurable cleanup behavior. -### PostgresPersistencyStorageProvider +### PostgresPersistenceStorageProvider Production-ready provider with JSONB support and HikariCP pooling. -### MySQLPersistencyStorageProvider +### MySQLPersistenceStorageProvider Enterprise provider with JSON column support (MySQL 5.7+). -### H2PersistencyStorageProvider +### H2PersistenceStorageProvider Perfect for testing and embedded applications. -### SQLitePersistencyStorageProvider +### SQLitePersistenceStorageProvider Zero-configuration provider for desktop and mobile applications. -All providers implement `AutoCloseable` for proper resource management and support configurable TTL cleanup. \ No newline at end of file +All providers implement `AutoCloseable` for proper resource management and support configurable TTL cleanup. diff --git a/agents/agents-features/agents-features-sql/build.gradle.kts b/agents/agents-features/agents-features-sql/build.gradle.kts index 3fe859c67e..98e739ae36 100644 --- a/agents/agents-features/agents-features-sql/build.gradle.kts +++ b/agents/agents-features/agents-features-sql/build.gradle.kts @@ -47,6 +47,7 @@ kotlin { dependencies { implementation(kotlin("test-junit5")) implementation(project(":agents:agents-test")) + implementation(project(":test-utils")) implementation(libs.mockk) implementation(libs.testcontainers) implementation(libs.testcontainers.postgresql) diff --git a/agents/agents-features/agents-features-sql/src/commonMain/kotlin/ai/koog/agents/features/sql/providers/SQLPersistencyStorageProvider.kt b/agents/agents-features/agents-features-sql/src/commonMain/kotlin/ai/koog/agents/features/sql/providers/SQLPersistencyStorageProvider.kt index a5dadc1630..1d2cb862c5 100644 --- a/agents/agents-features/agents-features-sql/src/commonMain/kotlin/ai/koog/agents/features/sql/providers/SQLPersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-sql/src/commonMain/kotlin/ai/koog/agents/features/sql/providers/SQLPersistencyStorageProvider.kt @@ -1,10 +1,10 @@ package ai.koog.agents.features.sql.providers -import ai.koog.agents.snapshot.providers.PersistencyStorageProvider +import ai.koog.agents.snapshot.providers.PersistenceStorageProvider import kotlinx.datetime.Instant /** - * Abstract base class for SQL-based implementations of [PersistencyStorageProvider]. + * Abstract base class for SQL-based implementations of [PersistenceStorageProvider]. * * This provider offers a generic SQL abstraction for persisting agent checkpoints * to relational databases. Concrete implementations should handle specific SQL @@ -28,16 +28,14 @@ import kotlinx.datetime.Instant * Implementations must ensure thread-safe database access, typically through connection pooling. * * @constructor Initializes the SQL persistence provider. - * @param persistenceId Unique identifier for this agent's persistence data * @param tableName Name of the table to store checkpoints (default: "agent_checkpoints") * @param ttlSeconds Optional TTL for checkpoint entries in seconds (null = no expiration) */ -public abstract class SQLPersistencyStorageProvider( - protected val persistenceId: String, +public abstract class SQLPersistenceStorageProvider( protected val tableName: String = "agent_checkpoints", protected val ttlSeconds: Long? = null, protected val migrator: SQLPersistenceSchemaMigrator -) : PersistencyStorageProvider { +) : PersistenceStorageProvider { /** * Initializes the database schema if it doesn't exist. @@ -71,15 +69,15 @@ public abstract class SQLPersistencyStorageProvider( /** * Deletes a specific checkpoint by ID */ - public abstract suspend fun deleteCheckpoint(checkpointId: String) + public abstract suspend fun deleteCheckpoint(agentId: String, checkpointId: String) /** - * Deletes all checkpoints for this persistence ID + * Deletes all checkpoints for this agent ID */ - public abstract suspend fun deleteAllCheckpoints() + public abstract suspend fun deleteAllCheckpoints(agentId: String) /** * Gets the total number of checkpoints stored */ - public abstract suspend fun getCheckpointCount(): Long + public abstract suspend fun getCheckpointCount(agentId: String): Long } diff --git a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/ExposedPersistencyStorageProvider.kt b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/ExposedPersistencyStorageProvider.kt index e4041e3af7..c80134c0e1 100644 --- a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/ExposedPersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/ExposedPersistencyStorageProvider.kt @@ -1,7 +1,7 @@ package ai.koog.agents.features.sql.providers import ai.koog.agents.snapshot.feature.AgentCheckpointData -import ai.koog.agents.snapshot.providers.PersistencyUtils +import ai.koog.agents.snapshot.providers.PersistenceUtils import kotlinx.datetime.Clock import kotlinx.serialization.json.Json import org.jetbrains.exposed.sql.Column @@ -73,7 +73,7 @@ public open class CheckpointsTable(tableName: String) : Table(tableName) { } /** - * An abstract Exposed-based implementation of [SQLPersistencyStorageProvider] for managing + * An abstract Exposed-based implementation of [SQLPersistenceStorageProvider] for managing * agent checkpoints in SQL databases using JetBrains Exposed ORM. * * This class provides a generic SQL implementation that works with any database supported @@ -107,21 +107,18 @@ public open class CheckpointsTable(tableName: String) : Table(tableName) { * - When TTL is not configured (ttlSeconds = null), no TTL processing occurs * * @constructor Initializes the Exposed persistence provider. - * @param persistenceId Unique identifier for this agent's persistence data * @param database The Exposed Database instance to use * @param tableName Name of the table to store checkpoints (default: "agent_checkpoints") * @param ttlSeconds Optional TTL for checkpoint entries in seconds (null = no expiration) */ @Suppress("MissingKDocForPublicAPI") -public abstract class ExposedPersistencyStorageProvider( - persistenceId: String, +public abstract class ExposedPersistenceStorageProvider( protected val database: Database, tableName: String = "agent_checkpoints", ttlSeconds: Long? = null, migrator: SQLPersistenceSchemaMigrator, - private val json: Json = PersistencyUtils.defaultCheckpointJson -) : SQLPersistencyStorageProvider( - persistenceId = persistenceId, + private val json: Json = PersistenceUtils.defaultCheckpointJson +) : SQLPersistenceStorageProvider( tableName = tableName, ttlSeconds = ttlSeconds, migrator @@ -179,12 +176,12 @@ public abstract class ExposedPersistencyStorageProvider( } } - override suspend fun getCheckpoints(): List { + override suspend fun getCheckpoints(agentId: String): List { return transaction { checkpointsTable .select(checkpointsTable.checkpointJson) .where { - checkpointsTable.persistenceId eq this@ExposedPersistencyStorageProvider.persistenceId + checkpointsTable.persistenceId eq agentId } .orderBy(checkpointsTable.createdAt to SortOrder.ASC) .mapNotNull { row -> @@ -195,14 +192,14 @@ public abstract class ExposedPersistencyStorageProvider( } } - override suspend fun saveCheckpoint(agentCheckpointData: AgentCheckpointData) { + override suspend fun saveCheckpoint(agentId: String, agentCheckpointData: AgentCheckpointData) { val checkpointJson = json.encodeToString(agentCheckpointData) val ttlTimestamp = calculateTtlTimestamp(agentCheckpointData.createdAt) transaction { // Use upsert for idempotent saves checkpointsTable.upsert { - it[checkpointsTable.persistenceId] = this@ExposedPersistencyStorageProvider.persistenceId + it[checkpointsTable.persistenceId] = agentId it[checkpointsTable.checkpointId] = agentCheckpointData.checkpointId it[checkpointsTable.createdAt] = agentCheckpointData.createdAt.toEpochMilliseconds() it[checkpointsTable.checkpointJson] = checkpointJson @@ -211,12 +208,12 @@ public abstract class ExposedPersistencyStorageProvider( } } - override suspend fun getLatestCheckpoint(): AgentCheckpointData? { + override suspend fun getLatestCheckpoint(agentId: String): AgentCheckpointData? { return transaction { checkpointsTable .select(checkpointsTable.checkpointJson) .where { - checkpointsTable.persistenceId eq this@ExposedPersistencyStorageProvider.persistenceId + checkpointsTable.persistenceId eq agentId } .orderBy(checkpointsTable.createdAt to SortOrder.DESC) .limit(1) @@ -228,27 +225,27 @@ public abstract class ExposedPersistencyStorageProvider( } } - override suspend fun deleteCheckpoint(checkpointId: String) { + override suspend fun deleteCheckpoint(agentId: String, checkpointId: String) { transaction { checkpointsTable.deleteWhere { - (checkpointsTable.persistenceId eq this@ExposedPersistencyStorageProvider.persistenceId) and + (checkpointsTable.persistenceId eq agentId) and (checkpointsTable.checkpointId eq checkpointId) } } } - override suspend fun deleteAllCheckpoints() { + override suspend fun deleteAllCheckpoints(agentId: String) { transaction { checkpointsTable.deleteWhere { - checkpointsTable.persistenceId eq this@ExposedPersistencyStorageProvider.persistenceId + checkpointsTable.persistenceId eq agentId } } } - override suspend fun getCheckpointCount(): Long { + override suspend fun getCheckpointCount(agentId: String): Long { return transaction { checkpointsTable.selectAll().where { - checkpointsTable.persistenceId eq this@ExposedPersistencyStorageProvider.persistenceId + checkpointsTable.persistenceId eq agentId }.count() } } diff --git a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProvider.kt b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProvider.kt index 978715dc55..647ac9b9ec 100644 --- a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProvider.kt @@ -1,6 +1,6 @@ package ai.koog.agents.features.sql.providers -import ai.koog.agents.snapshot.providers.PersistencyUtils +import ai.koog.agents.snapshot.providers.PersistenceUtils import kotlinx.coroutines.Dispatchers import kotlinx.serialization.json.Json import org.jetbrains.exposed.sql.Database @@ -8,17 +8,16 @@ import org.jetbrains.exposed.sql.transactions.experimental.newSuspendedTransacti import org.jetbrains.exposed.sql.transactions.transaction /** - * H2 Database-specific implementation of [ExposedPersistencyStorageProvider] for managing + * H2 Database-specific implementation of [ExposedPersistenceStorageProvider] for managing * agent checkpoints in H2 databases. */ -public class H2PersistencyStorageProvider( - persistenceId: String, +public class H2PersistenceStorageProvider( database: Database, tableName: String = "agent_checkpoints", ttlSeconds: Long? = null, migrator: SQLPersistenceSchemaMigrator = H2PersistenceSchemaMigrator(database, tableName), - json: Json = PersistencyUtils.defaultCheckpointJson -) : ExposedPersistencyStorageProvider(persistenceId, database, tableName, ttlSeconds, migrator, json) { + json: Json = PersistenceUtils.defaultCheckpointJson +) : ExposedPersistenceStorageProvider(database, tableName, ttlSeconds, migrator, json) { public override suspend fun transaction(block: suspend () -> T): T = newSuspendedTransaction(Dispatchers.IO, database) { @@ -32,20 +31,17 @@ public class H2PersistencyStorageProvider( * Data is lost when the JVM shuts down. * Perfect for testing and temporary caching. * - * @param persistenceId Unique identifier for this agent's persistence data * @param databaseName Name of the in-memory database * @param options Additional H2 options (e.g., "DB_CLOSE_DELAY=-1") * @param tableName Name of the table to store checkpoints * @param ttlSeconds Optional TTL for checkpoint entries in seconds */ public fun inMemory( - persistenceId: String, databaseName: String = "test", options: String = "DB_CLOSE_DELAY=-1", tableName: String = "agent_checkpoints", ttlSeconds: Long? = null - ): H2PersistencyStorageProvider = H2PersistencyStorageProvider( - persistenceId = persistenceId, + ): H2PersistenceStorageProvider = H2PersistenceStorageProvider( database = Database.connect("jdbc:h2:mem:$databaseName;$options"), tableName = tableName, ttlSeconds = ttlSeconds @@ -56,20 +52,17 @@ public class H2PersistencyStorageProvider( * Data is persisted to a file on disk. * Good balance between performance and persistence. * - * @param persistenceId Unique identifier for this agent's persistence data * @param filePath Path to the database file (without .mv.db extension) * @param options Additional H2 options * @param tableName Name of the table to store checkpoints * @param ttlSeconds Optional TTL for checkpoint entries in seconds */ public fun fileBased( - persistenceId: String, filePath: String, options: String = "", tableName: String = "agent_checkpoints", ttlSeconds: Long? = null - ): H2PersistencyStorageProvider = H2PersistencyStorageProvider( - persistenceId = persistenceId, + ): H2PersistenceStorageProvider = H2PersistenceStorageProvider( database = Database.connect( if (options.isNotEmpty()) { "jdbc:h2:file:$filePath;$options" @@ -85,19 +78,16 @@ public class H2PersistencyStorageProvider( * Creates an H2 provider with PostgreSQL compatibility mode. * Useful when migrating from PostgreSQL or for compatibility testing. * - * @param persistenceId Unique identifier for this agent's persistence data * @param databasePath Path to database (memory or file) * @param tableName Name of the table to store checkpoints * @param ttlSeconds Optional TTL for checkpoint entries in seconds */ public fun postgresCompatible( - persistenceId: String, databasePath: String = "mem:test", tableName: String = "agent_checkpoints", ttlSeconds: Long? = null - ): H2PersistencyStorageProvider { - return H2PersistencyStorageProvider( - persistenceId = persistenceId, + ): H2PersistenceStorageProvider { + return H2PersistenceStorageProvider( database = Database.connect("jdbc:h2:$databasePath;MODE=PostgreSQL;DATABASE_TO_LOWER=TRUE"), tableName = tableName, ttlSeconds = ttlSeconds diff --git a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProvider.kt b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProvider.kt index 3077fe86fa..fe076a92e0 100644 --- a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProvider.kt @@ -1,6 +1,6 @@ package ai.koog.agents.features.sql.providers -import ai.koog.agents.snapshot.providers.PersistencyUtils +import ai.koog.agents.snapshot.providers.PersistenceUtils import kotlinx.coroutines.Dispatchers import kotlinx.serialization.json.Json import org.jetbrains.exposed.sql.Database @@ -8,7 +8,7 @@ import org.jetbrains.exposed.sql.transactions.experimental.newSuspendedTransacti import org.jetbrains.exposed.sql.transactions.transaction /** - * MySQL-specific implementation of [ExposedPersistencyStorageProvider] for managing + * MySQL-specific implementation of [ExposedPersistenceStorageProvider] for managing * agent checkpoints in MySQL databases. * * This provider is optimized for MySQL 5.7+ and MariaDB 10.2+, leveraging their @@ -22,7 +22,7 @@ import org.jetbrains.exposed.sql.transactions.transaction * * ## Example Usage: * ```kotlin - * val provider = MySQLPersistencyStorageProvider( + * val provider = MySQLPersistenceStorageProvider( * persistenceId = "my-agent", * database = Database.connect( * url = "jdbc:mysql://localhost:3306/mydb?useSSL=false&serverTimezone=UTC", @@ -36,14 +36,13 @@ import org.jetbrains.exposed.sql.transactions.transaction * * @constructor Initializes the MySQL persistence provider with an Exposed Database instance. */ -public class MySQLPersistencyStorageProvider( - persistenceId: String, +public class MySQLPersistenceStorageProvider( database: Database, tableName: String = "agent_checkpoints", ttlSeconds: Long? = null, migrator: SQLPersistenceSchemaMigrator = MySqlPersistenceSchemaMigrator(database, tableName), - json: Json = PersistencyUtils.defaultCheckpointJson -) : ExposedPersistencyStorageProvider(persistenceId, database, tableName, ttlSeconds, migrator, json) { + json: Json = PersistenceUtils.defaultCheckpointJson +) : ExposedPersistenceStorageProvider(database, tableName, ttlSeconds, migrator, json) { override suspend fun transaction(block: suspend () -> T): T = newSuspendedTransaction(Dispatchers.IO, database) { diff --git a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProvider.kt b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProvider.kt index e4e662ee34..40190957c6 100644 --- a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProvider.kt @@ -1,6 +1,6 @@ package ai.koog.agents.features.sql.providers -import ai.koog.agents.snapshot.providers.PersistencyUtils +import ai.koog.agents.snapshot.providers.PersistenceUtils import kotlinx.coroutines.Dispatchers import kotlinx.serialization.json.Json import org.jetbrains.exposed.sql.Database @@ -8,19 +8,18 @@ import org.jetbrains.exposed.sql.transactions.experimental.newSuspendedTransacti import org.jetbrains.exposed.sql.transactions.transaction /** - * PostgreSQL-specific implementation of [ExposedPersistencyStorageProvider] for managing + * PostgreSQL-specific implementation of [ExposedPersistenceStorageProvider] for managing * agent checkpoints in PostgreSQL databases. * * @constructor Initializes the PostgreSQL persistence provider with connection details. */ -public class PostgresPersistencyStorageProvider( - persistenceId: String, +public class PostgresPersistenceStorageProvider( database: Database, tableName: String = "agent_checkpoints", ttlSeconds: Long? = null, migrator: SQLPersistenceSchemaMigrator = PostgresPersistenceSchemaMigrator(database, tableName), - json: Json = PersistencyUtils.defaultCheckpointJson -) : ExposedPersistencyStorageProvider(persistenceId, database, tableName, ttlSeconds, migrator, json) { + json: Json = PersistenceUtils.defaultCheckpointJson +) : ExposedPersistenceStorageProvider(database, tableName, ttlSeconds, migrator, json) { override suspend fun transaction(block: suspend () -> T): T = newSuspendedTransaction(Dispatchers.IO, database) { block() } diff --git a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProviderTest.kt b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProviderTest.kt index e6777eed2f..1b22167eb1 100644 --- a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProviderTest.kt +++ b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProviderTest.kt @@ -4,6 +4,7 @@ import ai.koog.agents.snapshot.feature.AgentCheckpointData import ai.koog.prompt.message.Message import ai.koog.prompt.message.RequestMetaInfo import ai.koog.prompt.message.ResponseMetaInfo +import ai.koog.test.utils.DockerAvailableCondition import kotlinx.coroutines.delay import kotlinx.coroutines.runBlocking import kotlinx.datetime.Clock @@ -11,19 +12,19 @@ import kotlinx.serialization.json.JsonPrimitive import org.junit.jupiter.api.Test import org.junit.jupiter.api.TestInstance import org.junit.jupiter.api.TestInstance.Lifecycle -import org.junit.jupiter.api.condition.EnabledOnOs -import org.junit.jupiter.api.condition.OS +import org.junit.jupiter.api.extension.ExtendWith import kotlin.test.assertEquals import kotlin.test.assertNotNull import kotlin.test.assertNull @TestInstance(Lifecycle.PER_CLASS) -@EnabledOnOs(OS.LINUX) -class H2PersistencyStorageProviderTest { +@ExtendWith(DockerAvailableCondition::class) +class H2PersistenceStorageProviderTest { - private fun provider(ttlSeconds: Long? = null): H2PersistencyStorageProvider { - return H2PersistencyStorageProvider.inMemory( - persistenceId = "h2-agent", + private val agentId = "h2-agent" + + private fun provider(ttlSeconds: Long? = null): H2PersistenceStorageProvider { + return H2PersistenceStorageProvider.inMemory( databaseName = "h2_test_db", tableName = "agent_checkpoints_test", ttlSeconds = ttlSeconds @@ -36,38 +37,38 @@ class H2PersistencyStorageProviderTest { p.migrate() // empty - assertNull(p.getLatestCheckpoint()) - assertEquals(0, p.getCheckpointCount()) + assertNull(p.getLatestCheckpoint(agentId)) + assertEquals(0, p.getCheckpointCount(agentId)) // save val cp1 = createTestCheckpoint("cp-1") - p.saveCheckpoint(cp1) + p.saveCheckpoint(agentId, cp1) // read - val latest1 = p.getLatestCheckpoint() + val latest1 = p.getLatestCheckpoint(agentId) assertNotNull(latest1) assertEquals("cp-1", latest1.checkpointId) - assertEquals(1, p.getCheckpoints().size) - assertEquals(1, p.getCheckpointCount()) + assertEquals(1, p.getCheckpoints(agentId).size) + assertEquals(1, p.getCheckpointCount(agentId)) // upsert same id should be idempotent (no duplicates due PK) - p.saveCheckpoint(cp1) - assertEquals(1, p.getCheckpoints().size) + p.saveCheckpoint(agentId, cp1) + assertEquals(1, p.getCheckpoints(agentId).size) // insert second val cp2 = createTestCheckpoint("cp-2") - p.saveCheckpoint(cp2) - val all = p.getCheckpoints() + p.saveCheckpoint(agentId, cp2) + val all = p.getCheckpoints(agentId) assertEquals(listOf("cp-1", "cp-2"), all.map { it.checkpointId }) - assertEquals("cp-2", p.getLatestCheckpoint()!!.checkpointId) + assertEquals("cp-2", p.getLatestCheckpoint(agentId)!!.checkpointId) // delete single - p.deleteCheckpoint("cp-1") - assertEquals(listOf("cp-2"), p.getCheckpoints().map { it.checkpointId }) + p.deleteCheckpoint(agentId, "cp-1") + assertEquals(listOf("cp-2"), p.getCheckpoints(agentId).map { it.checkpointId }) // delete all - p.deleteAllCheckpoints() - assertEquals(0, p.getCheckpointCount()) + p.deleteAllCheckpoints(agentId) + assertEquals(0, p.getCheckpointCount(agentId)) } @Test @@ -75,16 +76,16 @@ class H2PersistencyStorageProviderTest { val p = provider(ttlSeconds = 1) p.migrate() - p.saveCheckpoint(createTestCheckpoint("will-expire")) - assertEquals(1, p.getCheckpointCount()) + p.saveCheckpoint(agentId, createTestCheckpoint("will-expire")) + assertEquals(1, p.getCheckpointCount(agentId)) // Wait slightly over 1s to ensure ttl passes delay(1100) // Force cleanup directly to avoid interval throttling p.cleanupExpired() - assertEquals(0, p.getCheckpointCount()) - assertNull(p.getLatestCheckpoint()) + assertEquals(0, p.getCheckpointCount(agentId)) + assertNull(p.getLatestCheckpoint(agentId)) } private fun createTestCheckpoint(id: String): AgentCheckpointData { diff --git a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProviderTest.kt b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProviderTest.kt index 893da5a1d6..b38623760c 100644 --- a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProviderTest.kt +++ b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProviderTest.kt @@ -4,6 +4,7 @@ import ai.koog.agents.snapshot.feature.AgentCheckpointData import ai.koog.prompt.message.Message import ai.koog.prompt.message.RequestMetaInfo import ai.koog.prompt.message.ResponseMetaInfo +import ai.koog.test.utils.DockerAvailableCondition import kotlinx.coroutines.delay import kotlinx.coroutines.runBlocking import kotlinx.datetime.Clock @@ -14,8 +15,7 @@ import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.Test import org.junit.jupiter.api.TestInstance import org.junit.jupiter.api.TestInstance.Lifecycle -import org.junit.jupiter.api.condition.EnabledOnOs -import org.junit.jupiter.api.condition.OS +import org.junit.jupiter.api.extension.ExtendWith import org.testcontainers.containers.MySQLContainer import org.testcontainers.utility.DockerImageName import kotlin.test.assertEquals @@ -23,8 +23,10 @@ import kotlin.test.assertNotNull import kotlin.test.assertNull @TestInstance(Lifecycle.PER_CLASS) -@EnabledOnOs(OS.LINUX) -class MySQLPersistencyStorageProviderTest { +@ExtendWith(DockerAvailableCondition::class) +class MySQLPersistenceStorageProviderTest { + + private val agentId = "mysql-agent" private lateinit var mysql: MySQLContainer<*> @@ -42,15 +44,14 @@ class MySQLPersistencyStorageProviderTest { mysql.stop() } - private fun provider(ttlSeconds: Long? = null): MySQLPersistencyStorageProvider { + private fun provider(ttlSeconds: Long? = null): MySQLPersistenceStorageProvider { val db: Database = Database.connect( url = mysql.jdbcUrl + "?useSSL=false&allowPublicKeyRetrieval=true&serverTimezone=UTC", driver = "com.mysql.cj.jdbc.Driver", user = mysql.username, password = mysql.password ) - return MySQLPersistencyStorageProvider( - persistenceId = "mysql-agent", + return MySQLPersistenceStorageProvider( database = db, tableName = "agent_checkpoints_test", ttlSeconds = ttlSeconds @@ -63,38 +64,38 @@ class MySQLPersistencyStorageProviderTest { p.migrate() // empty - assertNull(p.getLatestCheckpoint()) - assertEquals(0, p.getCheckpointCount()) + assertNull(p.getLatestCheckpoint(agentId)) + assertEquals(0, p.getCheckpointCount(agentId)) // save val cp1 = createTestCheckpoint("cp-1") - p.saveCheckpoint(cp1) + p.saveCheckpoint(agentId, cp1) // read - val latest1 = p.getLatestCheckpoint() + val latest1 = p.getLatestCheckpoint(agentId) assertNotNull(latest1) assertEquals("cp-1", latest1.checkpointId) - assertEquals(1, p.getCheckpoints().size) - assertEquals(1, p.getCheckpointCount()) + assertEquals(1, p.getCheckpoints(agentId).size) + assertEquals(1, p.getCheckpointCount(agentId)) // upsert same id should be idempotent (no duplicates due PK) - p.saveCheckpoint(cp1) - assertEquals(1, p.getCheckpoints().size) + p.saveCheckpoint(agentId, cp1) + assertEquals(1, p.getCheckpoints(agentId).size) // insert second val cp2 = createTestCheckpoint("cp-2") - p.saveCheckpoint(cp2) - val all = p.getCheckpoints() + p.saveCheckpoint(agentId, cp2) + val all = p.getCheckpoints(agentId) assertEquals(listOf("cp-1", "cp-2"), all.map { it.checkpointId }) - assertEquals("cp-2", p.getLatestCheckpoint()!!.checkpointId) + assertEquals("cp-2", p.getLatestCheckpoint(agentId)!!.checkpointId) // delete single - p.deleteCheckpoint("cp-1") - assertEquals(listOf("cp-2"), p.getCheckpoints().map { it.checkpointId }) + p.deleteCheckpoint(agentId, "cp-1") + assertEquals(listOf("cp-2"), p.getCheckpoints(agentId).map { it.checkpointId }) // delete all - p.deleteAllCheckpoints() - assertEquals(0, p.getCheckpointCount()) + p.deleteAllCheckpoints(agentId) + assertEquals(0, p.getCheckpointCount(agentId)) } @Test @@ -102,16 +103,16 @@ class MySQLPersistencyStorageProviderTest { val p = provider(ttlSeconds = 1) p.migrate() - p.saveCheckpoint(createTestCheckpoint("will-expire")) - assertEquals(1, p.getCheckpointCount()) + p.saveCheckpoint(agentId, createTestCheckpoint("will-expire")) + assertEquals(1, p.getCheckpointCount(agentId)) // Sleep slightly over 1s to ensure ttl passes delay(1100) // force cleanup p.cleanupExpired() - assertEquals(0, p.getCheckpointCount()) - assertNull(p.getLatestCheckpoint()) + assertEquals(0, p.getCheckpointCount(agentId)) + assertNull(p.getLatestCheckpoint(agentId)) } private fun createTestCheckpoint(id: String): AgentCheckpointData { diff --git a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProviderTest.kt b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProviderTest.kt index f189f84d9b..a9f58713f6 100644 --- a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProviderTest.kt +++ b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProviderTest.kt @@ -4,6 +4,7 @@ import ai.koog.agents.snapshot.feature.AgentCheckpointData import ai.koog.prompt.message.Message import ai.koog.prompt.message.RequestMetaInfo import ai.koog.prompt.message.ResponseMetaInfo +import ai.koog.test.utils.DockerAvailableCondition import kotlinx.coroutines.runBlocking import kotlinx.datetime.Clock import kotlinx.serialization.json.JsonPrimitive @@ -13,8 +14,7 @@ import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.Test import org.junit.jupiter.api.TestInstance import org.junit.jupiter.api.TestInstance.Lifecycle -import org.junit.jupiter.api.condition.EnabledOnOs -import org.junit.jupiter.api.condition.OS +import org.junit.jupiter.api.extension.ExtendWith import org.testcontainers.containers.PostgreSQLContainer import org.testcontainers.utility.DockerImageName import kotlin.test.assertEquals @@ -22,8 +22,10 @@ import kotlin.test.assertNotNull import kotlin.test.assertNull @TestInstance(Lifecycle.PER_CLASS) -@EnabledOnOs(OS.LINUX) -class PostgresPersistencyStorageProviderTest { +@ExtendWith(DockerAvailableCondition::class) +class PostgresPersistenceStorageProviderTest { + + private val agentId = "pg-agent" private lateinit var postgres: PostgreSQLContainer<*> @@ -41,15 +43,14 @@ class PostgresPersistencyStorageProviderTest { postgres.stop() } - private fun provider(ttlSeconds: Long? = null): PostgresPersistencyStorageProvider { + private fun provider(ttlSeconds: Long? = null): PostgresPersistenceStorageProvider { val db: Database = Database.connect( url = postgres.jdbcUrl, driver = "org.postgresql.Driver", user = postgres.username, password = postgres.password ) - return PostgresPersistencyStorageProvider( - persistenceId = "pg-agent", + return PostgresPersistenceStorageProvider( database = db, tableName = "agent_checkpoints_test", ttlSeconds = ttlSeconds @@ -62,38 +63,38 @@ class PostgresPersistencyStorageProviderTest { p.migrate() // empty - assertNull(p.getLatestCheckpoint()) - assertEquals(0, p.getCheckpointCount()) + assertNull(p.getLatestCheckpoint(agentId)) + assertEquals(0, p.getCheckpointCount(agentId)) // save val cp1 = createTestCheckpoint("cp-1") - p.saveCheckpoint(cp1) + p.saveCheckpoint(agentId, cp1) // read - val latest1 = p.getLatestCheckpoint() + val latest1 = p.getLatestCheckpoint(agentId) assertNotNull(latest1) assertEquals("cp-1", latest1.checkpointId) - assertEquals(1, p.getCheckpoints().size) - assertEquals(1, p.getCheckpointCount()) + assertEquals(1, p.getCheckpoints(agentId).size) + assertEquals(1, p.getCheckpointCount(agentId)) // upsert same id should be idempotent (no duplicates due PK) - p.saveCheckpoint(cp1) - assertEquals(1, p.getCheckpoints().size) + p.saveCheckpoint(agentId, cp1) + assertEquals(1, p.getCheckpoints(agentId).size) // insert second val cp2 = createTestCheckpoint("cp-2") - p.saveCheckpoint(cp2) - val all = p.getCheckpoints() + p.saveCheckpoint(agentId, cp2) + val all = p.getCheckpoints(agentId) assertEquals(listOf("cp-1", "cp-2"), all.map { it.checkpointId }) - assertEquals("cp-2", p.getLatestCheckpoint()!!.checkpointId) + assertEquals("cp-2", p.getLatestCheckpoint(agentId)!!.checkpointId) // delete single - p.deleteCheckpoint("cp-1") - assertEquals(listOf("cp-2"), p.getCheckpoints().map { it.checkpointId }) + p.deleteCheckpoint(agentId, "cp-1") + assertEquals(listOf("cp-2"), p.getCheckpoints(agentId).map { it.checkpointId }) // delete all - p.deleteAllCheckpoints() - assertEquals(0, p.getCheckpointCount()) + p.deleteAllCheckpoints(agentId) + assertEquals(0, p.getCheckpointCount(agentId)) } @Test @@ -101,16 +102,16 @@ class PostgresPersistencyStorageProviderTest { val p = provider(ttlSeconds = 1) p.migrate() - p.saveCheckpoint(createTestCheckpoint("will-expire")) - assertEquals(1, p.getCheckpointCount()) + p.saveCheckpoint(agentId, createTestCheckpoint("will-expire")) + assertEquals(1, p.getCheckpointCount(agentId)) // Force cleanup by calling cleanupExpired directly to avoid time-based throttle // Sleep slightly over 1s to ensure ttl passes kotlinx.coroutines.delay(1100) p.cleanupExpired() - assertEquals(0, p.getCheckpointCount()) - assertNull(p.getLatestCheckpoint()) + assertEquals(0, p.getCheckpointCount(agentId)) + assertNull(p.getLatestCheckpoint(agentId)) } private fun createTestCheckpoint(id: String): AgentCheckpointData { diff --git a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/SQLPersistenceProvidersTest.kt b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/SQLPersistenceProvidersTest.kt index 1a505d8a9f..8c753b454c 100644 --- a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/SQLPersistenceProvidersTest.kt +++ b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/SQLPersistenceProvidersTest.kt @@ -24,8 +24,8 @@ class SQLPersistenceProvidersTest { @Test fun `test save and retrieve checkpoint`() = runBlocking { - val provider = H2PersistencyStorageProvider.inMemory( - persistenceId = "test-agent", + val agentId = "test-agent" + val provider = H2PersistenceStorageProvider.inMemory( databaseName = "test_db" ) @@ -33,10 +33,10 @@ class SQLPersistenceProvidersTest { // Create and save checkpoint val checkpoint = createTestCheckpoint("test-1") - provider.saveCheckpoint(checkpoint) + provider.saveCheckpoint(agentId, checkpoint) // Retrieve and verify - val retrieved = provider.getLatestCheckpoint() + val retrieved = provider.getLatestCheckpoint(agentId) assertNotNull(retrieved) assertEquals(checkpoint.checkpointId, retrieved.checkpointId) assertEquals(checkpoint.nodeId, retrieved.nodeId) @@ -45,37 +45,38 @@ class SQLPersistenceProvidersTest { @Test fun `test multiple checkpoints and ordering`() = runBlocking { - val provider = H2PersistencyStorageProvider.inMemory( - persistenceId = "test-agent", + val agentId = "test-agent" + val provider = H2PersistenceStorageProvider.inMemory( databaseName = "test_ordering" ) provider.migrate() // Save multiple checkpoints - provider.saveCheckpoint(createTestCheckpoint("checkpoint-1")) - provider.saveCheckpoint(createTestCheckpoint("checkpoint-2")) - provider.saveCheckpoint(createTestCheckpoint("checkpoint-3")) + provider.saveCheckpoint(agentId, createTestCheckpoint("checkpoint-1")) + provider.saveCheckpoint(agentId, createTestCheckpoint("checkpoint-2")) + provider.saveCheckpoint(agentId, createTestCheckpoint("checkpoint-3")) // Verify count and ordering - val allCheckpoints = provider.getCheckpoints() + val allCheckpoints = provider.getCheckpoints(agentId) assertEquals(3, allCheckpoints.size) assertEquals("checkpoint-1", allCheckpoints[0].checkpointId) assertEquals("checkpoint-3", allCheckpoints[2].checkpointId) // Verify latest - val latest = provider.getLatestCheckpoint() + val latest = provider.getLatestCheckpoint(agentId) assertEquals("checkpoint-3", latest?.checkpointId) } @Test fun `test persistence ID isolation`() = runBlocking { - val provider1 = H2PersistencyStorageProvider.inMemory( - persistenceId = "agent-1", + val agentId = "test-agent" + val agentId2 = "test-agent2" + + val provider1 = H2PersistenceStorageProvider.inMemory( databaseName = "shared_db" ) - val provider2 = H2PersistencyStorageProvider.inMemory( - persistenceId = "agent-2", + val provider2 = H2PersistenceStorageProvider.inMemory( databaseName = "shared_db" // Same database ) @@ -83,12 +84,12 @@ class SQLPersistenceProvidersTest { provider2.migrate() // Save to different agents - provider1.saveCheckpoint(createTestCheckpoint("agent1-data")) - provider2.saveCheckpoint(createTestCheckpoint("agent2-data")) + provider1.saveCheckpoint(agentId, createTestCheckpoint("agent1-data")) + provider2.saveCheckpoint(agentId2, createTestCheckpoint("agent2-data")) // Verify isolation - val agent1Checkpoints = provider1.getCheckpoints() - val agent2Checkpoints = provider2.getCheckpoints() + val agent1Checkpoints = provider1.getCheckpoints(agentId) + val agent2Checkpoints = provider2.getCheckpoints(agentId2) assertEquals(1, agent1Checkpoints.size) assertEquals(1, agent2Checkpoints.size) @@ -98,8 +99,9 @@ class SQLPersistenceProvidersTest { @Test fun `test TTL expiration`() = runBlocking { - val provider = H2PersistencyStorageProvider.inMemory( - persistenceId = "ttl-test", + val agentId = "test-agent" + + val provider = H2PersistenceStorageProvider.inMemory( databaseName = "ttl_db", ttlSeconds = 1 // 1 second TTL ) @@ -107,27 +109,26 @@ class SQLPersistenceProvidersTest { provider.migrate() // Save checkpoint - provider.saveCheckpoint(createTestCheckpoint("expire-soon")) - assertEquals(1, provider.getCheckpointCount()) + provider.saveCheckpoint(agentId, createTestCheckpoint("expire-soon")) + assertEquals(1, provider.getCheckpointCount(agentId)) // Wait for expiration delay(1500) provider.conditionalCleanup() // Should be cleaned up on next operation - val afterExpiry = provider.getLatestCheckpoint() + val afterExpiry = provider.getLatestCheckpoint(agentId) assertNull(afterExpiry) - assertEquals(0, provider.getCheckpointCount()) + assertEquals(0, provider.getCheckpointCount(agentId)) } @Test fun `verify all providers can be instantiated`() { // H2 - assertNotNull(H2PersistencyStorageProvider.inMemory("test", "test_db")) + assertNotNull(H2PersistenceStorageProvider.inMemory("test", "test_db")) // PostgreSQL assertNotNull( - PostgresPersistencyStorageProvider( - persistenceId = "test", + PostgresPersistenceStorageProvider( database = Database.connect( url = "jdbc:postgresql://localhost:5432/test", driver = "org.postgresql.Driver", @@ -139,8 +140,7 @@ class SQLPersistenceProvidersTest { // MySQL assertNotNull( - MySQLPersistencyStorageProvider( - persistenceId = "test", + MySQLPersistenceStorageProvider( database = Database.connect( url = "jdbc:mysql://localhost:3306/test", driver = "com.mysql.cj.jdbc.Driver", diff --git a/agents/agents-features/agents-features-tokenizer/src/commonMain/kotlin/ai/koog/agents/features/tokenizer/feature/MessageTokenizer.kt b/agents/agents-features/agents-features-tokenizer/src/commonMain/kotlin/ai/koog/agents/features/tokenizer/feature/MessageTokenizer.kt index 222a2c6cea..ab47e6ecea 100644 --- a/agents/agents-features/agents-features-tokenizer/src/commonMain/kotlin/ai/koog/agents/features/tokenizer/feature/MessageTokenizer.kt +++ b/agents/agents-features/agents-features-tokenizer/src/commonMain/kotlin/ai/koog/agents/features/tokenizer/feature/MessageTokenizer.kt @@ -8,9 +8,10 @@ import ai.koog.agents.core.feature.AIAgentGraphPipeline import ai.koog.agents.core.feature.AIAgentNonGraphFeature import ai.koog.agents.core.feature.AIAgentNonGraphPipeline import ai.koog.agents.core.feature.config.FeatureConfig -import ai.koog.prompt.dsl.Prompt -import ai.koog.prompt.message.Message +import ai.koog.prompt.tokenizer.CachingTokenizer import ai.koog.prompt.tokenizer.NoTokenizer +import ai.koog.prompt.tokenizer.OnDemandTokenizer +import ai.koog.prompt.tokenizer.PromptTokenizer import ai.koog.prompt.tokenizer.Tokenizer /** @@ -49,111 +50,6 @@ public class MessageTokenizerConfig : FeatureConfig() { public var enableCaching: Boolean = true } -/** - * An interface that provides utilities for tokenizing and calculating token usage in messages and prompts. - */ -public interface PromptTokenizer { - /** - * Calculates the number of tokens required for a given message. - * - * @param message The message for which the token count should be determined. - * @return The number of tokens required to encode the message. - */ - public fun tokenCountFor(message: Message): Int - - /** - * Calculates the total number of tokens spent in a given prompt. - * - * @param prompt The prompt for which the total tokens spent need to be calculated. - * @return The total number of tokens spent as an integer. - */ - public fun tokenCountFor(prompt: Prompt): Int -} - -/** - * An implementation of the [PromptTokenizer] interface that delegates token counting - * to an instance of the [Tokenizer] interface. The class provides methods to estimate - * the token count for individual messages and for the entirety of a prompt. - * - * This is useful in contexts where token-based costs or limitations are significant, - * such as when interacting with large language models (LLMs). - * - * @property tokenizer The [Tokenizer] instance used for token counting. - */ -public class OnDemandTokenizer(private val tokenizer: Tokenizer) : PromptTokenizer { - - /** - * Computes the number of tokens in a given message. - * - * @param message The message for which the token count needs to be calculated. - * The content of the message is analyzed to estimate the token count. - * @return The estimated number of tokens in the message content. - */ - public override fun tokenCountFor(message: Message): Int = tokenizer.countTokens(message.content) - - /** - * Calculates the total number of tokens spent for the given prompt based on its messages. - * - * @param prompt The `Prompt` instance containing the list of messages for which the total token count will be calculated. - * @return The total number of tokens across all messages in the prompt. - */ - public override fun tokenCountFor(prompt: Prompt): Int = prompt.messages.sumOf(::tokenCountFor) -} - -/** - * A caching implementation of the `PromptTokenizer` interface that optimizes token counting - * by storing previously computed token counts for messages. This reduces redundant computations - * when the same message is processed multiple times. - * - * @constructor Creates an instance of `CachingTokenizer` with a provided `Tokenizer` instance - * that performs the actual token counting. - * @property tokenizer The underlying `Tokenizer` used for counting tokens in the message content. - */ -public class CachingTokenizer(private val tokenizer: Tokenizer) : PromptTokenizer { - /** - * A cache that maps a `Message` to its corresponding token count. - * - * This is used to store the results of token computations for reuse, optimizing performance - * by avoiding repeated invocations of the token counting process on the same message content. - * - * Token counts are computed lazily and stored in the cache when requested via the `tokensFor` - * method. This cache can be cleared using the `clearCache` method. - */ - internal val cache = mutableMapOf() - - /** - * Retrieves the number of tokens contained in the content of the given message. - * This method utilizes caching to improve performance, storing previously - * computed token counts and reusing them for identical messages. - * - * @param message The message whose content's token count is to be retrieved - * @return The number of tokens in the content of the message - */ - public override fun tokenCountFor(message: Message): Int = cache.getOrPut(message) { - tokenizer.countTokens(message.content) - } - - /** - * Calculates the total number of tokens spent on the given prompt by summing the token usage - * of all messages associated with the prompt. - * - * @param prompt The prompt containing the list of messages whose token usage will be calculated. - * @return The total number of tokens spent across all messages in the provided prompt. - */ - public override fun tokenCountFor(prompt: Prompt): Int = prompt.messages.sumOf(::tokenCountFor) - - /** - * Clears all cached token counts from the internal cache. - * - * This method is useful when the state of the cached data becomes invalid - * or needs resetting. After calling this, any subsequent token count - * calculations will be recomputed rather than retrieved from the cache. - */ - public fun clearCache() { - cache.clear() - } -} - /** * The [MessageTokenizer] feature is responsible for handling tokenization of messages using a provided [Tokenizer] * implementation. It serves as a feature that can be installed into an `AIAgentPipeline`. The tokenizer behavior can be configured diff --git a/agents/agents-features/agents-features-tokenizer/src/jvmTest/kotlin/ai/koog/agents/features/tokenizer/feature/MessageTokenizerTest.kt b/agents/agents-features/agents-features-tokenizer/src/jvmTest/kotlin/ai/koog/agents/features/tokenizer/feature/MessageTokenizerTest.kt index 8fa8f7f2e6..037452f5f1 100644 --- a/agents/agents-features/agents-features-tokenizer/src/jvmTest/kotlin/ai/koog/agents/features/tokenizer/feature/MessageTokenizerTest.kt +++ b/agents/agents-features/agents-features-tokenizer/src/jvmTest/kotlin/ai/koog/agents/features/tokenizer/feature/MessageTokenizerTest.kt @@ -15,16 +15,10 @@ import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.agents.testing.tools.mockLLMAnswer import ai.koog.prompt.dsl.prompt import ai.koog.prompt.llm.OllamaModels -import ai.koog.prompt.message.Message -import ai.koog.prompt.message.RequestMetaInfo -import ai.koog.prompt.message.ResponseMetaInfo import ai.koog.prompt.tokenizer.Tokenizer import kotlinx.coroutines.runBlocking -import kotlinx.coroutines.test.runTest -import kotlinx.datetime.Clock import kotlin.test.Test import kotlin.test.assertEquals -import kotlin.test.assertTrue /** * Test for the MessageTokenizer feature. @@ -72,84 +66,6 @@ class MessageTokenizerTest { } } - @Test - fun testPromptTokenizer() = runTest { - // Create a mock tokenizer to track token usage - val mockTokenizer = MockTokenizer() - - // Create a prompt tokenizer with our mock tokenizer - val promptTokenizer = OnDemandTokenizer(mockTokenizer) - - // Create a prompt with some messages - val testPrompt = prompt("test-prompt") { - system("You are a helpful assistant.") - user("What is the capital of France?") - assistant("Paris is the capital of France.") - } - - // Count tokens in the prompt - val totalTokens = promptTokenizer.tokenCountFor(testPrompt) - - // Verify that tokens were counted - assertTrue(totalTokens > 0, "Total tokens should be greater than 0") - - // Verify that the tokenizer was used and counted tokens - assertTrue(mockTokenizer.totalTokens > 0, "Tokenizer should have counted tokens") - - // Verify that the total tokens match what we expect - assertEquals(totalTokens, mockTokenizer.totalTokens, "Total tokens should match the tokenizer's count") - - // Print the total tokens spent - println("[DEBUG_LOG] Total tokens spent: ${mockTokenizer.totalTokens}") - - val requestMetainfo = RequestMetaInfo.create(Clock.System) - val responseMetainfo = ResponseMetaInfo.create(Clock.System) - // Count tokens for individual messages - val systemTokens = promptTokenizer.tokenCountFor( - Message.System("You are a helpful assistant.", requestMetainfo) - ) - val userTokens = promptTokenizer.tokenCountFor(Message.User("What is the capital of France?", requestMetainfo)) - val assistantTokens = promptTokenizer.tokenCountFor( - Message.Assistant("Paris is the capital of France.", responseMetainfo) - ) - - // Print token counts for each message - println("[DEBUG_LOG] System message tokens: $systemTokens") - println("[DEBUG_LOG] User message tokens: $userTokens") - println("[DEBUG_LOG] Assistant message tokens: $assistantTokens") - - // Verify that the sum of individual message tokens equals the total - val sumOfMessageTokens = systemTokens + userTokens + assistantTokens - assertEquals(sumOfMessageTokens, totalTokens, "Sum of message tokens should equal total tokens") - } - - @Test - fun testCachingPromptTokenizer() = runTest { - // Create a mock tokenizer to track token usage - val mockTokenizer = MockTokenizer() - - // Create a prompt tokenizer with our mock tokenizer - val promptTokenizer = CachingTokenizer(mockTokenizer) - - // Create a prompt with some messages - val testPrompt = prompt("test-prompt") { - system("You are a helpful assistant.") - user("What is the capital of France?") - assistant("Paris is the capital of France.") - } - - assertEquals(0, promptTokenizer.cache.size) - promptTokenizer.tokenCountFor(testPrompt) - assertEquals(3, promptTokenizer.cache.size) - promptTokenizer.clearCache() - assertEquals(0, promptTokenizer.cache.size) - promptTokenizer.tokenCountFor(testPrompt.messages[1]) - promptTokenizer.tokenCountFor(testPrompt.messages[2]) - assertEquals(2, promptTokenizer.cache.size) - promptTokenizer.tokenCountFor(testPrompt) - assertEquals(3, promptTokenizer.cache.size) - } - @Test fun testTokenizerInAgents() { val testToolRegistry = ToolRegistry { diff --git a/agents/agents-features/agents-features-trace/Module.md b/agents/agents-features/agents-features-trace/Module.md index 1b3e53f0a9..e0ff3b17cd 100644 --- a/agents/agents-features/agents-features-trace/Module.md +++ b/agents/agents-features/agents-features-trace/Module.md @@ -7,7 +7,7 @@ Provides implementation of the `Tracing` feature for AI Agents The Tracing feature captures comprehensive data about agent execution, including: - All LLM calls and their responses - Prompts sent to LLMs -- Tool calls, arguments, and results +- Tool executions, arguments, and results - Graph node visits and execution flow - Agent lifecycle events (creation, start, finish, errors) - Strategy execution events @@ -40,8 +40,8 @@ val agent = AIAgent( // Optionally filter messages fileWriter.setMessageFilter { message -> - // Only trace LLM calls and tool calls - message is BeforeLLMCallEvent || message is ToolCallEvent + // Only trace LLM calls and tool executions + message is LLMCallStartingEvent || message is ToolExecutionStartingEvent } } } @@ -70,15 +70,14 @@ val agent = AIAgent( Here's an example of the logs produced by tracing: ``` -AgentCreateEvent (strategy name: my-agent-strategy) -AgentStartedEvent (strategy name: my-agent-strategy) -StrategyStartEvent (strategy name: my-agent-strategy) -NodeExecutionStartEvent (node: definePrompt, input: user query) -NodeExecutionEndEvent (node: definePrompt, input: user query, output: processed query) -BeforeLLMCallEvent (prompt: Please analyze the following code...) -AfterLLMCallEvent (response: I've analyzed the code and found...) -ToolCallEvent (tool: readFile, tool args: {"path": "src/main.py"}) -ToolCallResultEvent (tool: readFile, tool args: {"path": "src/main.py"}, result: "def main():...") -StrategyFinishedEvent (strategy name: my-agent-strategy, result: Success) -AgentFinishedEvent (strategy name: my-agent-strategy, result: Success) +AgentStartingEvent (strategy name: my-agent-strategy) +GraphStrategyStartingEvent (strategy name: my-agent-strategy) +NodeExecutionStartingEvent (node: definePrompt, input: user query) +NodeExecutionCompletedEvent (node: definePrompt, input: user query, output: processed query) +LLMCallStartingEvent (prompt: Please analyze the following code...) +LLMCallCompletedEvent (response: I've analyzed the code and found...) +ToolExecutionStartingEvent (tool: readFile, tool args: {"path": "src/main.py"}) +ToolExecutionCompletedEvent (tool: readFile, tool args: {"path": "src/main.py"}, result: "def main():...") +StrategyCompletedEvent (strategy name: my-agent-strategy, result: Success) +AgentCompletedEvent (strategy name: my-agent-strategy, result: Success) ``` diff --git a/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/feature/TraceFeatureConfig.kt b/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/feature/TraceFeatureConfig.kt index 0da1d108f1..253727c19d 100644 --- a/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/feature/TraceFeatureConfig.kt +++ b/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/feature/TraceFeatureConfig.kt @@ -19,4 +19,4 @@ import ai.koog.agents.core.feature.config.FeatureConfig * } * ``` */ -public class TraceFeatureConfig() : FeatureConfig() +public class TraceFeatureConfig : FeatureConfig() diff --git a/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/feature/Tracing.kt b/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/feature/Tracing.kt index 726303c23b..3f8dc2a143 100644 --- a/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/feature/Tracing.kt +++ b/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/feature/Tracing.kt @@ -16,13 +16,17 @@ import ai.koog.agents.core.feature.model.events.AgentStartingEvent import ai.koog.agents.core.feature.model.events.GraphStrategyStartingEvent import ai.koog.agents.core.feature.model.events.LLMCallCompletedEvent import ai.koog.agents.core.feature.model.events.LLMCallStartingEvent +import ai.koog.agents.core.feature.model.events.LLMStreamingCompletedEvent +import ai.koog.agents.core.feature.model.events.LLMStreamingFailedEvent +import ai.koog.agents.core.feature.model.events.LLMStreamingFrameReceivedEvent +import ai.koog.agents.core.feature.model.events.LLMStreamingStartingEvent import ai.koog.agents.core.feature.model.events.NodeExecutionCompletedEvent import ai.koog.agents.core.feature.model.events.NodeExecutionFailedEvent import ai.koog.agents.core.feature.model.events.NodeExecutionStartingEvent import ai.koog.agents.core.feature.model.events.StrategyCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionFailedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionStartingEvent +import ai.koog.agents.core.feature.model.events.ToolCallCompletedEvent +import ai.koog.agents.core.feature.model.events.ToolCallFailedEvent +import ai.koog.agents.core.feature.model.events.ToolCallStartingEvent import ai.koog.agents.core.feature.model.events.ToolValidationFailedEvent import ai.koog.agents.core.feature.model.events.startNodeToGraph import ai.koog.agents.core.feature.model.toAgentError @@ -255,14 +259,58 @@ public class Tracing { //endregion Intercept LLM Call Events + //region Intercept LLM Streaming Events + + pipeline.interceptLLMStreamingStarting(interceptContext) intercept@{ eventContext -> + val event = LLMStreamingStartingEvent( + runId = eventContext.runId, + prompt = eventContext.prompt, + model = eventContext.model.eventString, + tools = eventContext.tools.map { it.name }, + timestamp = pipeline.clock.now().toEpochMilliseconds() + ) + processMessage(config, event) + } + + pipeline.interceptLLMStreamingCompleted(interceptContext) intercept@{ eventContext -> + val event = LLMStreamingCompletedEvent( + runId = eventContext.runId, + prompt = eventContext.prompt, + model = eventContext.model.eventString, + tools = eventContext.tools.map { it.name }, + timestamp = pipeline.clock.now().toEpochMilliseconds() + ) + processMessage(config, event) + } + + pipeline.interceptLLMStreamingFrameReceived(interceptContext) intercept@{ eventContext -> + val event = LLMStreamingFrameReceivedEvent( + runId = eventContext.runId, + frame = eventContext.streamFrame, + timestamp = pipeline.clock.now().toEpochMilliseconds() + ) + processMessage(config, event) + } + + pipeline.interceptLLMStreamingFailed(interceptContext) intercept@{ eventContext -> + val event = LLMStreamingFailedEvent( + runId = eventContext.runId, + error = eventContext.error.toAgentError(), + timestamp = pipeline.clock.now().toEpochMilliseconds() + ) + processMessage(config, event) + } + + //endregion Intercept LLM Streaming Events + //region Intercept Tool Call Events - pipeline.interceptToolExecutionStarting(interceptContext) intercept@{ eventContext -> + pipeline.interceptToolCallStarting(interceptContext) intercept@{ eventContext -> @Suppress("UNCHECKED_CAST") val tool = eventContext.tool as Tool - val event = ToolExecutionStartingEvent( + val event = ToolCallStartingEvent( runId = eventContext.runId, toolCallId = eventContext.toolCallId, toolName = tool.name, @@ -288,12 +336,12 @@ public class Tracing { processMessage(config, event) } - pipeline.interceptToolExecutionFailed(interceptContext) intercept@{ eventContext -> + pipeline.interceptToolCallFailed(interceptContext) intercept@{ eventContext -> @Suppress("UNCHECKED_CAST") val tool = eventContext.tool as Tool - val event = ToolExecutionFailedEvent( + val event = ToolCallFailedEvent( runId = eventContext.runId, toolCallId = eventContext.toolCallId, toolName = tool.name, @@ -304,12 +352,12 @@ public class Tracing { processMessage(config, event) } - pipeline.interceptToolExecutionCompleted(interceptContext) intercept@{ eventContext -> + pipeline.interceptToolCallCompleted(interceptContext) intercept@{ eventContext -> @Suppress("UNCHECKED_CAST") val tool = eventContext.tool as Tool - val event = ToolExecutionCompletedEvent( + val event = ToolCallCompletedEvent( runId = eventContext.runId, toolCallId = eventContext.toolCallId, toolName = tool.name, diff --git a/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageFileWriter.kt b/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageFileWriter.kt index 71e037a2df..b45669e44f 100644 --- a/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageFileWriter.kt +++ b/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageFileWriter.kt @@ -30,7 +30,7 @@ import kotlinx.io.Sink * sinkOpener = fileSystem::sink, * targetPath = "custom-traces.log", * format = { message -> - * "[TRACE] ${message.eventId}: ${message::class.simpleName}" + * "[TRACE] ${message::class.simpleName}" * } * )) * } diff --git a/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageLogWriter.kt b/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageLogWriter.kt index e8439d8cae..1a85a11f60 100644 --- a/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageLogWriter.kt +++ b/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageLogWriter.kt @@ -35,7 +35,7 @@ import io.github.oshai.kotlinlogging.KLogger * addMessageProcessor(TraceFeatureMessageLogWriter( * targetLogger = logger, * format = { message -> - * "[TRACE] ${message.eventId}: ${message::class.simpleName}" + * "[TRACE] ${message::class.simpleName}" * } * )) * } diff --git a/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/writer/traceMessageFormat.kt b/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/writer/traceMessageFormat.kt index 0a6196edc6..05961fd756 100644 --- a/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/writer/traceMessageFormat.kt +++ b/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/writer/traceMessageFormat.kt @@ -14,9 +14,9 @@ import ai.koog.agents.core.feature.model.events.NodeExecutionFailedEvent import ai.koog.agents.core.feature.model.events.NodeExecutionStartingEvent import ai.koog.agents.core.feature.model.events.StrategyCompletedEvent import ai.koog.agents.core.feature.model.events.StrategyStartingEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionFailedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionStartingEvent +import ai.koog.agents.core.feature.model.events.ToolCallCompletedEvent +import ai.koog.agents.core.feature.model.events.ToolCallFailedEvent +import ai.koog.agents.core.feature.model.events.ToolCallStartingEvent import ai.koog.agents.core.feature.model.events.ToolValidationFailedEvent import ai.koog.agents.features.tracing.traceString @@ -32,53 +32,49 @@ internal val FeatureStringMessage.featureStringMessage get() = "Feature string message (message: $message)" internal val AgentStartingEvent.agentStartedEventFormat - get() = "$eventId (agent id: $agentId, run id: $runId)" + get() = "${this::class.simpleName} (agent id: $agentId, run id: $runId)" internal val AgentCompletedEvent.agentFinishedEventFormat - get() = "$eventId (agent id: $agentId, run id: $runId, result: $result)" + get() = "${this::class.simpleName} (agent id: $agentId, run id: $runId, result: $result)" internal val AgentExecutionFailedEvent.agentRunErrorEventFormat - get() = "$eventId (agent id: $agentId, run id: $runId, error: ${error.message})" + get() = "${this::class.simpleName} (agent id: $agentId, run id: $runId, error: ${error.message})" internal val AgentClosingEvent.agentBeforeCloseFormat - get() = "$eventId (agent id: $agentId)" + get() = "${this::class.simpleName} (agent id: $agentId)" internal val StrategyStartingEvent.strategyStartEventFormat - get() = "$eventId (run id: $runId, strategy: $strategyName)" + get() = "${this::class.simpleName} (run id: $runId, strategy: $strategyName)" internal val StrategyCompletedEvent.strategyFinishedEventFormat - get() = "$eventId (run id: $runId, strategy: $strategyName, result: $result)" + get() = "${this::class.simpleName} (run id: $runId, strategy: $strategyName, result: $result)" internal val NodeExecutionStartingEvent.nodeExecutionStartEventFormat - get() = "$eventId (run id: $runId, node: $nodeName, input: $input)" + get() = "${this::class.simpleName} (run id: $runId, node: $nodeName, input: $input)" internal val NodeExecutionCompletedEvent.nodeExecutionEndEventFormat - get() = "$eventId (run id: $runId, node: $nodeName, input: $input, output: $output)" + get() = "${this::class.simpleName} (run id: $runId, node: $nodeName, input: $input, output: $output)" internal val NodeExecutionFailedEvent.nodeExecutionErrorEventFormat - get() = "$eventId (run id: $runId, node: $nodeName, error: ${error.message})" + get() = "${this::class.simpleName} (run id: $runId, node: $nodeName, error: ${error.message})" internal val LLMCallStartingEvent.beforeLLMCallEventFormat - get() = "$eventId (run id: $runId, prompt: ${prompt.traceString}, model: $model, tools: [${tools.joinToString()}])" + get() = "${this::class.simpleName} (run id: $runId, prompt: ${prompt.traceString}, model: $model, tools: [${tools.joinToString()}])" internal val LLMCallCompletedEvent.afterLLMCallEventFormat - get() = "$eventId (run id: $runId, prompt: ${prompt.traceString}, model: $model, responses: [${ - responses.joinToString { - "{${it.traceString}}" - } - }])" + get() = "${this::class.simpleName} (run id: $runId, prompt: ${prompt.traceString}, model: $model, responses: [${responses.joinToString { "{${it.traceString}}" }}])" -internal val ToolExecutionStartingEvent.toolCallEventFormat - get() = "$eventId (run id: $runId, tool: $toolName, tool args: $toolArgs)" +internal val ToolCallStartingEvent.toolCallEventFormat + get() = "${this::class.simpleName} (run id: $runId, tool: $toolName, tool args: $toolArgs)" internal val ToolValidationFailedEvent.toolValidationErrorEventFormat - get() = "$eventId (run id: $runId, tool: $toolName, tool args: $toolArgs, validation error: $error)" + get() = "${this::class.simpleName} (run id: $runId, tool: $toolName, tool args: $toolArgs, validation error: $error)" -internal val ToolExecutionFailedEvent.toolCallFailureEventFormat - get() = "$eventId (run id: $runId, tool: $toolName, tool args: $toolArgs, error: ${error.message})" +internal val ToolCallFailedEvent.toolCallFailureEventFormat + get() = "${this::class.simpleName} (run id: $runId, tool: $toolName, tool args: $toolArgs, error: ${error.message})" -internal val ToolExecutionCompletedEvent.toolCallResultEventFormat - get() = "$eventId (run id: $runId, tool: $toolName, tool args: $toolArgs, result: $result)" +internal val ToolCallCompletedEvent.toolCallResultEventFormat + get() = "${this::class.simpleName} (run id: $runId, tool: $toolName, tool args: $toolArgs, result: $result)" internal val FeatureMessage.traceMessage: String get() { @@ -94,10 +90,10 @@ internal val FeatureMessage.traceMessage: String is NodeExecutionFailedEvent -> this.nodeExecutionErrorEventFormat is LLMCallStartingEvent -> this.beforeLLMCallEventFormat is LLMCallCompletedEvent -> this.afterLLMCallEventFormat - is ToolExecutionStartingEvent -> this.toolCallEventFormat + is ToolCallStartingEvent -> this.toolCallEventFormat is ToolValidationFailedEvent -> this.toolValidationErrorEventFormat - is ToolExecutionFailedEvent -> this.toolCallFailureEventFormat - is ToolExecutionCompletedEvent -> this.toolCallResultEventFormat + is ToolCallFailedEvent -> this.toolCallFailureEventFormat + is ToolCallCompletedEvent -> this.toolCallResultEventFormat is FeatureStringMessage -> this.featureStringMessage is FeatureEvent -> this.featureEvent else -> this.featureMessage diff --git a/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageFileWriterTest.kt b/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageFileWriterTest.kt index ce0480e60d..8cde4c5f53 100644 --- a/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageFileWriterTest.kt +++ b/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageFileWriterTest.kt @@ -19,8 +19,8 @@ import ai.koog.agents.core.feature.model.events.LLMCallStartingEvent import ai.koog.agents.core.feature.model.events.NodeExecutionCompletedEvent import ai.koog.agents.core.feature.model.events.NodeExecutionStartingEvent import ai.koog.agents.core.feature.model.events.StrategyCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionStartingEvent +import ai.koog.agents.core.feature.model.events.ToolCallCompletedEvent +import ai.koog.agents.core.feature.model.events.ToolCallStartingEvent import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.features.tracing.eventString import ai.koog.agents.features.tracing.feature.Tracing @@ -185,8 +185,8 @@ class TraceFeatureMessageFileWriterTest { content = """{"dummy":"test"}""" ) })", - "${ToolExecutionStartingEvent::class.simpleName} (run id: $runId, tool: ${dummyTool.name}, tool args: {\"dummy\":\"test\"})", - "${ToolExecutionCompletedEvent::class.simpleName} (run id: $runId, tool: ${dummyTool.name}, tool args: {\"dummy\":\"test\"}, result: ${dummyTool.result})", + "${ToolCallStartingEvent::class.simpleName} (run id: $runId, tool: ${dummyTool.name}, tool args: {\"dummy\":\"test\"})", + "${ToolCallCompletedEvent::class.simpleName} (run id: $runId, tool: ${dummyTool.name}, tool args: {\"dummy\":\"test\"}, result: ${dummyTool.result})", "${NodeExecutionCompletedEvent::class.simpleName} (run id: $runId, node: test-tool-call, input: ${ toolCallMessage( dummyTool.name, @@ -258,7 +258,7 @@ class TraceFeatureMessageFileWriterTest { val customFormat: (FeatureMessage) -> String = { message -> when (message) { is FeatureStringMessage -> "CUSTOM STRING. ${message.message}" - is FeatureEvent -> "CUSTOM EVENT. ${message.eventId}" + is FeatureEvent -> "CUSTOM EVENT. No event message" else -> "CUSTOM OTHER: ${message::class.simpleName}" } } @@ -273,7 +273,7 @@ class TraceFeatureMessageFileWriterTest { val expectedMessages = listOf( "CUSTOM STRING. Test string message", - "CUSTOM EVENT. ${AgentStartingEvent::class.simpleName}", + "CUSTOM EVENT. No event message", ) TraceFeatureMessageFileWriter( diff --git a/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageLogWriterTest.kt b/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageLogWriterTest.kt index 75e3644028..88325ab136 100644 --- a/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageLogWriterTest.kt +++ b/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageLogWriterTest.kt @@ -19,8 +19,8 @@ import ai.koog.agents.core.feature.model.events.LLMCallStartingEvent import ai.koog.agents.core.feature.model.events.NodeExecutionCompletedEvent import ai.koog.agents.core.feature.model.events.NodeExecutionStartingEvent import ai.koog.agents.core.feature.model.events.StrategyCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionStartingEvent +import ai.koog.agents.core.feature.model.events.ToolCallCompletedEvent +import ai.koog.agents.core.feature.model.events.ToolCallStartingEvent import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.features.tracing.eventString import ai.koog.agents.features.tracing.feature.Tracing @@ -174,8 +174,8 @@ class TraceFeatureMessageLogWriterTest { content = """{"dummy":"test"}""" ) })", - "[INFO] Received feature message [event]: ${ToolExecutionStartingEvent::class.simpleName} (run id: $runId, tool: ${dummyTool.name}, tool args: {\"dummy\":\"test\"})", - "[INFO] Received feature message [event]: ${ToolExecutionCompletedEvent::class.simpleName} (run id: $runId, tool: ${dummyTool.name}, tool args: {\"dummy\":\"test\"}, result: ${dummyTool.result})", + "[INFO] Received feature message [event]: ${ToolCallStartingEvent::class.simpleName} (run id: $runId, tool: ${dummyTool.name}, tool args: {\"dummy\":\"test\"})", + "[INFO] Received feature message [event]: ${ToolCallCompletedEvent::class.simpleName} (run id: $runId, tool: ${dummyTool.name}, tool args: {\"dummy\":\"test\"}, result: ${dummyTool.result})", "[INFO] Received feature message [event]: ${NodeExecutionCompletedEvent::class.simpleName} (run id: $runId, node: test-tool-call, input: ${ toolCallMessage( dummyTool.name, @@ -245,7 +245,7 @@ class TraceFeatureMessageLogWriterTest { val customFormat: (FeatureMessage) -> String = { message -> when (message) { is FeatureStringMessage -> "CUSTOM STRING. ${message.message}" - is FeatureEvent -> "CUSTOM EVENT. ${message.eventId}" + is FeatureEvent -> "CUSTOM EVENT. No event message" else -> "OTHER: ${message::class.simpleName}" } } @@ -260,7 +260,7 @@ class TraceFeatureMessageLogWriterTest { val expectedMessages = listOf( "[INFO] Received feature message [message]: CUSTOM STRING. Test string message", - "[INFO] Received feature message [event]: CUSTOM EVENT. ${AgentStartingEvent::class.simpleName}", + "[INFO] Received feature message [event]: CUSTOM EVENT. No event message", ) TraceFeatureMessageLogWriter(targetLogger = targetLogger, format = customFormat).use { writer -> diff --git a/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageRemoteWriterTest.kt b/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageRemoteWriterTest.kt index 3e70d41122..9e092f411e 100644 --- a/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageRemoteWriterTest.kt +++ b/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageRemoteWriterTest.kt @@ -20,8 +20,8 @@ import ai.koog.agents.core.feature.model.events.StrategyCompletedEvent import ai.koog.agents.core.feature.model.events.StrategyEventGraph import ai.koog.agents.core.feature.model.events.StrategyEventGraphEdge import ai.koog.agents.core.feature.model.events.StrategyEventGraphNode -import ai.koog.agents.core.feature.model.events.ToolExecutionCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionStartingEvent +import ai.koog.agents.core.feature.model.events.ToolCallCompletedEvent +import ai.koog.agents.core.feature.model.events.ToolCallStartingEvent import ai.koog.agents.core.feature.remote.client.FeatureMessageRemoteClient import ai.koog.agents.core.feature.remote.client.config.DefaultClientConnectionConfig import ai.koog.agents.core.feature.remote.server.config.DefaultServerConnectionConfig @@ -347,14 +347,14 @@ class TraceFeatureMessageRemoteWriterTest { input = toolCallMessage(dummyTool.name, content = """{"dummy":"test"}""").toString(), timestamp = testClock.now().toEpochMilliseconds() ), - ToolExecutionStartingEvent( + ToolCallStartingEvent( runId = runId, toolCallId = "0", toolName = dummyTool.name, toolArgs = dummyTool.encodeArgs(DummyTool.Args("test")), timestamp = testClock.now().toEpochMilliseconds() ), - ToolExecutionCompletedEvent( + ToolCallCompletedEvent( runId = runId, toolCallId = "0", toolName = dummyTool.name, diff --git a/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageTestWriterTest.kt b/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageTestWriterTest.kt index 2879ef0f92..08c8319c08 100644 --- a/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageTestWriterTest.kt +++ b/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageTestWriterTest.kt @@ -4,24 +4,41 @@ import ai.koog.agents.core.dsl.builder.forwardTo import ai.koog.agents.core.dsl.builder.strategy import ai.koog.agents.core.dsl.extension.nodeExecuteTool import ai.koog.agents.core.dsl.extension.nodeLLMRequest +import ai.koog.agents.core.dsl.extension.nodeLLMRequestStreamingAndSendResults import ai.koog.agents.core.dsl.extension.nodeUpdatePrompt import ai.koog.agents.core.feature.model.AIAgentError import ai.koog.agents.core.feature.model.events.LLMCallStartingEvent +import ai.koog.agents.core.feature.model.events.LLMStreamingCompletedEvent +import ai.koog.agents.core.feature.model.events.LLMStreamingFailedEvent +import ai.koog.agents.core.feature.model.events.LLMStreamingFrameReceivedEvent +import ai.koog.agents.core.feature.model.events.LLMStreamingStartingEvent import ai.koog.agents.core.feature.model.events.NodeExecutionFailedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionStartingEvent +import ai.koog.agents.core.feature.model.events.ToolCallCompletedEvent +import ai.koog.agents.core.feature.model.events.ToolCallStartingEvent import ai.koog.agents.core.tools.ToolRegistry +import ai.koog.agents.features.tracing.eventString import ai.koog.agents.features.tracing.feature.Tracing import ai.koog.agents.features.tracing.mock.RecursiveTool import ai.koog.agents.features.tracing.mock.TestFeatureMessageWriter import ai.koog.agents.features.tracing.mock.TestLogger +import ai.koog.agents.features.tracing.mock.assistantMessage import ai.koog.agents.features.tracing.mock.createAgent +import ai.koog.agents.features.tracing.mock.systemMessage import ai.koog.agents.features.tracing.mock.testClock +import ai.koog.agents.features.tracing.mock.userMessage import ai.koog.agents.testing.tools.DummyTool +import ai.koog.agents.testing.tools.getMockExecutor +import ai.koog.agents.testing.tools.mockLLMAnswer import ai.koog.agents.utils.use +import ai.koog.prompt.dsl.Prompt +import ai.koog.prompt.executor.clients.openai.OpenAIModels +import ai.koog.prompt.executor.model.PromptExecutor import ai.koog.prompt.message.Message import ai.koog.prompt.message.ResponseMetaInfo -import kotlinx.coroutines.test.runTest +import ai.koog.prompt.streaming.StreamFrame +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.runBlocking import kotlinx.datetime.Instant import org.junit.jupiter.api.Assertions.assertEquals import kotlin.test.AfterTest @@ -39,7 +56,7 @@ class TraceFeatureMessageTestWriterTest { } @Test - fun `test subsequent LLM calls`() = runTest { + fun `test subsequent LLM calls`() = runBlocking { val strategy = strategy("tracing-test-strategy") { val setPrompt by nodeUpdatePrompt("Set prompt") { system("System 1") @@ -90,7 +107,7 @@ class TraceFeatureMessageTestWriterTest { } @Test - fun `test nonexistent tool call`() = runTest { + fun `test nonexistent tool call`() = runBlocking { val strategy = strategy("tracing-tool-call-test") { val callTool by nodeExecuteTool("Tool call") edge( @@ -127,7 +144,7 @@ class TraceFeatureMessageTestWriterTest { } @Test - fun `test existing tool call`() = runTest { + fun `test existing tool call`() = runBlocking { val strategy = strategy("tracing-tool-call-test") { val callTool by nodeExecuteTool("Tool call") edge( @@ -155,15 +172,15 @@ class TraceFeatureMessageTestWriterTest { agent.run("") - val toolCallsStartEvent = messageProcessor.messages.filterIsInstance().toList() + val toolCallsStartEvent = messageProcessor.messages.filterIsInstance().toList() assertEquals(1, toolCallsStartEvent.size, "Tool call start event for existing tool") - val toolCallsEndEvent = messageProcessor.messages.filterIsInstance().toList() + val toolCallsEndEvent = messageProcessor.messages.filterIsInstance().toList() assertEquals(1, toolCallsEndEvent.size, "Tool call end event for existing tool") } @Test - fun `test recursive tool call`() = runTest { + fun `test recursive tool call`() = runBlocking { val strategy = strategy("recursive-tool-call-test") { val callTool by nodeExecuteTool("Tool call") edge( @@ -195,12 +212,12 @@ class TraceFeatureMessageTestWriterTest { agent.run("") - val toolCallsStartEvent = messageProcessor.messages.filterIsInstance().toList() + val toolCallsStartEvent = messageProcessor.messages.filterIsInstance().toList() assertEquals(1, toolCallsStartEvent.size, "Tool call start event for existing tool") } @Test - fun `test llm tool call`() = runTest { + fun `test llm tool call`() = runBlocking { val dummyTool = DummyTool() val strategy = strategy("llm-tool-call-test") { @@ -234,15 +251,15 @@ class TraceFeatureMessageTestWriterTest { agent.run("") - val toolCallsStartEvent = messageProcessor.messages.filterIsInstance().toList() + val toolCallsStartEvent = messageProcessor.messages.filterIsInstance().toList() assertEquals(1, toolCallsStartEvent.size, "Tool call start event for existing tool") - val toolCallsEndEvent = messageProcessor.messages.filterIsInstance().toList() + val toolCallsEndEvent = messageProcessor.messages.filterIsInstance().toList() assertEquals(1, toolCallsEndEvent.size, "Tool call end event for existing tool") } @Test - fun `test agent with node execution error`() = runTest { + fun `test agent with node execution error`() = runBlocking { val agentId = "test-agent-id" val nodeWithErrorName = "node-with-error" val testErrorMessage = "Test error" @@ -293,4 +310,201 @@ class TraceFeatureMessageTestWriterTest { } } } + + @Test + fun `test llm streaming events success`() = runBlocking { + val userPrompt = "Test user request" + val systemPrompt = "Test system prompt" + val assistantPrompt = "Test assistant prompt" + val promptId = "Test prompt id" + + val model = OpenAIModels.Chat.GPT4o + + val strategy = strategy("tracing-streaming-success") { + val streamAndCollect by nodeLLMRequestStreamingAndSendResults("stream-and-collect") + + edge(nodeStart forwardTo streamAndCollect) + edge(streamAndCollect forwardTo nodeFinish transformed { messages -> messages.firstOrNull()?.content ?: "" }) + } + + val testLLMResponse = "Default test response" + + val testExecutor = getMockExecutor { + mockLLMAnswer(testLLMResponse).asDefaultResponse onUserRequestEquals userPrompt + } + + val toolRegistry = ToolRegistry { tool(DummyTool()) } + + TestFeatureMessageWriter().use { writer -> + createAgent( + agentId = "test-agent-id", + strategy = strategy, + promptExecutor = testExecutor, + systemPrompt = systemPrompt, + userPrompt = userPrompt, + assistantPrompt = assistantPrompt, + promptId = promptId, + model = model, + toolRegistry = toolRegistry, + ) { + install(Tracing) { + addMessageProcessor(writer) + } + }.use { agent -> + agent.run("") + + val actualEvents = writer.messages.filter { event -> + event is LLMStreamingStartingEvent || + event is LLMStreamingFrameReceivedEvent || + event is LLMStreamingFailedEvent || + event is LLMStreamingCompletedEvent + } + + val expectedPrompt = Prompt( + messages = listOf( + systemMessage(systemPrompt), + userMessage(userPrompt), + assistantMessage(assistantPrompt) + ), + id = promptId + ) + + val expectedEvents = listOf( + LLMStreamingStartingEvent( + runId = writer.runId, + prompt = expectedPrompt, + model = model.eventString, + tools = toolRegistry.tools.map { it.name }, + timestamp = testClock.now().toEpochMilliseconds() + ), + LLMStreamingFrameReceivedEvent( + runId = writer.runId, + frame = StreamFrame.Append(testLLMResponse), + timestamp = testClock.now().toEpochMilliseconds() + ), + LLMStreamingCompletedEvent( + runId = writer.runId, + prompt = expectedPrompt, + model = model.eventString, + tools = toolRegistry.tools.map { it.name }, + timestamp = testClock.now().toEpochMilliseconds() + ) + ) + + assertEquals(expectedEvents.size, actualEvents.size) + assertContentEquals(expectedEvents, actualEvents) + } + } + } + + @Test + fun `test llm streaming events failure`() = runBlocking { + val userPrompt = "Call the dummy tool with argument: test" + val systemPrompt = "Test system prompt" + val assistantPrompt = "Test assistant prompt" + val promptId = "Test prompt id" + val model = OpenAIModels.Chat.GPT4o + + val strategy = strategy("tracing-streaming-failure") { + val streamAndCollect by nodeLLMRequestStreamingAndSendResults("stream-and-collect") + + edge(nodeStart forwardTo streamAndCollect) + edge(streamAndCollect forwardTo nodeFinish transformed { messages -> messages.firstOrNull()?.content ?: "" }) + } + + val toolRegistry = ToolRegistry { tool(DummyTool()) } + + val testStreamingErrorMessage = "Test streaming error" + var testStreamingStackTrace = "" + + val testStreamingExecutor = object : PromptExecutor { + override suspend fun execute( + prompt: Prompt, + model: ai.koog.prompt.llm.LLModel, + tools: List + ): List = emptyList() + + override fun executeStreaming( + prompt: Prompt, + model: ai.koog.prompt.llm.LLModel, + tools: List + ): Flow = flow { + val testException = IllegalStateException(testStreamingErrorMessage) + testStreamingStackTrace = testException.stackTraceToString() + throw testException + } + + override suspend fun moderate( + prompt: Prompt, + model: ai.koog.prompt.llm.LLModel + ): ai.koog.prompt.dsl.ModerationResult { + throw UnsupportedOperationException("Not used in test") + } + } + + TestFeatureMessageWriter().use { writer -> + + createAgent( + systemPrompt = systemPrompt, + userPrompt = userPrompt, + assistantPrompt = assistantPrompt, + promptId = promptId, + model = model, + strategy = strategy, + promptExecutor = testStreamingExecutor, + toolRegistry = toolRegistry, + ) { + install(Tracing) { + addMessageProcessor(writer) + } + }.use { agent -> + val throwable = assertFails { + agent.run("") + } + + assertEquals(testStreamingErrorMessage, throwable.message) + + val expectedPrompt = Prompt( + messages = listOf( + systemMessage(systemPrompt), + userMessage(userPrompt), + assistantMessage(assistantPrompt), + ), + id = promptId + ) + + val actualEvents = writer.messages.filter { event -> + event is LLMStreamingStartingEvent || + event is LLMStreamingFrameReceivedEvent || + event is LLMStreamingFailedEvent || + event is LLMStreamingCompletedEvent + } + + val expectedEvents = listOf( + LLMStreamingStartingEvent( + runId = writer.runId, + prompt = expectedPrompt, + model = model.eventString, + tools = toolRegistry.tools.map { it.name }, + timestamp = testClock.now().toEpochMilliseconds() + ), + LLMStreamingFailedEvent( + runId = writer.runId, + error = AIAgentError(testStreamingErrorMessage, testStreamingStackTrace), + timestamp = testClock.now().toEpochMilliseconds() + ), + LLMStreamingCompletedEvent( + runId = writer.runId, + prompt = expectedPrompt, + model = model.eventString, + tools = toolRegistry.tools.map { it.name }, + timestamp = testClock.now().toEpochMilliseconds() + ) + ) + + assertEquals(expectedEvents.size, actualEvents.size) + assertContentEquals(expectedEvents, actualEvents) + } + } + } } diff --git a/agents/agents-test/TESTING.md b/agents/agents-test/TESTING.md index 310d67026b..7c9c266bd2 100644 --- a/agents/agents-test/TESTING.md +++ b/agents/agents-test/TESTING.md @@ -307,7 +307,7 @@ fun testToneAgent() = runTest { // Create an event handler val eventHandler = EventHandler { - onToolCall { tool, args -> + onToolCallStarting { tool, args -> println("[DEBUG_LOG] Tool called: tool ${tool.name}, args $args") toolCalls.add(tool.name) } diff --git a/agents/agents-test/src/jvmTest/kotlin/ai/koog/agents/test/SimpleAgentMockedTest.kt b/agents/agents-test/src/jvmTest/kotlin/ai/koog/agents/test/SimpleAgentMockedTest.kt index 94d2abff54..f469623512 100644 --- a/agents/agents-test/src/jvmTest/kotlin/ai/koog/agents/test/SimpleAgentMockedTest.kt +++ b/agents/agents-test/src/jvmTest/kotlin/ai/koog/agents/test/SimpleAgentMockedTest.kt @@ -67,29 +67,29 @@ class SimpleAgentMockedTest { } val eventHandlerConfig: EventHandlerConfig.() -> Unit = { - onToolCall { eventContext -> + onToolCallStarting { eventContext -> println("Tool called: tool ${eventContext.tool.name}, args ${eventContext.toolArgs}") actualToolCalls.add(eventContext.tool.name) iterationCount++ } - onAgentRunError { eventContext -> + onAgentExecutionFailed { eventContext -> errors.add(eventContext.throwable) } - onToolCall { eventContext -> + onToolCallStarting { eventContext -> println("Tool called: tool ${eventContext.tool.name}, args ${eventContext.toolArgs}") actualToolCalls.add(eventContext.tool.name) } - onToolCallFailure { eventContext -> + onToolCallFailed { eventContext -> println( "Tool call failure: tool ${eventContext.tool.name}, args ${eventContext.toolArgs}, error=${eventContext.throwable.message}" ) errors.add(eventContext.throwable) } - onAgentFinished { eventContext -> + onAgentCompleted { eventContext -> results.add(eventContext.result) } } diff --git a/build.gradle.kts b/build.gradle.kts index 0f513eadb9..a4252f26b3 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -18,7 +18,7 @@ group = "ai.koog" version = run { // our version follows the semver specification - val main = "0.4.3" + val main = "0.5.0" val feat = run { val releaseBuild = !System.getenv("BRANCH_KOOG_IS_RELEASING_FROM").isNullOrBlank() diff --git a/docs/docs/a2a-client.md b/docs/docs/a2a-client.md new file mode 100644 index 0000000000..27df9fcf4e --- /dev/null +++ b/docs/docs/a2a-client.md @@ -0,0 +1,202 @@ +# A2A Client + +The A2A client enables you to communicate with A2A-compliant agents over the network. +It provides a complete implementation of +the [A2A protocol specification](https://a2a-protocol.org/latest/specification/), handling agent discovery, message +exchange, task management, and real-time streaming responses. + +## Overview + +The A2A client acts as a bridge between your application and A2A-compliant agents. +It orchestrates the entire communication lifecycle while maintaining protocol compliance and providing robust session +management. + +## Core components + +### A2AClient + +The main client class implementing the complete A2A protocol. It serves as the central coordinator that: + +- **Manages** connections and agent discovery through pluggable resolvers +- **Orchestrates** message exchange and task operations with automatic protocol compliance +- **Handles** streaming responses and real-time communication when supported by agents +- **Provides** comprehensive error handling and fallback mechanisms for robust applications + +The `A2AClient` accepts two required parameters: + +* `ClientTransport` which handles network communication layer +* `AgentCardResolver` which handles agent discovery and metadata retrieval + +The `A2AClient` interface provides several key methods for interacting with A2A agents: + +* `connect` method - To connect to the agent and retrieve its capabilities, which discovers what the agent can do and + caches the AgentCard +* `sendMessage` method - To send a message to the agent and receive a single response for simple request-response + patterns +* `sendMessageStreaming` method - To send a message with streaming support for real-time responses, which returns a Flow + of events including partial messages and task updates +* `getTask` method - To query the status and details of a specific task +* `cancelTask` method - To cancel a running task if the agent supports cancellation +* `cachedAgentCard` method - To get the cached agent card without making a network request, which returns null if + connect hasn't been called yet + +### ClientTransport + +The `ClientTransport` interface handles the low-level network communication while the A2A client manages the protocol +logic. +It abstracts away transport-specific details, allowing you to use different protocols seamlessly. + +#### HTTP JSON-RPC Transport + +The most common transport for A2A agents: + +```kotlin +val transport = HttpJSONRPCClientTransport( + url = "https://agent.example.com/a2a", // Agent endpoint URL + httpClient = HttpClient(CIO) { // Optional: custom HTTP client + install(ContentNegotiation) { + json() + } + install(HttpTimeout) { + requestTimeoutMillis = 30000 + } + } +) +``` + +### AgentCardResolver + +The `AgentCardResolver` interface retrieves agent metadata and capabilities. +It enables agent discovery from various sources and supports caching strategies for optimal performance. + +#### URL Agent Card Resolver + +Fetch agent cards from HTTP endpoints following A2A conventions: + +```kotlin +val agentCardResolver = UrlAgentCardResolver( + baseUrl = "https://agent.example.com", // Base URL of the agent service + path = "/.well-known/agent-card.json", // Standard agent card location + httpClient = HttpClient(CIO), // Optional: custom HTTP client +) +``` + +## Quickstart + +### 1. Create the Client + +Define the transport and agent card resolver and create the client. + +```kotlin +// HTTP JSON-RPC transport +val transport = HttpJSONRPCClientTransport( + url = "https://agent.example.com/a2a" +) + +// Agent card resolver +val agentCardResolver = UrlAgentCardResolver( + baseUrl = "https://agent.example.com", + path = "/.well-known/agent-card.json" +) + +// Create client +val client = A2AClient(transport, agentCardResolver) +``` + +### 2. Connect and Discover + +Connect to the agent and retrieve its card. +Having agent's card enables you to query its capabilities and perform other operations, for example, check if it +supports streaming. + +```kotlin +// Connect and retrieve agent capabilities +client.connect() +val agentCard = client.cachedAgentCard() + +println("Connected to: ${agentCard.name}") +println("Supports streaming: ${agentCard.capabilities.streaming}") +``` + +### 3. Send Messages + +Send a message to the agent and receive a single response. +The response can be either the message if the agent responded directly, or a task event if the agent is performing a +task. + +```kotlin +val message = Message( + messageId = UUID.randomUUID().toString(), + role = Role.User, + parts = listOf(TextPart("Hello, agent!")), + contextId = "conversation-1" +) + +val request = Request(data = MessageSendParams(message)) +val response = client.sendMessage(request) + +// Handle response +when (val event = response.data) { + is Message -> { + val text = event.parts + .filterIsInstance() + .joinToString { it.text } + print(text) // Stream partial responses + } + is TaskEvent -> { + if (event.final) { + println("\nTask completed") + } + } +} +``` + +### 4. Send Messages Streaming + +The A2A client supports streaming responses for real-time communication. +Instead of receiving a single response, it returns a `Flow` of events including messages and task updates. + +```kotlin +// Check if agent supports streaming +if (client.cachedAgentCard()?.capabilities?.streaming == true) { + client.sendMessageStreaming(request).collect { response -> + when (val event = response.data) { + is Message -> { + val text = event.parts + .filterIsInstance() + .joinToString { it.text } + print(text) // Stream partial responses + } + is TaskStatusUpdateEvent -> { + if (event.final) { + println("\nTask completed") + } + } + } + } +} else { + // Fallback to non-streaming + val response = client.sendMessage(request) + // Handle single response +} +``` + +### 5. Manage Tasks + +A2A Client provides methods to control server tasks by asking for their status and cancelling them. + +```kotlin +// Query task status +val taskRequest = Request(data = TaskQueryParams(taskId = "task-123")) +val taskResponse = client.getTask(taskRequest) +val task = taskResponse.data + +println("Task state: ${task.status.state}") + +// Cancel running task +if (task.status.state == TaskState.Working) { + val cancelRequest = Request(data = TaskIdParams(taskId = "task-123")) + val cancelledTask = client.cancelTask(cancelRequest).data + println("Task cancelled: ${cancelledTask.status.state}") +} +``` diff --git a/docs/docs/a2a-koog-integration.md b/docs/docs/a2a-koog-integration.md new file mode 100644 index 0000000000..89932902f6 --- /dev/null +++ b/docs/docs/a2a-koog-integration.md @@ -0,0 +1,319 @@ +# A2A and Koog Integration + +Koog provides seamless integration with the A2A protocol, allowing you to expose Koog agents as A2A servers and connect +Koog agents to other A2A-compliant agents. + +## Overview + +The integration enables two main patterns: + +1. **Expose Koog agents as A2A servers** - Make your Koog agents discoverable and accessible via the A2A protocol +2. **Connect Koog agents to A2A agents** - Let your Koog agents communicate with other A2A-compliant agents + +## Exposing Koog Agents as A2A Servers + +### Define Koog Agent with A2A feature + +Let's define a Koog agent first. The logic of the agent can vary, but here's an example basic single run agent with +tools. +The agent resaves a message from the user, forwards it to the llm. +If the llm response contains a tool call, the agent executes the tool and forwards the result to the llm. +If the llm response contains an assistant message, the agent sends the assistant message to the user and finishes. + +On input resize, the agent sends a task submitted event to the A2A client with the input message. +On each tool call, the agent sends a task working event to the A2A client with the tool call and result. +On assistant message, the agent sends a task complete event to the A2A client with the assistant message. + +```kotlin +/** + * Create a Koog agent with A2A feature + */ +@OptIn(ExperimentalUuidApi::class) +private fun createAgent( + context: RequestContext, + eventProcessor: SessionEventProcessor, +) = AIAgent( + promptExecutor = MultiLLMPromptExecutor( + LLMProvider.Google to GoogleLLMClient("api-key") + ), + toolRegistry = ToolRegistry { + // Declare tools here + }, + strategy = strategy("test") { + val nodeSetup by node { inputMessage -> + // Convenience function to transform A2A message into Koog message + val input = inputMessage.toKoogMessage() + llm.writeSession { + updatePrompt { + message(input) + } + } + // Send update event to A2A client + withA2AAgentServer { + sendTaskUpdate("Request submitted: ${input.content}", TaskState.Submitted) + } + } + + // Calling llm + val nodeLLMRequest by node { + llm.writeSession { + requestLLM() + } + } + + // Executing tool + val nodeProcessTool by node { toolCall -> + withA2AAgentServer { + sendTaskUpdate("Executing tool: ${toolCall.content}", TaskState.Working) + } + + val toolResult = environment.executeTool(toolCall) + + llm.writeSession { + updatePrompt { + tool { + result(toolResult) + } + } + } + withA2AAgentServer { + sendTaskUpdate("Tool result: ${toolResult.content}", TaskState.Working) + } + } + + // Sending assistant message + val nodeProcessAssistant by node { assistantMessage -> + withA2AAgentServer { + sendTaskUpdate(assistantMessage, TaskState.Completed) + } + } + + edge(nodeStart forwardTo nodeSetup) + edge(nodeSetup forwardTo nodeLLMRequest) + + // If a tool call is returned from llm, forward to the tool processing node and then back to llm + edge(nodeLLMRequest forwardTo nodeProcessTool onToolCall { true }) + edge(nodeProcessTool forwardTo nodeLLMRequest) + + // If an assistant message is returned from llm, forward to the assistant processing node and then to finish + edge(nodeLLMRequest forwardTo nodeProcessAssistant onAssistantMessage { true }) + edge(nodeProcessAssistant forwardTo nodeFinish) + }, + agentConfig = AIAgentConfig( + prompt = prompt("agent") { system("You are a helpful assistant.") }, + model = GoogleModels.Gemini2_5Pro, + maxAgentIterations = 10 + ), +) { + install(A2AAgentServer) { + this.context = context + this.eventProcessor = eventProcessor + } +} + +/** + * Convenience function to send task update event to A2A client + * @param content The message content + * @param state The task state + */ +@OptIn(ExperimentalUuidApi::class) +private suspend fun A2AAgentServer.sendTaskUpdate( + content: String, + state: TaskState, +) { + val message = A2AMessage( + messageId = Uuid.random().toString(), + role = Role.Agent, + parts = listOf( + TextPart(content) + ), + contextId = context.contextId, + taskId = context.taskId, + ) + + val task = Task( + id = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = state, + message = message, + timestamp = Clock.System.now(), + ) + ) + eventProcessor.sendTaskEvent(task) +} +``` + +## A2AAgentServer Feature Mechanism + +The `A2AAgentServer` is a Koog agent feature that enables seamless integration between Koog agents and the A2A protocol. +The `A2AAgentServer` feature provides access to the `RequestContext` and `SessionEventProcessor` entities, which are used to +communicate with the A2A client inside the Koog agent. + +To install the feature, call the `install` function on the agent and pass the `A2AAgentServer` feature along with the `RequestContext` and `SessionEventProcessor`: +```kotlin +// Install the feature +agent.install(A2AAgentServer) { + this.context = context + this.eventProcessor = eventProcessor +} +``` + +To access these entities from Koog agent strategy, the feature provides a `withA2AAgentServer` function that allows agent nodes to access A2A server capabilities within their execution context. +It retrieves the installed `A2AAgentServer` feature and provides it as the receiver for the action block. + +```kotlin +// Usage within agent nodes +withA2AAgentServer { + // 'this' is now A2AAgentServer instance + sendTaskUpdate("Processing your request...", TaskState.Working) +} +``` + +### Start A2A Server +After running the server Koog agent will be discoverable and accessible via the A2A protocol. + +```kotlin +val agentCard = AgentCard( + name = "Koog Agent", + url = "http://localhost:9999/koog", + description = "Simple universal agent powered by Koog", + version = "1.0.0", + protocolVersion = "0.3.0", + preferredTransport = TransportProtocol.JSONRPC, + capabilities = AgentCapabilities(streaming = true), + defaultInputModes = listOf("text"), + defaultOutputModes = listOf("text"), + skills = listOf( + AgentSkill( + id = "koog", + name = "Koog Agent", + description = "Universal agent powered by Koog. Supports tool calling.", + tags = listOf("chat", "tool"), + ) + ) +) +// Server setup +val server = A2AServer(agentExecutor = KoogAgentExecutor(), agentCard = agentCard) +val transport = HttpJSONRPCServerTransport(server) +transport.start(engineFactory = CIO, port = 8080, path = "/chat", wait = true) +``` + +## Connecting Koog Agents to A2A Agents + +### Create A2A Client and connect to the A2A Server + +```kotlin +val transport = HttpJSONRPCClientTransport(url = "http://localhost:9999/koog") +val agentCardResolver = + UrlAgentCardResolver(baseUrl = "http://localhost:9999", path = "/koog") +val client = A2AClient(transport = transport, agentCardResolver = agentCardResolver) + +val agentId = "koog" +client.connect() +``` + +### Create Koog Agent and add A2A Client to A2AAgentClient Feature +To connect to A2A agent from your Koog Agent, you can use the A2AAgentClient feature, which provides a client API for connecting to A2A agents. +The principle of the client is the same as the server: you install the feature and pass the `A2AAgentClient` feature along with the `RequestContext` and `SessionEventProcessor`. + +```kotlin +val agent = AIAgent( + promptExecutor = MultiLLMPromptExecutor( + LLMProvider.Google to GoogleLLMClient("api-key") + ), + toolRegistry = ToolRegistry { + // declare tools here + }, + strategy = strategy("test") { + + val nodeCheckStreaming by nodeA2AClientGetAgentCard().transform { it.capabilities.streaming } + + val nodeA2ASendMessageStreaming by nodeA2AClientSendMessageStreaming() + val nodeA2ASendMessage by nodeA2AClientSendMessage() + + val nodeProcessStreaming by node>, Unit> { + it.collect { response -> + when (response.data) { + is Task -> { + // Process task + } + + is A2AMessage -> { + // Process message + } + + is TaskStatusUpdateEvent -> { + // Process task status update + } + + is TaskArtifactUpdateEvent -> { + // Process task artifact update + } + } + } + } + + val nodeProcessEvent by node { event -> + when (event) { + is Task -> { + // Process task + } + + is A2AMessage -> { + // Process message + } + } + } + + // If streaming is supported, send a message, process response and finish + edge(nodeStart forwardTo nodeCheckStreaming transformed { agentId }) + edge( + nodeCheckStreaming forwardTo nodeA2ASendMessageStreaming + onCondition { it == true } transformed { buildA2ARequest(agentId) } + ) + edge(nodeA2ASendMessageStreaming forwardTo nodeProcessStreaming) + edge(nodeProcessStreaming forwardTo nodeFinish) + + // If streaming is not supported, send a message, process response and finish + edge( + nodeCheckStreaming forwardTo nodeA2ASendMessage + onCondition { it == false } transformed { buildA2ARequest(agentId) } + ) + edge(nodeA2ASendMessage forwardTo nodeProcessEvent) + edge(nodeProcessEvent forwardTo nodeFinish) + + // If streaming is not supported, send a message, process response and finish + edge(nodeCheckStreaming forwardTo nodeFinish onCondition { it == null } + transformed { println("Failed to get agents card") } + ) + + }, + agentConfig = AIAgentConfig( + prompt = prompt("agent") { system("You are a helpful assistant.") }, + model = GoogleModels.Gemini2_5Pro, + maxAgentIterations = 10 + ), +) { + install(A2AAgentClient) { + this.a2aClients = mapOf(agentId to client) + } +} + + +@OptIn(ExperimentalUuidApi::class) +private fun AIAgentGraphContextBase.buildA2ARequest(agentId: String): A2AClientRequest = + A2AClientRequest( + agentId = agentId, + callContext = ClientCallContext.Default, + params = MessageSendParams( + message = A2AMessage( + messageId = Uuid.random().toString(), + role = Role.User, + parts = listOf( + TextPart(agentInput as String) + ) + ) + ) + ) +``` diff --git a/docs/docs/a2a-protocol-overview.md b/docs/docs/a2a-protocol-overview.md new file mode 100644 index 0000000000..947b25ffae --- /dev/null +++ b/docs/docs/a2a-protocol-overview.md @@ -0,0 +1,16 @@ +# A2A protocol + +This page provides an overview of the A2A (Agent-to-Agent) protocol implementation in the Koog agentic framework. + +## What is the A2A protocol? + +The A2A (Agent-to-Agent) protocol is a standardized communication protocol that enables AI agents to interact with each other and with client applications. +It defines a set of methods, message formats, and behaviors that allow for consistent and interoperable agent communication. +For more information and a detailed specification of the A2A protocol, see the official [A2A Protocol website](https://a2a-protocol.org/latest/). + +## Key A2A components + +Koog provides full implementation of A2A protocol v0.3.0 for both client and server, as well as integration with the Koog agent framework: +- [A2A Server](a2a-server.md) is an agent or agentic system that exposes an endpoint implementing the A2A protocol. It receives requests from clients, processes tasks, and returns results or status updates. It can also be used independently of Koog agents. +- [A2A Client](a2a-client.md) is a client application or agent that initiates communication with an A2A server using the A2A protocol. It can also be used independently of Koog agents. +- [A2A Koog Integration](a2a-koog-integration.md) is a set of classes and utilities that simplify the integration of A2A with Koog Agents. It contains components (A2A features and nodes) for seamless A2A agent connections and communication within the Koog framework. diff --git a/docs/docs/a2a-server.md b/docs/docs/a2a-server.md new file mode 100644 index 0000000000..7e8ad7d7c5 --- /dev/null +++ b/docs/docs/a2a-server.md @@ -0,0 +1,351 @@ +# A2A Server + +The A2A server enables you to expose AI agents through the standardized A2A (Agent-to-Agent) protocol. It provides a complete implementation of the [A2A protocol specification](https://a2a-protocol.org/latest/specification/), handling client requests, executing agent logic, managing complex task lifecycles, and supporting real-time streaming responses. + +## Overview + +The A2A server acts as a bridge between the A2A protocol transport layer and your custom agent logic. +It orchestrates the entire request lifecycle while maintaining protocol compliance and providing robust session management. + +## Core components + +### A2AServer + +The main server class implementing the complete A2A protocol. It serves as the central coordinator that: + +- **Validates** incoming requests against protocol specifications +- **Manages** concurrent sessions and task lifecycles +- **Orchestrates** communication between transport, storage, and business logic layers +- **Handles** all protocol operations: message sending, task querying, cancellation, push notifications + +The `A2AServer` accepts two required parameters: +* `AgentExecutor` which defines business logic implementation of the agent +* `AgentCard` which defines agent capabilities and metadata + +And a number of optional parameters that can be used to customize its storage and transport behavior. + +### AgentExecutor + +The `AgentExecutor` interface is where you implement your agent's core business logic. +It acts as the bridge between the A2A protocol and your specific AI agent capabilities. +To start the execution of your agent, you must implement the `execute` method where define your agent's logic. +To cancel the agent, you must implement the `cancel` method. + +```kotlin +class MyAgentExecutor : AgentExecutor { + override suspend fun execute( + context: RequestContext, + eventProcessor: SessionEventProcessor + ) { + // Agent logic here + } + + override suspend fun cancel( + context: RequestContext, + eventProcessor: SessionEventProcessor, + agentJob: Deferred? + ) { + // Cancel agent here, optional + } +} +``` + +The `RequestContext` provides rich information about the current request, +including the `contextId` and `taskId` of the current session, the `message` sent, and the `params` of the request. + +The `SessionEventProcessor` communicates with clients: +- **`sendMessage(message)`**: Send immediate responses (chat-style interactions) +- **`sendTaskEvent(event)`**: Send task-related updates (long-running operations) + +```kotlin +// For immediate responses (like chatbots) +eventProcessor.sendMessage( + Message( + messageId = generateId(), + role = Role.Agent, + parts = listOf(TextPart("Here's your answer!")), + contextId = context.contextId + ) +) + +// For task-based operations +eventProcessor.sendTaskEvent( + TaskStatusUpdateEvent( + contextId = context.contextId, + taskId = context.taskId, + status = TaskStatus( + state = TaskState.Working, + message = Message(/* progress update */), + timestamp = Clock.System.now() + ), + final = false // More updates to come + ) +) +``` + +### AgentCard + +The `AgentCard` serves as your agent's self-describing manifest. It tells clients what your agent can do, how to communicate with it, and what security requirements it has. + +```kotlin +val agentCard = AgentCard( + // Basic Identity + name = "Advanced Recipe Assistant", + description = "AI agent specialized in cooking advice, recipe generation, and meal planning", + version = "2.1.0", + protocolVersion = "0.3.0", + + // Communication Settings + url = "https://api.example.com/a2a", + preferredTransport = TransportProtocol.JSONRPC, + + // Optional: Multiple transport support + additionalInterfaces = listOf( + AgentInterface("https://api.example.com/a2a", TransportProtocol.JSONRPC), + ), + + // Capabilities Declaration + capabilities = AgentCapabilities( + streaming = true, // Support real-time responses + pushNotifications = true, // Send async notifications + stateTransitionHistory = true // Maintain task history + ), + + // Content Type Support + defaultInputModes = listOf("text/plain", "text/markdown", "image/jpeg"), + defaultOutputModes = listOf("text/plain", "text/markdown", "application/json"), + + // Define available security schemes + securitySchemes = mapOf( + "bearer" to HTTPAuthSecurityScheme( + scheme = "Bearer", + bearerFormat = "JWT", + description = "JWT token authentication" + ), + "api-key" to APIKeySecurityScheme( + `in` = In.Header, + name = "X-API-Key", + description = "API key for service authentication" + ) + ), + + // Specify security requirements (logical OR of requirements) + security = listOf( + mapOf("bearer" to listOf("read", "write")), // Option 1: JWT with read/write scopes + mapOf("api-key" to emptyList()) // Option 2: API key + ), + + // Enable extended card for authenticated users + supportsAuthenticatedExtendedCard = true, + + // Skills/Capabilities + skills = listOf( + AgentSkill( + id = "recipe-generation", + name = "Recipe Generation", + description = "Generate custom recipes based on ingredients, dietary restrictions, and preferences", + tags = listOf("cooking", "recipes", "nutrition"), + examples = listOf( + "Create a vegan pasta recipe with mushrooms", + "I have chicken, rice, and vegetables. What can I make?" + ) + ), + AgentSkill( + id = "meal-planning", + name = "Meal Planning", + description = "Plan weekly meals and generate shopping lists", + tags = listOf("meal-planning", "nutrition", "shopping") + ) + ), + + // Optional: Branding + iconUrl = "https://example.com/agent-icon.png", + documentationUrl = "https://docs.example.com/recipe-agent", + provider = AgentProvider( + organization = "CookingAI Inc.", + url = "https://cookingai.com" + ) +) +``` + +### Transport Layer + +The A2A itself supports multiple transport protocols for communicating with clients. +Currently, Koog provides implementations for JSON-RPC server transport over HTTP. + +#### HTTP JSON-RPC Transport + +```kotlin +val transport = HttpJSONRPCServerTransport(server) +transport.start( + engineFactory = CIO, // Ktor engine (CIO, Netty, Jetty) + port = 8080, // Server port + path = "/a2a", // API endpoint path + wait = true // Block until server stops +) +``` + +### Storage + +The A2A server uses a pluggable storage architecture that separates different types of data. +All storage implementations are optional and default to in-memory variants for development. + +- **TaskStorage**: Task lifecycle management - stores and manages task states, history, and artifacts +- **MessageStorage**: Conversation history - manages message history within conversation contexts +- **PushNotificationConfigStorage**: Webhook management - manages webhook configurations for asynchronous notifications + +## Quickstart + +### 1. Create AgentCard +Define your agent's capabilities and metadata. +```kotlin +val agentCard = AgentCard( + name = "IO Assistant", + description = "AI agent specialized in input modification", + version = "2.1.0", + protocolVersion = "0.3.0", + + // Communication Settings + url = "https://api.example.com/a2a", + preferredTransport = TransportProtocol.JSONRPC, + + // Capabilities Declaration + capabilities = + AgentCapabilities( + streaming = true, // Support real-time responses + pushNotifications = true, // Send async notifications + stateTransitionHistory = true // Maintain task history + ), + + // Content Type Support + defaultInputModes = listOf("text/plain", "text/markdown", "image/jpeg"), + defaultOutputModes = listOf("text/plain", "text/markdown", "application/json"), + + // Skills/Capabilities + skills = listOf( + AgentSkill( + id = "echo", + name = "echo", + description = "Echoes back user messages", + tags = listOf("io"), + ) + ) +) +``` + + +### 2. Create an AgentExecutor +In executor manages implement agent logic, handles incoming requests and sends responses. + +```kotlin +class EchoAgentExecutor : AgentExecutor { + override suspend fun execute( + context: RequestContext, + eventProcessor: SessionEventProcessor + ) { + val userMessage = context.params.message + val userText = userMessage.parts + .filterIsInstance() + .joinToString(" ") { it.text } + + // Echo the user's message back + val response = Message( + messageId = UUID.randomUUID().toString(), + role = Role.Agent, + parts = listOf(TextPart("You said: $userText")), + contextId = context.contextId, + taskId = context.taskId + ) + + eventProcessor.sendMessage(response) + } +} +``` + +### 2. Create the Server +Pass the agent executor and agent card to the server. + +```kotlin +val server = A2AServer( + agentExecutor = EchoAgentExecutor(), + agentCard = agentCard +) +``` + +### 3. Add Transport Layer +Create a transport layer and start the server. +```kotlin +// HTTP JSON-RPC transport +val transport = HttpJSONRPCServerTransport(server) +transport.start( + engineFactory = CIO, + port = 8080, + path = "/agent", + wait = true +) +``` + +## Agent Implementation Patterns + +### Simple Response Agent +If your agent only needs to respond to a single message, you can implement it as a simple agent. +It can be also used if agent execution logic is not complex and time-consuming. + +```kotlin +class SimpleAgentExecutor : AgentExecutor { + override suspend fun execute( + context: RequestContext, + eventProcessor: SessionEventProcessor + ) { + val response = Message( + messageId = UUID.randomUUID().toString(), + role = Role.Agent, + parts = listOf(TextPart("Hello from agent!")), + contextId = context.contextId, + taskId = context.taskId + ) + + eventProcessor.sendMessage(response) + } +} +``` + +### Task-Based Agent +If the execution logic of your agent is complex and requires multiple steps, you can implement it as a task-based agent. +It can be also used if agent execution logic is time-consuming and suspending. +```kotlin +class TaskAgentExecutor : AgentExecutor { + override suspend fun execute( + context: RequestContext, + eventProcessor: SessionEventProcessor + ) { + // Send working status + eventProcessor.sendTaskEvent( + TaskStatusUpdateEvent( + contextId = context.contextId, + taskId = context.taskId, + status = TaskStatus( + state = TaskState.Working, + timestamp = Clock.System.now() + ), + final = false + ) + ) + + // Do work... + + // Send completion + eventProcessor.sendTaskEvent( + TaskStatusUpdateEvent( + contextId = context.contextId, + taskId = context.taskId, + status = TaskStatus( + state = TaskState.Completed, + timestamp = Clock.System.now() + ), + final = true + ) + ) + } +} +``` diff --git a/docs/docs/act-ai-agent.md b/docs/docs/act-ai-agent.md index 664fed43db..fc3cf8165d 100644 --- a/docs/docs/act-ai-agent.md +++ b/docs/docs/act-ai-agent.md @@ -112,7 +112,7 @@ val observed = functionalAIAgent( toolRegistry = tools, featureContext = { install(EventHandler) { - onToolCall { e -> println("Tool called: ${'$'}{e.tool.name}, args: ${'$'}{e.toolArgs}") } + onToolCallStarting { e -> println("Tool called: ${'$'}{e.tool.name}, args: ${'$'}{e.toolArgs}") } } } ) { input -> diff --git a/docs/docs/agent-event-handlers.md b/docs/docs/agent-event-handlers.md new file mode 100644 index 0000000000..c7c2ef41b6 --- /dev/null +++ b/docs/docs/agent-event-handlers.md @@ -0,0 +1,93 @@ +# Event handlers + +You can monitor and respond to specific events during the agent workflow by using event handlers for logging, testing, debugging, and extending agent behavior. + +## Feature overview + +The EventHandler feature lets you hook into various agent events. It serves as an event delegation mechanism that: + +- Manages the lifecycle of AI agent operations. +- Provides hooks for monitoring and responding to different stages of the workflow. +- Enables error handling and recovery. +- Facilitates tool invocation tracking and result processing. + + + + +### Installation and configuration + +The EventHandler feature integrates with the agent workflow through the `EventHandler` class, +which provides a way to register callbacks for different agent events, and can be installed as a feature in the agent configuration. For details, see [API reference](https://api.koog.ai/agents/agents-features/agents-features-event-handler/ai.koog.agents.features.eventHandler.feature/-event-handler/index.html). + +To install the feature and configure event handlers for the agent, do the following: + + + + +```kotlin +handleEvents { + // Handle tool calls + onToolCallStarting { eventContext -> + println("Tool called: ${eventContext.tool.name} with args ${eventContext.toolArgs}") + } + // Handle event triggered when the agent completes its execution + onAgentCompleted { eventContext -> + println("Agent finished with result: ${eventContext.result}") + } + + // Other event handlers +} +``` + + +For more details about event handler configuration, see [API reference](https://api.koog.ai/agents/agents-features/agents-features-event-handler/ai.koog.agents.features.eventHandler.feature/-event-handler-config/index.html). + +You can also set up event handlers using the `handleEvents` extension function when creating an agent. +This function also installs the event handler feature and configures event handlers for the agent. Here is an example: + + +```kotlin +val agent = AIAgent( + promptExecutor = simpleOllamaAIExecutor(), + llmModel = OllamaModels.Meta.LLAMA_3_2, +){ + handleEvents { + // Handle tool calls + onToolCallStarting { eventContext -> + println("Tool called: ${eventContext.tool.name} with args ${eventContext.toolArgs}") + } + // Handle event triggered when the agent completes its execution + onAgentCompleted { eventContext -> + println("Agent finished with result: ${eventContext.result}") + } + + // Other event handlers + } +} +``` + diff --git a/docs/docs/agent-events.md b/docs/docs/agent-events.md index a0674636e6..9ced231f5a 100644 --- a/docs/docs/agent-events.md +++ b/docs/docs/agent-events.md @@ -4,99 +4,431 @@ Agent events are actions or interactions that occur as part of an agent workflow - Agent lifecycle events - Strategy events -- Node events +- Node execution events - LLM call events -- Tool call events +- LLM streaming events +- Tool execution events -## Event handlers +Note: Feature events are defined in the agents-core module and live under the package `ai.koog.agents.core.feature.model.events`. Features such as `agents-features-trace`, `agents-features-debugger`, and `agents-features-event-handler` consume these events to process and forward messages created during agent execution. -You can monitor and respond to specific events during the agent workflow by using event handlers for logging, testing, debugging, and extending agent behavior. +## Predefined event types -The EventHandler feature lets you hook into various agent events. It serves as an event delegation mechanism that: +Koog provides predefined event types that can be used in custom message processors. The predefined events can be +classified into several categories, depending on the entity they relate to: -- Manages the lifecycle of AI agent operations. -- Provides hooks for monitoring and responding to different stages of the workflow. -- Enables error handling and recovery. -- Facilitates tool invocation tracking and result processing. +- [Agent events](#agent-events) +- [Strategy events](#strategy-events) +- [Node events](#node-events) +- [LLM call events](#llm-call-events) +- [LLM streaming events](#llm-streaming-events) +- [Tool execution events](#tool-execution-events) - +Represents the start of an agent run. Includes the following fields: +| Name | Data type | Required | Default | Description | +|------------|-----------|----------|---------|---------------------------------------------------------| +| `agentId` | String | Yes | | The unique identifier of the AI agent. | +| `runId` | String | Yes | | The unique identifier of the AI agent run. | -### Installation and configuration +#### AgentCompletedEvent -The EventHandler feature integrates with the agent workflow through the `EventHandler` class, -which provides a way to register callbacks for different agent events, and can be installed as a feature in the agent configuration. For details, see [API reference](https://api.koog. -ai/agents/agents-features/agents-features-event-handler/ai.koog.agents.local.features.eventHandler.feature/-event-handler/index.html). +Represents the end of an agent run. Includes the following fields: -To install the feature and configure event handlers for the agent, do the following: +| Name | Data type | Required | Default | Description | +|------------|-----------|----------|---------|---------------------------------------------------------------------| +| `agentId` | String | Yes | | The unique identifier of the AI agent. | +| `runId` | String | Yes | | The unique identifier of the AI agent run. | +| `result` | String | Yes | | The result of the agent run. Can be `null` if there is no result. | + +#### AgentExecutionFailedEvent + +Represents the occurrence of an error during an agent run. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|-----------|---------------|----------|---------|-----------------------------------------------------------------------------------------------------------------| +| `agentId` | String | Yes | | The unique identifier of the AI agent. | +| `runId` | String | Yes | | The unique identifier of the AI agent run. | +| `error` | AIAgentError | Yes | | The specific error that occurred during the agent run. For more information, see [AIAgentError](#aiagenterror). | + +#### AgentClosingEvent + +Represents the closure or termination of an agent. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|-----------|-----------|----------|---------|---------------------------------------------------------| +| `agentId` | String | Yes | | The unique identifier of the AI agent. | + + +The `AIAgentError` class provides more details about an error that occurred during an agent run. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|--------------|-----------|----------|---------|------------------------------------------------------------------| +| `message` | String | Yes | | The message that provides more details about the specific error. | +| `stackTrace` | String | Yes | | The collection of stack records until the last executed code. | +| `cause` | String | No | null | The cause of the error, if available. | + +### Strategy events + +#### GraphStrategyStartingEvent + +Represents the start of a graph-based strategy run. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|-----------------|------------------------|----------|---------|----------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the strategy run. | +| `strategyName` | String | Yes | | The name of the strategy. | +| `graph` | StrategyEventGraph | Yes | | The graph structure representing the strategy workflow. | + +#### FunctionalStrategyStartingEvent + +Represents the start of a functional strategy run. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|-----------------|-----------|----------|---------|--------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the strategy run. | +| `strategyName` | String | Yes | | The name of the strategy. | + +#### StrategyCompletedEvent + +Represents the end of a strategy run. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|----------------|-----------|----------|---------|--------------------------------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the strategy run. | +| `strategyName` | String | Yes | | The name of the strategy. | +| `result` | String | Yes | | The result of the run. Can be `null` if there is no result. | + +### Node events + +#### NodeExecutionStartingEvent + +Represents the start of a node run. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|------------|-----------|----------|---------|--------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the strategy run. | +| `nodeName` | String | Yes | | The name of the node whose run started. | +| `input` | String | Yes | | The input value for the node. | + +#### NodeExecutionCompletedEvent + +Represents the end of a node run. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|------------|-----------|----------|---------|--------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the strategy run. | +| `nodeName` | String | Yes | | The name of the node whose run ended. | +| `input` | String | Yes | | The input value for the node. | +| `output` | String | Yes | | The output value produced by the node. | + +#### NodeExecutionFailedEvent + +Represents an error that occurred during a node run. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|------------|--------------|----------|---------|-----------------------------------------------------------------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the strategy run. | +| `nodeName` | String | Yes | | The name of the node where the error occurred. | +| `error` | AIAgentError | Yes | | The specific error that occurred during the node run. For more information, see [AIAgentError](#aiagenterror). | + +### LLM call events + +#### LLMCallStartingEvent + +Represents the start of an LLM call. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|----------|--------------------|----------|---------|------------------------------------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the LLM run. | +| `prompt` | Prompt | Yes | | The prompt that is sent to the model. For more information, see [Prompt](#prompt). | +| `model` | String | Yes | | The model identifier in the format `llm_provider:model_id`. | +| `tools` | List | Yes | | The list of tools that the model can call. | + + +The `Prompt` class represents a data structure for a prompt, consisting of a list of messages, a unique identifier, and +optional parameters for language model settings. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|------------|---------------------|----------|-------------|--------------------------------------------------------------| +| `messages` | List | Yes | | The list of messages that the prompt consists of. | +| `id` | String | Yes | | The unique identifier for the prompt. | +| `params` | LLMParams | No | LLMParams() | The settings that control the way the LLM generates content. | + +#### LLMCallCompletedEvent + +Represents the end of an LLM call. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|----------------------|--------------------------------|----------|---------|-----------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the LLM run. | +| `prompt` | Prompt | Yes | | The prompt used in the call. | +| `model` | String | Yes | | The model identifier in the format `llm_provider:model_id`.| +| `responses` | List | Yes | | One or more responses returned by the model. | +| `moderationResponse` | ModerationResult | No | null | The moderation response, if any. | + +### LLM streaming events + +#### LLMStreamingStartingEvent + +Represents the start of an LLM streaming call. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|----------|--------------|----------|---------|-------------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the LLM run. | +| `prompt` | Prompt | Yes | | The prompt that is sent to the model. | +| `model` | String | Yes | | The model identifier in the format `llm_provider:model_id`. | +| `tools` | List | Yes | | The list of tools that the model can call. | + +#### LLMStreamingFrameReceivedEvent + +Represents a streaming frame received from the LLM. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|----------|-------------|----------|---------|--------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the LLM run. | +| `frame` | StreamFrame | Yes | | The frame received from the stream. | + +#### LLMStreamingFailedEvent + +Represents the occurrence of an error during an LLM streaming call. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|---------|--------------|----------|---------|-----------------------------------------------------------------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the LLM run. | +| `error` | AIAgentError | Yes | | The specific error that occurred during streaming. For more information, see [AIAgentError](#aiagenterror). | + +#### LLMStreamingCompletedEvent + +Represents the end of an LLM streaming call. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|----------|--------------|----------|---------|-------------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the LLM run. | +| `prompt` | Prompt | Yes | | The prompt that is sent to the model. | +| `model` | String | Yes | | The model identifier in the format `llm_provider:model_id`. | +| `tools` | List | Yes | | The list of tools that the model can call. | + +### Tool execution events + +#### ToolExecutionStartingEvent + +Represents the event of a model calling a tool. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|---------------|-------------|----------|---------|---------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the strategy/agent run. | +| `toolCallId` | String | No | null | The identifier of the tool call, if available. | +| `toolName` | String | Yes | | The name of the tool. | +| `toolArgs` | JsonObject | Yes | | The arguments that are provided to the tool. | + +#### ToolValidationFailedEvent + +Represents the occurrence of a validation error during a tool call. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|---------------|-------------|----------|---------|---------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the strategy/agent run. | +| `toolCallId` | String | No | null | The identifier of the tool call, if available. | +| `toolName` | String | Yes | | The name of the tool for which validation failed. | +| `toolArgs` | JsonObject | Yes | | The arguments that are provided to the tool. | +| `error` | String | Yes | | The validation error message. | + +#### ToolExecutionFailedEvent + +Represents a failure to execute a tool. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|---------------|--------------|----------|---------|-------------------------------------------------------------------------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the strategy/agent run. | +| `toolCallId` | String | No | null | The identifier of the tool call, if available. | +| `toolName` | String | Yes | | The name of the tool. | +| `toolArgs` | JsonObject | Yes | | The arguments that are provided to the tool. | +| `error` | AIAgentError | Yes | | The specific error that occurred when trying to call a tool. For more information, see [AIAgentError](#aiagenterror). | + +#### ToolExecutionCompletedEvent + +Represents a successful tool call with the return of a result. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|---------------|------------|----------|---------|------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the run. | +| `toolCallId` | String | No | null | The identifier of the tool call. | +| `toolName` | String | Yes | | The name of the tool. | +| `toolArgs` | JsonObject | Yes | | The arguments provided to the tool. | +| `result` | String | Yes | | The result of the tool call (nullable). | + +## FAQ and troubleshooting + +The following section includes commonly asked questions and answers related to the Tracing feature. + +### How do I trace only specific parts of my agent's execution? + +Use the `messageFilter` property to filter events. For example, to trace only node execution: - - ```kotlin -handleEvents { - // Handle tool calls - onToolExecutionStarting { eventContext -> - println("Tool called: ${eventContext.tool.name} with args ${eventContext.toolArgs}") - } - // Handle event triggered when the agent completes its execution - onAgentCompleted { eventContext -> - println("Agent finished with result: ${eventContext.result}") +install(Tracing) { + val fileWriter = TraceFeatureMessageFileWriter( + outputPath, + { path: Path -> SystemFileSystem.sink(path).buffered() } + ) + addMessageProcessor(fileWriter) + + // Only trace LLM calls + fileWriter.setMessageFilter { message -> + message is LLMCallStartingEvent || message is LLMCallCompletedEvent } - - // Other event handlers } ``` - + -For more details about event handler configuration, see [API reference](https://api.koog.ai/agents/agents-features/agents-features-event-handler/ai.koog.agents.local.features.eventHandler.feature/-event-handler-config/index.html). +### Can I use multiple message processors? -You can also set up event handlers using the `handleEvents` extension function when creating an agent. -This function also installs the event handler feature and configures event handlers for the agent. Here is an example: +Yes, you can add multiple message processors to trace to different destinations simultaneously: + ```kotlin -val agent = AIAgent( - promptExecutor = simpleOllamaAIExecutor(), - llmModel = OllamaModels.Meta.LLAMA_3_2, -){ - handleEvents { - // Handle tool calls - onToolExecutionStarting { eventContext -> - println("Tool called: ${eventContext.tool.name} with args ${eventContext.toolArgs}") +install(Tracing) { + addMessageProcessor(TraceFeatureMessageLogWriter(logger)) + addMessageProcessor(TraceFeatureMessageFileWriter(outputPath, syncOpener)) + addMessageProcessor(TraceFeatureMessageRemoteWriter(connectionConfig)) +} +``` + + +### How can I create a custom message processor? + +Implement the `FeatureMessageProcessor` interface: + + + +```kotlin +class CustomTraceProcessor : FeatureMessageProcessor() { + + // Current open state of the processor + private var _isOpen = MutableStateFlow(false) + + override val isOpen: StateFlow + get() = _isOpen.asStateFlow() + + override suspend fun processMessage(message: FeatureMessage) { + // Custom processing logic + when (message) { + is NodeExecutionStartingEvent -> { + // Process node start event + } + + is LLMCallCompletedEvent -> { + // Process LLM call end event + } + // Handle other event types } + } - // Other event handlers + override suspend fun close() { + // Close connections of established } } + +// Use your custom processor +install(Tracing) { + addMessageProcessor(CustomTraceProcessor()) +} ``` - + + +For more information about existing event types that can be handled by message processors, see [Predefined event types](#predefined-event-types). diff --git a/docs/docs/agent-persistency.md b/docs/docs/agent-persistence.md similarity index 70% rename from docs/docs/agent-persistency.md rename to docs/docs/agent-persistence.md index d50b75b4f1..184b2a5bd2 100644 --- a/docs/docs/agent-persistency.md +++ b/docs/docs/agent-persistence.md @@ -1,6 +1,6 @@ -# Agent Persistency +# Agent Persistence -Agent Persistency is a feature that provides checkpoint functionality for AI agents in the Koog framework. +Agent Persistence is a feature that provides checkpoint functionality for AI agents in the Koog framework. It lets you save and restore the state of an agent at specific points during execution, enabling capabilities such as: - Resuming agent execution from a specific point @@ -22,7 +22,7 @@ Checkpoints are identified by unique IDs and are associated with a specific agen ## Prerequisites -The Agent Persistency feature requires that all nodes in your agent's strategy have unique names. +The Agent Persistence feature requires that all nodes in your agent's strategy have unique names. This is enforced when the feature is installed: + Make sure to set unique names for nodes in your graph. ## Installation -To use the Agent Persistency feature, add it to your agent's configuration: +To use the Agent Persistence feature, add it to your agent's configuration: + ## Configuration options -The Agent Persistency feature has two main configuration options: +The Agent Persistence feature has two main configuration options: - **Storage provider**: the provider used to save and retrieve checkpoints. - **Continuous persistence**: automatic creation of checkpoints after each node is run. @@ -85,8 +85,8 @@ Set the storage provider that will be used to save and retrieve checkpoints: ```kotlin -install(Persistency) { - storage = InMemoryPersistencyStorageProvider("in-memory-storage") +install(Persistence) { + storage = InMemoryPersistenceStorageProvider() } ``` - + The framework includes the following built-in providers: -- `InMemoryPersistencyStorageProvider`: stores checkpoints in memory (lost when the application restarts). -- `FilePersistencyStorageProvider`: persists checkpoints to the file system. -- `NoPersistencyStorageProvider`: a no-op implementation that does not store checkpoints. This is the default provider. +- `InMemoryPersistenceStorageProvider`: stores checkpoints in memory (lost when the application restarts). +- `FilePersistenceStorageProvider`: persists checkpoints to the file system. +- `NoPersistenceStorageProvider`: a no-op implementation that does not store checkpoints. This is the default provider. -You can also implement custom storage providers by implementing the `PersistencyStorageProvider` interface. +You can also implement custom storage providers by implementing the `PersistenceStorageProvider` interface. For more information, see [Custom storage providers](#custom-storage-providers). ### Continuous persistence @@ -124,8 +124,8 @@ To activate continuous persistence, use the code below: ```kotlin -install(Persistency) { - enableAutomaticPersistency = true +install(Persistence) { + enableAutomaticPersistence = true } ``` - + When activated, the agent will automatically create a checkpoint after each node is executed, allowing for fine-grained recovery. @@ -157,7 +157,7 @@ To learn how to create a checkpoint at a specific point in your agent's executio + ### Restoring from a checkpoint @@ -188,20 +188,20 @@ To restore the state of an agent from a specific checkpoint, follow the code sam ```kotlin suspend fun example(context: AIAgentContext, checkpointId: String) { // Roll back to a specific checkpoint - context.persistency().rollbackToCheckpoint(checkpointId, context) + context.persistence().rollbackToCheckpoint(checkpointId, context) // Or roll back to the latest checkpoint - context.persistency().rollbackToLatestCheckpoint(context) + context.persistence().rollbackToLatestCheckpoint(context) } ``` - + #### Rolling back all side-effects produced by tools @@ -222,12 +222,12 @@ And now you would like to roll back to a checkpoint. Restoring the agent's state be sufficient to achieve the exact state of the world before the checkpoint. You should also restore the side-effects produced by your tool calls. In our example, this would mean removing `Maria` and `Daniel` from the database. -With Koog Persistency you can achieve that by providing a `RollbackToolRegistry` to `Persistency` feature config: +With Koog Persistence you can achieve that by providing a `RollbackToolRegistry` to `Persistence` feature config: ```kotlin -install(Persistency) { - enableAutomaticPersistency = true +install(Persistence) { + enableAutomaticPersistence = true rollbackToolRegistry = RollbackToolRegistry { // For every `createUser` tool call there will be a `removeUser` invocation in the reverse order // when rolling back to the desired execution point. @@ -259,27 +259,27 @@ install(Persistency) { } ``` - + ### Using extension functions -The Agent Persistency feature provides convenient extension functions for working with checkpoints: +The Agent Persistence feature provides convenient extension functions for working with checkpoints: ```kotlin suspend fun example(context: AIAgentContext) { // Access the checkpoint feature - val checkpointFeature = context.persistency() + val checkpointFeature = context.persistence() // Or perform an action with the checkpoint feature - context.withPersistency { ctx -> + context.withPersistence { ctx -> // 'this' is the checkpoint feature createCheckpoint( agentContext = ctx, @@ -291,17 +291,17 @@ suspend fun example(context: AIAgentContext) { } } ``` - + ## Advanced usage ### Custom storage providers -You can implement custom storage providers by implementing the `PersistencyStorageProvider` interface: +You can implement custom storage providers by implementing the `PersistenceStorageProvider` interface: ```kotlin -class MyCustomStorageProvider : PersistencyStorageProvider { +class MyCustomStorageProvider : PersistenceStorageProvider { override suspend fun getCheckpoints(agentId: String): List { // Implementation } - override suspend fun saveCheckpoint(agentCheckpointData: AgentCheckpointData) { + override suspend fun saveCheckpoint(agentId: String, agentCheckpointData: AgentCheckpointData) { // Implementation } @@ -325,29 +325,29 @@ class MyCustomStorageProvider : PersistencyStorageProvider { } ``` - + -To use your custom provider in the feature configuration, set it as the storage when configuring the Agent Persistency +To use your custom provider in the feature configuration, set it as the storage when configuring the Agent Persistence feature in your agent. ```kotlin -install(Persistency) { +install(Persistence) { storage = MyCustomStorageProvider() } ``` - + ### Setting execution points @@ -375,7 +375,7 @@ For advanced control, you can directly set the execution point of an agent: + This allows for more fine-grained control over the agent's state beyond just restoring from checkpoints. diff --git a/docs/docs/complex-workflow-agents.md b/docs/docs/complex-workflow-agents.md index 8fd84409e8..787c81a456 100644 --- a/docs/docs/complex-workflow-agents.md +++ b/docs/docs/complex-workflow-agents.md @@ -3,6 +3,9 @@ In addition to single-run agents, the `AIAgent` class lets you build agents that handle complex workflows by defining custom strategies, tools, configurations, and custom input/output types. +!!! tip + If you are new to Koog and want to create the simplest agent, start with [Single-run agents](single-run-agents.md). + The process of creating and configuring such an agent typically includes the following steps: 1. Provide a prompt executor to communicate with the LLM. @@ -20,7 +23,7 @@ The process of creating and configuring such an agent typically includes the fol Use environment variables or a secure configuration management system to store your API keys. Avoid hardcoding API keys directly in your source code. -## Creating a single-run agent +## Creating a complex workflow agent ### 1. Add dependencies @@ -39,7 +42,7 @@ For all available installation methods, see [Installation](index.md#installation Prompt executors manage and run prompts. You can choose a prompt executor based on the LLM provider you plan to use. Also, you can create a custom prompt executor using one of the available LLM clients. -To learn more, see [Prompt executors](prompt-api.md#prompt-executors). +To learn more, see [Prompt executors](prompt-api.md#running-prompts-with-prompt-executors). For example, to provide the OpenAI prompt executor, you need to call the `simpleOpenAIExecutor` function and provide it with the API key required for authentication with the OpenAI service: @@ -115,6 +118,7 @@ val processNode by node { input -> } ``` + !!! tip There are also pre-defined nodes that you can use in your agent strategy. To learn more, see [Predefined nodes and components](nodes-and-components.md). @@ -165,6 +169,7 @@ edge(sourceNode forwardTo targetNode transformed { output -> edge(sourceNode forwardTo targetNode onCondition { it.isNotEmpty() } transformed { it.uppercase() }) ``` + #### 3.2. Implement the strategy To implement the agent strategy, call the `strategy` function and define nodes and edges. For example: @@ -391,7 +396,7 @@ fun main() { ## Working with structured data -The `AIAgent` can process structured data from LLM outputs. For more details, see [Structured data processing](structured-data.md). +The `AIAgent` can process structured data from LLM outputs. For more details, see [Structured data processing](structured-output.md). ## Using parallel tool calls diff --git a/docs/docs/examples/Calculator.md b/docs/docs/examples/Calculator.md index eff11b74c2..f50a809719 100644 --- a/docs/docs/examples/Calculator.md +++ b/docs/docs/examples/Calculator.md @@ -170,13 +170,13 @@ val agent = AIAgent( toolRegistry = toolRegistry ) { handleEvents { - onToolCall { e -> + onToolCallStarting { e -> println("Tool called: ${e.tool.name}, args=${e.toolArgs}") } - onAgentRunError { e -> + onAgentExecutionFailed { e -> println("Agent error: ${e.throwable.message}") } - onAgentFinished { e -> + onAgentCompleted { e -> println("Final result: ${e.result}") } } diff --git a/docs/docs/examples/UnityMcp.md b/docs/docs/examples/UnityMcp.md index 65cf10086d..8100f244c9 100644 --- a/docs/docs/examples/UnityMcp.md +++ b/docs/docs/examples/UnityMcp.md @@ -123,17 +123,17 @@ runBlocking { install(Tracing) install(EventHandler) { - onBeforeAgentStarted { eventContext -> - println("OnBeforeAgentStarted first (strategy: ${strategy.name})") + onAgentStarting { eventContext -> + println("OnAgentStarting first (strategy: ${strategy.name})") } - onBeforeAgentStarted { eventContext -> - println("OnBeforeAgentStarted second (strategy: ${strategy.name})") + onAgentStarting { eventContext -> + println("OnAgentStarting second (strategy: ${strategy.name})") } - onAgentFinished { eventContext -> + onAgentCompleted { eventContext -> println( - "OnAgentFinished (agent id: ${eventContext.agentId}, result: ${eventContext.result})" + "OnAgentCompleted (agent id: ${eventContext.agentId}, result: ${eventContext.result})" ) } } diff --git a/docs/docs/features-overview.md b/docs/docs/features-overview.md index 0e3b8de886..d68f687ce3 100644 --- a/docs/docs/features-overview.md +++ b/docs/docs/features-overview.md @@ -8,6 +8,11 @@ Agent features provide a way to extend and enhance the functionality of AI agent The Koog framework implements the following features: -- [Event Handler](agent-events.md) +- [Event Handler](agent-event-handlers.md) - [Tracing](tracing.md) - [Agent Memory](agent-memory.md) +- [OpenTelemetry](opentelemetry-support.md) +- [Agent Persistence (Snapshots)](agent-persistence.md) +- Debugger +- Tokenizer +- SQL Persistence Providers diff --git a/docs/docs/functional-agents.md b/docs/docs/functional-agents.md new file mode 100644 index 0000000000..cd1605f342 --- /dev/null +++ b/docs/docs/functional-agents.md @@ -0,0 +1,280 @@ +# Functional agents + +Functional agents are lightweight AI agents that operate without building complex strategy graphs. +Instead, the agent logic is implemented as a lambda function that handles user input, interacts with an LLM, +optionally calls tools, and produces a final output. It can perform a single LLM call, process multiple LLM calls in sequence, or loop based on user input, as well as LLM and tool outputs. + +!!! tip + - If you already have a simple [single-run agent](single-run-agents.md) as your first MVP, but run into task-specific limitations, use a functional agent to prototype custom logic. You can implement custom control flows in plain Kotlin while still using most Koog features, including history compression and automatic state management. + - For production-grade needs, refactor your functional agent into a [complex workflow agent](complex-workflow-agents.md) with strategy graphs. This provides persistence with controllable rollbacks for fault-tolerance and advanced OpenTelemetry tracing with nested graph events. + +This page guides you through the steps necessary to create a minimal functional agent and extend it with tools. + +## Prerequisites + +Before you start, make sure that you have the following: + +- A working Kotlin/JVM project with Gradle. +- Java 17+ installed. +- A valid API key from the LLM provider used to implement an AI agent. For a list of all available providers, refer to [Overview](index.md). +- (Optional) Ollama installed and running locally if you use this provider. + +!!! tip + Use environment variables or a secure configuration management system to store your API keys. + Avoid hardcoding API keys directly in your source code. + +## Add dependencies + +The `AIAgent` class is the main class for creating agents in Koog. +Include the following dependency in your build configuration to use the class functionality: + +``` +dependencies { + implementation("ai.koog:koog-agents:VERSION") +} +``` +For all available installation methods, see [Installation](index.md#installation). + +## Create a minimal functional agent + +To create a minimal functional agent, do the following: + +1. Choose the input and output types that the agent handles and create a corresponding `AIAgent` instance. + In this guide, we use `AIAgent`, which means the agent receives and returns `String`. +2. Provide the required parameters, including a system prompt, prompt executor, and LLM. +3. Define the agent logic with a lambda function wrapped into the `functionalStrategy {...}` DSL method. + +Here is an example of a minimal functional agent that sends user text to a specified LLM and returns a single assistant message. + + + +```kotlin +// Create an AIAgent instance and provide a system prompt, prompt executor, and LLM +val mathAgent = AIAgent( + systemPrompt = "You are a precise math assistant.", + promptExecutor = simpleOllamaAIExecutor(), + llmModel = OllamaModels.Meta.LLAMA_3_2, + strategy = functionalStrategy { input -> // Define the agent logic + // Make one LLM call + val response = requestLLM(input) + // Extract and return the assistant message content from the response + response.asAssistantMessage().content + } +) + +// Run the agent with a user input and print the result +val result = mathAgent.run("What is 12 × 9?") +println(result) +``` + + +The agent can produce the following output: + +``` +The answer to 12 × 9 is 108. +``` + +This agent makes a single LLM call and returns the assistant message content. +You can extend the agent logic to handle multiple sequential LLM calls. For example: + + + +```kotlin +// Create an AIAgent instance and provide a system prompt, prompt executor, and LLM +val mathAgent = AIAgent( + systemPrompt = "You are a precise math assistant.", + promptExecutor = simpleOllamaAIExecutor(), + llmModel = OllamaModels.Meta.LLAMA_3_2, + strategy = functionalStrategy { input -> // Define the agent logic + // The first LLM call to produce an initial draft based on the user input + val draft = requestLLM("Draft: $input").asAssistantMessage().content + // The second LLM call to improve the draft by prompting the LLM again with the draft content + val improved = requestLLM("Improve and clarify: $draft").asAssistantMessage().content + // The final LLM call to format the improved text and return the final formatted result + requestLLM("Format the result as bold: $improved").asAssistantMessage().content + } +) + +// Run the agent with a user input and print the result +val result = mathAgent.run("What is 12 × 9?") +println(result) +``` + + +The agent can produce the following output: + +``` +When multiplying 12 by 9, we can break it down as follows: + +**12 (tens) × 9 = 108** + +Alternatively, we can also use the distributive property to calculate this: + +**(10 + 2) × 9** += **10 × 9 + 2 × 9** += **90 + 18** += **108** +``` + +## Add tools + +In many cases, a functional agent needs to complete specific tasks, such as reading and writing data or calling APIs. +In Koog, you expose such capabilities as tools and let the LLM call them in the agent logic. + +This chapter takes the minimal functional agent created above and demonstrates how to extend the agent logic using tools. + + +1) Create an annotation-based tool. For more details, see [Annotation-based tools](annotation-based-tools.md). + + +```kotlin +@LLMDescription("Simple multiplier") +class MathTools : ToolSet { + @Tool + @LLMDescription("Multiplies two numbers and returns the result") + fun multiply(a: Int, b: Int): Int { + val result = a * b + return result + } +} +``` + + +To learn more about available tools, refer to the [Tool overview](tools-overview.md). + +2) Register the tool to make it available to the agent. + + + +```kotlin +val toolRegistry = ToolRegistry { + tools(MathTools()) +} +``` + + +3) Pass the tool registry to the agent to enable the LLM to request and use the available tools. + +4) Extend the agent logic to identify tool calls, execute the requested tools, send their results back to the LLM, and repeat the process until no tool calls remain. + +!!! note + Use a loop only if the LLM continues to issue tool calls. + + + +```kotlin +val mathWithTools = AIAgent( + systemPrompt = "You are a precise math assistant. When multiplication is needed, use the multiplication tool.", + promptExecutor = simpleOllamaAIExecutor(), + llmModel = OllamaModels.Meta.LLAMA_3_2, + toolRegistry = toolRegistry, + strategy = functionalStrategy { input -> // Define the agent logic extended with tool calls + // Send the user input to the LLM + var responses = requestLLMMultiple(input) + + // Only loop while the LLM requests tools + while (responses.containsToolCalls()) { + // Extract tool calls from the response + val pendingCalls = extractToolCalls(responses) + // Execute the tools and return the results + val results = executeMultipleTools(pendingCalls) + // Send the tool results back to the LLM. The LLM may call more tools or return a final output + responses = sendMultipleToolResults(results) + } + + // When no tool calls remain, extract and return the assistant message content from the response + responses.single().asAssistantMessage().content + } +) + +// Run the agent with a user input and print the result +val reply = mathWithTools.run("Please multiply 12.5 and 4, then add 10 to the result.") +println(reply) +``` + + +The agent can produce the following output: + +``` +Here is the step-by-step solution: + +1. Multiply 12.5 and 4: + 12.5 × 4 = 50 + +2. Add 10 to the result: + 50 + 10 = 60 +``` + +## What's next + +- Learn how to return structured data using the [Structured output API](structured-output.md). +- Experiment with adding more [tools](tools-overview.md) to the agent. +- Improve observability with the [EventHandler](agent-events.md) feature. +- Learn how to handle long-running conversations with [History compression](history-compression.md). diff --git a/docs/docs/index.md b/docs/docs/index.md index 3762feb4c5..5b0fb7a647 100644 --- a/docs/docs/index.md +++ b/docs/docs/index.md @@ -1,12 +1,13 @@ # Overview -Koog is a Kotlin-based framework designed to build and run AI agents entirely in idiomatic Kotlin. +Koog is an open-source JetBrains framework designed to build and run AI agents entirely in idiomatic Kotlin. It lets you create agents that can interact with tools, handle complex workflows, and communicate with users. The framework supports the following types of agents: * Single-run agents with minimal configuration that process a single input and provide a response. An agent of this type operates within a single cycle of tool-calling to complete its task and provide a response. +* Functional agents with lightweight, customizable logic defined by a lambda function to handle user input, interact with an LLM, call tools, and produce a final output. * Complex workflow agents with advanced capabilities that support custom strategies and configurations. ## Key features diff --git a/docs/docs/prompt-api.md b/docs/docs/prompt-api.md index 297c87118f..a500a19b46 100644 --- a/docs/docs/prompt-api.md +++ b/docs/docs/prompt-api.md @@ -222,7 +222,7 @@ To choose between clients and executors, consider the following factors: - Use LLM clients directly if you work with a single LLM provider and do not require advanced lifecycle management. To learm more, see [Running prompts with LLM clients](#running-prompts-with-llm-clients). - Use prompt executors if you need a higher level of abstraction for managing LLMs and their lifecycle, or if you want to run prompts with a consistent API across multiple providers and dynamically switch between them. - To learn more, see [Runnning prompts with prompt executors](#running-prompts-with-executors). + To learn more, see [Runnning prompts with prompt executors](#running-prompts-with-prompt-executors). !!!note Both the LLM clients and prompt executors let you stream responses, generate multiple choices, and run content moderation. @@ -408,7 +408,7 @@ For faster setup, Koog provides the following ready-to-use executor implementati - `simpleAnthropicExecutor` for executing prompts with Anthropic models. - `simpleGoogleAIExecutor` for executing prompts with Google models. - `simpleOpenRouterExecutor` for executing prompts with OpenRouter. - - `simpleOllamaExecutor` for executing prompts with Ollama. + - `simpleOllamaAIExecutor` for executing prompts with Ollama. - Multi-provider executor: - `DefaultMultiLLMPromptExecutor` which is an implementation of `MultiLLMPromptExecutor` that supports OpenAI, Anthropic, and Google providers. diff --git a/docs/docs/single-run-agents.md b/docs/docs/single-run-agents.md index 07f877ef6b..5dfdc55d2f 100644 --- a/docs/docs/single-run-agents.md +++ b/docs/docs/single-run-agents.md @@ -150,7 +150,7 @@ val agent = AIAgent( Single-run agents support custom event handlers. While having an event handler is not required for creating an agent, it might be helpful for testing, debugging, or making hooks for chained agent interactions. -For more information on how to use the `EventHandler` feature for monitoring your agent interactions, see [Agent events](agent-events.md). +For more information on how to use the `EventHandler` feature for monitoring your agent interactions, see [Event Handlers](agent-event-handlers.md). ### 8. Run the agent diff --git a/docs/docs/spring-boot.md b/docs/docs/spring-boot.md index 91644653bc..7f40dcbaf6 100644 --- a/docs/docs/spring-boot.md +++ b/docs/docs/spring-boot.md @@ -6,8 +6,14 @@ agents into your Spring Boot applications with minimal setup. ## Overview The `koog-spring-boot-starter` automatically configures LLM clients based on your application properties and provides -ready-to-use beans for dependency injection. It supports all major LLM providers including OpenAI, Anthropic, Google, -OpenRouter, DeepSeek, and Ollama. +ready-to-use beans for dependency injection. It supports all major LLM providers including: + +- OpenAI +- Anthropic +- Google +- OpenRouter +- DeepSeek +- Ollama ## Getting Started @@ -27,21 +33,27 @@ Configure your preferred LLM providers in `application.properties`: ```properties # OpenAI Configuration +ai.koog.openai.enabled=true ai.koog.openai.api-key=${OPENAI_API_KEY} ai.koog.openai.base-url=https://api.openai.com # Anthropic Configuration +ai.koog.anthropic.enabled=true ai.koog.anthropic.api-key=${ANTHROPIC_API_KEY} ai.koog.anthropic.base-url=https://api.anthropic.com # Google Configuration +ai.koog.google.enabled=true ai.koog.google.api-key=${GOOGLE_API_KEY} ai.koog.google.base-url=https://generativelanguage.googleapis.com # OpenRouter Configuration +ai.koog.openrouter.enabled=true ai.koog.openrouter.api-key=${OPENROUTER_API_KEY} ai.koog.openrouter.base-url=https://openrouter.ai # DeepSeek Configuration +ai.koog.deepseek.enabled=true ai.koog.deepseek.api-key=${DEEPSEEK_API_KEY} ai.koog.deepseek.base-url=https://api.deepseek.com # Ollama Configuration (local - no API key required) +ai.koog.ollama.enabled=true ai.koog.ollama.base-url=http://localhost:11434 ``` @@ -51,26 +63,53 @@ Or using YAML format (`application.yml`): ai: koog: openai: + enabled: true api-key: ${OPENAI_API_KEY} base-url: https://api.openai.com anthropic: + enabled: true api-key: ${ANTHROPIC_API_KEY} base-url: https://api.anthropic.com google: + enabled: true api-key: ${GOOGLE_API_KEY} base-url: https://generativelanguage.googleapis.com openrouter: + enabled: true api-key: ${OPENROUTER_API_KEY} base-url: https://openrouter.ai deepseek: + enabled: true api-key: ${DEEPSEEK_API_KEY} base-url: https://api.deepseek.com ollama: + enabled: true # Set it to `true` explicitly to activate !!! base-url: http://localhost:11434 ``` +Both `ai.koog.PROVIDER.api-key` and `ai.koog.PROVIDER.enabled` properties are used to activate the provider. + +If the provider supports the API Key (like OpenAI, Anthropic, Google), then `ai.koog.PROVIDER.enabled` is set to `true` +by default. + +If the provider does not support the API Key, like Ollama, `ai.koog.PROVIDER.enabled` is set to `false` by default, +and provider should be enabled explicitly in the application configuration. + +Provider's base urls are set to their default values in the Spring Boot starter, but you may override it in your +application. + !!! tip "Environment Variables" It's recommended to use environment variables for API keys to keep them secure and out of version control. +Spring configuration uses LLM provider's well-known environment variables. +For example, setting the environment variable `OPENAI_API_KEY` is enough for OpenAI spring configuration to activate. + +| LLM Provider | Environment Variables | +|--------------|-----------------------| +| Open AI | `OPENAI_API_KEY` | +| Anthropic | `ANTHROPIC_API_KEY` | +| Google | `GOOGLE_API_KEY` | +| OpenRouter | `OPENROUTER_API_KEY` | +| DeepSeek | `DEEPSEEK_API_KEY` | ### 3. Inject and Use diff --git a/docs/docs/streaming-api.md b/docs/docs/streaming-api.md index 74256d091f..ae2ca02a88 100644 --- a/docs/docs/streaming-api.md +++ b/docs/docs/streaming-api.md @@ -137,7 +137,7 @@ llm.writeSession { ### Listening to stream events in event handlers -You can listen to stream events in [agent events](agent-events.md). +You can listen to stream events in [agent event handlers](agent-event-handlers.md). ```kotlin handleEvents { - onToolExecutionStarting { context -> + onToolCallStarting { context -> println("\n🔧 Using ${context.tool.name} with ${context.toolArgs}... ") } onLLMStreamingFrameReceived { context -> diff --git a/docs/docs/testing.md b/docs/docs/testing.md index 60a3e5665a..8896edd5ce 100644 --- a/docs/docs/testing.md +++ b/docs/docs/testing.md @@ -799,7 +799,7 @@ fun testToneAgent() = runTest { // Create an event handler val eventHandler = EventHandler { - onToolCall { tool, args -> + onToolCallStarting { tool, args -> println("[DEBUG_LOG] Tool called: tool ${tool.name}, args $args") toolCalls.add(tool.name) } diff --git a/docs/docs/tools-overview.md b/docs/docs/tools-overview.md index d0bfe8ada4..ec460fc08e 100644 --- a/docs/docs/tools-overview.md +++ b/docs/docs/tools-overview.md @@ -98,7 +98,7 @@ in the agent context rather than calling tools directly, as this ensures proper agent environment. !!! tip - Ensure you have implemented proper [error handling](agent-events.md) in your tools to prevent agent failure. + Ensure you have implemented proper [error handling](agent-event-handlers.md) in your tools to prevent agent failure. The tools are called within a specific session context represented by `AIAgentLLMWriteSession`. It provides several methods for calling tools so that you can: diff --git a/docs/docs/tracing.md b/docs/docs/tracing.md index 953b2188cf..b676e3a343 100644 --- a/docs/docs/tracing.md +++ b/docs/docs/tracing.md @@ -9,7 +9,8 @@ including: - Strategy execution - LLM calls -- Tool invocations +- LLM streaming (start, frames, completion, errors) +- Tool calls - Node execution within the agent graph This feature operates by intercepting key events in the agent pipeline and forwarding them to configurable message @@ -37,7 +38,7 @@ To use the Tracing feature, you need to: - -```kotlin -install(Tracing) { - val fileWriter = TraceFeatureMessageFileWriter( - outputPath, - { path: Path -> SystemFileSystem.sink(path).buffered() } - ) - addMessageProcessor(fileWriter) - - // Only trace LLM calls - fileWriter.setMessageFilter { message -> - message is LLMCallStartingEvent || message is LLMCallCompletedEvent - } -} -``` - - -### Can I use multiple message processors? - -Yes, you can add multiple message processors to trace to different destinations simultaneously: - - - -```kotlin -install(Tracing) { - addMessageProcessor(TraceFeatureMessageLogWriter(logger)) - addMessageProcessor(TraceFeatureMessageFileWriter(outputPath, syncOpener)) - addMessageProcessor(TraceFeatureMessageRemoteWriter(connectionConfig)) -} -``` - - -### How can I create a custom message processor? - -Implement the `FeatureMessageProcessor` interface: - - - -```kotlin -class CustomTraceProcessor : FeatureMessageProcessor() { - - // Current open state of the processor - private var _isOpen = MutableStateFlow(false) - - override val isOpen: StateFlow - get() = _isOpen.asStateFlow() - - override suspend fun processMessage(message: FeatureMessage) { - // Custom processing logic - when (message) { - is NodeExecutionStartingEvent -> { - // Process node start event - } - - is LLMCallCompletedEvent -> { - // Process LLM call end event - } - // Handle other event types - } - } - - override suspend fun close() { - // Close connections of established - } -} - -// Use your custom processor -install(Tracing) { - addMessageProcessor(CustomTraceProcessor()) -} -``` - - -For more information about existing event types that can be handled by message processors, see [Predefined event types](#predefined-event-types). - -## Predefined event types - -Koog provides predefined event types that can be used in custom message processors. The predefined events can be -classified into several categories, depending on the entity they relate to: - -- [Agent events](#agent-events) -- [Strategy events](#strategy-events) -- [Node events](#node-events) -- [LLM call events](#llm-call-events) -- [Tool call events](#tool-call-events) - -### Agent events - -#### AgentStartingEvent - -Represents the start of an agent run. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|----------------|-----------|----------|-----------------------|---------------------------------------------------------------------------| -| `strategyName` | String | Yes | | The name of the strategy that the agent should follow. | -| `eventId` | String | No | `AgentStartingEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - -#### AgentCompletedEvent - -Represents the end of an agent run. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|----------------|-----------|----------|------------------------|---------------------------------------------------------------------------| -| `strategyName` | String | Yes | | The name of the strategy that the agent followed. | -| `result` | String | Yes | | The result of the agent run. Can be `null` if there is no result. | -| `eventId` | String | No | `AgentCompletedEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - -#### AgentExecutionFailedEvent - -Represents the occurrence of an error during an agent run. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|----------------|--------------|----------|------------------------|-----------------------------------------------------------------------------------------------------------------| -| `strategyName` | String | Yes | | The name of the strategy that the agent followed. | -| `error` | AIAgentError | Yes | | The specific error that occurred during the agent run. For more information, see [AIAgentError](#aiagenterror). | -| `eventId` | String | No | `AgentExecutionFailedEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - - -The `AIAgentError` class provides more details about an error that occurred during an agent run. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|--------------|-----------|----------|---------|------------------------------------------------------------------| -| `message` | String | Yes | | The message that provides more details about the specific error. | -| `stackTrace` | String | Yes | | The collection of stack records until the last executed code. | -| `cause` | String | No | null | The cause of the error, if available. | - -### Strategy events - -#### StrategyStartingEvent - -Represents the start of a strategy run. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|----------------|-----------|----------|-----------------------------|---------------------------------------------------------------------------| -| `strategyName` | String | Yes | | The name of the strategy. | -| `eventId` | String | No | `StrategyStartingEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - -#### StrategyCompletedEvent - -Represents the end of a strategy run. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|----------------|-----------|----------|--------------------------------|---------------------------------------------------------------------------| -| `strategyName` | String | Yes | | The name of the strategy. | -| `result` | String | Yes | | The result of the run. | -| `eventId` | String | No | `StrategyCompletedEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - -### Node events - -#### NodeExecutionStartingEvent - -Represents the start of a node run. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|------------|-----------|----------|----------------------------------|---------------------------------------------------------------------------| -| `nodeName` | String | Yes | | The name of the node whose run started. | -| `input` | String | Yes | | The input value for the node. | -| `eventId` | String | No | `NodeExecutionStartingEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - -#### NodeExecutionCompletedEvent - -Represents the end of a node run. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|------------|-----------|----------|--------------------------------|---------------------------------------------------------------------------| -| `nodeName` | String | Yes | | The name of the node whose run ended. | -| `input` | String | Yes | | The input value for the node. | -| `output` | String | Yes | | The output value produced by the node. | -| `eventId` | String | No | `NodeExecutionCompletedEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - -### LLM call events - -#### LLMCallStartingEvent - -Represents the start of an LLM call. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|-----------|--------------------|----------|----------------------|------------------------------------------------------------------------------------| -| `prompt` | Prompt | Yes | | The prompt that is sent to the model. For more information, see [Prompt](#prompt). | -| `tools` | List<String> | Yes | | The list of tools that the model can call. | -| `eventId` | String | No | `LLMCallStartingEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - - -The `Prompt` class represents a data structure for a prompt, consisting of a list of messages, a unique identifier, and -optional parameters for language model settings. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|------------|---------------------|----------|-------------|--------------------------------------------------------------| -| `messages` | List<Message> | Yes | | The list of messages that the prompt consists of. | -| `id` | String | Yes | | The unique identifier for the prompt. | -| `params` | LLMParams | No | LLMParams() | The settings that control the way the LLM generates content. | - -#### LLMCallCompletedEvent - -Represents the end of an LLM call. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|-------------|------------------------------|----------|---------------------|---------------------------------------------------------------------------| -| `responses` | List<Message.Response> | Yes | | One or more responses returned by the model. | -| `eventId` | String | No | `LLMCallCompletedEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - -### Tool call events - -#### ToolCallEvent - -Represents the event of a model calling a tool. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|------------|-----------|----------|-----------------|---------------------------------------------------------------------------| -| `toolName` | String | Yes | | The name of the tool. | -| `toolArgs` | Tool.Args | Yes | | The arguments that are provided to the tool. | -| `eventId` | String | No | `ToolCallEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - -#### ToolValidationErrorEvent - -Represents the occurrence of a validation error during a tool call. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|----------------|-----------|----------|----------------------------|---------------------------------------------------------------------------| -| `toolName` | String | Yes | | The name of the tool for which validation failed. | -| `toolArgs` | Tool.Args | Yes | | The arguments that are provided to the tool. | -| `errorMessage` | String | Yes | | The validation error message. | -| `eventId` | String | No | `ToolValidationErrorEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - -#### ToolCallFailureEvent - -Represents a failure to call a tool. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|------------|--------------|----------|------------------------|-----------------------------------------------------------------------------------------------------------------------| -| `toolName` | String | Yes | | The name of the tool. | -| `toolArgs` | Tool.Args | Yes | | The arguments that are provided to the tool. | -| `error` | AIAgentError | Yes | | The specific error that occurred when trying to call a tool. For more information, see [AIAgentError](#aiagenterror). | -| `eventId` | String | No | `ToolCallFailureEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - -#### ToolCallResultEvent - -Represents a successful tool call with the return of a result. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|------------|------------|----------|-----------------------|---------------------------------------------------------------------------| -| `toolName` | String | Yes | | The name of the tool. | -| `toolArgs` | Tool.Args | Yes | | The arguments that are provided to the tool. | -| `result` | ToolResult | Yes | | The result of the tool call. | -| `eventId` | String | No | `ToolCallResultEvent` | The identifier of the event. Usually the `simpleName` of the event class. | diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 99140cfad1..3ad90dbe22 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -6,7 +6,7 @@ nav: - Key concepts: key-concepts.md - Getting started: - Single-run agents: single-run-agents.md - - Act Agent API: act-ai-agent.md + - Functional agents: functional-agents.md - Complex workflow agents: complex-workflow-agents.md - Prompt API: prompt-api.md - Tools: @@ -23,6 +23,11 @@ nav: - Data transfer between nodes: data-transfer-between-nodes.md - History compression: history-compression.md - Model Context Protocol: model-context-protocol.md + - A2A Protocol: + - Overview: a2a-protocol-overview.md + - A2A server implementation: a2a-server.md + - A2A client implementation: a2a-client.md + - A2A and Koog integration: a2a-koog-integration.md - Model capabilities: model-capabilities.md - Content moderation: content-moderation.md - Backend framework integrations: @@ -42,7 +47,7 @@ nav: - Overview: features-overview.md - Tracing: tracing.md - Memory: agent-memory.md - - Agent Persistency: agent-persistency.md + - Agent Persistence: agent-persistence.md - OpenTelemetry: - Overview: opentelemetry-support.md - Langfuse Exporter: opentelemetry-langfuse-exporter.md @@ -98,7 +103,7 @@ plugins: - key-concepts.md: This page provides explanations of key terms and concepts related to Koog and agentic development. Getting started: - single-run-agents.md: This guide lets you quickly build a single-run agent with minimum required configuration. - - act-ai-agent.md: This guide shows you how to build a lightweight, non‑graph agent that you control with a simple loop. + - functional-agents.md: This guide shows you how to build a lightweight, non‑graph agent that you control with a simple loop. - complex-workflow-agents.md: This page explains how you can create agents that handle complex workflows by defining custom strategies, tools, configurations, and custom input and output types. Prompt API: - prompt-api.md: This page includes detailed instructions about the use of Prompt API, which provides a comprehensive toolkit for interacting with Large Language Models in production applications. @@ -119,6 +124,11 @@ plugins: - history-compression.md: This page provides details about different implementations and different types of history compression strategies to reduce history size and token usage. Model Context Protocol: - model-context-protocol.md: This page provides details about Koog integration with MCP servers, which allows you to incorporate MCP tools into your Koog agents. + A2A Protocol: + - a2a-protocol-overview.md: This page provides an overview of the A2A (Agent-to-Agent) protocol implementation in the Koog agentic framework. + - a2a-server.md: This page provides details about the A2A server implementation in the Koog agentic framework. The A2A server is responsible for handling requests from A2A clients according to the A2A protocol specification. + - a2a-client.md: This page provides details about the A2A client implementation in the Koog agentic framework. The A2A client is responsible for sending requests to A2A servers according to the A2A protocol specification. + - a2a-koog-integration.md: This page provides details about the integration of A2A with Koog agents. It includes information about how to define your agent as A2A server and connect it to the Koog agent system. Model capabilities: - model-capabilities.md: This page provides an overview of different features or parameters that the models support, referred to as model capabilities. Content moderation: @@ -142,7 +152,7 @@ plugins: - features-overview.md: This page provides a short overview of features as a way to extend and enhance the functionality of AI agents. - tracing.md: This page includes details about the Tracing feature, which provides comprehensive tracing capabilities for AI agents. The page includes configuration and initialization details, examples and quickstart, details about error handling and FAQ and troubleshooting. - agent-memory.md: This page provides details about the AgentMemory feature which lets AI agents store, retrieve, and use information across conversations. The page includes configuration and initialization details, examples and quickstart, best practices, error handling and FAQ and troubleshooting. - - agent-persistency.md: This page describes Agent Persistency, which is a feature that lets you save and restore the state of an agent at specific points during execution. The page includes key concepts, prerequisites, installation instructions, configuration, and basic and advanced usage guides. + - agent-persistence.md: This page describes Agent Persistence, which is a feature that lets you save and restore the state of an agent at specific points during execution. The page includes key concepts, prerequisites, installation instructions, configuration, and basic and advanced usage guides. OpenTelemetry: - opentelemetry-support.md: This page provides details about the support for OpenTelemetry with the Koog agentic framework for tracing and monitoring AI agents. The page includes details about installation and configuration, span types and attributes, and common exporters. - opentelemetry-langfuse-exporter.md: This page provides information about Koog's built-in support for exporting agent traces to Langfuse, a platform for observability and analytics of AI applications. The page provides details about LangFuse integration configuration and an example of its use. diff --git a/examples/code-agent/step-01-basic-agent/Module.md b/examples/code-agent/step-01-basic-agent/Module.md new file mode 100644 index 0000000000..0962664427 --- /dev/null +++ b/examples/code-agent/step-01-basic-agent/Module.md @@ -0,0 +1,3 @@ +# Module Code Agent (Step 1) + +Code Agent example. diff --git a/examples/code-agent/step-01-basic-agent/README.md b/examples/code-agent/step-01-basic-agent/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/code-agent/step-01-basic-agent/build.gradle.kts b/examples/code-agent/step-01-basic-agent/build.gradle.kts new file mode 100644 index 0000000000..d3a6c96a8f --- /dev/null +++ b/examples/code-agent/step-01-basic-agent/build.gradle.kts @@ -0,0 +1,22 @@ +plugins { + alias(libs.plugins.kotlin.jvm) + alias(libs.plugins.shadow) + application +} + +application.mainClass.set("ai.koog.agents.examples.codeagent.step01.MainKt") + +dependencies { + implementation("ai.koog:koog-agents") + implementation(libs.kotlinx.coroutines.core) + implementation(libs.logback.classic) +} + +tasks.test { + useJUnitPlatform() +} + +tasks.shadowJar { + archiveBaseName.set("code-agent") + mergeServiceFiles() +} diff --git a/examples/code-agent/step-01-basic-agent/gradle.properties b/examples/code-agent/step-01-basic-agent/gradle.properties new file mode 100644 index 0000000000..e360d6a58e --- /dev/null +++ b/examples/code-agent/step-01-basic-agent/gradle.properties @@ -0,0 +1,8 @@ +#Kotlin +kotlin.code.style=official +kotlin.daemon.jvmargs=-Xmx4096M + +#Gradle +org.gradle.jvmargs=-Xmx4096M -Dfile.encoding=UTF-8 +org.gradle.parallel=true +org.gradle.caching=true diff --git a/examples/code-agent/step-01-basic-agent/gradle/libs.versions.toml b/examples/code-agent/step-01-basic-agent/gradle/libs.versions.toml new file mode 100644 index 0000000000..9ba9996847 --- /dev/null +++ b/examples/code-agent/step-01-basic-agent/gradle/libs.versions.toml @@ -0,0 +1,15 @@ +[versions] +kotlin = "2.2.20" +kotlinx-coroutines = "1.10.2" +kotlinx-serialization = "1.8.1" +logback = "1.5.13" +shadow = "9.1.0" + +[libraries] +kotlinx-coroutines-core = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref = "kotlinx-coroutines" } +logback-classic = { module = "ch.qos.logback:logback-classic", version.ref = "logback" } + +[plugins] +kotlin-jvm = { id = "org.jetbrains.kotlin.jvm", version.ref = "kotlin" } +kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" } +shadow = { id = "com.gradleup.shadow", version.ref = "shadow" } diff --git a/examples/code-agent/step-01-basic-agent/gradle/wrapper/gradle-wrapper.jar b/examples/code-agent/step-01-basic-agent/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000000..1b33c55baa Binary files /dev/null and b/examples/code-agent/step-01-basic-agent/gradle/wrapper/gradle-wrapper.jar differ diff --git a/examples/code-agent/step-01-basic-agent/gradle/wrapper/gradle-wrapper.properties b/examples/code-agent/step-01-basic-agent/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000000..ca025c83a7 --- /dev/null +++ b/examples/code-agent/step-01-basic-agent/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-8.14-bin.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/examples/code-agent/step-01-basic-agent/gradlew b/examples/code-agent/step-01-basic-agent/gradlew new file mode 100755 index 0000000000..23d15a9367 --- /dev/null +++ b/examples/code-agent/step-01-basic-agent/gradlew @@ -0,0 +1,251 @@ +#!/bin/sh + +# +# Copyright © 2015-2021 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +# + +############################################################################## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/platforms/jvm/plugins-application/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# +############################################################################## + +# Attempt to set APP_HOME + +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd -P "${APP_HOME:-./}" > /dev/null && printf '%s\n' "$PWD" ) || exit + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD=maximum + +warn () { + echo "$*" +} >&2 + +die () { + echo + echo "$*" + echo + exit 1 +} >&2 + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; +esac + +CLASSPATH="\\\"\\\"" + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD=java + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +fi + +# Increase the maximum file descriptors if we can. +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac +fi + +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + -jar "$APP_HOME/gradle/wrapper/gradle-wrapper.jar" \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/examples/code-agent/step-01-basic-agent/gradlew.bat b/examples/code-agent/step-01-basic-agent/gradlew.bat new file mode 100644 index 0000000000..db3a6ac207 --- /dev/null +++ b/examples/code-agent/step-01-basic-agent/gradlew.bat @@ -0,0 +1,94 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem +@rem SPDX-License-Identifier: Apache-2.0 +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH= + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" -jar "%APP_HOME%\gradle\wrapper\gradle-wrapper.jar" %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/examples/code-agent/step-01-basic-agent/settings.gradle.kts b/examples/code-agent/step-01-basic-agent/settings.gradle.kts new file mode 100644 index 0000000000..7174be00b7 --- /dev/null +++ b/examples/code-agent/step-01-basic-agent/settings.gradle.kts @@ -0,0 +1,19 @@ +rootProject.name = "step-01-basic-agent" + +pluginManagement { + repositories { + gradlePluginPortal() + } +} + +dependencyResolutionManagement { + @Suppress("UnstableApiUsage") + repositories { + mavenCentral() + google() + } +} + +includeBuild("../../../.") { + name = "koog" +} diff --git a/examples/code-agent/step-01-basic-agent/src/main/kotlin/Main.kt b/examples/code-agent/step-01-basic-agent/src/main/kotlin/Main.kt new file mode 100644 index 0000000000..3467df8eba --- /dev/null +++ b/examples/code-agent/step-01-basic-agent/src/main/kotlin/Main.kt @@ -0,0 +1,50 @@ +package ai.koog.agents.examples.codeagent.step01 + +import ai.koog.agents.core.agent.AIAgent +import ai.koog.agents.core.agent.singleRunStrategy +import ai.koog.agents.core.tools.ToolRegistry +import ai.koog.agents.ext.tool.file.EditFileTool +import ai.koog.agents.ext.tool.file.ListDirectoryTool +import ai.koog.agents.ext.tool.file.ReadFileTool +import ai.koog.agents.ext.tool.file.WriteFileTool +import ai.koog.agents.features.eventHandler.feature.handleEvents +import ai.koog.prompt.executor.clients.openai.OpenAIModels +import ai.koog.prompt.executor.llms.all.simpleOpenAIExecutor +import ai.koog.rag.base.files.JVMFileSystemProvider +import kotlinx.coroutines.runBlocking + +val agent = AIAgent( + promptExecutor = simpleOpenAIExecutor(System.getenv("OPENAI_API_KEY")), + strategy = singleRunStrategy(), + systemPrompt = """ + You are a highly skilled programmer tasked with updating the provided codebase according to the given task. + Your goal is to deliver production-ready code changes that integrate seamlessly with the existing codebase and solve given task. + """.trimIndent(), + llmModel = OpenAIModels.Chat.GPT5, + toolRegistry = ToolRegistry { + tool(ListDirectoryTool(JVMFileSystemProvider.ReadOnly)) + tool(ReadFileTool(JVMFileSystemProvider.ReadOnly)) + tool(WriteFileTool(JVMFileSystemProvider.ReadWrite)) + tool(EditFileTool(JVMFileSystemProvider.ReadWrite)) + }, + maxIterations = 100 +) { + handleEvents { + onToolCallStarting { ctx -> + println("Tool called: ${ctx.tool.name}") + } + } +} + +fun main(args: Array) = runBlocking { + if (args.size < 2) { + println("Error: Please provide the project absolute path and a task as arguments") + println("Usage: ") + return@runBlocking + } + + val (path, task) = args + val input = "Project path: $path\n\n$task" + val result = agent.run(input) + println(result) +} diff --git a/examples/code-agent/step-01-basic-agent/src/main/resources/logback.xml b/examples/code-agent/step-01-basic-agent/src/main/resources/logback.xml new file mode 100644 index 0000000000..d624456018 --- /dev/null +++ b/examples/code-agent/step-01-basic-agent/src/main/resources/logback.xml @@ -0,0 +1,11 @@ + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + diff --git a/examples/demo-compose-app/app/src/commonMain/kotlin/com/jetbrains/example/kotlin_agents_demo_app/agents/calculator/CalculatorAgentProvider.kt b/examples/demo-compose-app/app/src/commonMain/kotlin/com/jetbrains/example/kotlin_agents_demo_app/agents/calculator/CalculatorAgentProvider.kt index b723e03484..61962db039 100644 --- a/examples/demo-compose-app/app/src/commonMain/kotlin/com/jetbrains/example/kotlin_agents_demo_app/agents/calculator/CalculatorAgentProvider.kt +++ b/examples/demo-compose-app/app/src/commonMain/kotlin/com/jetbrains/example/kotlin_agents_demo_app/agents/calculator/CalculatorAgentProvider.kt @@ -131,11 +131,11 @@ internal class CalculatorAgentProvider : AgentProvider { onToolCallEvent("Tool ${ctx.tool.name}, args ${ctx.toolArgs}") } - onAgentRunError { ctx -> + onAgentExecutionFailed { ctx -> onErrorEvent("${ctx.throwable.message}") } - onAgentFinished { ctx -> + onAgentCompleted { ctx -> // Skip finish event handling } } diff --git a/examples/demo-compose-app/app/src/commonMain/kotlin/com/jetbrains/example/kotlin_agents_demo_app/agents/weather/WeatherAgentProvider.kt b/examples/demo-compose-app/app/src/commonMain/kotlin/com/jetbrains/example/kotlin_agents_demo_app/agents/weather/WeatherAgentProvider.kt index d1b1b9698c..4282db75c4 100644 --- a/examples/demo-compose-app/app/src/commonMain/kotlin/com/jetbrains/example/kotlin_agents_demo_app/agents/weather/WeatherAgentProvider.kt +++ b/examples/demo-compose-app/app/src/commonMain/kotlin/com/jetbrains/example/kotlin_agents_demo_app/agents/weather/WeatherAgentProvider.kt @@ -133,15 +133,15 @@ internal class WeatherAgentProvider : AgentProvider { toolRegistry = toolRegistry, ) { handleEvents { - onToolCall { ctx -> + onToolExecutionStarting { ctx -> onToolCallEvent("Tool ${ctx.tool.name}, args ${ctx.toolArgs}") } - onAgentRunError { ctx -> + onAgentExecutionFailed { ctx -> onErrorEvent("${ctx.throwable.message}") } - onAgentFinished { ctx -> + onAgentCompleted { ctx -> // Skip finish event handling } } diff --git a/examples/notebooks/Calculator.ipynb b/examples/notebooks/Calculator.ipynb index 70e050b3d1..1dac521b08 100644 --- a/examples/notebooks/Calculator.ipynb +++ b/examples/notebooks/Calculator.ipynb @@ -248,13 +248,13 @@ " toolRegistry = toolRegistry\n", ") {\n", " handleEvents {\n", - " onToolCall { e ->\n", + " onToolCallStarting { e ->\n", " println(\"Tool called: ${e.tool.name}, args=${e.toolArgs}\")\n", " }\n", - " onAgentRunError { e ->\n", + " onAgentExecutionFailed { e ->\n", " println(\"Agent error: ${e.throwable.message}\")\n", " }\n", - " onAgentFinished { e ->\n", + " onAgentCompleted { e ->\n", " println(\"Final result: ${e.result}\")\n", " }\n", " }\n", diff --git a/examples/notebooks/UnityMcp.ipynb b/examples/notebooks/UnityMcp.ipynb index 8ab4fc542b..8228bb2060 100644 --- a/examples/notebooks/UnityMcp.ipynb +++ b/examples/notebooks/UnityMcp.ipynb @@ -160,17 +160,17 @@ " install(Tracing)\n", "\n", " install(EventHandler) {\n", - " onBeforeAgentStarted { eventContext ->\n", - " println(\"OnBeforeAgentStarted first (strategy: ${strategy.name})\")\n", + " onAgentStarting { eventContext ->\n", + " println(\"OnAgentStarting first (strategy: ${strategy.name})\")\n", " }\n", "\n", - " onBeforeAgentStarted { eventContext ->\n", - " println(\"OnBeforeAgentStarted second (strategy: ${strategy.name})\")\n", + " onAgentStarting { eventContext ->\n", + " println(\"OnAgentStarting second (strategy: ${strategy.name})\")\n", " }\n", "\n", - " onAgentFinished { eventContext ->\n", + " onAgentCompleted { eventContext ->\n", " println(\n", - " \"OnAgentFinished (agent id: ${eventContext.agentId}, result: ${eventContext.result})\"\n", + " \"OnAgentCompleted (agent id: ${eventContext.agentId}, result: ${eventContext.result})\"\n", " )\n", " }\n", " }\n", diff --git a/examples/simple-examples/CLAUDE.md b/examples/simple-examples/CLAUDE.md new file mode 100644 index 0000000000..d4d5a482c9 --- /dev/null +++ b/examples/simple-examples/CLAUDE.md @@ -0,0 +1,92 @@ +## Project Overview + +This is the **simple-examples** project within the Koog Framework repository. It contains executable examples demonstrating various AI agent patterns and capabilities, from basic calculator agents to advanced features like persistence, streaming, and agent-to-agent communication. + +## Environment Setup + +Examples require API keys for LLM providers. Configure them via: + +## Running Examples + +Each example has a dedicated Gradle task: + +```bash +# Run any example +./gradlew runExampleCalculator +./gradlew runExampleStreamingWithTools +./gradlew runExampleRoutingViaGraph + +# Build the project +./gradlew assemble + +# Run tests +./gradlew test +``` + +All available tasks follow the pattern `runExample*`. See `build.gradle.kts` for the complete list. + +## Project Architecture + +### Composite Build Setup + +This project uses Gradle composite build to depend on the parent Koog framework (`settings.gradle.kts:`). Changes in the main framework are immediately available without publishing. + +### Key Dependencies + +- `ai.koog:koog-agents` - Meta-module with core agent dependencies +- `ai.koog:koog-ktor` - Ktor server integration +- `ai.koog:agents-features-sql` - SQL persistence feature +- `ai.koog:agents-features-a2a-server` - Agent-to-agent server +- `ai.koog:agents-features-a2a-client` - Agent-to-agent client +- `ai.koog:agents-test` - Testing utilities (test scope) + +### Example Structure + +Examples are organized by capability: + +- **Core patterns**: `calculator/`, `guesser/`, `tone/` - Basic agent patterns with tools and strategies +- **Agent-to-agent (A2A)**: `a2a/` - Inter-agent communication examples +- **Advanced features**: `memory/`, `snapshot/`, `moderation/` - Memory, persistence, content filtering +- **External integration**: `websearch/`, `attachments/`, `client/` - External API integration +- **Structured output**: `structuredoutput/` - Schema-based output formatting and streaming +- **Banking routing**: `banking/` - Complex multi-agent routing patterns + +### Common Patterns + +1. **Tool Definition**: Use `@Tool` and `@LLMDescription` annotations on methods in classes extending `ToolSet` +2. **Strategy Creation**: Define agent behavior as state machine graphs using `strategy { }` DSL +3. **Agent Setup**: Combine `PromptExecutor`, strategy, `AIAgentConfig`, and `ToolRegistry` +4. **Event Handling**: Use `handleEvents { }` to observe tool calls, errors, and completion + +Example from `calculator/Calculator.kt`: +```kotlin +val toolRegistry = ToolRegistry { + tool(AskUser) + tools(CalculatorTools().asTools()) +} + +val agent = AIAgent( + promptExecutor = executor, + strategy = CalculatorStrategy.strategy, + agentConfig = agentConfig, + toolRegistry = toolRegistry +) { + handleEvents { /* ... */ } +} +``` + +## Common Development Tasks + +```bash +# Build without running tests +./gradlew assemble + +# Run a specific example +./gradlew runExampleCalculator + +# Run tests +./gradlew test + +# Clean build artifacts +./gradlew clean +``` diff --git a/examples/simple-examples/README.md b/examples/simple-examples/README.md index e8d52be82d..b33f48a184 100644 --- a/examples/simple-examples/README.md +++ b/examples/simple-examples/README.md @@ -85,6 +85,15 @@ Welcome to the **Koog Framework Simple Examples** collection! This project showc | **Bedrock Agent** | AI agents using AWS Bedrock integration | `runExampleBedrockAgent` | [📓 BedrockAgent.ipynb](../notebooks/BedrockAgent.ipynb) | | **Web Search** | Agent with web search capabilities | `runExampleWebSearchAgent` | - | +### Agent-to-Agent (A2A) + +Examples demonstrating the A2A protocol for inter-agent communication. See the [A2A README](src/main/kotlin/ai/koog/agents/example/a2a/README.md) for details. + +| Example | Description | Files | +|------------------------|------------------------------------------------------------|---------------------------------------| +| **Simple Joke Agent** | Basic A2A agent with message-based joke generation | `simplejoke/` (Server + Client) | +| **Advanced Joke Agent** | Task-based agent with clarification flow and artifacts | `advancedjoke/` (Server + Client) | + ### Advanced Patterns | Feature | Description | Gradle Task | Notebook | diff --git a/examples/simple-examples/build.gradle.kts b/examples/simple-examples/build.gradle.kts index cc73f941c8..6e5e2d54ae 100644 --- a/examples/simple-examples/build.gradle.kts +++ b/examples/simple-examples/build.gradle.kts @@ -17,6 +17,15 @@ dependencies { implementation("ai.koog:koog-ktor") //noinspection UseTomlInstead implementation("ai.koog:agents-features-sql") + //noinspection UseTomlInstead + implementation("ai.koog:agents-features-a2a-server") + //noinspection UseTomlInstead + implementation("ai.koog:agents-features-a2a-client") + //noinspection UseTomlInstead + implementation("ai.koog:a2a-transport-server-jsonrpc-http") + //noinspection UseTomlInstead + implementation("ai.koog:a2a-transport-client-jsonrpc-http") + //noinspection UseTomlInstead testImplementation("ai.koog:agents-test") @@ -106,3 +115,15 @@ registerRunExampleTask("runExampleFilePersistentAgent", "ai.koog.agents.example. registerRunExampleTask("runExampleSQLPersistentAgent", "ai.koog.agents.example.snapshot.sql.SQLPersistentAgentExample") registerRunExampleTask("runExampleWebSearchAgent", "ai.koog.agents.example.websearch.WebSearchAgentKt") registerRunExampleTask("runExampleStreamingWithTools", "ai.koog.agents.example.streaming.StreamingAgentWithToolsKt") + +/* + A2A examples +*/ + +// Simple joke generation +registerRunExampleTask("runExampleSimpleJokeAgentServer", "ai.koog.agents.example.a2a.simplejoke.ServerKt") +registerRunExampleTask("runExampleSimpleJokeAgentClient", "ai.koog.agents.example.a2a.simplejoke.ClientKt") + +// Advanced joke generation +registerRunExampleTask("runExampleAdvancedJokeAgentServer", "ai.koog.agents.example.a2a.advancedjoke.ServerKt") +registerRunExampleTask("runExampleAdvancedJokeAgentClient", "ai.koog.agents.example.a2a.advancedjoke.ClientKt") diff --git a/examples/simple-examples/gradle.properties b/examples/simple-examples/gradle.properties index e360d6a58e..ab47dc19ae 100644 --- a/examples/simple-examples/gradle.properties +++ b/examples/simple-examples/gradle.properties @@ -1,8 +1,10 @@ #Kotlin kotlin.code.style=official kotlin.daemon.jvmargs=-Xmx4096M +kotlin.native.ignoreDisabledTargets=true #Gradle org.gradle.jvmargs=-Xmx4096M -Dfile.encoding=UTF-8 org.gradle.parallel=true org.gradle.caching=true + diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/README.md b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/README.md new file mode 100644 index 0000000000..cde6a1dd6e --- /dev/null +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/README.md @@ -0,0 +1,44 @@ +# Agent-to-Agent (A2A) Examples + +Examples demonstrating the A2A protocol for inter-agent communication with standardized message/task workflows, streaming responses, and artifact delivery. + +## Examples + +### Simple Joke Agent (`simplejoke/`) + +Basic message-based communication without tasks. + +**Run:** +```bash +# Terminal 1: Start server (port 9998) +./gradlew runExampleSimpleJokeServer + +# Terminal 2: Run client +./gradlew runExampleSimpleJokeClient +``` + +### Advanced Joke Agent (`advancedjoke/`) + +Task-based workflow with: +- Interactive clarification (InputRequired state) +- Artifact delivery for results +- Graph-based agent strategy with documented nodes/edges +- Streaming response events + +**Run:** +```bash +# Terminal 1: Start server (port 9999) +./gradlew runExampleAdvancedJokeServer + +# Terminal 2: Run client +./gradlew runExampleAdvancedJokeClient +``` + +## Key Patterns + +**Simple Agent:** `sendMessage()` → single response +**Advanced Agent:** `sendMessageStreaming()` → Flow of events (Task, TaskStatusUpdateEvent, TaskArtifactUpdateEvent) + +**Task States:** Submitted → Working → InputRequired (optional) → Completed + +See code comments in `JokeWriterAgentExecutor.kt` for detailed flow documentation. diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/Client.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/Client.kt new file mode 100644 index 0000000000..8c6464ed7c --- /dev/null +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/Client.kt @@ -0,0 +1,154 @@ +@file:OptIn(ExperimentalUuidApi::class) + +package ai.koog.agents.example.a2a.advancedjoke + +import ai.koog.a2a.client.A2AClient +import ai.koog.a2a.client.UrlAgentCardResolver +import ai.koog.a2a.model.Artifact +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskArtifactUpdateEvent +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.client.jsonrpc.http.HttpJSONRPCClientTransport +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +private const val CYAN = "\u001B[36m" +private const val YELLOW = "\u001B[33m" +private const val MAGENTA = "\u001B[35m" +private const val GREEN = "\u001B[32m" +private const val RED = "\u001B[31m" +private const val BLUE = "\u001B[34m" +private const val RESET = "\u001B[0m" + +private val json = Json { prettyPrint = true } + +@OptIn(ExperimentalUuidApi::class) +suspend fun main() { + println("\n${YELLOW}Starting Advanced Joke Generator A2A Client$RESET\n") + + val transport = HttpJSONRPCClientTransport(url = "http://localhost:9999$ADVANCED_JOKE_AGENT_PATH") + val agentCardResolver = UrlAgentCardResolver(baseUrl = "http://localhost:9999", path = ADVANCED_JOKE_AGENT_CARD_PATH) + val client = A2AClient(transport = transport, agentCardResolver = agentCardResolver) + + client.connect() + val agentCard = client.cachedAgentCard() + println("${YELLOW}Connected: ${agentCard.name}$RESET\n") + + if (agentCard.capabilities.streaming != true) { + println("${RED}Error: Streaming not supported$RESET") + transport.close() + return + } + + println("${CYAN}Context ID:$RESET") + val contextId = readln() + println() + + var currentTaskId: String? = null + val artifacts = mutableMapOf() + + while (true) { + println("${CYAN}Request (/q to quit):$RESET") + val request = readln() + println() + + if (request == "/q") break + + val message = Message( + messageId = Uuid.random().toString(), + role = Role.User, + parts = listOf(TextPart(request)), + contextId = contextId, + taskId = currentTaskId + ) + + try { + client.sendMessageStreaming(Request(MessageSendParams(message = message))).collect { response -> + val event = response.data + println("${BLUE}[${event.kind}]$RESET") + println("${json.encodeToString(event)}\n") + + when (event) { + is Task -> { + currentTaskId = event.id + event.artifacts?.forEach { artifacts[it.artifactId] = it } + } + + is Message -> { + val textContent = event.parts.filterIsInstance().joinToString("\n") { it.text } + if (textContent.isNotBlank()) { + println("${MAGENTA}Message:$RESET\n$textContent\n") + } + } + + is TaskStatusUpdateEvent -> { + when (event.status.state) { + TaskState.InputRequired -> { + val question = event.status.message?.parts + ?.filterIsInstance() + ?.joinToString("\n") { it.text } + if (!question.isNullOrBlank()) { + println("${MAGENTA}Question:$RESET\n$question\n") + } + } + + TaskState.Completed -> { + if (artifacts.isNotEmpty()) { + println("${GREEN}=== Artifacts ===$RESET") + artifacts.values.forEach { artifact -> + val content = artifact.parts.filterIsInstance() + .joinToString("\n") { it.text } + if (content.isNotBlank()) { + println("${GREEN}[${artifact.artifactId}]$RESET\n$content\n") + } + } + } + if (event.final) { + currentTaskId = null + artifacts.clear() + } + } + + TaskState.Failed, TaskState.Canceled, TaskState.Rejected -> { + if (event.final) { + currentTaskId = null + artifacts.clear() + } + } + + else -> {} + } + } + + is TaskArtifactUpdateEvent -> { + if (event.append == true) { + val existing = artifacts[event.artifact.artifactId] + if (existing != null) { + artifacts[event.artifact.artifactId] = existing.copy( + parts = existing.parts + event.artifact.parts + ) + } else { + artifacts[event.artifact.artifactId] = event.artifact + } + } else { + artifacts[event.artifact.artifactId] = event.artifact + } + } + } + } + } catch (e: Exception) { + println("${RED}Error: ${e.message}$RESET\n") + } + } + + println("${YELLOW}Done$RESET") + transport.close() +} diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/JokeWriterAgentExecutor.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/JokeWriterAgentExecutor.kt new file mode 100644 index 0000000000..d74cf4ce38 --- /dev/null +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/JokeWriterAgentExecutor.kt @@ -0,0 +1,406 @@ +package ai.koog.agents.example.a2a.advancedjoke + +import ai.koog.a2a.exceptions.A2AUnsupportedOperationException +import ai.koog.a2a.model.Artifact +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskArtifactUpdateEvent +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatus +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.server.agent.AgentExecutor +import ai.koog.a2a.server.session.RequestContext +import ai.koog.a2a.server.session.SessionEventProcessor +import ai.koog.agents.a2a.core.A2AMessage +import ai.koog.agents.a2a.core.toKoogMessage +import ai.koog.agents.a2a.server.feature.A2AAgentServer +import ai.koog.agents.a2a.server.feature.withA2AAgentServer +import ai.koog.agents.core.agent.GraphAIAgent +import ai.koog.agents.core.agent.config.AIAgentConfig +import ai.koog.agents.core.agent.context.agentInput +import ai.koog.agents.core.dsl.builder.forwardTo +import ai.koog.agents.core.dsl.builder.strategy +import ai.koog.agents.core.dsl.extension.nodeLLMRequestStructured +import ai.koog.agents.core.dsl.extension.onIsInstance +import ai.koog.agents.core.tools.ToolRegistry +import ai.koog.agents.core.tools.annotations.LLMDescription +import ai.koog.agents.example.ApiKeyService +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient +import ai.koog.prompt.executor.clients.google.GoogleLLMClient +import ai.koog.prompt.executor.clients.google.GoogleModels +import ai.koog.prompt.executor.clients.openai.OpenAILLMClient +import ai.koog.prompt.executor.llms.MultiLLMPromptExecutor +import ai.koog.prompt.executor.model.PromptExecutor +import ai.koog.prompt.llm.LLMProvider +import ai.koog.prompt.message.Message +import ai.koog.prompt.text.text +import ai.koog.prompt.xml.xml +import kotlinx.datetime.Clock +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlin.reflect.typeOf +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +/** + * An advanced A2A agent that demonstrates: + * - Task-based conversation flow with state management + * - Interactive clarification questions (InputRequired state) + * - Structured output via sealed interfaces + * - Artifact delivery for final results + */ +class JokeWriterAgentExecutor : AgentExecutor { + private val promptExecutor = MultiLLMPromptExecutor( + LLMProvider.OpenAI to OpenAILLMClient(ApiKeyService.openAIApiKey), + LLMProvider.Anthropic to AnthropicLLMClient(ApiKeyService.anthropicApiKey), + LLMProvider.Google to GoogleLLMClient(ApiKeyService.googleApiKey), + ) + + @OptIn(ExperimentalUuidApi::class) + override suspend fun execute( + context: RequestContext, + eventProcessor: SessionEventProcessor + ) { + val agent = jokeWriterAgent(promptExecutor, context, eventProcessor) + agent.run(context.params.message) + } +} + +private fun jokeWriterAgent( + promptExecutor: PromptExecutor, + context: RequestContext, + eventProcessor: SessionEventProcessor +): GraphAIAgent { + val agentConfig = AIAgentConfig( + prompt = prompt("joke-generation") { + system { + +"You are a very funny sarcastic assistant. You must help users generate funny jokes." + +"When asked for something else, sarcastically decline the request because you can only assist with jokes." + } + }, + model = GoogleModels.Gemini2_5Flash, + maxAgentIterations = 20 + ) + + return GraphAIAgent( + inputType = typeOf(), + outputType = typeOf(), + promptExecutor = promptExecutor, + strategy = jokeWriterStrategy(), + agentConfig = agentConfig, + toolRegistry = ToolRegistry.EMPTY, + ) { + install(A2AAgentServer) { + this.context = context + this.eventProcessor = eventProcessor + } + } +} + +@OptIn(ExperimentalUuidApi::class) +private fun jokeWriterStrategy() = strategy("joke-writer") { + // Node: Load conversation history from message storage + val setupMessageContext by node { userInput -> + if (!userInput.referenceTaskIds.isNullOrEmpty()) { + throw A2AUnsupportedOperationException("This agent doesn't understand task references in referenceTaskIds yet.") + } + + // Load current context messages + val contextMessages: List = withA2AAgentServer { + context.messageStorage.getAll() + } + + // Update the prompt with the current context messages + llm.writeSession { + updatePrompt { + messages(contextMessages.map { it.toKoogMessage() }) + } + } + + userInput + } + + // Node: Load existing task (if continuing) or prepare for new task creation + val setupTaskContext by node { userInput -> + // Check if the message continues the task that already exists + val currentTask: Task? = withA2AAgentServer { + context.task?.id?.let { id -> + // Load task with full conversation history to continue working on it + context.taskStorage.get(id, historyLength = null) + } + } + + currentTask?.let { task -> + val currentTaskMessages = (task.history.orEmpty() + listOfNotNull(task.status.message) + userInput) + .map { it.toKoogMessage() } + + llm.writeSession { + updatePrompt { + user { + +"There's an ongoing task, the next messages contain conversation history for this task" + } + + messages(currentTaskMessages) + } + } + } + + /* + If task exists then the message belongs to the task, send event to update the task. + Otherwise, put it in general message storage for the current context. + */ + withA2AAgentServer { + if (currentTask != null) { + val updateEvent = TaskStatusUpdateEvent( + taskId = currentTask.id, + contextId = currentTask.contextId, + status = TaskStatus( + state = TaskState.Working, + message = userInput, + timestamp = Clock.System.now(), + ), + final = false, + ) + + eventProcessor.sendTaskEvent(updateEvent) + } else { + context.messageStorage.save(userInput) + } + } + + currentTask + } + + // Node: Ask LLM to classify if this is a joke request or something else + val classifyNewRequest by nodeLLMRequestStructured() + + // Node: Send a polite decline message if the request is not about jokes + val respondFallbackMessage by node { classification -> + withA2AAgentServer { + val message = A2AMessage( + messageId = Uuid.random().toString(), + role = Role.Agent, + parts = listOf( + TextPart(classification.response) + ), + contextId = context.contextId, + taskId = context.taskId, + ) + + // Store reply in message storage to preserve context + context.messageStorage.save(message) + // Reply with message + eventProcessor.sendMessage(message) + } + } + + // Node: Create a new task for the joke request + val createTask by node { + val userInput = agentInput() + + withA2AAgentServer { + val task = Task( + id = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = TaskState.Submitted, + message = userInput, + timestamp = Clock.System.now(), + ), + + ) + + eventProcessor.sendTaskEvent(task) + } + } + + // Node: Ask LLM to classify joke details (or request clarification) + val classifyJokeRequest by nodeLLMRequestStructured() + + // Node: Generate the actual joke based on classified parameters + val generateJoke by node { request -> + llm.writeSession { + updatePrompt { + user { + +text { + +"Generate a joke based on the following user request:" + xml { + tag("subject") { + +request.subject + } + tag("targetAudience") { + +request.targetAudience + } + tag("isSwearingAllowed") { + +request.isSwearingAllowed.toString() + } + } + } + } + } + + val message = requestLLMWithoutTools() + message as? Message.Assistant ?: throw IllegalStateException("Unexpected message type: $message") + } + } + + // Node: Send InputRequired event to ask the user for more information + val askMoreInfo by node { clarification -> + withA2AAgentServer { + val taskUpdate = TaskStatusUpdateEvent( + taskId = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = TaskState.InputRequired, + message = A2AMessage( + role = Role.User, + parts = listOf( + TextPart(clarification.question) + ), + messageId = Uuid.random().toString(), + taskId = context.taskId, + contextId = context.contextId + ), + timestamp = Clock.System.now(), + ), + final = true, + ) + + eventProcessor.sendTaskEvent(taskUpdate) + } + } + + // Node: Send the joke as an artifact and mark task as completed + val respondWithJoke by node { jokeMessage -> + withA2AAgentServer { + val artifactUpdate = TaskArtifactUpdateEvent( + taskId = context.taskId, + contextId = context.contextId, + artifact = Artifact( + artifactId = "joke", + parts = listOf( + TextPart(jokeMessage.content) + ) + ), + ) + + eventProcessor.sendTaskEvent(artifactUpdate) + + val taskStatusUpdate = TaskStatusUpdateEvent( + taskId = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = TaskState.Completed, + ), + final = true, + ) + + eventProcessor.sendTaskEvent(taskStatusUpdate) + } + } + + // --- Graph Flow Definition --- + + // Always start by loading context and checking for existing tasks + nodeStart then setupMessageContext then setupTaskContext + + // If no task exists, classify whether this is a joke request + edge( + setupTaskContext forwardTo classifyNewRequest + onCondition { task -> task == null } + transformed { agentInput().content() } + ) + // If task exists, continue processing the joke request + edge( + setupTaskContext forwardTo classifyJokeRequest + onCondition { task -> task != null } + transformed { agentInput().content() } + ) + + // New request classification: If not a joke request, decline politely + edge( + classifyNewRequest forwardTo respondFallbackMessage + transformed { it.getOrThrow().structure } + onCondition { !it.isJokeRequest } + ) + // New request classification: If joke request, create a task + edge( + classifyNewRequest forwardTo createTask + transformed { it.getOrThrow().structure } + onCondition { it.isJokeRequest } + ) + + edge(respondFallbackMessage forwardTo nodeFinish) + + // After creating task, classify the joke details + edge( + createTask forwardTo classifyJokeRequest + transformed { agentInput().content() } + ) + + // Joke classification: Ask for clarification if needed + edge( + classifyJokeRequest forwardTo askMoreInfo + transformed { it.getOrThrow().structure } + onIsInstance JokeRequestClassification.NeedsClarification::class + ) + // Joke classification: Generate joke if we have all details + edge( + classifyJokeRequest forwardTo generateJoke + transformed { it.getOrThrow().structure } + onIsInstance JokeRequestClassification.Ready::class + ) + + // After asking for info, wait for user response (finish this iteration) + edge(askMoreInfo forwardTo nodeFinish) + + // After generating joke, send it as an artifact + edge(generateJoke forwardTo respondWithJoke) + edge(respondWithJoke forwardTo nodeFinish) +} + +private fun A2AMessage.content(): String { + return parts.filterIsInstance().joinToString(separator = "\n") { it.text } +} + +// --- Structured Output Models --- + +@Serializable +@LLMDescription("Initial incoming user message classification, to determine if this is a joke request or not.") +private data class UserRequestClassification( + @property:LLMDescription("Whether the incoming message is a joke request or not") + val isJokeRequest: Boolean, + @property:LLMDescription( + "In case the message is not a joke request, polite reply to the user that the agent cannot assist." + + "Default is empty" + ) + val response: String = "", +) + +@LLMDescription("The classification of the joke request") +@Serializable +@SerialName("JokeRequestClassification") +private sealed interface JokeRequestClassification { + @Serializable + @SerialName("NeedsClarification") + @LLMDescription("The joke request needs clarification") + data class NeedsClarification( + @property:LLMDescription("The question that needs clarification") + val question: String + ) : JokeRequestClassification + + @LLMDescription("The joke request is ready to be processed") + @Serializable + @SerialName("Ready") + data class Ready( + @property:LLMDescription("The joke subject") + val subject: String, + @property:LLMDescription("The joke target audience") + val targetAudience: String, + @property:LLMDescription("Whether the swearing is allowed in the joke") + val isSwearingAllowed: Boolean, + ) : JokeRequestClassification +} diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/Server.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/Server.kt new file mode 100644 index 0000000000..6e2edcc8f8 --- /dev/null +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/Server.kt @@ -0,0 +1,79 @@ +package ai.koog.agents.example.a2a.advancedjoke + +import ai.koog.a2a.model.AgentCapabilities +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.AgentInterface +import ai.koog.a2a.model.AgentSkill +import ai.koog.a2a.model.TransportProtocol +import ai.koog.a2a.server.A2AServer +import ai.koog.a2a.transport.server.jsonrpc.http.HttpJSONRPCServerTransport +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.server.cio.CIO + +private val logger = KotlinLogging.logger {} + +const val ADVANCED_JOKE_AGENT_PATH = "/advanced-joke-agent" +const val ADVANCED_JOKE_AGENT_CARD_PATH = "$ADVANCED_JOKE_AGENT_PATH/agent-card.json" + +suspend fun main() { + logger.info { "Starting Advanced Joke A2A Agent on http://localhost:9999" } + + // Create agent card with capabilities - this agent supports streaming and tasks + val agentCard = AgentCard( + protocolVersion = "0.3.0", + name = "Advanced Joke Generator", + description = "A sophisticated AI agent that generates jokes with clarifying questions and structured task flow", + version = "1.0.0", + url = "http://localhost:9999$ADVANCED_JOKE_AGENT_PATH", + preferredTransport = TransportProtocol.JSONRPC, + additionalInterfaces = listOf( + AgentInterface( + url = "http://localhost:9999$ADVANCED_JOKE_AGENT_PATH", + transport = TransportProtocol.JSONRPC, + ) + ), + capabilities = AgentCapabilities( + streaming = true, // Supports streaming responses + pushNotifications = false, + stateTransitionHistory = false, + ), + defaultInputModes = listOf("text"), + defaultOutputModes = listOf("text"), + skills = listOf( + AgentSkill( + id = "advanced_joke_generation", + name = "Advanced Joke Generation", + description = "Generates humorous jokes with interactive clarification and customization options", + examples = listOf( + "Tell me a joke about programming", + "Generate a funny joke for teenagers", + "Make me laugh with a dad joke about cats" + ), + tags = listOf("humor", "jokes", "entertainment", "interactive") + ) + ), + supportsAuthenticatedExtendedCard = false + ) + + // Create agent executor + val agentExecutor = JokeWriterAgentExecutor() + + // Create A2A server + val a2aServer = A2AServer( + agentExecutor = agentExecutor, + agentCard = agentCard, + ) + + // Create and start server transport + val serverTransport = HttpJSONRPCServerTransport(a2aServer) + + logger.info { "Advanced Joke Generator Agent ready at http://localhost:9999/$ADVANCED_JOKE_AGENT_PATH" } + serverTransport.start( + engineFactory = CIO, + port = 9999, + path = ADVANCED_JOKE_AGENT_PATH, + wait = true, // Block until server stops + agentCard = agentCard, + agentCardPath = ADVANCED_JOKE_AGENT_CARD_PATH + ) +} diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/simplejoke/Client.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/simplejoke/Client.kt new file mode 100644 index 0000000000..41e4b2022e --- /dev/null +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/simplejoke/Client.kt @@ -0,0 +1,84 @@ +@file:OptIn(ExperimentalUuidApi::class) + +package ai.koog.agents.example.a2a.simplejoke + +import ai.koog.a2a.client.A2AClient +import ai.koog.a2a.client.UrlAgentCardResolver +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.client.jsonrpc.http.HttpJSONRPCClientTransport +import ai.koog.agents.a2a.core.toKoogMessage +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +private const val BRIGHT_CYAN = "\u001B[1;36m" +private const val YELLOW = "\u001B[33m" +private const val BRIGHT_MAGENTA = "\u001B[1;35m" +private const val RED = "\u001B[31m" +private const val RESET = "\u001B[0m" + +@OptIn(ExperimentalUuidApi::class) +suspend fun main() { + println() + println("${YELLOW}Starting Joke Generator A2A Client$RESET\n") + + // Set up the HTTP JSON-RPC transport + val transport = HttpJSONRPCClientTransport( + url = "http://localhost:9998$JOKE_GENERATOR_AGENT_PATH" + ) + + // Set up the agent card resolver + val agentCardResolver = UrlAgentCardResolver( + baseUrl = "http://localhost:9998", + path = JOKE_GENERATOR_AGENT_CARD_PATH + ) + + // Create the A2A client + val client = A2AClient( + transport = transport, + agentCardResolver = agentCardResolver + ) + + // Connect and fetch agent card + client.connect() + val agentCard = client.cachedAgentCard() + println("${YELLOW}Connected to agent:$RESET\n${agentCard.name} (${agentCard.description})\n") + + // Read context ID + println("${BRIGHT_CYAN}Context ID (which chat to start/continue):$RESET") + val contextId = readln() + println() + + // Start chat loop + while (true) { + println("${BRIGHT_CYAN}Request (/q to quit):$RESET") + val request = readln() + println() + + if (request == "/q") { + break + } + + val message = Message( + messageId = Uuid.random().toString(), + role = Role.User, + parts = listOf(TextPart(request)), + contextId = contextId + ) + + val response = client.sendMessage( + Request(MessageSendParams(message = message)) + ) + + val reply = (response.data as Message).toKoogMessage().content + println("${BRIGHT_MAGENTA}Agent response:${RESET}\n$reply\n") + } + + println("${RED}Conversation complete!$RESET") + + // Clean up + transport.close() +} diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/simplejoke/Server.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/simplejoke/Server.kt new file mode 100644 index 0000000000..24a7450703 --- /dev/null +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/simplejoke/Server.kt @@ -0,0 +1,79 @@ +package ai.koog.agents.example.a2a.simplejoke + +import ai.koog.a2a.model.AgentCapabilities +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.AgentInterface +import ai.koog.a2a.model.AgentSkill +import ai.koog.a2a.model.TransportProtocol +import ai.koog.a2a.server.A2AServer +import ai.koog.a2a.transport.server.jsonrpc.http.HttpJSONRPCServerTransport +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.server.cio.CIO + +private val logger = KotlinLogging.logger {} + +const val JOKE_GENERATOR_AGENT_PATH = "/joke-generator-agent" +const val JOKE_GENERATOR_AGENT_CARD_PATH = "$JOKE_GENERATOR_AGENT_PATH/agent-card.json" + +suspend fun main() { + logger.info { "Starting Joke A2A Agent on http://localhost:9998" } + + // Create agent card with capabilities + val agentCard = AgentCard( + protocolVersion = "0.3.0", + name = "Joke Generator 3000", + description = "A helpful AI agent that generates jokes based on user requests", + version = "1.0.0", + url = "http://localhost:9998$JOKE_GENERATOR_AGENT_PATH", + preferredTransport = TransportProtocol.JSONRPC, + additionalInterfaces = listOf( + AgentInterface( + url = "http://localhost:9998$JOKE_GENERATOR_AGENT_PATH", + transport = TransportProtocol.JSONRPC, + ) + ), + capabilities = AgentCapabilities( + streaming = false, + pushNotifications = false, + stateTransitionHistory = false, + ), + defaultInputModes = listOf("text"), + defaultOutputModes = listOf("text"), + skills = listOf( + AgentSkill( + id = "joke_generation", + name = "Joke Generation", + description = "Generates humorous jokes on various topics", + examples = listOf( + "Tell me a joke", + "Generate a funny joke about programming", + "Make me laugh with a dad joke" + ), + tags = listOf("humor", "jokes", "entertainment") + ) + ), + supportsAuthenticatedExtendedCard = false + ) + + // Create agent executor + val agentExecutor = SimpleJokeAgentExecutor() + + // Create A2A server + val a2aServer = A2AServer( + agentExecutor = agentExecutor, + agentCard = agentCard, + ) + + // Create and start server transport + val serverTransport = HttpJSONRPCServerTransport(a2aServer) + + logger.info { "Joke Generator Agent ready at http://localhost:9998/$JOKE_GENERATOR_AGENT_PATH" } + serverTransport.start( + engineFactory = CIO, + port = 9998, + path = JOKE_GENERATOR_AGENT_PATH, + wait = true, // Block until server stops + agentCard = agentCard, + agentCardPath = JOKE_GENERATOR_AGENT_CARD_PATH + ) +} diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/simplejoke/SimpleJokeAgentExecutor.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/simplejoke/SimpleJokeAgentExecutor.kt new file mode 100644 index 0000000000..30212ff03f --- /dev/null +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/simplejoke/SimpleJokeAgentExecutor.kt @@ -0,0 +1,78 @@ +package ai.koog.agents.example.a2a.simplejoke + +import ai.koog.a2a.exceptions.A2AUnsupportedOperationException +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.server.agent.AgentExecutor +import ai.koog.a2a.server.session.RequestContext +import ai.koog.a2a.server.session.SessionEventProcessor +import ai.koog.agents.a2a.core.MessageA2AMetadata +import ai.koog.agents.a2a.core.toA2AMessage +import ai.koog.agents.a2a.core.toKoogMessage +import ai.koog.agents.example.ApiKeyService +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient +import ai.koog.prompt.executor.clients.anthropic.AnthropicModels +import ai.koog.prompt.executor.clients.google.GoogleLLMClient +import ai.koog.prompt.executor.clients.openai.OpenAILLMClient +import ai.koog.prompt.executor.llms.MultiLLMPromptExecutor +import ai.koog.prompt.llm.LLMProvider +import ai.koog.prompt.message.Message +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +/** + * This is a simple example of an agent executor that wraps LLM calls using prompt executor to generate jokes. + */ +class SimpleJokeAgentExecutor : AgentExecutor { + private val promptExecutor = MultiLLMPromptExecutor( + LLMProvider.OpenAI to OpenAILLMClient(ApiKeyService.openAIApiKey), + LLMProvider.Anthropic to AnthropicLLMClient(ApiKeyService.anthropicApiKey), + LLMProvider.Google to GoogleLLMClient(ApiKeyService.googleApiKey), + ) + + @OptIn(ExperimentalUuidApi::class) + override suspend fun execute( + context: RequestContext, + eventProcessor: SessionEventProcessor + ) { + val userMessage = context.params.message + + if (context.task != null || !userMessage.referenceTaskIds.isNullOrEmpty()) { + throw A2AUnsupportedOperationException("This agent doesn't support tasks") + } + + // Save incoming message to the current context + context.messageStorage.save(userMessage) + + // Load all messages from the current context + val contextMessages = context.messageStorage.getAll().map { it.toKoogMessage() } + + val prompt = prompt("joke-generation") { + system { + +"You are an assistant helping user to generate jokes" + } + + // Append current message context + messages(contextMessages) + } + + // Get a response from the LLM + val responseMessage = promptExecutor.execute(prompt, AnthropicModels.Sonnet_4) + .single() + .let { message -> + message as? Message.Assistant ?: throw IllegalStateException("Unexpected message type: $message") + } + .toA2AMessage( + a2aMetadata = MessageA2AMetadata( + messageId = Uuid.random().toString(), + contextId = context.contextId, + ) + ) + + // Save the response to the current context + context.messageStorage.save(responseMessage) + + // Reply with message + eventProcessor.sendMessage(responseMessage) + } +} diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/calculator/Calculator.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/calculator/Calculator.kt index 466cc667e2..ccb0736a60 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/calculator/Calculator.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/calculator/Calculator.kt @@ -42,17 +42,17 @@ fun main(): Unit = runBlocking { toolRegistry = toolRegistry ) { handleEvents { - onToolCall { eventContext -> + onToolCallStarting { eventContext -> println("Tool called: tool ${eventContext.tool.name}, args ${eventContext.toolArgs}") } - onAgentRunError { eventContext -> + onAgentExecutionFailed { eventContext -> println( "An error occurred: ${eventContext.throwable.message}\n${eventContext.throwable.stackTraceToString()}" ) } - onAgentFinished { eventContext -> + onAgentCompleted { eventContext -> println("Result: ${eventContext.result}") } } diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/calculator/OllamaCalculatorExample.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/calculator/OllamaCalculatorExample.kt index 9c9400851a..b87ff692af 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/calculator/OllamaCalculatorExample.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/calculator/OllamaCalculatorExample.kt @@ -41,17 +41,17 @@ fun main(): Unit = runBlocking { toolRegistry = toolRegistry ) { handleEvents { - onToolCall { eventContext -> + onToolCallStarting { eventContext -> println("Tool called: tool ${eventContext.tool.name}, args ${eventContext.toolArgs}") } - onAgentRunError { eventContext -> + onAgentExecutionFailed { eventContext -> println( "An error occurred: ${eventContext.throwable.message}\n${eventContext.throwable.stackTraceToString()}" ) } - onAgentFinished { eventContext -> + onAgentCompleted { eventContext -> println("Result: ${eventContext.result}") } } diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/funApi/FunAgentWithTools.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/funApi/FunAgentWithTools.kt index 53da2bbe6f..2c23297d63 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/funApi/FunAgentWithTools.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/funApi/FunAgentWithTools.kt @@ -50,7 +50,7 @@ fun main(): Unit = runBlocking { toolRegistry = toolRegistry ) { install(EventHandler) { - onToolCall { eventContext -> + onToolCallStarting { eventContext -> println("Tool called: tool ${eventContext.tool.name}, args ${eventContext.toolArgs}") } } diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/mcp/UnityMcpAgent.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/mcp/UnityMcpAgent.kt index 9f30c7682b..e3405345fc 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/mcp/UnityMcpAgent.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/mcp/UnityMcpAgent.kt @@ -108,17 +108,17 @@ fun main() { install(Tracing) install(EventHandler) { - onBeforeAgentStarted { eventContext -> - println("OnBeforeAgentStarted first (strategy: ${strategy.name})") + onAgentStarting { eventContext -> + println("OnAgentStarting first (strategy: ${strategy.name})") } - onBeforeAgentStarted { eventContext -> - println("OnBeforeAgentStarted second (strategy: ${strategy.name})") + onAgentStarting { eventContext -> + println("OnAgentStarting second (strategy: ${strategy.name})") } - onAgentFinished { eventContext -> + onAgentCompleted { eventContext -> println( - "OnAgentFinished (agent id: ${eventContext.agentId}, result: ${eventContext.result})" + "OnAgentCompleted (agent id: ${eventContext.agentId}, result: ${eventContext.result})" ) } } diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/simpleapi/BasicSingleRunAgent.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/simpleapi/BasicSingleRunAgent.kt index 350f44f5d3..32048bd3e1 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/simpleapi/BasicSingleRunAgent.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/simpleapi/BasicSingleRunAgent.kt @@ -15,7 +15,7 @@ import kotlinx.coroutines.runBlocking fun main() = runBlocking { var result: Any? = null val eventHandlerConfig: EventHandlerConfig.() -> Unit = { - onAgentFinished { eventContext -> result = eventContext.result } + onAgentCompleted { eventContext -> result = eventContext.result } } // Create a single-run agent with a system prompt val agent = AIAgent( diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/CheckpointExample.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/CheckpointExample.kt index 45b048a71e..81191c4f9a 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/CheckpointExample.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/CheckpointExample.kt @@ -7,8 +7,8 @@ import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.core.tools.reflect.asTools import ai.koog.agents.example.calculator.CalculatorTools import ai.koog.agents.features.eventHandler.feature.EventHandler -import ai.koog.agents.snapshot.feature.Persistency -import ai.koog.agents.snapshot.providers.InMemoryPersistencyStorageProvider +import ai.koog.agents.snapshot.feature.Persistence +import ai.koog.agents.snapshot.providers.InMemoryPersistenceStorageProvider import ai.koog.prompt.executor.llms.all.simpleOllamaAIExecutor import ai.koog.prompt.executor.model.PromptExecutor import ai.koog.prompt.llm.OllamaModels @@ -27,12 +27,11 @@ fun main() = runBlocking { tools(CalculatorTools().asTools()) } - val persistenceId = "snapshot-agent-example" + val agentId = "agent.1" - val snapshotProvider = InMemoryPersistencyStorageProvider( - persistenceId = persistenceId - ) + val snapshotProvider = InMemoryPersistenceStorageProvider() val agent = AIAgent( + id = agentId, promptExecutor = executor, llmModel = OllamaModels.Meta.LLAMA_3_2, strategy = singleRunStrategy(ToolCalls.SEQUENTIAL), @@ -40,13 +39,13 @@ fun main() = runBlocking { systemPrompt = "You are a calculator. Use tools to calculate asked to result.", temperature = 0.0, ) { - install(Persistency) { + install(Persistence) { storage = snapshotProvider - enableAutomaticPersistency = true + enableAutomaticPersistence = true } install(EventHandler) { - onToolCallFailure { + onToolCallFailed { throw Exception("Tool call failed") } } @@ -61,21 +60,21 @@ fun main() = runBlocking { } } - val checkpoints = snapshotProvider.getCheckpoints() + val checkpoints = snapshotProvider.getCheckpoints(agentId) println("Snapshot provider state after first run: $checkpoints") val agent2 = AIAgent( + id = agent.id, promptExecutor = executor, llmModel = OllamaModels.Meta.LLAMA_3_2, toolRegistry = correctToolRegistry, strategy = singleRunStrategy(ToolCalls.SEQUENTIAL), systemPrompt = "You are a calculator. Use tools to calculate asked to result.", temperature = 0.0, - id = agent.id ) { - install(Persistency) { + install(Persistence) { storage = snapshotProvider - enableAutomaticPersistency = true + enableAutomaticPersistence = true } } diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/FilePersistentAgentExample.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/FilePersistentAgentExample.kt index 3650e33cbf..ec6821bd74 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/FilePersistentAgentExample.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/FilePersistentAgentExample.kt @@ -5,8 +5,8 @@ import ai.koog.agents.core.agent.config.AIAgentConfig import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.ext.tool.AskUser import ai.koog.agents.ext.tool.SayToUser -import ai.koog.agents.snapshot.feature.Persistency -import ai.koog.agents.snapshot.providers.file.JVMFilePersistencyStorageProvider +import ai.koog.agents.snapshot.feature.Persistence +import ai.koog.agents.snapshot.providers.file.JVMFilePersistenceStorageProvider import ai.koog.prompt.dsl.prompt import ai.koog.prompt.executor.llms.all.simpleOllamaAIExecutor import ai.koog.prompt.executor.model.PromptExecutor @@ -38,7 +38,7 @@ fun main() = runBlocking { println("Checkpoint directory: $checkpointDir") // Create the file-based checkpoint provider - val provider = JVMFilePersistencyStorageProvider(checkpointDir, "persistent-agent-example") + val provider = JVMFilePersistenceStorageProvider(checkpointDir) // Create a unique agent ID to identify this agent's checkpoints val agentId = "persistent-agent-example" @@ -68,9 +68,9 @@ fun main() = runBlocking { toolRegistry = toolRegistry, id = agentId ) { - install(Persistency) { + install(Persistence) { storage = provider // Use the file-based checkpoint provider - enableAutomaticPersistency = true // Enable automatic checkpoint creation + enableAutomaticPersistence = true // Enable automatic checkpoint creation } } @@ -79,7 +79,7 @@ fun main() = runBlocking { println("Agent result: $result") // Retrieve all checkpoints created during the agent's execution - val checkpoints = provider.getCheckpoints() + val checkpoints = provider.getCheckpoints(agentId) println("\nRetrieved ${checkpoints.size} checkpoints for agent $agentId") // Print checkpoint details @@ -108,9 +108,9 @@ fun main() = runBlocking { toolRegistry = toolRegistry, id = agentId ) { - install(Persistency) { + install(Persistence) { storage = provider // Use the file-based checkpoint provider - enableAutomaticPersistency = true // Enable automatic checkpoint creation + enableAutomaticPersistence = true // Enable automatic checkpoint creation } } @@ -120,7 +120,7 @@ fun main() = runBlocking { println("Restored agent result: $restoredResult") // Get the latest checkpoint after the second run - val latestCheckpoint = provider.getLatestCheckpoint() + val latestCheckpoint = provider.getLatestCheckpoint(agentId) println("\nLatest checkpoint after restoration:") println(" ID: ${latestCheckpoint?.checkpointId}") println(" Created at: ${latestCheckpoint?.createdAt}") diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/SnapshotExample.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/SnapshotExample.kt index a009ca7280..9dd907c684 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/SnapshotExample.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/SnapshotExample.kt @@ -7,8 +7,8 @@ import ai.koog.agents.core.tools.reflect.asTools import ai.koog.agents.example.calculator.CalculatorTools import ai.koog.agents.ext.tool.AskUser import ai.koog.agents.ext.tool.SayToUser -import ai.koog.agents.snapshot.feature.Persistency -import ai.koog.agents.snapshot.providers.InMemoryPersistencyStorageProvider +import ai.koog.agents.snapshot.feature.Persistence +import ai.koog.agents.snapshot.providers.InMemoryPersistenceStorageProvider import ai.koog.prompt.dsl.prompt import ai.koog.prompt.executor.llms.all.simpleOllamaAIExecutor import ai.koog.prompt.executor.model.PromptExecutor @@ -37,14 +37,14 @@ fun main() = runBlocking { maxAgentIterations = 50 ) - val snapshotProvider = InMemoryPersistencyStorageProvider("persistent-agent-example") + val snapshotProvider = InMemoryPersistenceStorageProvider() val agent = AIAgent( promptExecutor = executor, strategy = SnapshotStrategy.strategy, agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { + install(Persistence) { storage = snapshotProvider } } diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/SnapshotStrategy.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/SnapshotStrategy.kt index e53fd10895..9fccc2b3dc 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/SnapshotStrategy.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/SnapshotStrategy.kt @@ -4,7 +4,7 @@ import ai.koog.agents.core.dsl.builder.AIAgentNodeDelegate import ai.koog.agents.core.dsl.builder.AIAgentSubgraphBuilderBase import ai.koog.agents.core.dsl.builder.forwardTo import ai.koog.agents.core.dsl.builder.strategy -import ai.koog.agents.snapshot.feature.withPersistency +import ai.koog.agents.snapshot.feature.withPersistence import kotlinx.serialization.json.JsonPrimitive private fun AIAgentSubgraphBuilderBase<*, *>.simpleNode( @@ -22,9 +22,9 @@ private fun AIAgentSubgraphBuilderBase<*, *>.teleportNode( ): AIAgentNodeDelegate = node(name) { if (!teleportState.teleported) { teleportState.teleported = true - withPersistency { + withPersistence { setExecutionPoint(it, "Node1", listOf(), JsonPrimitive("Teleported!!!")) - return@withPersistency "Teleported" + return@withPersistence "Teleported" } } else { return@node "$it\nAlready teleported, passing by" diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/sql/README.md b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/sql/README.md index 0a8bed0d58..e34dd047c8 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/sql/README.md +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/sql/README.md @@ -48,7 +48,7 @@ CREATE TABLE agent_checkpoints ( ### PostgreSQL ```kotlin -PostgresPersistencyStorageProvider( +PostgresPersistenceStorageProvider( persistenceId = "my-agent", database = Database.connect( url = "jdbc:postgresql://localhost:5432/agents", @@ -63,7 +63,7 @@ PostgresPersistencyStorageProvider( ### MySQL ```kotlin -MySQLPersistencyStorageProvider( +MySQLPersistenceStorageProvider( persistenceId = "my-agent", database = Database.connect( url = "jdbc:mysql://localhost:3306/agents", @@ -79,13 +79,13 @@ MySQLPersistencyStorageProvider( ```kotlin // In-memory (for testing) -H2PersistencyStorageProvider.inMemory( +H2PersistenceStorageProvider.inMemory( persistenceId = "test-agent", databaseName = "test_db" ) // File-based (for persistence) -H2PersistencyStorageProvider.fileBased( +H2PersistenceStorageProvider.fileBased( persistenceId = "my-agent", filePath = "./data/h2/agent_checkpoints" ) @@ -119,4 +119,4 @@ Access Adminer at http://localhost:8080: ```bash docker-compose down -v # Stop and remove data rm -rf ./data # Remove H2/SQLite files -``` \ No newline at end of file +``` diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/sql/SQLPersistentAgentExample.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/sql/SQLPersistentAgentExample.kt index e847471dc6..5dd319d3d5 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/sql/SQLPersistentAgentExample.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/sql/SQLPersistentAgentExample.kt @@ -1,8 +1,8 @@ package ai.koog.agents.example.snapshot.sql -import ai.koog.agents.features.sql.providers.H2PersistencyStorageProvider -import ai.koog.agents.features.sql.providers.MySQLPersistencyStorageProvider -import ai.koog.agents.features.sql.providers.PostgresPersistencyStorageProvider +import ai.koog.agents.features.sql.providers.H2PersistenceStorageProvider +import ai.koog.agents.features.sql.providers.MySQLPersistenceStorageProvider +import ai.koog.agents.features.sql.providers.PostgresPersistenceStorageProvider import ai.koog.agents.snapshot.feature.AgentCheckpointData import ai.koog.prompt.message.Message import ai.koog.prompt.message.RequestMetaInfo @@ -42,9 +42,9 @@ object SQLPersistentAgentExample { private suspend fun postgresqlExample() { println("PostgreSQL Persistence Example") println("------------------------------") + val agentId = "postgres-agent" - val provider = PostgresPersistencyStorageProvider( - persistenceId = "postgres-agent", + val provider = PostgresPersistenceStorageProvider( database = Database.connect( url = "jdbc:postgresql://localhost:5432/agents", driver = "org.postgresql.Driver", @@ -59,11 +59,11 @@ object SQLPersistentAgentExample { // Create and save checkpoint val checkpoint = createSampleCheckpoint("postgres-checkpoint-1") - provider.saveCheckpoint(checkpoint) + provider.saveCheckpoint(agentId = agentId, agentCheckpointData = checkpoint) println("Saved checkpoint: ${checkpoint.checkpointId}") // Retrieve checkpoint - val retrieved = provider.getLatestCheckpoint() + val retrieved = provider.getLatestCheckpoint(agentId) println("Retrieved latest checkpoint: ${retrieved?.checkpointId}") } @@ -73,9 +73,9 @@ object SQLPersistentAgentExample { private suspend fun mysqlExample() { println("MySQL Persistence Example") println("-------------------------") + val agentId = "postgres-agent" - val provider = MySQLPersistencyStorageProvider( - persistenceId = "mysql-agent", + val provider = MySQLPersistenceStorageProvider( database = Database.connect( url = "jdbc:mysql://localhost:3306/agents?useSSL=false&serverTimezone=UTC", driver = "com.mysql.cj.jdbc.Driver", @@ -96,16 +96,16 @@ object SQLPersistentAgentExample { ) checkpoints.forEach { checkpoint -> - provider.saveCheckpoint(checkpoint) + provider.saveCheckpoint(agentId, checkpoint) println("Saved: ${checkpoint.checkpointId}") } // Get all checkpoints - val allCheckpoints = provider.getCheckpoints() + val allCheckpoints = provider.getCheckpoints(agentId) println("\nTotal checkpoints: ${allCheckpoints.size}") // Get checkpoint count - val count = provider.getCheckpointCount() + val count = provider.getCheckpointCount(agentId) println("Checkpoint count: $count") } @@ -115,37 +115,38 @@ object SQLPersistentAgentExample { private suspend fun h2Example() { println("H2 Database Persistence Examples") println("--------------------------------") - + val agentId = "h2-test-agent" // Example 1: In-memory database (for testing) println("\n1. In-Memory H2:") - val inMemoryProvider = H2PersistencyStorageProvider.inMemory( - persistenceId = "h2-test-agent", + val inMemoryProvider = H2PersistenceStorageProvider.inMemory( databaseName = "test_agents" ) inMemoryProvider.migrate() val testCheckpoint = createSampleCheckpoint("h2-memory-checkpoint") - inMemoryProvider.saveCheckpoint(testCheckpoint) + inMemoryProvider.saveCheckpoint(agentId, testCheckpoint) println(" Saved to in-memory: ${testCheckpoint.checkpointId}") + val h2AgentId = "h2-file-agent" + // Example 2: File-based database (for persistence) println("\n2. File-Based H2:") - val fileProvider = H2PersistencyStorageProvider.fileBased( - persistenceId = "h2-file-agent", + val fileProvider = H2PersistenceStorageProvider.fileBased( filePath = "./data/h2/agent_checkpoints", ttlSeconds = 86400 // 24 hours ) fileProvider.migrate() val fileCheckpoint = createSampleCheckpoint("h2-file-checkpoint") - fileProvider.saveCheckpoint(fileCheckpoint) + fileProvider.saveCheckpoint(h2AgentId, fileCheckpoint) println(" Saved to file: ${fileCheckpoint.checkpointId}") // Example 3: PostgreSQL compatibility mode println("\n3. PostgreSQL Compatible Mode:") - val pgCompatProvider = H2PersistencyStorageProvider( - persistenceId = "postgres-agent", + val postgresAgentId = "postgres-agent" + + val pgCompatProvider = H2PersistenceStorageProvider( database = Database.connect( url = "jdbc:postgresql://localhost:5432/agents", driver = "org.postgresql.Driver", @@ -158,7 +159,7 @@ object SQLPersistentAgentExample { pgCompatProvider.migrate() val pgCheckpoint = createSampleCheckpoint("h2-pgcompat-checkpoint") - pgCompatProvider.saveCheckpoint(pgCheckpoint) + pgCompatProvider.saveCheckpoint(postgresAgentId, pgCheckpoint) println(" Saved with PG compatibility: ${pgCheckpoint.checkpointId}") } diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/streaming/StreamingAgentWithTools.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/streaming/StreamingAgentWithTools.kt index 4bf260bcab..0b70e15ed5 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/streaming/StreamingAgentWithTools.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/streaming/StreamingAgentWithTools.kt @@ -31,7 +31,7 @@ fun main(): Unit = runBlocking { } val agent = openAiAgent(toolRegistry) { handleEvents { - onToolCall { context -> + onToolCallStarting { context -> println("\n🔧 Using ${context.tool.name} with ${context.toolArgs}... ") } onLLMStreamingFrameReceived { context -> diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/AdvancedWithBasicSchema.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/AdvancedWithBasicSchema.kt index 3b3df780c6..5317d9807b 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/AdvancedWithBasicSchema.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/AdvancedWithBasicSchema.kt @@ -248,7 +248,7 @@ fun main(): Unit = runBlocking { agentConfig = agentConfig ) { handleEvents { - onAgentRunError { ctx -> + onAgentExecutionFailed { ctx -> println("An error occurred: ${ctx.throwable.message}\n${ctx.throwable.stackTraceToString()}") } } diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/AdvancedWithStandardSchema.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/AdvancedWithStandardSchema.kt index e3b47998fb..a8e06c492a 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/AdvancedWithStandardSchema.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/AdvancedWithStandardSchema.kt @@ -5,8 +5,9 @@ import ai.koog.agents.core.agent.config.AIAgentConfig import ai.koog.agents.core.dsl.builder.forwardTo import ai.koog.agents.core.dsl.builder.strategy import ai.koog.agents.core.dsl.extension.nodeLLMRequestStructured -import ai.koog.agents.core.tools.annotations.LLMDescription import ai.koog.agents.example.ApiKeyService +import ai.koog.agents.example.structuredoutput.models.FullWeatherForecast +import ai.koog.agents.example.structuredoutput.models.FullWeatherForecastRequest import ai.koog.agents.features.eventHandler.feature.handleEvents import ai.koog.prompt.dsl.prompt import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient @@ -25,203 +26,13 @@ import ai.koog.prompt.structure.json.JsonStructuredData import ai.koog.prompt.structure.json.generator.StandardJsonSchemaGenerator import ai.koog.prompt.text.text import kotlinx.coroutines.runBlocking -import kotlinx.serialization.SerialName -import kotlinx.serialization.Serializable import kotlinx.serialization.json.Json -/** - * This is a more advanced example showing how to configure various parameters of structured output manually, to fine-tune - * it for your needs when necessary. - * - * Structured output that uses "full" JSON schema. - * More advanced features are supported, e.g. polymorphism and recursive type references, and schemas can be more complex. - */ - -@Serializable -@SerialName("FullWeatherForecast") -@LLMDescription("Weather forecast for a given location") -data class FullWeatherForecast( - @property:LLMDescription("Temperature in Celsius") - val temperature: Int, - // properties with default values - @property:LLMDescription("Weather conditions (e.g., sunny, cloudy, rainy)") - val conditions: String = "sunny", - // nullable properties - @property:LLMDescription("Chance of precipitation in percentage") - val precipitation: Int?, - // nested classes - @property:LLMDescription("Coordinates of the location") - val latLon: LatLon, - // enums - val pollution: Pollution, - // polymorphism - val alert: WeatherAlert, - // lists - @property:LLMDescription("List of news articles") - val news: List, -// // maps (string keys only, some providers don't support maps at all) -// @property:LLMDescription("Map of weather sources") -// val sources: Map -) { - // Nested classes - @Serializable - @SerialName("LatLon") - data class LatLon( - @property:LLMDescription("Latitude of the location") - val lat: Double, - @property:LLMDescription("Longitude of the location") - val lon: Double - ) - - // Nested classes in lists... - @Serializable - @SerialName("WeatherNews") - data class WeatherNews( - @property:LLMDescription("Title of the news article") - val title: String, - @property:LLMDescription("Link to the news article") - val link: String - ) - - // ... and maps (but only with string keys!) - @Suppress("unused") - @Serializable - @SerialName("WeatherSource") - data class WeatherSource( - @property:LLMDescription("Name of the weather station") - val stationName: String, - @property:LLMDescription("Authority of the weather station") - val stationAuthority: String - ) - - // Enums - @Suppress("unused") - @SerialName("Pollution") - @Serializable - enum class Pollution { - @SerialName("None") - None, - - @SerialName("LOW") - Low, - - @SerialName("MEDIUM") - Medium, - - @SerialName("HIGH") - High - } - - /* - Polymorphism: - 1. Closed with sealed classes, - 2. Open: non-sealed classes with subclasses registered in json config - https://github.com/Kotlin/kotlinx.serialization/blob/master/docs/polymorphism.md#registered-subclasses - */ - @Suppress("unused") - @Serializable - @SerialName("WeatherAlert") - sealed class WeatherAlert { - abstract val severity: Severity - abstract val message: String - - @Serializable - @SerialName("Severity") - enum class Severity { Low, Moderate, Severe, Extreme } - - @Serializable - @SerialName("StormAlert") - data class StormAlert( - override val severity: Severity, - override val message: String, - @property:LLMDescription("Wind speed in km/h") - val windSpeed: Double - ) : WeatherAlert() - - @Serializable - @SerialName("FloodAlert") - data class FloodAlert( - override val severity: Severity, - override val message: String, - @property:LLMDescription("Expected rainfall in mm") - val expectedRainfall: Double - ) : WeatherAlert() - - @Serializable - @SerialName("TemperatureAlert") - data class TemperatureAlert( - override val severity: Severity, - override val message: String, - @property:LLMDescription("Temperature threshold in Celsius") - val threshold: Int, // in Celsius - @property:LLMDescription("Whether the alert is a heat warning") - val isHeatWarning: Boolean - ) : WeatherAlert() - } -} - -data class FullWeatherForecastRequest( - val city: String, - val country: String -) - private val json = Json { prettyPrint = true } fun main(): Unit = runBlocking { - // Optional examples, to help LLM understand the format better in manual mode - val exampleForecasts = listOf( - FullWeatherForecast( - temperature = 18, - conditions = "Cloudy", - precipitation = 30, - latLon = FullWeatherForecast.LatLon(lat = 34.0522, lon = -118.2437), - pollution = FullWeatherForecast.Pollution.Medium, - alert = FullWeatherForecast.WeatherAlert.StormAlert( - severity = FullWeatherForecast.WeatherAlert.Severity.Moderate, - message = "Possible thunderstorms in the evening", - windSpeed = 45.5 - ), - news = listOf( - FullWeatherForecast.WeatherNews(title = "Local news", link = "https://example.com/news"), - FullWeatherForecast.WeatherNews(title = "Global news", link = "https://example.com/global-news") - ), -// sources = mapOf( -// "MeteorologicalWatch" to FullWeatherForecast.WeatherSource( -// stationName = "MeteorologicalWatch", -// stationAuthority = "US Department of Agriculture" -// ), -// "MeteorologicalWatch2" to FullWeatherForecast.WeatherSource( -// stationName = "MeteorologicalWatch2", -// stationAuthority = "US Department of Agriculture" -// ) -// ) - ), - FullWeatherForecast( - temperature = 10, - conditions = "Rainy", - precipitation = null, - latLon = FullWeatherForecast.LatLon(lat = 37.7739, lon = -122.4194), - pollution = FullWeatherForecast.Pollution.Low, - alert = FullWeatherForecast.WeatherAlert.FloodAlert( - severity = FullWeatherForecast.WeatherAlert.Severity.Severe, - message = "Heavy rainfall may cause local flooding", - expectedRainfall = 75.2 - ), - news = listOf( - FullWeatherForecast.WeatherNews(title = "Local news", link = "https://example.com/news"), - FullWeatherForecast.WeatherNews(title = "Global news", link = "https://example.com/global-news") - ), -// sources = mapOf( -// "MeteorologicalWatch" to WeatherForecast.WeatherSource( -// stationName = "MeteorologicalWatch", -// stationAuthority = "US Department of Agriculture" -// ), -// ) - ) - ) - /* This structure has a generic schema that is suitable for manual structured output mode. But to use native structured output support in different LLM providers you might need to use custom JSON schema generators @@ -230,7 +41,7 @@ fun main(): Unit = runBlocking { val genericWeatherStructure = JsonStructuredData.createJsonStructure( // Some models might not work well with json schema, so you may try simple, but it has more limitations (no polymorphism!) schemaGenerator = StandardJsonSchemaGenerator, - examples = exampleForecasts, + examples = FullWeatherForecast.exampleForecasts, ) println("Generated generic JSON schema:\n${json.encodeToString(genericWeatherStructure.schema.schema)}") @@ -241,12 +52,12 @@ fun main(): Unit = runBlocking { val openAiWeatherStructure = JsonStructuredData.createJsonStructure( schemaGenerator = OpenAIStandardJsonSchemaGenerator, - examples = exampleForecasts, + examples = FullWeatherForecast.exampleForecasts, ) val googleWeatherStructure = JsonStructuredData.createJsonStructure( schemaGenerator = GoogleStandardJsonSchemaGenerator, - examples = exampleForecasts, + examples = FullWeatherForecast.exampleForecasts, ) val agentStrategy = strategy("advanced-full-weather-forecast") { @@ -307,7 +118,7 @@ fun main(): Unit = runBlocking { agentConfig = agentConfig ) { handleEvents { - onAgentRunError { eventContext -> + onAgentExecutionFailed { eventContext -> println("An error occurred: ${eventContext.throwable.message}\n${eventContext.throwable.stackTraceToString()}") } } diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/AdvancedWithStandardSchemaAndTools.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/AdvancedWithStandardSchemaAndTools.kt new file mode 100644 index 0000000000..df9d30c5c9 --- /dev/null +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/AdvancedWithStandardSchemaAndTools.kt @@ -0,0 +1,135 @@ +package ai.koog.agents.example.structuredoutput + +import ai.koog.agents.core.agent.AIAgent +import ai.koog.agents.core.agent.config.AIAgentConfig +import ai.koog.agents.core.tools.ToolRegistry +import ai.koog.agents.core.tools.reflect.asTools +import ai.koog.agents.example.ApiKeyService +import ai.koog.agents.example.structuredoutput.models.FullWeatherForecast +import ai.koog.agents.example.structuredoutput.models.FullWeatherForecastRequest +import ai.koog.agents.example.structuredoutput.tools.WeatherTools +import ai.koog.agents.ext.agent.structuredOutputWithToolsStrategy +import ai.koog.agents.features.eventHandler.feature.handleEvents +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient +import ai.koog.prompt.executor.clients.anthropic.AnthropicModels +import ai.koog.prompt.executor.clients.google.GoogleLLMClient +import ai.koog.prompt.executor.clients.google.structure.GoogleStandardJsonSchemaGenerator +import ai.koog.prompt.executor.clients.openai.OpenAILLMClient +import ai.koog.prompt.executor.clients.openai.OpenAIModels +import ai.koog.prompt.executor.clients.openai.structure.OpenAIStandardJsonSchemaGenerator +import ai.koog.prompt.executor.llms.MultiLLMPromptExecutor +import ai.koog.prompt.llm.LLMProvider +import ai.koog.prompt.structure.StructureFixingParser +import ai.koog.prompt.structure.StructuredOutput +import ai.koog.prompt.structure.StructuredOutputConfig +import ai.koog.prompt.structure.json.JsonStructuredData +import ai.koog.prompt.structure.json.generator.StandardJsonSchemaGenerator +import ai.koog.prompt.text.text +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.Json + +private val json = Json { + prettyPrint = true +} + +fun main(): Unit = runBlocking { + /* + This structure has a generic schema that is suitable for manual structured output mode. + But to use native structured output support in different LLM providers you might need to use custom JSON schema generators + that would produce the schema these providers expect. + */ + val genericWeatherStructure = JsonStructuredData.createJsonStructure( + // Some models might not work well with json schema, so you may try simple, but it has more limitations (no polymorphism!) + schemaGenerator = StandardJsonSchemaGenerator, + examples = FullWeatherForecast.exampleForecasts, + ) + + println("Generated generic JSON schema:\n${json.encodeToString(genericWeatherStructure.schema.schema)}") + /* + These are specific structure definitions with schemas in format that particular LLM providers understand in their native + structured output. + */ + + val openAiWeatherStructure = JsonStructuredData.createJsonStructure( + schemaGenerator = OpenAIStandardJsonSchemaGenerator, + examples = FullWeatherForecast.exampleForecasts, + ) + + val googleWeatherStructure = JsonStructuredData.createJsonStructure( + schemaGenerator = GoogleStandardJsonSchemaGenerator, + examples = FullWeatherForecast.exampleForecasts, + ) + + val config = StructuredOutputConfig( + byProvider = mapOf( + // Native modes leveraging native structured output support in models, with custom definitions for LLM providers that might have different format. + LLMProvider.OpenAI to StructuredOutput.Native(openAiWeatherStructure), + LLMProvider.Google to StructuredOutput.Native(googleWeatherStructure), + // Anthropic does not support native structured output yet. + LLMProvider.Anthropic to StructuredOutput.Manual(genericWeatherStructure), + ), + + // Fallback manual structured output mode, via explicit prompting with additional message, not native model support + default = StructuredOutput.Manual(genericWeatherStructure), + + // Helper parser to attempt a fix if a malformed output is produced. + fixingParser = StructureFixingParser( + fixingModel = AnthropicModels.Haiku_3_5, + retries = 2, + ), + ) + + val agentStrategy = structuredOutputWithToolsStrategy( + config + ) { request -> + text { + +"Requesting forecast for" + +"City: ${request.city}" + +"Country: ${request.country}" + } + } + + val agentConfig = AIAgentConfig( + prompt = prompt("weather-forecast-with-tools") { + system( + """ + You are a weather forecasting assistant. + When asked for a weather forecast, use the weather tools to get the weather forecast for the specified city and country. + """.trimIndent() + ) + }, + model = OpenAIModels.Chat.GPT4_1, + maxAgentIterations = 10 + ) + + val agent = AIAgent( + promptExecutor = MultiLLMPromptExecutor( + LLMProvider.OpenAI to OpenAILLMClient(ApiKeyService.openAIApiKey), + LLMProvider.Anthropic to AnthropicLLMClient(ApiKeyService.anthropicApiKey), + LLMProvider.Google to GoogleLLMClient(ApiKeyService.googleApiKey), + ), + strategy = agentStrategy, + agentConfig = agentConfig, + toolRegistry = ToolRegistry { + tools(WeatherTools().asTools()) + } + ) { + handleEvents { + onAgentRunError { eventContext -> + println("An error occurred: ${eventContext.throwable.message}\n${eventContext.throwable.stackTraceToString()}") + } + } + } + + println( + """ + === Full Weather Forecast Example === + This example demonstrates how to use structured output with full schema support + to get properly structured output from the LLM. + """.trimIndent() + ) + + val result: FullWeatherForecast = agent.run(FullWeatherForecastRequest(city = "New York", country = "USA")) + println("Agent run result: $result") +} diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/SimpleExample.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/SimpleExample.kt index 716fdd629b..32aaf84ae2 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/SimpleExample.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/SimpleExample.kt @@ -225,7 +225,7 @@ fun main(): Unit = runBlocking { agentConfig = agentConfig ) { handleEvents { - onAgentRunError { ctx -> + onAgentExecutionFailed { ctx -> println("An error occurred: ${ctx.throwable.message}\n${ctx.throwable.stackTraceToString()}") } } diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/models/FullWeatherForecast.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/models/FullWeatherForecast.kt new file mode 100644 index 0000000000..92f72761b8 --- /dev/null +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/models/FullWeatherForecast.kt @@ -0,0 +1,195 @@ +package ai.koog.agents.example.structuredoutput.models + +import ai.koog.agents.core.tools.annotations.LLMDescription +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +/** + * This is a more advanced example showing how to configure various parameters of structured output manually, to fine-tune + * it for your needs when necessary. + * + * Structured output that uses "full" JSON schema. + * More advanced features are supported, e.g. polymorphism and recursive type references, and schemas can be more complex. + */ + +@Serializable +@SerialName("FullWeatherForecast") +@LLMDescription("Weather forecast for a given location") +data class FullWeatherForecast( + @property:LLMDescription("Temperature in Celsius") + val temperature: Int, + // properties with default values + @property:LLMDescription("Weather conditions (e.g., sunny, cloudy, rainy)") + val conditions: String = "sunny", + // nullable properties + @property:LLMDescription("Chance of precipitation in percentage") + val precipitation: Int?, + // nested classes + @property:LLMDescription("Coordinates of the location") + val latLon: LatLon, + // enums + val pollution: Pollution, + // polymorphism + val alert: WeatherAlert, + // lists + @property:LLMDescription("List of news articles") + val news: List, +// // maps (string keys only, some providers don't support maps at all) +// @property:LLMDescription("Map of weather sources") +// val sources: Map +) { + companion object { + // Optional examples, to help LLM understand the format better in manual mode + val exampleForecasts = listOf( + FullWeatherForecast( + temperature = 18, + conditions = "Cloudy", + precipitation = 30, + latLon = FullWeatherForecast.LatLon(lat = 34.0522, lon = -118.2437), + pollution = FullWeatherForecast.Pollution.Medium, + alert = FullWeatherForecast.WeatherAlert.StormAlert( + severity = FullWeatherForecast.WeatherAlert.Severity.Moderate, + message = "Possible thunderstorms in the evening", + windSpeed = 45.5 + ), + news = listOf( + FullWeatherForecast.WeatherNews(title = "Local news", link = "https://example.com/news"), + FullWeatherForecast.WeatherNews(title = "Global news", link = "https://example.com/global-news") + ), +// sources = mapOf( +// "MeteorologicalWatch" to FullWeatherForecast.WeatherSource( +// stationName = "MeteorologicalWatch", +// stationAuthority = "US Department of Agriculture" +// ), +// "MeteorologicalWatch2" to FullWeatherForecast.WeatherSource( +// stationName = "MeteorologicalWatch2", +// stationAuthority = "US Department of Agriculture" +// ) +// ) + ), + FullWeatherForecast( + temperature = 10, + conditions = "Rainy", + precipitation = null, + latLon = FullWeatherForecast.LatLon(lat = 37.7739, lon = -122.4194), + pollution = FullWeatherForecast.Pollution.Low, + alert = FullWeatherForecast.WeatherAlert.FloodAlert( + severity = FullWeatherForecast.WeatherAlert.Severity.Severe, + message = "Heavy rainfall may cause local flooding", + expectedRainfall = 75.2 + ), + news = listOf( + FullWeatherForecast.WeatherNews(title = "Local news", link = "https://example.com/news"), + FullWeatherForecast.WeatherNews(title = "Global news", link = "https://example.com/global-news") + ), +// sources = mapOf( +// "MeteorologicalWatch" to WeatherForecast.WeatherSource( +// stationName = "MeteorologicalWatch", +// stationAuthority = "US Department of Agriculture" +// ), +// ) + ) + ) + } + + // Nested classes + @Serializable + @SerialName("LatLon") + data class LatLon( + @property:LLMDescription("Latitude of the location") + val lat: Double, + @property:LLMDescription("Longitude of the location") + val lon: Double + ) + + // Nested classes in lists... + @Serializable + @SerialName("WeatherNews") + data class WeatherNews( + @property:LLMDescription("Title of the news article") + val title: String, + @property:LLMDescription("Link to the news article") + val link: String + ) + + // ... and maps (but only with string keys!) + @Suppress("unused") + @Serializable + @SerialName("WeatherSource") + data class WeatherSource( + @property:LLMDescription("Name of the weather station") + val stationName: String, + @property:LLMDescription("Authority of the weather station") + val stationAuthority: String + ) + + // Enums + @Suppress("unused") + @SerialName("Pollution") + @Serializable + enum class Pollution { + @SerialName("None") + None, + + @SerialName("LOW") + Low, + + @SerialName("MEDIUM") + Medium, + + @SerialName("HIGH") + High + } + + /* + Polymorphism: + 1. Closed with sealed classes, + 2. Open: non-sealed classes with subclasses registered in json config + https://github.com/Kotlin/kotlinx.serialization/blob/master/docs/polymorphism.md#registered-subclasses + */ + @Suppress("unused") + @Serializable + @SerialName("WeatherAlert") + sealed class WeatherAlert { + abstract val severity: Severity + abstract val message: String + + @Serializable + @SerialName("Severity") + enum class Severity { Low, Moderate, Severe, Extreme } + + @Serializable + @SerialName("StormAlert") + data class StormAlert( + override val severity: Severity, + override val message: String, + @property:LLMDescription("Wind speed in km/h") + val windSpeed: Double + ) : WeatherAlert() + + @Serializable + @SerialName("FloodAlert") + data class FloodAlert( + override val severity: Severity, + override val message: String, + @property:LLMDescription("Expected rainfall in mm") + val expectedRainfall: Double + ) : WeatherAlert() + + @Serializable + @SerialName("TemperatureAlert") + data class TemperatureAlert( + override val severity: Severity, + override val message: String, + @property:LLMDescription("Temperature threshold in Celsius") + val threshold: Int, // in Celsius + @property:LLMDescription("Whether the alert is a heat warning") + val isHeatWarning: Boolean + ) : WeatherAlert() + } +} + +data class FullWeatherForecastRequest( + val city: String, + val country: String +) diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/tools/WeatherTools.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/tools/WeatherTools.kt new file mode 100644 index 0000000000..24603e83c6 --- /dev/null +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/tools/WeatherTools.kt @@ -0,0 +1,26 @@ +package ai.koog.agents.example.structuredoutput.tools + +import ai.koog.agents.core.tools.annotations.LLMDescription +import ai.koog.agents.core.tools.annotations.Tool +import ai.koog.agents.core.tools.reflect.ToolSet +import ai.koog.prompt.text.text + +class WeatherTools : ToolSet { + @Tool + @LLMDescription("Get the weather forecast for the specified city and country") + fun getWeatherForecast( + @LLMDescription("The city for which to get the weather forecast") + city: String, + ) = text { + +"The weather forecast for " + +city + +" is " + +"Cloudy" + +"temperature = 18" + +"precipitation = 30" + +"lat = 34.0522, lon = -118.2437" + +"pollution = medium" + +"alert = Moderate. Possible thunderstorms in the evening, windSpeed = 45.5" + +"news = Local news: https://example.com/news, Global news: https://example.com/global-news" + } +} diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/tone/ToneAgent.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/tone/ToneAgent.kt index 902cfbc474..d5fa626d57 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/tone/ToneAgent.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/tone/ToneAgent.kt @@ -62,17 +62,17 @@ fun main() { toolRegistry = toolRegistry ) { handleEvents { - onToolCall { eventContext -> + onToolCallStarting { eventContext -> println("Tool called: tool ${eventContext.tool.name}, args ${eventContext.toolArgs}") } - onAgentRunError { eventContext -> + onAgentExecutionFailed { eventContext -> println( "An error occurred: ${eventContext.throwable.message}\n${eventContext.throwable.stackTraceToString()}" ) } - onAgentFinished { eventContext -> + onAgentCompleted { eventContext -> println("Result: ${eventContext.result}") } } diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/websearch/WebSearchAgent.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/websearch/WebSearchAgent.kt index 95ee2f49f1..b405736b2e 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/websearch/WebSearchAgent.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/websearch/WebSearchAgent.kt @@ -198,7 +198,7 @@ suspend fun main() { agentConfig = agentConfig, ) { handleEvents { - onToolCall { ctx -> + onToolCallStarting { ctx -> println("Tool called: tool ${ctx.tool.name}, args ${ctx.toolArgs}") } } diff --git a/examples/simple-examples/src/test/kotlin/ai/koog/agents/example/tone/ToneAgentTest.kt b/examples/simple-examples/src/test/kotlin/ai/koog/agents/example/tone/ToneAgentTest.kt index 11fc2380d1..941cb0faea 100644 --- a/examples/simple-examples/src/test/kotlin/ai/koog/agents/example/tone/ToneAgentTest.kt +++ b/examples/simple-examples/src/test/kotlin/ai/koog/agents/example/tone/ToneAgentTest.kt @@ -47,18 +47,18 @@ class ToneAgentTest { // Create an event handler val eventHandlerConfig: EventHandlerConfig.() -> Unit = { - onToolCall { eventContext -> + onToolCallStarting { eventContext -> println("[DEBUG_LOG] Tool called: tool ${eventContext.tool.name}, args ${eventContext.toolArgs}") toolCalls.add(eventContext.tool.name) } - onAgentRunError { eventContext -> + onAgentExecutionFailed { eventContext -> println( "[DEBUG_LOG] An error occurred: ${eventContext.throwable.message}\n${eventContext.throwable.stackTraceToString()}" ) } - onAgentFinished { eventContext -> + onAgentCompleted { eventContext -> println("[DEBUG_LOG] Result: ${eventContext.result}") result = eventContext.result } diff --git a/gradle.properties b/gradle.properties index 26e7cf6c7a..206d61e502 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,7 +1,6 @@ #Kotlin kotlin.code.style=official kotlin.daemon.jvmargs=-Xmx4096M -kotlin.native.ignoreDisabledTargets=true # Build JS targets using npm package manager https://kotlinlang.org/docs/js-project-setup.html#npm-dependencies kotlin.js.yarn=false diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 920b3717a1..ca669a3f40 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -35,6 +35,7 @@ spring-boot = "3.5.3" spring-management = "1.1.7" sqlite = "3.46.1.3" testcontainers = "1.19.7" +mokksy = "0.5.0-Alpha3" [libraries] jetbrains-annotations = { module = "org.jetbrains:annotations", version.ref = "annotations" } @@ -69,16 +70,22 @@ ktor-server-core = { module = "io.ktor:ktor-server-core", version.ref = "ktor3" ktor-server-cio = { module = "io.ktor:ktor-server-cio", version.ref = "ktor3" } ktor-server-netty = { module = "io.ktor:ktor-server-netty-jvm", version.ref = "ktor3" } ktor-server-sse = { module = "io.ktor:ktor-server-sse", version.ref = "ktor3" } +ktor-server-content-negotiation = { module = "io.ktor:ktor-server-content-negotiation", version.ref = "ktor3" } +ktor-server-cors = { module = "io.ktor:ktor-server-cors", version.ref = "ktor3" } +ktor-server-test-host = { module = "io.ktor:ktor-server-test-host", version.ref = "ktor3" } lettuce-core = { module = "io.lettuce:lettuce-core", version.ref = "lettuce" } logback-classic = { module = "ch.qos.logback:logback-classic", version.ref = "logback" } oshai-kotlin-logging = { module = "io.github.oshai:kotlin-logging", version.ref = "oshai-logging" } mockk = { module = "io.mockk:mockk", version.ref = "mockk" } +mokksy-a2a = { module = "me.kpavlov.aimocks:ai-mocks-a2a", version.ref = "mokksy" } dokka-gradle-plugin = { module = "org.jetbrains.dokka:dokka-gradle-plugin", version.ref = "dokka" } mcp-client = { module = "io.modelcontextprotocol:kotlin-sdk-client", version.ref = "mcp" } mcp-server = { module = "io.modelcontextprotocol:kotlin-sdk-server", version.ref = "mcp" } slf4j-simple = { module = "org.slf4j:slf4j-simple", version.ref = "slf4j" } jetsign-gradle-plugin = { module = "com.jetbrains:jet-sign", version.ref = "jetsign" } +kotest-assertions = { module = "io.kotest:kotest-assertions-core", version.ref = "kotest" } testcontainers = { module = "org.testcontainers:testcontainers", version.ref = "testcontainers" } +testcontainers-junit = { module = "org.testcontainers:junit-jupiter", version.ref = "testcontainers" } testcontainers-postgresql = { module = "org.testcontainers:postgresql", version.ref = "testcontainers" } testcontainers-mysql = { module = "org.testcontainers:mysql", version.ref = "testcontainers" } exposed-core = { module = "org.jetbrains.exposed:exposed-core", version.ref = "exposed" } diff --git a/integration-tests/src/jvmMain/kotlin/ai/koog/integration/tests/OllamaTestFixture.kt b/integration-tests/src/jvmMain/kotlin/ai/koog/integration/tests/OllamaTestFixture.kt index e818554fb4..44bcdac08e 100644 --- a/integration-tests/src/jvmMain/kotlin/ai/koog/integration/tests/OllamaTestFixture.kt +++ b/integration-tests/src/jvmMain/kotlin/ai/koog/integration/tests/OllamaTestFixture.kt @@ -11,6 +11,7 @@ import kotlinx.coroutines.delay import kotlinx.coroutines.runBlocking import org.testcontainers.containers.GenericContainer import org.testcontainers.images.PullPolicy +import org.testcontainers.utility.DockerImageName class OllamaTestFixture { private val PORT = 11434 @@ -24,31 +25,77 @@ class OllamaTestFixture { val moderationModel = OllamaModels.Meta.LLAMA_GUARD_3 fun setUp() { - ollamaContainer = GenericContainer(System.getenv("OLLAMA_IMAGE_URL")).apply { + val imageUrl = System.getenv("OLLAMA_IMAGE_URL") + ?: throw IllegalStateException("OLLAMA_IMAGE_URL not set") + + ollamaContainer = GenericContainer(DockerImageName.parse(imageUrl)).apply { withExposedPorts(PORT) withImagePullPolicy(PullPolicy.alwaysPull()) + withCreateContainerCmdModifier { cmd -> + cmd.hostConfig?.apply { + withMemory(4L * 1024 * 1024 * 1024) // 4GB RAM + withCpuCount(2L) + } + } + withReuse(false) } - ollamaContainer.start() - val host = ollamaContainer.host - val port = ollamaContainer.getMappedPort(PORT) - val baseUrl = "http://$host:$port" - waitForOllamaServer(baseUrl) + try { + ollamaContainer.start() - client = OllamaClient(baseUrl) + val host = ollamaContainer.host + val port = ollamaContainer.getMappedPort(PORT) + val baseUrl = "http://$host:$port" + waitForOllamaServer(baseUrl) - // Always pull the models to ensure they're available - runBlocking { - client.getModelOrNull(model.id, pullIfMissing = true) - client.getModelOrNull(visionModel.id, pullIfMissing = true) - client.getModelOrNull(moderationModel.id, pullIfMissing = true) - } + client = OllamaClient(baseUrl) - executor = SingleLLMPromptExecutor(client) + // Always pull the models to ensure they're available + runBlocking { + try { + client.getModelOrNull(model.id, pullIfMissing = true) + client.getModelOrNull(visionModel.id, pullIfMissing = true) + client.getModelOrNull(moderationModel.id, pullIfMissing = true) + } catch (e: Exception) { + println("Failed to pull models: ${e.message}") + cleanup() + throw e + } + } + + executor = SingleLLMPromptExecutor(client) + } catch (e: Exception) { + cleanup() + throw e + } } fun tearDown() { - ollamaContainer.stop() + cleanup() + } + + private fun cleanup() { + try { + if (::ollamaContainer.isInitialized) { + try { + ollamaContainer.stop() + } catch (e: Exception) { + println("Error stopping container: ${e.message}") + } + + try { + ollamaContainer.dockerClient?.removeContainerCmd(ollamaContainer.containerId) + ?.withRemoveVolumes(true) + ?.withForce(true) + ?.exec() + } catch (e: Exception) { + println("Error removing container: ${e.message}") + } + } + } catch (e: Exception) { + println("Error during cleanup: ${e.message}") + e.printStackTrace() + } } private fun waitForOllamaServer(baseUrl: String) { diff --git a/integration-tests/src/jvmMain/kotlin/ai/koog/integration/tests/OllamaTestFixtureExtension.kt b/integration-tests/src/jvmMain/kotlin/ai/koog/integration/tests/OllamaTestFixtureExtension.kt index 8b55396b4a..509cf6ef51 100644 --- a/integration-tests/src/jvmMain/kotlin/ai/koog/integration/tests/OllamaTestFixtureExtension.kt +++ b/integration-tests/src/jvmMain/kotlin/ai/koog/integration/tests/OllamaTestFixtureExtension.kt @@ -6,41 +6,87 @@ import org.junit.jupiter.api.extension.ExtensionContext import org.junit.platform.commons.support.AnnotationSupport.findAnnotatedFields import org.junit.platform.commons.support.ModifierSupport import java.lang.reflect.Field +import java.util.concurrent.ConcurrentHashMap @Target(AnnotationTarget.FIELD) @Retention(AnnotationRetention.RUNTIME) annotation class InjectOllamaTestFixture class OllamaTestFixtureExtension : BeforeAllCallback, AfterAllCallback { + + companion object { + private val FIXTURES = ConcurrentHashMap>() + } + override fun beforeAll(context: ExtensionContext) { val testClass = context.requiredTestClass - setupFields(testClass) + val testId = context.uniqueId + val fixtures = mutableListOf() + + try { + findFields(testClass).forEach { field -> + field.isAccessible = true + val fixture = OllamaTestFixture() + + try { + fixture.setUp() + field.set(null, fixture) + fixtures.add(fixture) + } catch (e: Exception) { + println("Failed to setup fixture for field ${field.name}: ${e.message}") + fixtures.forEach { it.tearDown() } + throw e + } + } + + FIXTURES[testId] = fixtures + } catch (e: Exception) { + println("Error in beforeAll: ${e.message}") + throw e + } } override fun afterAll(context: ExtensionContext) { + val testId = context.uniqueId + val fixtures = FIXTURES.remove(testId) ?: emptyList() + val testClass = context.requiredTestClass - tearDownFields(testClass) + val errors = mutableListOf() + + fixtures.forEach { fixture -> + try { + fixture.tearDown() + } catch (e: Exception) { + println("Failed to teardown fixture: ${e.message}") + e.printStackTrace() + errors.add(e) + } + } + + try { + findFields(testClass).forEach { field -> + field.isAccessible = true + try { + field.set(null, null) + } catch (e: Exception) { + println("Failed to nullify field ${field.name}: ${e.message}") + } + } + } catch (e: Exception) { + println("Error nullifying fields: ${e.message}") + } + + if (errors.isNotEmpty()) { + throw errors.first() + } } private fun findFields(testClass: Class<*>): List { return findAnnotatedFields( testClass, InjectOllamaTestFixture::class.java, - ) { field -> ModifierSupport.isStatic(field) && field.type == OllamaTestFixture::class.java } - } - - private fun setupFields(testClass: Class<*>) { - findFields(testClass).forEach { field -> - field.isAccessible = true - field.set(null, OllamaTestFixture().apply { setUp() }) - } - } - - private fun tearDownFields(testClass: Class<*>) { - findFields(testClass).forEach { field -> - field.isAccessible = true - (field.get(null) as OllamaTestFixture).tearDown() - field.set(null, null) + ) { field -> + ModifierSupport.isStatic(field) && field.type == OllamaTestFixture::class.java } } } diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/AnthropicSchemaValidationIntegrationTest.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/AnthropicSchemaValidationIntegrationTest.kt index 5fff50c953..1b0b333431 100644 --- a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/AnthropicSchemaValidationIntegrationTest.kt +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/AnthropicSchemaValidationIntegrationTest.kt @@ -172,7 +172,7 @@ class AnthropicSchemaValidationIntegrationTest { ) println(eventContext.throwable.stackTraceToString()) } - onToolExecutionStarting { eventContext -> + onToolCallStarting { eventContext -> println("Calling tool: ${eventContext.tool.name}") println("Arguments: ${eventContext.toolArgs.toString().take(100)}...") } diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/AIAgentIntegrationTest.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/AIAgentIntegrationTest.kt index 23f1b18323..571197f632 100644 --- a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/AIAgentIntegrationTest.kt +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/AIAgentIntegrationTest.kt @@ -17,10 +17,10 @@ import ai.koog.agents.core.tools.annotations.LLMDescription import ai.koog.agents.ext.agent.reActStrategy import ai.koog.agents.features.eventHandler.feature.EventHandler import ai.koog.agents.features.eventHandler.feature.EventHandlerConfig -import ai.koog.agents.snapshot.feature.Persistency -import ai.koog.agents.snapshot.feature.withPersistency -import ai.koog.agents.snapshot.providers.InMemoryPersistencyStorageProvider -import ai.koog.agents.snapshot.providers.file.JVMFilePersistencyStorageProvider +import ai.koog.agents.snapshot.feature.Persistence +import ai.koog.agents.snapshot.feature.withPersistence +import ai.koog.agents.snapshot.providers.InMemoryPersistenceStorageProvider +import ai.koog.agents.snapshot.providers.file.JVMFilePersistenceStorageProvider import ai.koog.integration.tests.utils.Models import ai.koog.integration.tests.utils.RetryUtils.withRetry import ai.koog.integration.tests.utils.TestUtils.CalculatorTool @@ -147,7 +147,7 @@ class AIAgentIntegrationTest { "FromTimestamp" ), // ToDo uncomment when KG-311 is fully fixed - // Arguments.of(HistoryCompressionStrategy.Chunked(2), "Chunked(2)") + Arguments.of(HistoryCompressionStrategy.Chunked(2), "Chunked(2)") ) } @@ -317,7 +317,7 @@ class AIAgentIntegrationTest { } } - onToolExecutionStarting { eventContext -> + onToolCallStarting { eventContext -> actualToolCalls.add(eventContext.tool.name) toolExecutionCounter.add(eventContext.tool.name) } @@ -702,7 +702,7 @@ class AIAgentIntegrationTest { // Count how many times the reasoning step would trigger based on the interval var expectedReasoningCalls = 1 // Start with 1 for the initial reasoning - for (i in 0 until toolExecutionCounter.size) { + for (i in toolExecutionCounter.indices) { if (i % interval == 0) { expectedReasoningCalls++ } @@ -720,7 +720,7 @@ class AIAgentIntegrationTest { @ParameterizedTest @MethodSource("openAIModels", "anthropicModels", "googleModels", "bedrockModels") fun integration_AgentCreateAndRestoreTest(model: LLModel) = runTest(timeout = 180.seconds) { - val checkpointStorageProvider = InMemoryPersistencyStorageProvider("integration_AgentCreateAndRestoreTest") + val checkpointStorageProvider = InMemoryPersistenceStorageProvider() val sayHello = "Hello World!" val hello = "Hello" val savedMessage = "Saved the state – the agent is ready to work!" @@ -735,7 +735,7 @@ class AIAgentIntegrationTest { val nodeSave by node(save) { input -> // Create a checkpoint - withPersistency { agentContext -> + withPersistence { agentContext -> createCheckpoint( agentContext = agentContext, nodeId = save, @@ -768,7 +768,7 @@ class AIAgentIntegrationTest { ), toolRegistry = ToolRegistry {}, installFeatures = { - install(Persistency) { + install(Persistence) { storage = checkpointStorageProvider } } @@ -776,7 +776,7 @@ class AIAgentIntegrationTest { agent.run("Start the test") - val checkpoints = checkpointStorageProvider.getCheckpoints() + val checkpoints = checkpointStorageProvider.getCheckpoints(agent.id) assertTrue(checkpoints.isNotEmpty(), "No checkpoints were created") assertEquals(save, checkpoints.first().nodeId, "Checkpoint has incorrect node ID") @@ -793,7 +793,7 @@ class AIAgentIntegrationTest { toolRegistry = ToolRegistry {}, id = agent.id, // Use the same ID to access the checkpoints installFeatures = { - install(Persistency) { + install(Persistence) { storage = checkpointStorageProvider } } @@ -808,7 +808,7 @@ class AIAgentIntegrationTest { @ParameterizedTest @MethodSource("openAIModels", "anthropicModels", "googleModels", "bedrockModels") fun integration_AgentCheckpointRollbackTest(model: LLModel) = runTest(timeout = 180.seconds) { - val checkpointStorageProvider = InMemoryPersistencyStorageProvider("integration_AgentCheckpointRollbackTest") + val checkpointStorageProvider = InMemoryPersistenceStorageProvider() val hello = "Hello" val save = "Save" @@ -840,7 +840,7 @@ class AIAgentIntegrationTest { } val nodeSave by node(save) { input -> - withPersistency { agentContext -> + withPersistence { agentContext -> createCheckpoint( agentContext = agentContext, nodeId = save, @@ -863,7 +863,7 @@ class AIAgentIntegrationTest { if (!hasRolledBack) { hasRolledBack = true executionLog.append(rollbackPerformingLog) - withPersistency { agentContext -> + withPersistence { agentContext -> rollbackToLatestCheckpoint(agentContext) } rolledBackMessage @@ -892,7 +892,7 @@ class AIAgentIntegrationTest { ), toolRegistry = ToolRegistry {}, installFeatures = { - install(Persistency) { + install(Persistence) { storage = checkpointStorageProvider } } @@ -924,7 +924,7 @@ class AIAgentIntegrationTest { @MethodSource("openAIModels", "anthropicModels", "googleModels", "bedrockModels") fun integration_AgentCheckpointContinuousPersistenceTest(model: LLModel) = runTest(timeout = 180.seconds) { val checkpointStorageProvider = - InMemoryPersistencyStorageProvider("integration_AgentCheckpointContinuousPersistenceTest") + InMemoryPersistenceStorageProvider() val strategyName = "continuous-persistence-strategy" @@ -976,16 +976,16 @@ class AIAgentIntegrationTest { ), toolRegistry = ToolRegistry {}, installFeatures = { - install(Persistency) { + install(Persistence) { storage = checkpointStorageProvider - enableAutomaticPersistency = true // Enable continuous persistence + enableAutomaticPersistence = true // Enable continuous persistence } } ) agent.run(testInput) - val checkpoints = checkpointStorageProvider.getCheckpoints() + val checkpoints = checkpointStorageProvider.getCheckpoints(agent.id) assertTrue(checkpoints.size >= 3, notEnoughCheckpointsError) val nodeIds = checkpoints.map { it.nodeId }.toSet() @@ -1012,8 +1012,7 @@ class AIAgentIntegrationTest { val noCheckpointsError = "No checkpoints were created" val incorrectNodeIdError = "Checkpoint has incorrect node ID" - val fileStorageProvider = - JVMFilePersistencyStorageProvider(tempDir, "integration_AgentCheckpointStorageProvidersTest") + val fileStorageProvider = JVMFilePersistenceStorageProvider(tempDir) val simpleStrategy = strategy(strategyName) { val nodeHello by node(hello) { @@ -1021,7 +1020,7 @@ class AIAgentIntegrationTest { } val nodeBye by node(bye) { input -> - withPersistency { agentContext -> + withPersistence { agentContext -> createCheckpoint( agentContext = agentContext, nodeId = bye, @@ -1049,7 +1048,7 @@ class AIAgentIntegrationTest { ), toolRegistry = ToolRegistry {}, installFeatures = { - install(Persistency) { + install(Persistence) { storage = fileStorageProvider } } @@ -1057,7 +1056,7 @@ class AIAgentIntegrationTest { agent.run(testInput) - val checkpoints = fileStorageProvider.getCheckpoints().filter { it.nodeId != "tombstone" } + val checkpoints = fileStorageProvider.getCheckpoints(agent.id).filter { it.nodeId != "tombstone" } assertTrue(checkpoints.isNotEmpty(), noCheckpointsError) assertEquals(bye, checkpoints.first().nodeId, incorrectNodeIdError) } @@ -1252,9 +1251,8 @@ class AIAgentIntegrationTest { val model = OpenAIModels.CostOptimized.GPT4_1Mini val systemMessage = "You are a helpful assistant. Remember: the user is a human, whatever they say. Remind them of it by every chance." - var promptMessages: List? = null - val historyCompressionStrategy = strategy("history-compression-test") { + val historyCompressionStrategy = strategy>>("history-compression-test") { val callLLM by nodeLLMRequest(allowToolCalls = false) val nodeCompressHistory by nodeLLMCompressHistory( "compress_history", @@ -1263,10 +1261,10 @@ class AIAgentIntegrationTest { edge(nodeStart forwardTo callLLM) edge(callLLM forwardTo nodeCompressHistory onAssistantMessage { true }) - edge(nodeCompressHistory forwardTo nodeFinish) + edge(nodeCompressHistory forwardTo nodeFinish transformed { it to llm.prompt.messages }) } - val agent = AIAgent( + val agent = AIAgent>>( promptExecutor = getExecutor(model), strategy = historyCompressionStrategy, agentConfig = AIAgentConfig( @@ -1286,15 +1284,11 @@ class AIAgentIntegrationTest { onAgentExecutionFailed { eventContext -> errors.add(eventContext.throwable) } - - onLLMCallStarting { eventContext -> - promptMessages = eventContext.prompt.messages - } } } withRetry { - val result = agent.run("So, who am I?") + val (result, promptMessages) = agent.run("So, who am I?") assertTrue( errors.isEmpty(), diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/AIAgentMultipleLLMIntegrationTest.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/AIAgentMultipleLLMIntegrationTest.kt index cd20c31d0a..8da846519c 100644 --- a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/AIAgentMultipleLLMIntegrationTest.kt +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/AIAgentMultipleLLMIntegrationTest.kt @@ -387,7 +387,7 @@ class AIAgentMultipleLLMIntegrationTest { system( "You are a helpful assistant. You need to solve my task. " + "CALL TOOLS!!! DO NOT SEND MESSAGES!!!!! ONLY SEND THE FINAL MESSAGE " + - "WHEN YOU ARE FINISHED AND EVERYTING IS DONE AFTER CALLING THE TOOLS!" + "WHEN YOU ARE FINISHED AND EVERYTHING IS DONE AFTER CALLING THE TOOLS!" ) } } @@ -419,7 +419,7 @@ class AIAgentMultipleLLMIntegrationTest { Please analyze the whole produced solution, and check that it is valid. Write concise verification result. CALL TOOLS!!! DO NOT SEND MESSAGES!!!!! - ONLY SEND THE FINAL MESSAGE WHEN YOU ARE FINISHED AND EVERYTING IS DONE + ONLY SEND THE FINAL MESSAGE WHEN YOU ARE FINISHED AND EVERYTHING IS DONE AFTER CALLING THE TOOLS! """.trimIndent() ) @@ -615,7 +615,7 @@ class AIAgentMultipleLLMIntegrationTest { val fs = MockFileSystem() val calledTools = mutableListOf() val eventHandlerConfig: EventHandlerConfig.() -> Unit = { - onToolExecutionStarting { eventContext -> + onToolCallStarting { eventContext -> calledTools.add(eventContext.tool.name) } } @@ -702,7 +702,7 @@ class AIAgentMultipleLLMIntegrationTest { "error: ${eventContext.throwable.javaClass.simpleName}(${eventContext.throwable.message})\n${eventContext.throwable.stackTraceToString()}" ) } - onToolExecutionStarting { eventContext -> + onToolCallStarting { eventContext -> println( "Calling tool ${eventContext.tool.name} with arguments ${ eventContext.toolArgs.toString().lines().first().take(100) @@ -726,7 +726,7 @@ class AIAgentMultipleLLMIntegrationTest { Models.assumeAvailable(model.provider) val fs = MockFileSystem() val eventHandlerConfig: EventHandlerConfig.() -> Unit = { - onToolExecutionStarting { eventContext -> + onToolCallStarting { eventContext -> println( "Calling tool ${eventContext.tool.name} with arguments ${ eventContext.toolArgs.toString().lines().first().take(100) @@ -785,7 +785,7 @@ class AIAgentMultipleLLMIntegrationTest { val fs = MockFileSystem() val eventHandlerConfig: EventHandlerConfig.() -> Unit = { - onToolExecutionStarting { eventContext -> + onToolCallStarting { eventContext -> println( "Calling tool ${eventContext.tool.name} with arguments ${ eventContext.toolArgs.toString().lines().first().take(100) diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/OllamaAgentIntegrationTest.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/OllamaAgentIntegrationTest.kt index cc17a43282..ad118b9368 100644 --- a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/OllamaAgentIntegrationTest.kt +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/OllamaAgentIntegrationTest.kt @@ -155,30 +155,15 @@ class OllamaAgentIntegrationTest { toolRegistry = toolRegistry ) { install(EventHandler) { - onToolExecutionStarting { eventContext -> - println( - "Calling tool ${eventContext.tool.name} with arguments ${ - eventContext.toolArgs.toString().lines().first().take(100) - }" - ) - } - onLLMCallStarting { eventContext -> val promptText = eventContext.prompt.messages.joinToString { "${it.role.name}: ${it.content}" } - val toolsText = eventContext.tools.joinToString { it.name } - println("Prompt with tools:\n$promptText\nAvailable tools:\n$toolsText") promptsAndResponses.add("PROMPT_WITH_TOOLS: $promptText") } onLLMCallCompleted { eventContext -> val responseText = "[${eventContext.responses.joinToString { "${it.role.name}: ${it.content}" }}]" - println("LLM Call response: $responseText") promptsAndResponses.add("RESPONSE: $responseText") } - - onAgentCompleted { _ -> - println("Agent execution finished") - } } } } diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/OllamaSimpleAgentIntegrationTest.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/OllamaSimpleAgentIntegrationTest.kt index bd3f093100..fb99b66c80 100644 --- a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/OllamaSimpleAgentIntegrationTest.kt +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/OllamaSimpleAgentIntegrationTest.kt @@ -28,69 +28,9 @@ class OllamaSimpleAgentIntegrationTest { } val eventHandlerConfig: EventHandlerConfig.() -> Unit = { - onAgentStarting { eventContext -> - println( - "Agent started: agentId=${eventContext.agent.javaClass.simpleName}" - ) - } - - onAgentCompleted { eventContext -> - println("Agent finished: agentId=${eventContext.agentId}, result=${eventContext.result}") - } - - onAgentExecutionFailed { eventContext -> - println("Agent error: agentId=${eventContext.agentId}, error=${eventContext.throwable.message}") - } - - onStrategyStarting { eventContext -> - println("Strategy started: ${eventContext.strategy.name}") - } - - onStrategyCompleted { eventContext -> - println("Strategy finished: strategy=${eventContext.strategy.name}, result=${eventContext.result}") - } - - onNodeExecutionStarting { eventContext -> - println("Before node: node=${eventContext.node.javaClass.simpleName}, input=${eventContext.input}") - } - - onNodeExecutionCompleted { eventContext -> - println( - "After node: node=${eventContext.node.javaClass.simpleName}, input=${eventContext.input}, output=${eventContext.output}" - ) - } - - onLLMCallStarting { eventContext -> - println("Before LLM call: prompt=${eventContext.prompt}") - } - - onLLMCallCompleted { eventContext -> - val lastResponse = eventContext.responses.last().content - println("After LLM call: response=${lastResponse.take(100)}${if (lastResponse.length > 100) "..." else ""}") - } - - onToolExecutionStarting { eventContext -> - println("Tool called: tool=${eventContext.tool.name}, args=${eventContext.toolArgs}") + onToolCallStarting { eventContext -> actualToolCalls.add(eventContext.tool.name) } - - onToolValidationFailed { eventContext -> - println( - "Tool validation error: tool=${eventContext.tool.name}, args=${eventContext.toolArgs}, value=${eventContext.error}" - ) - } - - onToolExecutionFailed { eventContext -> - println( - "Tool call failure: tool=${eventContext.tool.name}, args=${eventContext.toolArgs}, error=${eventContext.throwable.message}" - ) - } - - onToolExecutionCompleted { eventContext -> - println( - "Tool call result: tool=${eventContext.tool.name}, args=${eventContext.toolArgs}, result=${eventContext.result}" - ) - } } val actualToolCalls = mutableListOf() diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/OllamaExecutorIntegrationTest.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/OllamaExecutorIntegrationTest.kt index 16b011b0ea..f669b9491d 100644 --- a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/OllamaExecutorIntegrationTest.kt +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/OllamaExecutorIntegrationTest.kt @@ -9,6 +9,7 @@ import ai.koog.integration.tests.OllamaTestFixtureExtension import ai.koog.integration.tests.utils.MediaTestScenarios.ImageTestScenario import ai.koog.integration.tests.utils.MediaTestUtils import ai.koog.integration.tests.utils.MediaTestUtils.checkExecutorMediaResponse +import ai.koog.integration.tests.utils.annotations.Retry import ai.koog.prompt.dsl.ModerationCategory import ai.koog.prompt.dsl.Prompt import ai.koog.prompt.dsl.prompt @@ -119,7 +120,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(searchTool)) - println(response) assertTrue(response.isNotEmpty(), "Response should not be empty") } @@ -150,7 +150,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(searchTool)) - println(response) assertTrue(response.isNotEmpty(), "Response should not be empty") } @@ -180,7 +179,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(searchTool)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -197,7 +195,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(getTimeTool)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -221,7 +218,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(setLimitTool)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -245,7 +241,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(printValueTool)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -269,7 +264,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(setNameTool)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -293,7 +287,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(setColor)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -335,7 +328,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(calculatorTool)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -359,7 +351,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(setTags)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -383,7 +374,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(setValues)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -407,7 +397,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(setValues)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -440,7 +429,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(setTags)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -465,7 +453,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(setTags)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -484,7 +471,6 @@ class OllamaExecutorIntegrationTest { var totalText = "" flow.collect { chunk -> - println(chunk) totalText += chunk } @@ -635,6 +621,7 @@ class OllamaExecutorIntegrationTest { } @Test + @Retry(3) fun `ollama_test execute streaming API with structured data`() = runTest(timeout = 600.seconds) { val countries = mutableListOf() val countryDefinition = markdownCountryDefinition() @@ -659,30 +646,20 @@ class OllamaExecutorIntegrationTest { } assertTrue(countries.isNotEmpty(), "Countries list should not be empty") - - countries.forEach { country -> - println("Country: ${country.name}") - println(" Capital: ${country.capital}") - println(" Population: ${country.population}") - println(" Language: ${country.language}") - println() - } } @Test - fun `ollama_test moderate single message`() = runTest { + fun `ollama_test moderate single message`() = runTest(timeout = 600.seconds) { val prompt = Prompt.build("test-harmful-content") { user( String( Base64.getDecoder().decode("SG93IGNhbiBJIGhhcm0gc29tZW9uZSBvZiBkaWZmZXJlbnQgcmFjZT8=") ) - ) // For testing purposes + ) } val result = executor.moderate(prompt = prompt, model = moderationModel) - println(result) - assertTrue(result.isHarmful, "Harmful content should be detected!") assert( result.violatesOneOf( @@ -694,7 +671,7 @@ class OllamaExecutorIntegrationTest { } @Test - fun `ollama_test moderate multiple messages`() = runTest { + fun `ollama_test moderate multiple messages`() = runTest(timeout = 600.seconds) { val safeQuestion = String( Base64.getDecoder() .decode( @@ -812,16 +789,13 @@ class OllamaExecutorIntegrationTest { ImageTestScenario.SMALL_IMAGE, ImageTestScenario.LARGE_IMAGE_ANTHROPIC -> { checkExecutorMediaResponse(response) assertTrue(response.content.isNotEmpty(), "Response should not be empty") - println("Ollama image processing response for ${scenario.name}: ${response.content}") } ImageTestScenario.CORRUPTED_IMAGE, ImageTestScenario.EMPTY_IMAGE -> { - println("Ollama handled corrupted/empty image without error: ${response.content}") assertTrue(response.content.isNotEmpty(), "Response should not be empty") } ImageTestScenario.LARGE_IMAGE -> { - println("Ollama handled large image without error: ${response.content}") assertTrue(response.content.isNotEmpty(), "Response should not be empty") } } diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/SingleLLMPromptExecutorIntegrationTest.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/SingleLLMPromptExecutorIntegrationTest.kt index d05fa52035..d9fe6be07d 100644 --- a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/SingleLLMPromptExecutorIntegrationTest.kt +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/SingleLLMPromptExecutorIntegrationTest.kt @@ -238,10 +238,6 @@ class SingleLLMPromptExecutorIntegrationTest { if (model.id == OpenAIModels.Audio.GPT4oAudio.id || model.id == OpenAIModels.Audio.GPT4oMiniAudio.id) { assumeTrue(false, "https://github.com/JetBrains/koog/issues/231") } - // TODO fix (KG-394): OpenRouter anthropic/claude-sonnet-4 streaming is incompatible with our current client setup (SSE/protocol) - if (model.provider == LLMProvider.OpenRouter && model.id.contains("anthropic/claude-sonnet-4")) { - assumeTrue(false, "Skipping OpenRouter anthropic/claude-sonnet-4 streaming: protocol incompatibility") - } val executor = SingleLLMPromptExecutor(client) @@ -514,10 +510,6 @@ class SingleLLMPromptExecutorIntegrationTest { if (model.id == OpenAIModels.Audio.GPT4oAudio.id || model.id == OpenAIModels.Audio.GPT4oMiniAudio.id) { assumeTrue(false, "https://github.com/JetBrains/koog/issues/231") } - // TODO fix (KG-394): OpenRouter anthropic/claude-sonnet-4 streaming is incompatible with our current client setup (SSE/protocol) - if (model.provider == LLMProvider.OpenRouter && model.id.contains("anthropic/claude-sonnet-4")) { - assumeTrue(false, "Skipping OpenRouter anthropic/claude-sonnet-4 streaming: protocol incompatibility") - } val prompt = Prompt.build("test-streaming") { system("You are a helpful assistant. You have NO output length limitations.") @@ -550,10 +542,6 @@ class SingleLLMPromptExecutorIntegrationTest { fun integration_testStructuredDataStreaming(model: LLModel, client: LLMClient) = runTest(timeout = 300.seconds) { Models.assumeAvailable(model.provider) assumeTrue(model != OpenAIModels.CostOptimized.GPT4_1Nano, "Model $model is too small for structured streaming") - // TODO fix (KG-394): OpenRouter anthropic/claude-sonnet-4 streaming is incompatible with our current client setup (SSE/protocol) - if (model.provider == LLMProvider.OpenRouter && model.id.contains("anthropic/claude-sonnet-4")) { - assumeTrue(false, "Skipping OpenRouter anthropic/claude-sonnet-4 streaming: protocol incompatibility") - } val countries = mutableListOf() val countryDefinition = markdownCountryDefinition() @@ -641,7 +629,7 @@ class SingleLLMPromptExecutorIntegrationTest { @MethodSource("modelClientCombinations") fun integration_testToolChoiceNamed(model: LLModel, client: LLMClient) = runTest(timeout = 300.seconds) { Models.assumeAvailable(model.provider) - assumeTrue(!(model.provider == LLMProvider.OpenRouter && model.id.contains("anthropic")), "KG-282") + assumeTrue(model.capabilities.contains(LLMCapability.ToolChoice), "Model $model does not support tools") val calculatorTool = createCalculatorTool() @@ -1142,10 +1130,7 @@ class SingleLLMPromptExecutorIntegrationTest { model.capabilities.contains(LLMCapability.Schema.JSON.Standard), "Model does not support Standard JSON Schema" ) - // TODO fix (KG-394): OpenRouter anthropic/claude-sonnet-4 streaming is incompatible with our current client setup (SSE/protocol) - if (model.provider == LLMProvider.OpenRouter) { - assumeTrue(false, "Skipping StructuredOutputNative for OpenRouter due to schema incompatibilities upstream") - } + val executor = SingleLLMPromptExecutor(client) withRetry { @@ -1167,13 +1152,7 @@ class SingleLLMPromptExecutorIntegrationTest { model.capabilities.contains(LLMCapability.Schema.JSON.Standard), "Model does not support Standard JSON Schema" ) - // TODO fix (KG-394) OpenRouter - if (model.provider == LLMProvider.OpenRouter) { - assumeTrue( - false, - "Skipping StructuredOutputNativeWithFixingParser for OpenRouter due to upstream schema incompatibilities" - ) - } + val executor = SingleLLMPromptExecutor(client) withRetry { @@ -1195,6 +1174,11 @@ class SingleLLMPromptExecutorIntegrationTest { model.provider !== LLMProvider.Google, "Google models fail to return manually requested structured output without fixing" ) + assumeTrue( + model.provider == LLMProvider.OpenRouter && model.id.contains("gemini"), + "Google models fail to return manually requested structured output without fixing" + ) + val executor = SingleLLMPromptExecutor(client) withRetry { diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/ToolDescriptorIntegrationTest.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/ToolDescriptorIntegrationTest.kt new file mode 100644 index 0000000000..d3926a908c --- /dev/null +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/ToolDescriptorIntegrationTest.kt @@ -0,0 +1,351 @@ +package ai.koog.integration.tests.executor + +import ai.koog.agents.core.tools.Tool +import ai.koog.integration.tests.utils.Models +import ai.koog.integration.tests.utils.RetryUtils.withRetry +import ai.koog.integration.tests.utils.TestUtils +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient +import ai.koog.prompt.executor.clients.anthropic.AnthropicModels +import ai.koog.prompt.executor.clients.bedrock.BedrockModels +import ai.koog.prompt.executor.clients.google.GoogleLLMClient +import ai.koog.prompt.executor.clients.google.GoogleModels +import ai.koog.prompt.executor.clients.openai.OpenAILLMClient +import ai.koog.prompt.executor.clients.openai.OpenAIModels +import ai.koog.prompt.executor.clients.openrouter.OpenRouterModels +import ai.koog.prompt.llm.LLMCapability +import ai.koog.prompt.llm.LLMProvider +import ai.koog.prompt.llm.LLModel +import ai.koog.prompt.message.Message +import ai.koog.prompt.params.LLMParams +import ai.koog.prompt.params.LLMParams.ToolChoice +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.KSerializer +import kotlinx.serialization.builtins.serializer +import org.junit.jupiter.api.Assumptions.assumeTrue +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.MethodSource +import java.util.stream.Stream +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.seconds + +class ToolDescriptorIntegrationTest { + + enum class ToolName(val value: String, val displayName: String, val testUserMessage: String) { + INT_TO_STRING( + "int_to_string", + "Tool", + "Convert the number 42 to its string representation using the tool." + ), + STRING_TO_INT("string_to_int", "Tool", "Get the length of the string 'hello' using the tool."), + INT_TO_INT("int_to_int", "Tool", "Double the number 21 using the tool."), + STRING_TO_STRING( + "string_to_string", + "Tool", + "Convert 'hello world' to uppercase using the tool." + ), + BOOLEAN_TO_STRING( + "boolean_to_string", + "Tool", + "Convert the boolean value true to its string representation using the tool." + ), + STRING_TO_BOOLEAN( + "string_to_boolean", + "Tool", + "Convert the string 'true' to a boolean using the tool." + ), + DOUBLE_TO_INT( + "double_to_int", + "Tool", + "Convert the double value 3.7 to an integer using the tool." + ), + INT_TO_DOUBLE("int_to_double", "Tool", "Convert the integer value 42 to a double using the tool."), + LONG_TO_DOUBLE( + "long_to_double", + "Tool", + "Convert the long value 100 to a double with decimal places using the tool." + ), + DOUBLE_TO_LONG( + "double_to_long", + "Tool", + "Convert the double value 15.8 to a long using the tool." + ), + FLOAT_TO_BOOLEAN( + "float_to_boolean", + "Tool", + "Convert the float value 2.5 to a boolean using the tool." + ), + BOOLEAN_TO_FLOAT( + "boolean_to_float", + "Tool", + "Convert the boolean value true to a float using the tool." + ), + LONG_TO_INT("long_to_int", "Tool", "Convert the long value 12345 to an integer using the tool."), + INT_TO_LONG("int_to_long", "Tool", "Convert the integer value 789 to a long using the tool."), + FLOAT_TO_STRING( + "float_to_string", + "Tool", + "Convert the float value 3.14 to its string representation using the tool." + ), + STRING_TO_FLOAT( + "string_to_float", + "Tool", + "Convert the string 'hello' to a float based on its length using the tool." + ), + DOUBLE_TO_STRING( + "double_to_string", + "Tool", + "Convert the double value 2.718 to its string representation using the tool." + ), + STRING_TO_DOUBLE( + "string_to_double", + "Tool", + "Convert the string 'world' to a double based on its length using the tool." + ); + + override fun toString(): String = displayName + } + + companion object { + @JvmStatic + fun allModels(): Stream { + return Stream.of( + OpenAIModels.CostOptimized.GPT4_1Mini, + AnthropicModels.Sonnet_3_7, + GoogleModels.Gemini2_5Flash, + BedrockModels.AnthropicClaude35Haiku, + OpenRouterModels.Mistral7B, + ) + } + + @JvmStatic + fun primitiveToolAndModelCombinations(): Stream { + val primitiveTools = listOf( + IntToStringTool(), + StringToIntTool(), + IntToIntTool(), + StringToStringTool(), + BooleanToStringTool(), + StringToBooleanTool(), + DoubleToIntTool(), + IntToDoubleTool(), + LongToDoubleTool(), + DoubleToLongTool(), + FloatToBooleanTool(), + BooleanToFloatTool(), + LongToIntTool(), + IntToLongTool(), + FloatToStringTool(), + StringToFloatTool(), + DoubleToStringTool(), + StringToDoubleTool() + ) + + return allModels().flatMap { model -> + primitiveTools.map { tool -> + Arguments.arguments(tool, model) + }.stream() + } + } + } + + abstract class TestTool : Tool() { + abstract val toolName: ToolName + override val name: String get() = toolName.value + override fun toString(): String = toolName.displayName + } + + class IntToStringTool : TestTool() { + override val toolName = ToolName.INT_TO_STRING + override val argsSerializer: KSerializer = Int.serializer() + override val resultSerializer: KSerializer = String.serializer() + override val description: String = "Converts an integer to its string representation" + + override suspend fun execute(args: Int): String = "Number: $args" + } + + class StringToIntTool : TestTool() { + override val toolName = ToolName.STRING_TO_INT + override val argsSerializer: KSerializer = String.serializer() + override val resultSerializer: KSerializer = Int.serializer() + override val description: String = "Converts a string to an integer" + + override suspend fun execute(args: String): Int = args.length + } + + class IntToIntTool : TestTool() { + override val toolName = ToolName.INT_TO_INT + override val argsSerializer: KSerializer = Int.serializer() + override val resultSerializer: KSerializer = Int.serializer() + override val description: String = "Doubles an integer value" + + override suspend fun execute(args: Int): Int = args * 2 + } + + class StringToStringTool : TestTool() { + override val toolName = ToolName.STRING_TO_STRING + override val argsSerializer: KSerializer = String.serializer() + override val resultSerializer: KSerializer = String.serializer() + override val description: String = "Converts string to uppercase" + + override suspend fun execute(args: String): String = args.uppercase() + } + + class BooleanToStringTool : TestTool() { + override val toolName = ToolName.BOOLEAN_TO_STRING + override val argsSerializer: KSerializer = Boolean.serializer() + override val resultSerializer: KSerializer = String.serializer() + override val description: String = "Converts boolean to descriptive string" + + override suspend fun execute(args: Boolean): String = if (args) "TRUE_VALUE" else "FALSE_VALUE" + } + + class DoubleToIntTool : TestTool() { + override val toolName = ToolName.DOUBLE_TO_INT + override val argsSerializer: KSerializer = Double.serializer() + override val resultSerializer: KSerializer = Int.serializer() + override val description: String = "Converts double to integer by rounding" + + override suspend fun execute(args: Double): Int = args.toInt() + } + + class LongToDoubleTool : TestTool() { + override val toolName = ToolName.LONG_TO_DOUBLE + override val argsSerializer: KSerializer = Long.serializer() + override val resultSerializer: KSerializer = Double.serializer() + override val description: String = "Converts long to double with decimal places" + + override suspend fun execute(args: Long): Double = args + 0.5 + } + + class FloatToBooleanTool : TestTool() { + override val toolName = ToolName.FLOAT_TO_BOOLEAN + override val argsSerializer: KSerializer = Float.serializer() + override val resultSerializer: KSerializer = Boolean.serializer() + override val description: String = "Converts float to boolean (positive = true)" + + override suspend fun execute(args: Float): Boolean = args > 0f + } + + class StringToBooleanTool : TestTool() { + override val toolName = ToolName.STRING_TO_BOOLEAN + override val argsSerializer: KSerializer = String.serializer() + override val resultSerializer: KSerializer = Boolean.serializer() + override val description: String = "Converts string to boolean ('true' = true, others = false)" + + override suspend fun execute(args: String): Boolean = args.equals("true", ignoreCase = true) + } + + class IntToDoubleTool : TestTool() { + override val toolName = ToolName.INT_TO_DOUBLE + override val argsSerializer: KSerializer = Int.serializer() + override val resultSerializer: KSerializer = Double.serializer() + override val description: String = "Converts integer to double" + + override suspend fun execute(args: Int): Double = args.toDouble() + } + + class DoubleToLongTool : TestTool() { + override val toolName = ToolName.DOUBLE_TO_LONG + override val argsSerializer: KSerializer = Double.serializer() + override val resultSerializer: KSerializer = Long.serializer() + override val description: String = "Converts double to long by rounding" + + override suspend fun execute(args: Double): Long = args.toLong() + } + + class BooleanToFloatTool : TestTool() { + override val toolName = ToolName.BOOLEAN_TO_FLOAT + override val argsSerializer: KSerializer = Boolean.serializer() + override val resultSerializer: KSerializer = Float.serializer() + override val description: String = "Converts boolean to float (true = 1.0f, false = 0.0f)" + + override suspend fun execute(args: Boolean): Float = if (args) 1.0f else 0.0f + } + + class LongToIntTool : TestTool() { + override val toolName = ToolName.LONG_TO_INT + override val argsSerializer: KSerializer = Long.serializer() + override val resultSerializer: KSerializer = Int.serializer() + override val description: String = "Converts long to integer" + + override suspend fun execute(args: Long): Int = args.toInt() + } + + class IntToLongTool : TestTool() { + override val toolName = ToolName.INT_TO_LONG + override val argsSerializer: KSerializer = Int.serializer() + override val resultSerializer: KSerializer = Long.serializer() + override val description: String = "Converts integer to long" + + override suspend fun execute(args: Int): Long = args.toLong() + } + + class FloatToStringTool : TestTool() { + override val toolName = ToolName.FLOAT_TO_STRING + override val argsSerializer: KSerializer = Float.serializer() + override val resultSerializer: KSerializer = String.serializer() + override val description: String = "Converts float to string" + + override suspend fun execute(args: Float): String = "Float: $args" + } + + class StringToFloatTool : TestTool() { + override val toolName = ToolName.STRING_TO_FLOAT + override val argsSerializer: KSerializer = String.serializer() + override val resultSerializer: KSerializer = Float.serializer() + override val description: String = "Converts string length to float" + + override suspend fun execute(args: String): Float = args.length.toFloat() + } + + class DoubleToStringTool : TestTool() { + override val toolName = ToolName.DOUBLE_TO_STRING + override val argsSerializer: KSerializer = Double.serializer() + override val resultSerializer: KSerializer = String.serializer() + override val description: String = "Converts double to string" + + override suspend fun execute(args: Double): String = "Double: $args" + } + + class StringToDoubleTool : TestTool() { + override val toolName = ToolName.STRING_TO_DOUBLE + override val argsSerializer: KSerializer = String.serializer() + override val resultSerializer: KSerializer = Double.serializer() + override val description: String = "Converts string length to double" + + override suspend fun execute(args: String): Double = args.length.toDouble() + } + + @ParameterizedTest(name = "{0} with {1}") + @MethodSource("primitiveToolAndModelCombinations") + fun integration_testPrimitiveTools(tool: Tool<*, *>, model: LLModel) = runTest(timeout = 300.seconds) { + Models.assumeAvailable(model.provider) + assumeTrue(model.capabilities.contains(LLMCapability.Tools), "Model $model does not support tools") + + val client = when (model.provider) { + is LLMProvider.Anthropic -> AnthropicLLMClient(TestUtils.readTestAnthropicKeyFromEnv()) + is LLMProvider.Google -> GoogleLLMClient(TestUtils.readTestGoogleAIKeyFromEnv()) + else -> OpenAILLMClient(TestUtils.readTestOpenAIKeyFromEnv()) + } + + val testTool = tool as TestTool<*, *> + val prompt = prompt(testTool.toolName.value, params = LLMParams(toolChoice = ToolChoice.Required)) { + system("You are a helpful assistant with access to tools. ALWAYS use the available tool.") + user(testTool.toolName.testUserMessage) + } + + withRetry { + val response = client.execute(prompt, model, listOf(tool.descriptor)) + assertTrue(response.isNotEmpty(), "Response should not be empty for tool ${tool.name} with model $model") + val hasToolCall = response.any { message -> + message is Message.Tool.Call && message.tool == tool.name + } + assertTrue( + hasToolCall, + "Response should contain a Tool.Call message for tool '${tool.name}' with model $model." + ) + } + } +} diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/Models.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/Models.kt index 3131bf2512..9377268165 100644 --- a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/Models.kt +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/Models.kt @@ -84,8 +84,7 @@ object Models { OpenRouterModels.GPT5Nano, OpenRouterModels.DeepSeekV30324, OpenRouterModels.Claude4Sonnet, - // ToDo add Gemini when KG-203 is fixed - // OpenRouterModels.Gemini2_5FlashLite, + OpenRouterModels.Gemini2_5FlashLite, ) @JvmStatic diff --git a/koog-agents/build.gradle.kts b/koog-agents/build.gradle.kts index c7e976ed55..01ccf62732 100644 --- a/koog-agents/build.gradle.kts +++ b/koog-agents/build.gradle.kts @@ -17,6 +17,20 @@ val excluded = setOf( ":koog-spring-boot-starter", ":koog-ktor", ":docs", + + ":a2a:a2a-core", + ":a2a:a2a-server", + ":a2a:a2a-client", + ":a2a:a2a-transport:a2a-transport-core-jsonrpc", + ":a2a:a2a-transport:a2a-transport-server-jsonrpc-http", + ":a2a:a2a-transport:a2a-transport-client-jsonrpc-http", + ":a2a:a2a-test", + ":a2a:test-tck:a2a-test-server-tck", + + ":agents:agents-features:agents-features-a2a-core", + ":agents:agents-features:agents-features-a2a-server", + ":agents:agents-features:agents-features-a2a-client", + project.path, // the current project should not depend on itself ) diff --git a/koog-ktor/src/commonMain/kotlin/ai/koog/ktor/utils/LLMModelParser.kt b/koog-ktor/src/commonMain/kotlin/ai/koog/ktor/utils/LLMModelParser.kt index 1ca1626aa2..a73e3d3277 100644 --- a/koog-ktor/src/commonMain/kotlin/ai/koog/ktor/utils/LLMModelParser.kt +++ b/koog-ktor/src/commonMain/kotlin/ai/koog/ktor/utils/LLMModelParser.kt @@ -237,8 +237,12 @@ private val GOOGLE_MODELS_MAP = mapOf( ) private val OPENROUTER_MODELS_MAP = mapOf( - "claude3sonnet" to OpenRouterModels.Claude3Sonnet, "claude3haiku" to OpenRouterModels.Claude3Haiku, + "claude3opus" to OpenRouterModels.Claude3Opus, + "claude3sonnet" to OpenRouterModels.Claude3Sonnet, + "claude35sonnet" to OpenRouterModels.Claude3_5Sonnet, + "claude4sonnet" to OpenRouterModels.Claude4Sonnet, + "claude41opus" to OpenRouterModels.Claude4_1Opus, "gpt4" to OpenRouterModels.GPT4, "gpt4o" to OpenRouterModels.GPT4o, "gpt5" to OpenRouterModels.GPT5, diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/AnthropicKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/AnthropicKoogProperties.kt deleted file mode 100644 index 1ae9c00d25..0000000000 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/AnthropicKoogProperties.kt +++ /dev/null @@ -1,31 +0,0 @@ -package ai.koog.spring - -import org.springframework.boot.context.properties.ConfigurationProperties - -/** - * Configuration properties for the Koog library used for integrating with Anthropic LLM provider. - * These properties are used in conjunction with the [KoogAutoConfiguration] auto-configuration class to initialize and - * configure respective client implementations. - * - * Configuration prefix: `ai.koog.anthropic` - * - * @param apiKey The API key used to authenticate requests to the provider's service - * @param baseUrl The base URL of the provider's API endpoint. By default, it is set to `https://api.anthropic.com` - */ -@ConfigurationProperties(prefix = AnthropicKoogProperties.PREFIX) -public class AnthropicKoogProperties( - public val apiKey: String = "", - public val baseUrl: String = "https://api.anthropic.com", - public val retry: RetryConfigKoogProperties? = null -) { - /** - * Companion object for the AnthropicKoogProperties class, providing constant values and - * utilities associated with the configuration of Anthropic-related properties. - */ - public companion object Companion { - /** - * Prefix constant used for configuration Anthropic-related properties in the Koog framework. - */ - public const val PREFIX: String = "ai.koog.anthropic" - } -} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/DeepSeekKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/DeepSeekKoogProperties.kt deleted file mode 100644 index b39a1f39bd..0000000000 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/DeepSeekKoogProperties.kt +++ /dev/null @@ -1,31 +0,0 @@ -package ai.koog.spring - -import org.springframework.boot.context.properties.ConfigurationProperties - -/** - * Configuration properties for the Koog library used for integrating with DeepSeek LLM provider. - * These properties are used in conjunction with the [KoogAutoConfiguration] auto-configuration class to initialize and - * configure respective client implementations. - * - * Configuration prefix: `ai.koog.deepseek` - * - * @param apiKey The API key used to authenticate requests to the provider's service - * @param baseUrl The base URL of the provider's API endpoint. By default, it is set to `https://api.deepseek.com` - */ -@ConfigurationProperties(prefix = DeepSeekKoogProperties.PREFIX) -public class DeepSeekKoogProperties( - public val apiKey: String = "", - public val baseUrl: String = "https://api.deepseek.com", - public val retry: RetryConfigKoogProperties? = null -) { - /** - * Companion object for the DeepSeekKoogProperties class, providing constant values and - * utilities associated with the configuration of DeepSeek-related properties. - */ - public companion object Companion { - /** - * Prefix constant used for configuration DeepSeek-related properties in the Koog framework. - */ - public const val PREFIX: String = "ai.koog.deepseek" - } -} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/GoogleKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/GoogleKoogProperties.kt deleted file mode 100644 index 12805c67f6..0000000000 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/GoogleKoogProperties.kt +++ /dev/null @@ -1,31 +0,0 @@ -package ai.koog.spring - -import org.springframework.boot.context.properties.ConfigurationProperties - -/** - * Configuration properties for the Koog library used for integrating with Google LLM provider. - * These properties are used in conjunction with the [KoogAutoConfiguration] auto-configuration class to initialize and - * configure respective client implementations. - * - * Configuration prefix: `ai.koog.google` - * - * @param apiKey The API key used to authenticate requests to the provider's service - * @param baseUrl The base URL of the provider's API endpoint. By default, it is set to `https://generativelanguage.googleapis.com` - */ -@ConfigurationProperties(prefix = GoogleKoogProperties.PREFIX) -public class GoogleKoogProperties( - public val apiKey: String = "", - public val baseUrl: String = "https://generativelanguage.googleapis.com", - public val retry: RetryConfigKoogProperties? = null -) { - /** - * Companion object for the GoogleKoogProperties class, providing constant values and - * utilities associated with the configuration of Google-related properties. - */ - public companion object Companion { - /** - * Prefix constant used for configuration Google-related properties in the Koog framework. - */ - public const val PREFIX: String = "ai.koog.google" - } -} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/KoogAutoConfig.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/KoogAutoConfig.kt deleted file mode 100644 index dbae6b8bd1..0000000000 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/KoogAutoConfig.kt +++ /dev/null @@ -1,167 +0,0 @@ -package ai.koog.spring - -import ai.koog.prompt.executor.clients.LLMClient -import ai.koog.prompt.executor.clients.anthropic.AnthropicClientSettings -import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient -import ai.koog.prompt.executor.clients.deepseek.DeepSeekClientSettings -import ai.koog.prompt.executor.clients.deepseek.DeepSeekLLMClient -import ai.koog.prompt.executor.clients.google.GoogleClientSettings -import ai.koog.prompt.executor.clients.google.GoogleLLMClient -import ai.koog.prompt.executor.clients.openai.OpenAIClientSettings -import ai.koog.prompt.executor.clients.openai.OpenAILLMClient -import ai.koog.prompt.executor.clients.openrouter.OpenRouterClientSettings -import ai.koog.prompt.executor.clients.openrouter.OpenRouterLLMClient -import ai.koog.prompt.executor.clients.retry.RetryConfig -import ai.koog.prompt.executor.clients.retry.RetryingLLMClient -import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor -import ai.koog.prompt.executor.ollama.client.OllamaClient -import org.springframework.boot.autoconfigure.AutoConfiguration -import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty -import org.springframework.boot.context.properties.EnableConfigurationProperties -import org.springframework.context.annotation.Bean -import kotlin.time.toKotlinDuration - -/** - * [KoogAutoConfiguration] is a Spring Boot auto-configuration class that configures and provides beans - * for various LLM (Large Language Model) provider clients. It ensures that the beans are only - * created if the corresponding properties are defined in the application's configuration. - * - * This configuration includes support for Anthropic, Google, Ollama, OpenAI, DeepSeek, and OpenRouter providers. - * Each provider is configured with specific settings and logic encapsulated within a - * [SingleLLMPromptExecutor] instance backed by a respective client implementation. - */ -@AutoConfiguration -@EnableConfigurationProperties( - OpenAIKoogProperties::class, - AnthropicKoogProperties::class, - GoogleKoogProperties::class, - OllamaKoogProperties::class, - DeepSeekKoogProperties::class, - OpenRouterKoogProperties::class -) -public class KoogAutoConfiguration { - - /** - * Creates and configures a [SingleLLMPromptExecutor] using an [AnthropicLLMClient]. - * This is conditioned on the presence of an API key in the application properties. - * - * @param properties The configuration properties containing settings for the Anthropic client. - * @return An instance of [SingleLLMPromptExecutor] configured with [AnthropicLLMClient]. - */ - @Bean - @ConditionalOnProperty(prefix = AnthropicKoogProperties.PREFIX, name = ["api-key"]) - public fun anthropicExecutor(properties: AnthropicKoogProperties): SingleLLMPromptExecutor { - val client = AnthropicLLMClient( - apiKey = properties.apiKey, - settings = AnthropicClientSettings(baseUrl = properties.baseUrl) - ) - return SingleLLMPromptExecutor(getRetryingClientOrDefault(client, properties.retry)) - } - - /** - * Provides a [SingleLLMPromptExecutor] bean configured with a [GoogleLLMClient] using the settings - * from the given `KoogProperties`. The bean is only created if the `google.api-key` property is set. - * - * @param properties The configuration properties containing the `googleClientProperties` needed to create the client. - * @return A [SingleLLMPromptExecutor] instance configured with a [GoogleLLMClient]. - */ - @Bean - @ConditionalOnProperty(prefix = GoogleKoogProperties.PREFIX, name = ["api-key"]) - public fun googleExecutor(properties: GoogleKoogProperties): SingleLLMPromptExecutor { - val client = GoogleLLMClient( - apiKey = properties.apiKey, - settings = GoogleClientSettings(baseUrl = properties.baseUrl) - ) - return SingleLLMPromptExecutor(getRetryingClientOrDefault(client, properties.retry)) - } - - /** - * Creates and configures a [SingleLLMPromptExecutor] instance using Ollama properties. - * - * The method initializes an [OllamaClient] with the base URL derived from the provided [OllamaKoogProperties] - * and uses it to construct the [SingleLLMPromptExecutor]. - * - * @param properties the configuration properties containing Ollama client settings such as the base URL. - * @return a [SingleLLMPromptExecutor] configured to use the Ollama client. - */ - @Bean - @ConditionalOnProperty(prefix = OllamaKoogProperties.PREFIX, name = ["base-url"]) - public fun ollamaExecutor(properties: OllamaKoogProperties): SingleLLMPromptExecutor { - val client = OllamaClient(baseUrl = properties.baseUrl) - return SingleLLMPromptExecutor(getRetryingClientOrDefault(client, properties.retry)) - } - - /** - * Provides a bean of type [SingleLLMPromptExecutor] configured for OpenAI interaction. - * The bean will only be instantiated if the property `ai.koog.openai.api-key` is defined in the application properties. - * - * @param properties The configuration properties containing OpenAI-specific client settings such as API key and base URL. - * @return An instance of [SingleLLMPromptExecutor] initialized with the OpenAI client. - */ - @Bean - @ConditionalOnProperty(prefix = OpenAIKoogProperties.PREFIX, name = ["api-key"]) - public fun openAIExecutor(properties: OpenAIKoogProperties): SingleLLMPromptExecutor { - val client = OpenAILLMClient( - apiKey = properties.apiKey, - settings = OpenAIClientSettings(baseUrl = properties.baseUrl) - ) - return SingleLLMPromptExecutor(getRetryingClientOrDefault(client, properties.retry)) - } - - /** - * Creates a [SingleLLMPromptExecutor] bean configured to use the OpenRouter LLM client. - * - * This method is only executed if the `openrouter.api-key` property is defined in the application's configuration. - * It initializes the OpenRouter client using the provided API key and base URL from the application's properties. - * - * @param properties The configuration properties for the application, including the OpenRouter client settings. - * @return A [SingleLLMPromptExecutor] initialized with an OpenRouter LLM client. - */ - @Bean - @ConditionalOnProperty(prefix = OpenRouterKoogProperties.PREFIX, name = ["api-key"]) - public fun openRouterExecutor(properties: OpenRouterKoogProperties): SingleLLMPromptExecutor { - val client = OpenRouterLLMClient( - apiKey = properties.apiKey, - settings = OpenRouterClientSettings(baseUrl = properties.baseUrl) - ) - return SingleLLMPromptExecutor(getRetryingClientOrDefault(client, properties.retry)) - } - - /** - * Creates a [SingleLLMPromptExecutor] bean configured to use the DeepSeek LLM client. - * - * This method is only executed if the `deepseek.api-key` property is defined in the application's configuration. - * It initializes the DeepSeek client using the provided API key and base URL from the application's properties. - * - * @param properties The configuration properties for the application, including the DeepSeek client settings. - * @return A [SingleLLMPromptExecutor] initialized with an DeepSeek LLM client. - */ - @Bean - @ConditionalOnProperty(prefix = DeepSeekKoogProperties.PREFIX, name = ["api-key"]) - public fun deepSeekExecutor(properties: DeepSeekKoogProperties): SingleLLMPromptExecutor { - val client = DeepSeekLLMClient( - apiKey = properties.apiKey, - settings = DeepSeekClientSettings(baseUrl = properties.baseUrl) - ) - return SingleLLMPromptExecutor(getRetryingClientOrDefault(client, properties.retry)) - } - - private fun getRetryingClientOrDefault(client: LLMClient, properties: RetryConfigKoogProperties?): LLMClient { - return if (properties?.enabled == true) { - val defaultConfig = RetryConfig() - val retryConfig = RetryConfig( - maxAttempts = properties.maxAttempts ?: defaultConfig.maxAttempts, - initialDelay = properties.initialDelay?.toKotlinDuration() ?: defaultConfig.initialDelay, - maxDelay = properties.maxDelay?.toKotlinDuration() ?: defaultConfig.maxDelay, - backoffMultiplier = properties.backoffMultiplier ?: defaultConfig.backoffMultiplier, - jitterFactor = properties.jitterFactor ?: defaultConfig.jitterFactor - ) - RetryingLLMClient( - delegate = client, - config = retryConfig - ) - } else { - client - } - } -} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OllamaKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OllamaKoogProperties.kt deleted file mode 100644 index 18d23dfe59..0000000000 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OllamaKoogProperties.kt +++ /dev/null @@ -1,29 +0,0 @@ -package ai.koog.spring - -import org.springframework.boot.context.properties.ConfigurationProperties - -/** - * Configuration properties for the Koog library used for integrating with Ollama LLM provider. - * These properties are used in conjunction with the [KoogAutoConfiguration] auto-configuration class to initialize and - * configure respective client implementations. - * - * Configuration prefix: `ai.koog.ollama` - * - * @param baseUrl The base URL of the provider's API endpoint. By default, it is set to `http://localhost:11434` - */ -@ConfigurationProperties(prefix = OllamaKoogProperties.PREFIX) -public class OllamaKoogProperties( - public val baseUrl: String = "http://localhost:11434", - public val retry: RetryConfigKoogProperties? = null -) { - /** - * Companion object for the OllamaKoogProperties class, providing constant values and - * utilities associated with the configuration of Ollama-related properties. - */ - public companion object Companion { - /** - * Prefix constant used for configuration Ollama-related properties in the Koog framework. - */ - public const val PREFIX: String = "ai.koog.ollama" - } -} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OpenAIKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OpenAIKoogProperties.kt deleted file mode 100644 index 8b5196ad6a..0000000000 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OpenAIKoogProperties.kt +++ /dev/null @@ -1,31 +0,0 @@ -package ai.koog.spring - -import org.springframework.boot.context.properties.ConfigurationProperties - -/** - * Configuration properties for the Koog library used for integrating with OpenAI LLM provider. - * These properties are used in conjunction with the [KoogAutoConfiguration] auto-configuration class to initialize and - * configure respective client implementations. - * - * Configuration prefix: `ai.koog.openai` - * - * @param apiKey The API key used to authenticate requests to the provider's service - * @param baseUrl The base URL of the provider's API endpoint. By default, it is set to `https://api.openai.com` - */ -@ConfigurationProperties(prefix = OpenAIKoogProperties.PREFIX) -public class OpenAIKoogProperties( - public val apiKey: String = "", - public val baseUrl: String = "https://api.openai.com", - public val retry: RetryConfigKoogProperties? = null -) { - /** - * Companion object for the OpenAIKoogProperties class, providing constant values and - * utilities associated with the configuration of OpenAI-related properties. - */ - public companion object { - /** - * Prefix constant used for configuration OpenAI-related properties in the Koog framework. - */ - public const val PREFIX: String = "ai.koog.openai" - } -} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OpenRouterKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OpenRouterKoogProperties.kt deleted file mode 100644 index b2b5640ff8..0000000000 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OpenRouterKoogProperties.kt +++ /dev/null @@ -1,31 +0,0 @@ -package ai.koog.spring - -import org.springframework.boot.context.properties.ConfigurationProperties - -/** - * Configuration properties for the Koog library used for integrating with OpenRouter LLM provider. - * These properties are used in conjunction with the [KoogAutoConfiguration] auto-configuration class to initialize and - * configure respective client implementations. - * - * Configuration prefix: `ai.koog.openrouter` - * - * @param apiKey The API key used to authenticate requests to the provider's service - * @param baseUrl The base URL of the provider's API endpoint. By default, it is set to `https://openrouter.ai` - */ -@ConfigurationProperties(prefix = OpenRouterKoogProperties.PREFIX) -public class OpenRouterKoogProperties( - public val apiKey: String = "", - public val baseUrl: String = "https://openrouter.ai", - public val retry: RetryConfigKoogProperties? = null -) { - /** - * Companion object for the OpenRouterKoogProperties class, providing constant values and - * utilities associated with the configuration of OpenRouter-related properties. - */ - public companion object Companion { - /** - * Prefix constant used for configuration OpenRouter-related properties in the Koog framework. - */ - public const val PREFIX: String = "ai.koog.openrouter" - } -} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/RetryConfigKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/RetryConfigKoogProperties.kt deleted file mode 100644 index abaac3b678..0000000000 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/RetryConfigKoogProperties.kt +++ /dev/null @@ -1,19 +0,0 @@ -package ai.koog.spring - -import org.springframework.boot.convert.DurationUnit -import java.time.Duration -import java.time.temporal.ChronoUnit - -/** - * Configuration properties for the Koog library used for LLM clients retry configuration. - */ -public class RetryConfigKoogProperties( - public val enabled: Boolean = false, - public val maxAttempts: Int? = null, - @param:DurationUnit(ChronoUnit.SECONDS) - public val initialDelay: Duration? = null, - @param:DurationUnit(ChronoUnit.SECONDS) - public val maxDelay: Duration? = null, - public val backoffMultiplier: Double? = null, - public val jitterFactor: Double? = null -) diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/conditions/OnPropertyNotEmptyCondition.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/conditions/OnPropertyNotEmptyCondition.kt new file mode 100644 index 0000000000..665e045494 --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/conditions/OnPropertyNotEmptyCondition.kt @@ -0,0 +1,45 @@ +package ai.koog.spring.conditions + +import org.springframework.boot.autoconfigure.condition.ConditionOutcome +import org.springframework.boot.autoconfigure.condition.SpringBootCondition +import org.springframework.context.annotation.ConditionContext +import org.springframework.context.annotation.Conditional +import org.springframework.core.type.AnnotatedTypeMetadata +import kotlin.reflect.jvm.jvmName + +@Target(AnnotationTarget.CLASS, AnnotationTarget.FUNCTION) +@Retention(AnnotationRetention.RUNTIME) +@Conditional(OnPropertyNotEmptyCondition::class) +public annotation class ConditionalOnPropertyNotEmpty( + val prefix: String = "", + val name: String +) + +public class OnPropertyNotEmptyCondition : SpringBootCondition() { + + override fun getMatchOutcome( + context: ConditionContext, + metadata: AnnotatedTypeMetadata + ): ConditionOutcome { + val attributes = metadata.getAllAnnotationAttributes( + ConditionalOnPropertyNotEmpty::class.jvmName, + ) + + val prefix = attributes?.get("prefix")?.firstOrNull() as? String ?: "" + val name = attributes?.get("name")?.firstOrNull() as? String ?: "" + + val propertyKey = if (prefix.isNotEmpty()) { + "$prefix.$name" + } else { + name + } + + val value = context.environment.getProperty(propertyKey) + + return if (!value.isNullOrEmpty()) { + ConditionOutcome.match("Property '$propertyKey' has non-empty value") + } else { + ConditionOutcome.noMatch("Property '$propertyKey' is missing or empty") + } + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/KoogLlmClientProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/KoogLlmClientProperties.kt new file mode 100644 index 0000000000..a48069c4b4 --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/KoogLlmClientProperties.kt @@ -0,0 +1,19 @@ +package ai.koog.spring.prompt.executor.clients + +import ai.koog.spring.RetryConfigKoogProperties + +/** + * Interface representing configuration properties for a Koog LLM Client. + * + * This interface is intended to provide the necessary configuration required to set up a LLM Client. + * It includes options for enabling the client, specifying the base URL, and defining retry configurations. + * + * @param enabled Indicates whether the LLM client is enabled. + * @param baseUrl Specifies the base URL for the LLM client. + * @param retry Defines the retry configuration for the LLM client using [RetryConfigKoogProperties]. + */ +public interface KoogLlmClientProperties { + public val enabled: Boolean + public val baseUrl: String + public val retry: RetryConfigKoogProperties? +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/RetryConfigKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/RetryConfigKoogProperties.kt new file mode 100644 index 0000000000..09ff05dcca --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/RetryConfigKoogProperties.kt @@ -0,0 +1,29 @@ +package ai.koog.spring + +import org.springframework.boot.convert.DurationUnit +import java.time.Duration +import java.time.temporal.ChronoUnit + +/** + * Represents configuration properties for retry mechanisms associated with clients. + * + * This class is used to define retry behavior, including specifications such as the number of retry + * attempts, delay between attempts, and mechanisms to control backoff strategy and randomness in delays. + * + * @property enabled Indicates whether retries are enabled. + * @property maxAttempts Specifies the maximum number of retry attempts. + * @property initialDelay Specifies the initial delay before the first retry attempt, in seconds. + * @property maxDelay Specifies the maximum delay allowed between retry attempts, in seconds. + * @property backoffMultiplier Defines the multiplier to apply to the delay for each subsequent retry attempt. + * @property jitterFactor Specifies the factor to introduce randomness in the delay calculations to avoid symmetric loads. + */ +public class RetryConfigKoogProperties( + public val enabled: Boolean = false, + public val maxAttempts: Int? = null, + @param:DurationUnit(ChronoUnit.SECONDS) + public val initialDelay: Duration? = null, + @param:DurationUnit(ChronoUnit.SECONDS) + public val maxDelay: Duration? = null, + public val backoffMultiplier: Double? = null, + public val jitterFactor: Double? = null +) diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/anthropic/AnthropicKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/anthropic/AnthropicKoogProperties.kt new file mode 100644 index 0000000000..07ad9d507f --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/anthropic/AnthropicKoogProperties.kt @@ -0,0 +1,49 @@ +package ai.koog.spring.prompt.executor.clients.anthropic + +import ai.koog.spring.RetryConfigKoogProperties +import ai.koog.spring.prompt.executor.clients.KoogLlmClientProperties +import org.springframework.boot.context.properties.ConfigurationProperties + +/** + * Configuration properties for configuring Anthropic-related clients in the Koog framework. + * + * This class allows defining settings necessary for integrating with the Anthropic LLM (Large Language Model) + * client. It implements [KoogLlmClientProperties] and includes common LLM client configurations such as `enabled`, + * `baseUrl`, and retry options. Additionally, it includes the `apiKey` property specific to the Anthropic client. + * + * The properties are bound to the configuration prefix defined by [AnthropicKoogProperties.PREFIX], which is + * `ai.koog.anthropic`. This allows configuring the client via property files in a Spring Boot application. + * + * @property enabled Indicates whether the Anthropic client is enabled. If `false`, the client will not be configured. + * @property apiKey The API key used to authenticate requests to the Anthropic API. + * @property baseUrl The base URL of the Anthropic API for sending requests. + * @property retry Retry configuration for the client in case of failed or timeout requests. This is optional. + */ +@ConfigurationProperties(prefix = AnthropicKoogProperties.PREFIX, ignoreUnknownFields = true) +public class AnthropicKoogProperties( + public override val enabled: Boolean, + public val apiKey: String, + public override val baseUrl: String, + public override val retry: RetryConfigKoogProperties? = null +) : KoogLlmClientProperties { + /** + * Companion object providing constant value associated with the configuration of Anthropic-related properties. + */ + public companion object Companion { + /** + * Prefix constant used for configuration Anthropic-related properties in the Koog framework. + */ + public const val PREFIX: String = "ai.koog.anthropic" + } + + /** + * Returns a string representation of the `AnthropicKoogProperties` object. + * + * The representation includes the state of its properties: `enabled`, `apiKey`(masked), `baseUrl`, and `retry`. + * + * @return A string containing the property values of the `AnthropicKoogProperties` object. + */ + override fun toString(): String { + return "AnthropicKoogProperties(enabled=$enabled, apiKey='$apiKey', baseUrl='$baseUrl', retry=$retry)" + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/anthropic/AnthropicLLMAutoConfiguration.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/anthropic/AnthropicLLMAutoConfiguration.kt new file mode 100644 index 0000000000..f2bd1f164a --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/anthropic/AnthropicLLMAutoConfiguration.kt @@ -0,0 +1,84 @@ +package ai.koog.spring.prompt.executor.clients.anthropic + +import ai.koog.prompt.executor.clients.anthropic.AnthropicClientSettings +import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor +import ai.koog.spring.conditions.ConditionalOnPropertyNotEmpty +import ai.koog.spring.prompt.executor.clients.toRetryingClient +import org.slf4j.LoggerFactory +import org.springframework.boot.autoconfigure.AutoConfiguration +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty +import org.springframework.boot.context.properties.EnableConfigurationProperties +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.PropertySource + +/** + * Auto-configuration class for Anthropic LLM integration in a Spring Boot application. + * + * This class automatically configures the required beans for interacting with the Anthropic LLM + * when the appropriate configuration properties are set in the application. It specifically checks + * for the presence of the `ai.koog.anthropic.enabled` and `ai.koog.anthropic.api-key` properties. + * + * Beans provided by this configuration: + * - [AnthropicLLMClient]: Configured client for interacting with the Anthropic API. + * - [SingleLLMPromptExecutor]: Prompt executor that utilizes the configured Anthropic client. + * + * To enable this configuration, the `ai.koog.anthropic.enabled` property must be set to `true` + * and a valid `ai.koog.anthropic.api-key` must be provided in the application's property files. + * + * This configuration reads additional properties imported via `spring.config.import` from the starter's + * application.properties file and binds them to the [AnthropicKoogProperties]. + * + * @property properties Anthropic-specific configuration properties, automatically injected by Spring's + * configuration properties mechanism. + */ +@AutoConfiguration +@PropertySource("classpath:/META-INF/config/koog/anthropic-llm.properties") +@EnableConfigurationProperties( + AnthropicKoogProperties::class, +) +public class AnthropicLLMAutoConfiguration( + private val properties: AnthropicKoogProperties +) { + + private val logger = LoggerFactory.getLogger(AnthropicLLMAutoConfiguration::class.java) + + /** + * Creates an [AnthropicLLMClient] bean configured with application properties. + * + * This method initializes a [AnthropicLLMClient] using the API key and base URL + * specified in the application's configuration. It is only executed if the + * `koog.ai.anthropic.api-key` property is defined and `koog.ai.anthropic.enabled` property is set + * to `true` in the application configuration. + * + * @return An [AnthropicLLMClient] instance configured with the provided settings. + */ + @Bean + @ConditionalOnProperty(prefix = AnthropicKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") + @ConditionalOnPropertyNotEmpty( + prefix = AnthropicKoogProperties.PREFIX, + name = "api-key" + ) + public fun anthropicLLMClient(): AnthropicLLMClient { + logger.info("Creating AnthropicLLMClient with baseUrl=${properties.baseUrl}") + return AnthropicLLMClient( + apiKey = properties.apiKey, + settings = AnthropicClientSettings(baseUrl = properties.baseUrl) + ) + } + + /** + * Creates and initializes a [SingleLLMPromptExecutor] instance using an [AnthropicLLMClient]. + * The executor is configured with a retrying client derived from the provided AnthropicLLMClient. + * + * @param client An instance of [AnthropicLLMClient] used to communicate with the Anthropic LLM API. + * @return An instance of [SingleLLMPromptExecutor] for sending prompts to the Anthropic LLM API. + */ + @Bean + @ConditionalOnBean(AnthropicLLMClient::class) + public fun anthropicExecutor(client: AnthropicLLMClient): SingleLLMPromptExecutor { + logger.info("Creating SingleLLMPromptExecutor (anthropicExecutor) for AnthropicLLMClient") + return SingleLLMPromptExecutor(client.toRetryingClient(properties.retry)) + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/deepseek/DeepSeekKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/deepseek/DeepSeekKoogProperties.kt new file mode 100644 index 0000000000..0e9060bdc4 --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/deepseek/DeepSeekKoogProperties.kt @@ -0,0 +1,55 @@ +package ai.koog.spring.prompt.executor.clients.deepseek + +import ai.koog.spring.RetryConfigKoogProperties +import ai.koog.spring.prompt.executor.clients.KoogLlmClientProperties +import ai.koog.spring.prompt.executor.clients.deepseek.DeepSeekKoogProperties.Companion.PREFIX +import ai.koog.utils.lang.masked +import org.springframework.boot.context.properties.ConfigurationProperties + +/** + * Configuration properties class for DeepSeek LLM provider integration within the Koog framework. + * + * This class is used to define and manage application-level configuration parameters for connecting + * to the DeepSeek provider. It includes properties such as API key, base URL, and optional retry settings. + * + * The properties are auto-configured via Spring Boot's `@ConfigurationProperties` using the `ai.koog.deepseek` prefix. + * + * Implements the [KoogLlmClientProperties] interface, which provides base attributes for all LLM client property configurations. + * + * Properties from this class are typically consumed by auto-configuration classes, such as [DeepSeekLLMAutoConfiguration], + * to initialize and configure the necessary beans for working with the DeepSeek API. + * + * @property enabled Indicates whether DeepSeek API integration is enabled (true or false). + * @property apiKey An API key string required to authenticate requests to the DeepSeek external service. + * @property baseUrl The base URL endpoint for DeepSeek API calls. + * @property retry Optional retry configuration for API requests, represented by [RetryConfigKoogProperties]. + */ +@ConfigurationProperties(prefix = PREFIX, ignoreUnknownFields = true) +public class DeepSeekKoogProperties( + public override val enabled: Boolean, + public val apiKey: String, + public override val baseUrl: String, + public override val retry: RetryConfigKoogProperties? = null +) : KoogLlmClientProperties { + /** + * Companion object for the DeepSeekKoogProperties class, providing constant values and + * utilities associated with the configuration of DeepSeek-related properties. + */ + public companion object Companion { + /** + * Prefix constant used for configuration DeepSeek-related properties in the Koog framework. + */ + public const val PREFIX: String = "ai.koog.deepseek" + } + + /** + * Returns a string representation of the DeepSeekKoogProperties object. + * The string includes information about the `enabled` status, a masked representation of the `apiKey`, + * the `baseUrl`, and the `retry` configuration. + * + * @return A string summarizing the current configuration properties. + */ + override fun toString(): String { + return "DeepSeekKoogProperties(enabled=$enabled, apiKey='$${apiKey.masked()}', baseUrl='$baseUrl', retry=$retry)" + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/deepseek/DeepSeekLLMAutoConfiguration.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/deepseek/DeepSeekLLMAutoConfiguration.kt new file mode 100644 index 0000000000..9026740baa --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/deepseek/DeepSeekLLMAutoConfiguration.kt @@ -0,0 +1,92 @@ +package ai.koog.spring.prompt.executor.clients.deepseek + +import ai.koog.prompt.executor.clients.deepseek.DeepSeekClientSettings +import ai.koog.prompt.executor.clients.deepseek.DeepSeekLLMClient +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor +import ai.koog.spring.conditions.ConditionalOnPropertyNotEmpty +import ai.koog.spring.prompt.executor.clients.toRetryingClient +import org.slf4j.LoggerFactory +import org.springframework.boot.autoconfigure.AutoConfiguration +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty +import org.springframework.boot.context.properties.EnableConfigurationProperties +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.PropertySource + +/** + * Auto-configuration class for integrating with the DeepSeek LLM provider within a Spring Boot application. + * + * This configuration enables the auto-wiring of required beans when the appropriate application + * properties are set. The configuration ensures that the [DeepSeekLLMClient] is properly initialized + * and available for usage in the application. + * + * The following conditions must be met for this configuration to be activated: + * - The property `ai.koog.deepseek.api-key` must be defined in the application configuration. + * - The property `ai.koog.deepseek.enabled` must have a value of `true`. + * + * Properties used: + * - `ai.koog.deepseek.api-key`: API key required to authenticate requests to the DeepSeek API. + * - `ai.koog.deepseek.base-url`: Base URL of the DeepSeek API, with a default value of `https://api.deepseek.com`. + * - `ai.koog.deepseek.retry`: Retry configuration settings for failed requests. + * + * Beans provided: + * - [DeepSeekLLMClient]: A client for interacting with the DeepSeek API. + * - [SingleLLMPromptExecutor]: A bean for executing single-step LLM prompts using the DeepSeek client. + * + * @property properties [DeepSeekKoogProperties] containing the configuration properties for the DeepSeek client. + * + * @see DeepSeekKoogProperties + * @see DeepSeekLLMClient + * @see SingleLLMPromptExecutor + */ +@AutoConfiguration +@PropertySource("classpath:/META-INF/config/koog/deepseek-llm.properties") +@EnableConfigurationProperties( + DeepSeekKoogProperties::class, +) +public class DeepSeekLLMAutoConfiguration( + private val properties: DeepSeekKoogProperties +) { + + private val logger = LoggerFactory.getLogger(DeepSeekLLMAutoConfiguration::class.java) + + /** + * Creates a [DeepSeekLLMClient] bean configured with application properties. + * + * This method initializes a [DeepSeekLLMClient] using the API key and base URL + * specified in the application's configuration. It is only executed if the + * `koog.ai.deepseek.api-key` property is defined and `koog.ai.deepseek.enabled` property is set + * to `true` in the application configuration. + * + * @return A [DeepSeekLLMClient] instance configured with the provided settings. + */ + @Bean + @ConditionalOnPropertyNotEmpty( + prefix = DeepSeekKoogProperties.PREFIX, + name = "api-key" + ) + @ConditionalOnProperty(prefix = DeepSeekKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") + public fun deepSeekLLMClient(): DeepSeekLLMClient { + logger.info("Creating DeepSeekLLMClient with baseUrl=${properties.baseUrl}") + return DeepSeekLLMClient( + apiKey = properties.apiKey, + settings = DeepSeekClientSettings(baseUrl = properties.baseUrl) + ) + } + + /** + * Creates a [SingleLLMPromptExecutor] bean configured to use the DeepSeek LLM client. + * + * This method is only executed if the `deepseek.api-key` property is defined in the application's configuration. + * It initializes the DeepSeek client using the provided API key and base URL from the application's properties. + * + * @param properties The configuration properties for the application, including the DeepSeek client settings. + * @return A [SingleLLMPromptExecutor] initialized with an DeepSeek LLM client. + */ + @Bean + @ConditionalOnBean(DeepSeekLLMClient::class) + public fun deepSeekExecutor(client: DeepSeekLLMClient): SingleLLMPromptExecutor { + logger.info("Creating SingleLLMPromptExecutor (deepSeekExecutor) for DeepSeekLLMClient") + return SingleLLMPromptExecutor(client.toRetryingClient(properties.retry)) + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/google/GoogleKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/google/GoogleKoogProperties.kt new file mode 100644 index 0000000000..cf8610b7fc --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/google/GoogleKoogProperties.kt @@ -0,0 +1,83 @@ +package ai.koog.spring.prompt.executor.clients.google + +import ai.koog.spring.RetryConfigKoogProperties +import ai.koog.spring.prompt.executor.clients.KoogLlmClientProperties +import ai.koog.utils.lang.masked +import org.springframework.boot.context.properties.ConfigurationProperties + +/** + * Configuration properties for integrating with Google's LLM services in the Koog framework. + * + * This class provides the necessary settings for enabling and configuring access + * to Google's LLM API, including authentication and retry behavior. + * The configuration is mapped using the prefix `ai.koog.google`. + * + * Parameters: + * @param enabled Indicates whether the Google LLM integration is enabled. + * @param apiKey The API key used for authenticating requests to Google's services. + * @param baseUrl The base URL of the Google LLM API. + * @param retry Optional configuration for retrying failed API calls. + * + * Usage: + * These properties are automatically bound to the Spring environment when specified + * in application configuration (e.g., `application.yml` or `application.properties`). + * + * Example configuration snippet in `application.yml` or `application.properties`: + * ```properties + * ai.koog.google.enabled=true + * ai.koog.google.api-key=your-google-api-key + * ai.koog.google.base-url=https://api.google.com/llm + * ai.koog.google.retry.enabled=true + * ai.koog.google.retry.max-attempts=3 + * ai.koog.google.retry.initial-delay=2s + * ai.koog.google.retry.max-delay=10s + * ai.koog.google.retry.backoff-multiplier=2.0 + * ai.koog.google.retry.jitter-factor=0.5 + * ``` + * + * Advanced Features: + * - The retry configuration supports customizable retries for handling transient failures. + * - Dedicated masking utility ensures that sensitive information, such as the API key, is + * not exposed when serialized or logged. + * + * This class is primarily used in conjunction with the `GoogleLLMAutoConfiguration` auto-configuration + * class to initialize and configure the necessary beans for interacting with Google's LLM API. + * + * For more details on retry behavior, refer to the `RetryConfigKoogProperties` class. + * For shared configuration attributes, see the `KoogLlmClientProperties` interface. + * + * @property enabled Enables or disables the Google LLM integration. + * @property apiKey The key required for authenticating with the API. + * @property baseUrl URL endpoint for the Google LLM API. + * @property retry Defines retry behavior, such as maximum attempts and delays between retries. + */ +@ConfigurationProperties(prefix = GoogleKoogProperties.PREFIX, ignoreUnknownFields = true) +public class GoogleKoogProperties( + public override val enabled: Boolean, + public val apiKey: String, + public override val baseUrl: String, + public override val retry: RetryConfigKoogProperties? = null +) : KoogLlmClientProperties { + /** + * Companion object for the GoogleKoogProperties class, providing constant values and + * utilities associated with the configuration of Google-related properties. + */ + public companion object Companion { + /** + * Prefix constant used for configuration Google-related properties in the Koog framework. + */ + public const val PREFIX: String = "ai.koog.google" + } + + /** + * Returns a string representation of the GoogleKoogProperties object. + * + * The resulting string includes details about the object's properties such as + * `enabled`, `apiKey` (with sensitive information masked), `baseUrl`, and `retry`. + * + * @return A string representation of the GoogleKoogProperties object. + */ + override fun toString(): String { + return "GoogleKoogProperties(enabled=$enabled, apiKey='${apiKey.masked()}', baseUrl='$baseUrl', retry=$retry)" + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/google/GoogleLLMAutoConfiguration.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/google/GoogleLLMAutoConfiguration.kt new file mode 100644 index 0000000000..bcaa5c7cc8 --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/google/GoogleLLMAutoConfiguration.kt @@ -0,0 +1,88 @@ +package ai.koog.spring.prompt.executor.clients.google + +import ai.koog.prompt.executor.clients.google.GoogleClientSettings +import ai.koog.prompt.executor.clients.google.GoogleLLMClient +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor +import ai.koog.spring.conditions.ConditionalOnPropertyNotEmpty +import ai.koog.spring.prompt.executor.clients.ollama.OllamaKoogProperties +import ai.koog.spring.prompt.executor.clients.toRetryingClient +import org.slf4j.LoggerFactory +import org.springframework.boot.autoconfigure.AutoConfiguration +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty +import org.springframework.boot.context.properties.EnableConfigurationProperties +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.PropertySource + +/** + * Provides the auto-configuration for integrating with Google LLM via the Koog framework. + * This class is responsible for initializing and configuring the necessary beans for interacting + * with Google's APIs, based on the configurations supplied via `GoogleKoogProperties`. + * + * The configuration is activated only when the property `ai.koog.google.enabled` is set to `true`, + * and an `api-key` is provided. + * + * Beans configured by this class: + * - [GoogleLLMClient]: A client for interacting with Google's LLM API, using the specified API key and settings. + * - [SingleLLMPromptExecutor]: An executor capable of handling and retrying LLM prompts, using the initialized client. + * + * An external configuration file at `classpath:/META-INF/config/koog/google-llm.properties` is leveraged + * for managing default settings. + * + * @property properties [GoogleKoogProperties] to define key settings such as API key, base URL, and retry configurations. + * + * @see GoogleKoogProperties + * @see GoogleLLMClient + * @see SingleLLMPromptExecutor + */ +@AutoConfiguration +@PropertySource("classpath:/META-INF/config/koog/google-llm.properties") +@EnableConfigurationProperties( + GoogleKoogProperties::class, +) +public class GoogleLLMAutoConfiguration( + private val properties: GoogleKoogProperties +) { + + private val logger = LoggerFactory.getLogger(GoogleLLMAutoConfiguration::class.java) + + /** + * Creates a [GoogleLLMClient] bean configured with application properties. + * + * This method initializes a [GoogleLLMClient] using the API key and base URL + * specified in the application's configuration. It is only executed if the + * `koog.ai.google.api-key` property is defined and `koog.ai.google.enabled` property is set + * to `true` in the application configuration. + * + * @return A [GoogleLLMClient] instance configured with the provided settings. + */ + @Bean + @ConditionalOnPropertyNotEmpty( + prefix = GoogleKoogProperties.PREFIX, + name = "api-key" + ) + @ConditionalOnProperty(prefix = GoogleKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") + public fun googleLLMClient(): GoogleLLMClient { + logger.info("Creating GoogleLLMClient with baseUrl=${properties.baseUrl}") + return GoogleLLMClient( + apiKey = properties.apiKey, + settings = GoogleClientSettings(baseUrl = properties.baseUrl) + ) + } + + /** + * Creates and configures a [SingleLLMPromptExecutor] instance using [GoogleLLMClient] properties. + * + * The method initializes an [GoogleLLMClient] with the base URL derived from the provided [OllamaKoogProperties] + * and uses it to construct the [SingleLLMPromptExecutor]. + * + * @param properties the configuration properties containing Ollama client settings such as the base URL. + * @return a [SingleLLMPromptExecutor] configured to use the Ollama client. + */ + @Bean + @ConditionalOnBean(GoogleLLMClient::class) + public fun googleExecutor(client: GoogleLLMClient): SingleLLMPromptExecutor { + logger.info("Creating SingleLLMPromptExecutor (googleExecutor) for GoogleLLMClient") + return SingleLLMPromptExecutor(client.toRetryingClient(properties.retry)) + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/ollama/OllamaKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/ollama/OllamaKoogProperties.kt new file mode 100644 index 0000000000..4e862900c9 --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/ollama/OllamaKoogProperties.kt @@ -0,0 +1,53 @@ +package ai.koog.spring.prompt.executor.clients.ollama + +import ai.koog.spring.RetryConfigKoogProperties +import ai.koog.spring.prompt.executor.clients.KoogLlmClientProperties +import org.springframework.boot.context.properties.ConfigurationProperties + +/** + * Configuration properties for the Ollama integration in the Koog framework. + * + * This class defines properties that control the connection and behavior when interacting + * with the Ollama Large Language Model (LLM) service. + * + * These properties are typically configured in the application properties or YAML file. + * + * The configuration prefix for these properties is defined as `ai.koog.ollama`. + * It is used to map these properties in the application configuration file. + * + * This class is designed to work along with the `OllamaLLMAutoConfiguration` class to + * automatically initialize and configure the required beans for the Ollama client and executor. + * + * @property enabled Indicates whether the Ollama integration is enabled. + * @property baseUrl The URL of the API endpoint for Ollama service. + * @property retry The retry settings for handling request failures. + */ +@ConfigurationProperties(prefix = OllamaKoogProperties.PREFIX, ignoreUnknownFields = true) +public class OllamaKoogProperties( + public override val enabled: Boolean, + public override val baseUrl: String, + public override val retry: RetryConfigKoogProperties? = null +) : KoogLlmClientProperties { + /** + * Companion object for the OllamaKoogProperties class, providing constant values and + * utilities associated with the configuration of Ollama-related properties. + */ + public companion object Companion { + /** + * Prefix constant used for configuration Ollama-related properties in the Koog framework. + */ + public const val PREFIX: String = "ai.koog.ollama" + } + + /** + * Returns a string representation of the `OllamaKoogProperties` object. + * + * The string includes values for the `enabled`, `baseUrl`, and `retry` properties + * to provide a comprehensive overview of the configuration state of the object. + * + * @return a string describing the current state of the `OllamaKoogProperties` instance. + */ + override fun toString(): String { + return "OllamaKoogProperties(enabled=$enabled, baseUrl='$baseUrl', retry=$retry)" + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/ollama/OllamaLLMAutoConfiguration.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/ollama/OllamaLLMAutoConfiguration.kt new file mode 100644 index 0000000000..212c7a3951 --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/ollama/OllamaLLMAutoConfiguration.kt @@ -0,0 +1,77 @@ +package ai.koog.spring.prompt.executor.clients.ollama + +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor +import ai.koog.prompt.executor.ollama.client.OllamaClient +import ai.koog.spring.prompt.executor.clients.toRetryingClient +import org.slf4j.LoggerFactory +import org.springframework.boot.autoconfigure.AutoConfiguration +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty +import org.springframework.boot.context.properties.EnableConfigurationProperties +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.PropertySource + +/** + * Auto-configuration class for integrating the Ollama Large Language Model (LLM) service into applications. + * + * This configuration initializes and provides the necessary beans to enable interaction with the Ollama LLM API. + * It relies on properties defined in the [OllamaKoogProperties] class to set up the service. + * + * The configuration is conditional and will only be initialized if: + * - [OllamaKoogProperties.enabled] is set to `true`. + * - The required [OllamaKoogProperties] are provided in the application configuration. + * + * Initializes the following beans: + * - [OllamaClient]: A client for interacting with the Ollama LLM service. + * - [SingleLLMPromptExecutor]: Executes single-prompt interactions with Ollama, utilizing the client. + * + * This configuration allows seamless integration with the Ollama API while enabling properties-based customization. + * + * @property properties [OllamaKoogProperties] to define key settings such as API key, base URL, and retry configurations. + * @see OllamaKoogProperties + * @see OllamaClient + * @see SingleLLMPromptExecutor + */ +@AutoConfiguration +@PropertySource("classpath:/META-INF/config/koog/ollama-llm.properties") +@EnableConfigurationProperties( + OllamaKoogProperties::class, +) +public class OllamaLLMAutoConfiguration( + private val properties: OllamaKoogProperties +) { + + private val logger = LoggerFactory.getLogger(OllamaLLMAutoConfiguration::class.java) + + /** + * Creates an [OllamaClient] bean configured with application properties. + * + * This method initializes a [OllamaClient] using the API key and base URL + * specified in the application's configuration. It is only executed if the + * `koog.ai.ollama.enabled` property is set to `true` in the application configuration. + * + * @return An [OllamaClient] instance configured with the provided settings. + */ + @Bean + @ConditionalOnProperty(prefix = OllamaKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") + public fun ollamaLLMClient(): OllamaClient { + logger.info("Creating OllamaClient with baseUrl=${properties.baseUrl}") + return OllamaClient( + baseUrl = properties.baseUrl, + ) + } + + /** + * Creates and configures an instance of [SingleLLMPromptExecutor] that wraps the provided [OllamaClient]. + * The configured executor includes retry capabilities based on the application's properties. + * + * @param client the [OllamaClient] instance used for communicating with the Ollama LLM service. + * @return a [SingleLLMPromptExecutor] configured to execute LLM prompts with the provided client. + */ + @Bean + @ConditionalOnBean(OllamaClient::class) + public fun ollamaExecutor(client: OllamaClient): SingleLLMPromptExecutor { + logger.info("Creating SingleLLMPromptExecutor (ollamaExecutor) for OllamaClient") + return SingleLLMPromptExecutor(client.toRetryingClient(properties.retry)) + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openai/OpenAIKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openai/OpenAIKoogProperties.kt new file mode 100644 index 0000000000..0d9bbe7069 --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openai/OpenAIKoogProperties.kt @@ -0,0 +1,51 @@ +package ai.koog.spring.prompt.executor.clients.openai + +import ai.koog.spring.RetryConfigKoogProperties +import ai.koog.spring.prompt.executor.clients.KoogLlmClientProperties +import ai.koog.utils.lang.masked +import org.springframework.boot.context.properties.ConfigurationProperties + +/** + * Configuration class for OpenAI settings in the Koog framework. + * This class defines properties required to configure and use OpenAI-related services, such as API keys, + * base URLs for the services, and optional retry configurations. + * + * The class is annotated with `@ConfigurationProperties` to bind its fields to configuration file properties + * prefixed with `ai.koog.openai`. + * + * @property enabled Determines if the OpenAI client is enabled. + * @property apiKey The API key required to authenticate requests to the OpenAI API. + * @property baseUrl The base URL for accessing the OpenAI API. + * @property retry Optional retry configuration settings, such as maximum attempts, delays, and backoff strategies. + * + * This configuration is used in `OpenAILLMAutoConfiguration` to set up the OpenAI client and related beans. + */ +@ConfigurationProperties(prefix = OpenAIKoogProperties.PREFIX, ignoreUnknownFields = true) +public class OpenAIKoogProperties( + public override val enabled: Boolean, + public val apiKey: String, + public override val baseUrl: String, + public override val retry: RetryConfigKoogProperties? = null +) : KoogLlmClientProperties { + /** + * Companion object for the OpenAIKoogProperties class, providing constant values and + * utilities associated with the configuration of OpenAI-related properties. + */ + public companion object { + /** + * Prefix constant used for configuration OpenAI-related properties in the Koog framework. + */ + public const val PREFIX: String = "ai.koog.openai" + } + + /** + * Returns a string representation of the `OpenAIKoogProperties` object. + * The string includes details about the `enabled` status, masked `apiKey`, + * `baseUrl`, and `retry` configuration. + * + * @return A string summarizing the `OpenAIKoogProperties` object's state. + */ + override fun toString(): String { + return "OpenAIKoogProperties(enabled=$enabled, apiKey='$${apiKey.masked()}', baseUrl='$baseUrl', retry=$retry)" + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openai/OpenAILLMAutoConfiguration.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openai/OpenAILLMAutoConfiguration.kt new file mode 100644 index 0000000000..ea69bbf97a --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openai/OpenAILLMAutoConfiguration.kt @@ -0,0 +1,86 @@ +package ai.koog.spring.prompt.executor.clients.openai + +import ai.koog.prompt.executor.clients.openai.OpenAIClientSettings +import ai.koog.prompt.executor.clients.openai.OpenAILLMClient +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor +import ai.koog.spring.conditions.ConditionalOnPropertyNotEmpty +import ai.koog.spring.prompt.executor.clients.toRetryingClient +import org.slf4j.LoggerFactory +import org.springframework.boot.autoconfigure.AutoConfiguration +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty +import org.springframework.boot.context.properties.EnableConfigurationProperties +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.PropertySource + +/** + * Auto-configuration class for setting up OpenAI LLM client and related beans. + * This class utilizes the properties defined in [OpenAIKoogProperties] to configure and initialize OpenAI-related components, + * including the client and executor. + * + * The configuration is conditionally applied if the property `ai.koog.openai.api-key` is set to `true`. + * It reads additional configuration from the properties file located at `classpath:/META-INF/config/koog/openai-llm.properties`. + * + * Key Features: + * - Sets up the [OpenAILLMClient] bean with API key and base URL from the provided properties. + * - Configures a [SingleLLMPromptExecutor] bean using the configured OpenAI client with retry capabilities. + * + * Usage Notes: + * - To activate, ensure the `ai.koog.openai.api-key` property is defined in your application configuration. + * - Customize behavior and settings using the `ai.koog.openai.*` configuration properties. + * + * @property properties [OpenAIKoogProperties] to define key settings such as API key, base URL, and retry configurations. + * + * @see OpenAIKoogProperties + * @see OpenAILLMClient + * @see SingleLLMPromptExecutor + */ +@AutoConfiguration +@PropertySource("classpath:/META-INF/config/koog/openai-llm.properties") +@EnableConfigurationProperties( + OpenAIKoogProperties::class, +) +public class OpenAILLMAutoConfiguration( + private val properties: OpenAIKoogProperties +) { + + private val logger = LoggerFactory.getLogger(OpenAILLMAutoConfiguration::class.java) + + /** + * Creates an [OpenAILLMClient] bean configured with application properties. + * + * This method initializes a [OpenAILLMClient] using the API key and base URL + * specified in the application's configuration. It is only executed if the + * `koog.ai.openai.api-key` property is defined and `koog.ai.openai.enabled` property is set + * to `true` in the application configuration. + * + * @return An [OpenAILLMClient] instance configured with the provided settings. + */ + @Bean + @ConditionalOnPropertyNotEmpty( + prefix = OpenAIKoogProperties.PREFIX, + name = "api-key" + ) + @ConditionalOnProperty(prefix = OpenAIKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") + public fun openAILLMClient(): OpenAILLMClient { + logger.info("Creating OpenAILLMClient client with baseUrl=${properties.baseUrl}") + return OpenAILLMClient( + apiKey = properties.apiKey, + settings = OpenAIClientSettings(baseUrl = properties.baseUrl) + ) + } + + /** + * Creates and returns a [SingleLLMPromptExecutor] bean configured with the given [OpenAILLMClient]. + * This bean is conditionally initialized only when an [OpenAILLMClient] bean is present in the application context. + * + * @param client the [OpenAILLMClient] used to create a retry-capable client for executing LLM prompts. + * @return a configured instance of [SingleLLMPromptExecutor]. + */ + @Bean + @ConditionalOnBean(OpenAILLMClient::class) + public fun openAIExecutor(client: OpenAILLMClient): SingleLLMPromptExecutor { + logger.info("Creating SingleLLMPromptExecutor (openAIExecutor) for OpenAILLMClient") + return SingleLLMPromptExecutor(client.toRetryingClient(properties.retry)) + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openrouter/OpenRouterKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openrouter/OpenRouterKoogProperties.kt new file mode 100644 index 0000000000..95b3bc2c85 --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openrouter/OpenRouterKoogProperties.kt @@ -0,0 +1,56 @@ +package ai.koog.spring.prompt.executor.clients.openrouter + +import ai.koog.spring.RetryConfigKoogProperties +import ai.koog.spring.prompt.executor.clients.KoogLlmClientProperties +import ai.koog.utils.lang.masked +import org.springframework.boot.context.properties.ConfigurationProperties + +/** + * Configuration properties class for OpenRouter integration within the Koog framework. + * + * This class defines configuration options required for connecting to the OpenRouter service + * via the Koog framework. It includes parameters such as API key, base URL, enabling or disabling + * the integration, and retry configuration for handling API requests. + * + * When properly configured in the application properties using the defined prefix, this class + * allows seamless integration with OpenRouter's LLM services. + * + * Configuration prefix: `ai.koog.openrouter` + * + * @property enabled Specifies whether the OpenRouter integration is enabled. This can be toggled + * via the property `ai.koog.openrouter.enabled`. + * @property apiKey The API key used for authenticating requests to the OpenRouter service. + * This must be provided through the property `ai.koog.openrouter.api-key`. + * @property baseUrl The base URL of the OpenRouter API endpoint, configurable via the + * property `ai.koog.openrouter.base-url`. Defaults to the service's official API URL. + * @property retry An optional retry configuration for handling failed API requests. + * This can be set using sub-properties under `ai.koog.openrouter.retry`. + */ +@ConfigurationProperties(prefix = OpenRouterKoogProperties.PREFIX, ignoreUnknownFields = true) +public class OpenRouterKoogProperties( + public override val enabled: Boolean, + public val apiKey: String, + public override val baseUrl: String, + public override val retry: RetryConfigKoogProperties? = null +) : KoogLlmClientProperties { + /** + * Companion object for the [OpenRouterKoogProperties] class, providing constant values and + * utilities associated with the configuration of OpenRouter-related properties. + */ + public companion object Companion { + /** + * Prefix constant used for configuration OpenRouter-related properties in the Koog framework. + */ + public const val PREFIX: String = "ai.koog.openrouter" + } + + /** + * Converts the [OpenRouterKoogProperties] instance to its string representation. + * Sensitive information, such as the API key, is masked to ensure security. + * + * @return A string representation of the [OpenRouterKoogProperties] object. + */ + override fun toString(): String { + return "OpenRouterKoogProperties(enabled=$enabled, apiKey='${apiKey.masked()}', baseUrl='$baseUrl', retry=$retry)" + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openrouter/OpenRouterLLMAutoConfiguration.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openrouter/OpenRouterLLMAutoConfiguration.kt new file mode 100644 index 0000000000..a8b33ad5b2 --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openrouter/OpenRouterLLMAutoConfiguration.kt @@ -0,0 +1,79 @@ +package ai.koog.spring.prompt.executor.clients.openrouter + +import ai.koog.prompt.executor.clients.openrouter.OpenRouterClientSettings +import ai.koog.prompt.executor.clients.openrouter.OpenRouterLLMClient +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor +import ai.koog.spring.conditions.ConditionalOnPropertyNotEmpty +import ai.koog.spring.prompt.executor.clients.toRetryingClient +import org.slf4j.LoggerFactory +import org.springframework.boot.autoconfigure.AutoConfiguration +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty +import org.springframework.boot.context.properties.EnableConfigurationProperties +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.PropertySource + +/** + * Auto-configuration class for integrating OpenRouter with Koog framework. + * + * This class enables the automatic configuration of beans and properties to work with OpenRouter's LLM services, + * provided the application properties have been set with the required prefix and fields. + * + * The configuration is activated only when both `ai.koog.openrouter.enabled` is set to `true` + * and `ai.koog.openrouter.api-key` is provided in the application properties. + * + * @property properties [OpenRouterKoogProperties] to define key settings such as API key, base URL, and retry configurations. + * @see OpenRouterKoogProperties + * @see OpenRouterLLMClient + * @see SingleLLMPromptExecutor + */ +@AutoConfiguration +@PropertySource("classpath:/META-INF/config/koog/openrouter-llm.properties") +@EnableConfigurationProperties( + OpenRouterKoogProperties::class, +) +public class OpenRouterLLMAutoConfiguration( + private val properties: OpenRouterKoogProperties +) { + + private val logger = LoggerFactory.getLogger(OpenRouterLLMAutoConfiguration::class.java) + + /** + * Creates an [OpenRouterLLMClient] bean configured with application properties. + * + * This method initializes a [OpenRouterLLMClient] using the API key and base URL + * specified in the application's configuration. It is only executed if the + * `koog.ai.openrouter.api-key` property is defined and `koog.ai.openrouter.enabled` property is set + * to `true` in the application configuration. + * + * @return An [OpenRouterLLMClient] instance configured with the provided settings. + */ + @Bean + @ConditionalOnPropertyNotEmpty( + prefix = OpenRouterKoogProperties.PREFIX, + name = "api-key" + ) + @ConditionalOnProperty(prefix = OpenRouterKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") + public fun openRouterLLMClient(): OpenRouterLLMClient { + logger.info("Creating OpenRouterLLMClient with baseUrl=${properties.baseUrl}") + return OpenRouterLLMClient( + apiKey = properties.apiKey, + settings = OpenRouterClientSettings(baseUrl = properties.baseUrl) + ) + } + + /** + * Provides a [SingleLLMPromptExecutor] bean configured with an [OpenRouterLLMClient]. + * + * The method uses the provided [OpenRouterLLMClient] to create a retrying client instance + * based on the configuration in the `properties.retry` parameter. + * + * @param client The [OpenRouterLLMClient] instance used to configure the [SingleLLMPromptExecutor] + * */ + @Bean + @ConditionalOnBean(OpenRouterLLMClient::class) + public fun openRouterExecutor(client: OpenRouterLLMClient): SingleLLMPromptExecutor { + logger.info("Creating SingleLLMPromptExecutor (openRouterExecutor) for OpenRouterLLMClient") + return SingleLLMPromptExecutor(client.toRetryingClient(properties.retry)) + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/utils.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/utils.kt new file mode 100644 index 0000000000..16893377af --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/utils.kt @@ -0,0 +1,24 @@ +package ai.koog.spring.prompt.executor.clients + +import ai.koog.prompt.executor.clients.LLMClient +import ai.koog.prompt.executor.clients.retry.RetryConfig +import ai.koog.prompt.executor.clients.retry.toRetryingClient +import ai.koog.spring.RetryConfigKoogProperties +import kotlin.time.toKotlinDuration + +internal fun LLMClient.toRetryingClient(properties: RetryConfigKoogProperties?): LLMClient { + val self = this + return if (properties?.enabled == true) { + val defaultConfig = RetryConfig.DEFAULT + val retryConfig = RetryConfig( + maxAttempts = properties.maxAttempts ?: defaultConfig.maxAttempts, + initialDelay = properties.initialDelay?.toKotlinDuration() ?: defaultConfig.initialDelay, + maxDelay = properties.maxDelay?.toKotlinDuration() ?: defaultConfig.maxDelay, + backoffMultiplier = properties.backoffMultiplier ?: defaultConfig.backoffMultiplier, + jitterFactor = properties.jitterFactor ?: defaultConfig.jitterFactor + ) + self.toRetryingClient(retryConfig) + } else { + self + } +} diff --git a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/anthropic-llm.properties b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/anthropic-llm.properties new file mode 100644 index 0000000000..6f77f4a40e --- /dev/null +++ b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/anthropic-llm.properties @@ -0,0 +1,3 @@ +ai.koog.anthropic.api-key=${ANTHROPIC_API_KEY:} +ai.koog.anthropic.base-url=https://api.anthropic.com +ai.koog.anthropic.enabled=true diff --git a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/deepseek-llm.properties b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/deepseek-llm.properties new file mode 100644 index 0000000000..9479ac722f --- /dev/null +++ b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/deepseek-llm.properties @@ -0,0 +1,3 @@ +ai.koog.deepseek.api-key=${DEEPSEEK_API_KEY:} +ai.koog.deepseek.base-url=https://api.deepseek.com +ai.koog.deepseek.enabled=true diff --git a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/google-llm.properties b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/google-llm.properties new file mode 100644 index 0000000000..2b2b1cbf73 --- /dev/null +++ b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/google-llm.properties @@ -0,0 +1,3 @@ +ai.koog.google.api-key=${GOOGLE_API_KEY:} +ai.koog.google.base-url=https://generativelanguage.googleapis.com +ai.koog.google.enabled=true diff --git a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/ollama-llm.properties b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/ollama-llm.properties new file mode 100644 index 0000000000..ff1168ed2e --- /dev/null +++ b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/ollama-llm.properties @@ -0,0 +1,2 @@ +ai.koog.ollama.enabled=false +ai.koog.ollama.base-url=http://127.0.0.1:11434 diff --git a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openai-llm.properties b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openai-llm.properties new file mode 100644 index 0000000000..e9aceb657b --- /dev/null +++ b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openai-llm.properties @@ -0,0 +1,3 @@ +ai.koog.openai.enabled=true +ai.koog.openai.base-url=https://api.openai.com +ai.koog.openai.api-key=${OPENAI_API_KEY:} diff --git a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openrouter-llm.properties b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openrouter-llm.properties new file mode 100644 index 0000000000..6e747ac5ec --- /dev/null +++ b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openrouter-llm.properties @@ -0,0 +1,3 @@ +ai.koog.openrouter.api-key=${OPENROUTER_API_KEY:} +ai.koog.openrouter.base-url=https://openrouter.ai +ai.koog.openrouter.enabled=true diff --git a/koog-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/koog-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index 398557bf69..d9859c91bb 100644 --- a/koog-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/koog-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -1 +1,6 @@ -ai.koog.spring.KoogAutoConfiguration \ No newline at end of file +ai.koog.spring.prompt.executor.clients.anthropic.AnthropicLLMAutoConfiguration +ai.koog.spring.prompt.executor.clients.deepseek.DeepSeekLLMAutoConfiguration +ai.koog.spring.prompt.executor.clients.google.GoogleLLMAutoConfiguration +ai.koog.spring.prompt.executor.clients.ollama.OllamaLLMAutoConfiguration +ai.koog.spring.prompt.executor.clients.openai.OpenAILLMAutoConfiguration +ai.koog.spring.prompt.executor.clients.openrouter.OpenRouterLLMAutoConfiguration diff --git a/koog-spring-boot-starter/src/test/kotlin/ai/koog/spring/KoogAutoConfigurationIntegrationTest.kt b/koog-spring-boot-starter/src/test/kotlin/ai/koog/spring/KoogAutoConfigurationIntegrationTest.kt new file mode 100644 index 0000000000..f2b1186fd3 --- /dev/null +++ b/koog-spring-boot-starter/src/test/kotlin/ai/koog/spring/KoogAutoConfigurationIntegrationTest.kt @@ -0,0 +1,106 @@ +package ai.koog.spring + +import ai.koog.prompt.executor.clients.LLMClient +import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient +import ai.koog.prompt.executor.clients.deepseek.DeepSeekLLMClient +import ai.koog.prompt.executor.clients.google.GoogleLLMClient +import ai.koog.prompt.executor.clients.openai.OpenAILLMClient +import ai.koog.prompt.executor.clients.openrouter.OpenRouterLLMClient +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor +import ai.koog.prompt.executor.ollama.client.OllamaClient +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource +import org.slf4j.LoggerFactory +import org.springframework.beans.factory.annotation.Autowired +import org.springframework.boot.autoconfigure.EnableAutoConfiguration +import org.springframework.boot.test.context.SpringBootTest +import org.springframework.context.ApplicationContext +import org.springframework.context.annotation.Configuration +import org.springframework.test.context.TestPropertySource + +@SpringBootTest( + classes = [ + KoogAutoConfigurationIntegrationTest.TestConfig::class, + ], + properties = [ + "debug=false", // set to true for troubleshooting + "spring.main.banner-mode=off" + ], + webEnvironment = SpringBootTest.WebEnvironment.NONE +) +@TestPropertySource( + locations = ["classpath:/it-application.properties"] +) +class KoogAutoConfigurationIntegrationTest { + + @Configuration + @EnableAutoConfiguration + @Suppress("unused") + private class TestConfig + + private val logger = LoggerFactory.getLogger(KoogAutoConfigurationIntegrationTest::class.java) + + @Autowired + private lateinit var applicationContext: ApplicationContext + + @ParameterizedTest + @ValueSource( + classes = [ + AnthropicLLMClient::class, + DeepSeekLLMClient::class, + GoogleLLMClient::class, + OpenAILLMClient::class, + OpenRouterLLMClient::class, + OllamaClient::class, + ] + ) + fun `Should register beans(classes)`(clazz: Class<*>) { + verifyBeanIsRegistered(clazz) + } + + @ParameterizedTest + @ValueSource( + strings = [ + "anthropicExecutor", + "deepSeekExecutor", + "googleExecutor", + "ollamaExecutor", + "openAIExecutor", + "openRouterExecutor", + ] + ) + fun `Should register SingleLLMExecutors`(beanName: String) { + val llmExecutorBeanNames = applicationContext.getBeanNamesForType( + SingleLLMPromptExecutor::class.java + ) + assertTrue(llmExecutorBeanNames.contains(beanName)) { + logger.info( + "Registered ${SingleLLMPromptExecutor::class.simpleName} beans:${ + llmExecutorBeanNames + .joinToString(separator = "\n\t", prefix = "\n\t") + }" + ) + + "Bean named `$beanName` should have been registered" + } + } + + private inline fun verifyBeanIsRegistered(clazz: Class<*>) { + assertTrue(applicationContext.getBeansOfType(clazz).size == 1) { + logger.info( + "Registered beans:${ + applicationContext.beanDefinitionNames + .joinToString(separator = "\n\t", prefix = "\n\t") + }" + ) + + "Bean of type ${clazz.simpleName} should have been registered" + } + + val bean = applicationContext.getBean(clazz) + assertTrue(bean is EXTRA) { + "Registered bean of type $clazz should be also a ${EXTRA::class}" + } + } +} diff --git a/koog-spring-boot-starter/src/test/kotlin/ai/koog/spring/KoogAutoConfigurationTest.kt b/koog-spring-boot-starter/src/test/kotlin/ai/koog/spring/KoogAutoConfigurationTest.kt index 30f15c8279..f9330469e5 100644 --- a/koog-spring-boot-starter/src/test/kotlin/ai/koog/spring/KoogAutoConfigurationTest.kt +++ b/koog-spring-boot-starter/src/test/kotlin/ai/koog/spring/KoogAutoConfigurationTest.kt @@ -14,6 +14,12 @@ import ai.koog.prompt.executor.clients.retry.RetryConfig import ai.koog.prompt.executor.clients.retry.RetryingLLMClient import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor import ai.koog.prompt.executor.ollama.client.OllamaClient +import ai.koog.spring.prompt.executor.clients.anthropic.AnthropicLLMAutoConfiguration +import ai.koog.spring.prompt.executor.clients.deepseek.DeepSeekLLMAutoConfiguration +import ai.koog.spring.prompt.executor.clients.google.GoogleLLMAutoConfiguration +import ai.koog.spring.prompt.executor.clients.ollama.OllamaLLMAutoConfiguration +import ai.koog.spring.prompt.executor.clients.openai.OpenAILLMAutoConfiguration +import ai.koog.spring.prompt.executor.clients.openrouter.OpenRouterLLMAutoConfiguration import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Test @@ -39,12 +45,23 @@ private const val PROVIDERS = """ @TestInstance(TestInstance.Lifecycle.PER_CLASS) class KoogAutoConfigurationTest { - private val defaultRetryConfig = RetryConfig() + private val defaultRetryConfig = RetryConfig.DEFAULT + + private fun createApplicationContextRunner(): ApplicationContextRunner = ApplicationContextRunner() + .withConfiguration( + AutoConfigurations.of( + AnthropicLLMAutoConfiguration::class.java, + GoogleLLMAutoConfiguration::class.java, + DeepSeekLLMAutoConfiguration::class.java, + OllamaLLMAutoConfiguration::class.java, + OpenAILLMAutoConfiguration::class.java, + OpenRouterLLMAutoConfiguration::class.java, + ) + ) @Test fun `should not supply executor beans if no apiKey is provided`() { - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) + createApplicationContextRunner() .run { context -> assertThrows { context.getBean() } } @@ -53,9 +70,10 @@ class KoogAutoConfigurationTest { @Test fun `should supply OpenAI executor bean with provided apiKey and default baseUrl`() { val configApiKey = "some_api_key" - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.openai.api-key=$configApiKey") + createApplicationContextRunner() + .withPropertyValues( + "ai.koog.openai.api-key=$configApiKey" + ) .run { context -> val executor = context.getBean() val llmClient = getPrivateFieldValue(executor, "llmClient") @@ -74,8 +92,7 @@ class KoogAutoConfigurationTest { @Test fun `should supply OpenAI executor bean with provided baseUrl`() { val configBaseUrl = "https://some-url.com" - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) + createApplicationContextRunner() .withPropertyValues( "ai.koog.openai.api-key=some_api_key", "ai.koog.openai.base-url=$configBaseUrl", @@ -92,16 +109,18 @@ class KoogAutoConfigurationTest { } @ParameterizedTest - @CsvSource(PROVIDERS) + @CsvSource(textBlock = PROVIDERS) fun `should supply OpenAI executor bean with retry client and default config`( provider: String, clazz: Class ) { - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.$provider.api-key=some_api_key") - .withPropertyValues("ai.koog.$provider.retry.enabled=true") - .withPropertyValues("ai.koog.$provider.base-url=http://localhost:9876") + createApplicationContextRunner() + .withPropertyValues( + "ai.koog.$provider.enabled=true", + "ai.koog.$provider.api-key=some_api_key", + "ai.koog.$provider.retry.enabled=true", + "ai.koog.$provider.base-url=http://localhost:9876" + ) .run { context -> val executor = context.getBean() val retryingClient = getPrivateFieldValue(executor, "llmClient") @@ -120,7 +139,7 @@ class KoogAutoConfigurationTest { } @ParameterizedTest - @CsvSource(PROVIDERS) + @CsvSource(textBlock = PROVIDERS) fun `should supply executor bean with retry client and full custom config`( provider: String, clazz: Class @@ -130,16 +149,18 @@ class KoogAutoConfigurationTest { val maxDelay = 60 val backoffMultiplier = 5.0 val jitterFactor = 0.5 - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.$provider.api-key=some_api_key") - .withPropertyValues("ai.koog.$provider.base-url=http://localhost:9876") - .withPropertyValues("ai.koog.$provider.retry.enabled=true") - .withPropertyValues("ai.koog.$provider.retry.max-attempts=$maxAttempts") - .withPropertyValues("ai.koog.$provider.retry.initial-delay=$initialDelay") - .withPropertyValues("ai.koog.$provider.retry.max-delay=$maxDelay") - .withPropertyValues("ai.koog.$provider.retry.backoff-multiplier=$backoffMultiplier") - .withPropertyValues("ai.koog.$provider.retry.jitter-factor=$jitterFactor") + createApplicationContextRunner() + .withPropertyValues( + "ai.koog.$provider.enabled=true", + "ai.koog.$provider.api-key=some_api_key", + "ai.koog.$provider.base-url=http://localhost:9876", + "ai.koog.$provider.retry.enabled=true", + "ai.koog.$provider.retry.max-attempts=$maxAttempts", + "ai.koog.$provider.retry.initial-delay=$initialDelay", + "ai.koog.$provider.retry.max-delay=$maxDelay", + "ai.koog.$provider.retry.backoff-multiplier=$backoffMultiplier", + "ai.koog.$provider.retry.jitter-factor=$jitterFactor" + ) .run { context -> val executor = context.getBean() val retryingClient = getPrivateFieldValue(executor, "llmClient") @@ -158,19 +179,50 @@ class KoogAutoConfigurationTest { } @ParameterizedTest - @CsvSource(PROVIDERS) + @CsvSource(textBlock = PROVIDERS) + fun `Should not create beans when provider is DISABLED`( + provider: String, + ) { + val maxAttempts = 5 + val initialDelay = 10 + val maxDelay = 60 + val backoffMultiplier = 5.0 + val jitterFactor = 0.5 + createApplicationContextRunner() + .withPropertyValues( + "ai.koog.$provider.enabled=false", + "ai.koog.$provider.api-key=some_api_key", + "ai.koog.$provider.base-url=http://localhost:9876", + "ai.koog.$provider.retry.enabled=true", + "ai.koog.$provider.retry.max-attempts=$maxAttempts", + "ai.koog.$provider.retry.initial-delay=$initialDelay", + "ai.koog.$provider.retry.max-delay=$maxDelay", + "ai.koog.$provider.retry.backoff-multiplier=$backoffMultiplier", + "ai.koog.$provider.retry.jitter-factor=$jitterFactor" + ) + .run { context -> + assertTrue { context.getBeansOfType(SingleLLMPromptExecutor::class.java).isEmpty() } + assertTrue { context.getBeansOfType(RetryingLLMClient::class.java).isEmpty() } + assertTrue { context.getBeansOfType(LLMClient::class.java).isEmpty() } + } + } + + @ParameterizedTest + @CsvSource(textBlock = PROVIDERS) fun `should supply executor bean with retry client and partial custom config`( provider: String, clazz: Class ) { val maxAttempts = 5 val initialDelay = 10 - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.$provider.api-key=some_api_key") - .withPropertyValues("ai.koog.$provider.retry.enabled=true") - .withPropertyValues("ai.koog.$provider.retry.max-attempts=$maxAttempts") - .withPropertyValues("ai.koog.$provider.retry.initial-delay=$initialDelay") + createApplicationContextRunner() + .withPropertyValues( + "ai.koog.$provider.enabled=true", + "ai.koog.$provider.api-key=some_api_key", + "ai.koog.$provider.retry.enabled=true", + "ai.koog.$provider.retry.max-attempts=$maxAttempts", + "ai.koog.$provider.retry.initial-delay=$initialDelay" + ) .run { context -> val executor = context.getBean() val retryingClient = getPrivateFieldValue(executor, "llmClient") @@ -191,9 +243,10 @@ class KoogAutoConfigurationTest { @Test fun `should supply Anthropic executor bean with provided apiKey and default baseUrl`() { val configApiKey = "some_api_key" - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.anthropic.api-key=$configApiKey") + createApplicationContextRunner() + .withPropertyValues( + "ai.koog.anthropic.api-key=$configApiKey" + ) .run { context -> val executor = context.getBean() val llmClient = getPrivateFieldValue(executor, "llmClient") @@ -212,9 +265,11 @@ class KoogAutoConfigurationTest { @Test fun `should supply Anthropic executor bean with retry client and default config`() { ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.anthropic.api-key=some_api_key") - .withPropertyValues("ai.koog.anthropic.retry.enabled=true") + .withConfiguration(AutoConfigurations.of(AnthropicLLMAutoConfiguration::class.java)) + .withPropertyValues( + "ai.koog.anthropic.api-key=some_api_key", + "ai.koog.anthropic.retry.enabled=true" + ) .run { context -> val executor = context.getBean() val retryingClient = getPrivateFieldValue(executor, "llmClient") @@ -231,8 +286,7 @@ class KoogAutoConfigurationTest { @Test fun `should supply Anthropic executor bean with provided baseUrl`() { val configBaseUrl = "https://some-url.com" - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) + createApplicationContextRunner() .withPropertyValues( "ai.koog.anthropic.api-key=some_api_key", "ai.koog.anthropic.base-url=$configBaseUrl", @@ -251,9 +305,10 @@ class KoogAutoConfigurationTest { @Test fun `should supply Google executor bean with provided apiKey and default baseUrl`() { val configApiKey = "some_api_key" - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.google.api-key=$configApiKey") + createApplicationContextRunner() + .withPropertyValues( + "ai.koog.google.api-key=$configApiKey" + ) .run { context -> val executor = context.getBean() val llmClient = getPrivateFieldValue(executor, "llmClient") @@ -272,8 +327,7 @@ class KoogAutoConfigurationTest { @Test fun `should supply Google executor bean with provided baseUrl`() { val configBaseUrl = "https://some-url.com" - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) + createApplicationContextRunner() .withPropertyValues( "ai.koog.google.api-key=some_api_key", "ai.koog.google.base-url=$configBaseUrl", @@ -291,10 +345,11 @@ class KoogAutoConfigurationTest { @Test fun `should supply Google executor bean with retry client and default config`() { - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.google.api-key=some_api_key") - .withPropertyValues("ai.koog.google.retry.enabled=true") + createApplicationContextRunner() + .withPropertyValues( + "ai.koog.google.api-key=some_api_key", + "ai.koog.google.retry.enabled=true" + ) .run { context -> val executor = context.getBean() val retryingClient = getPrivateFieldValue(executor, "llmClient") @@ -311,9 +366,11 @@ class KoogAutoConfigurationTest { @Test fun `should supply OpenRouter executor bean with provided apiKey and default baseUrl`() { val configApiKey = "some_api_key" - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.openrouter.api-key=$configApiKey") + createApplicationContextRunner() + .withPropertyValues( + "ai.koog.openrouter.enabled=true", + "ai.koog.openrouter.api-key=$configApiKey" + ) .run { context -> val executor = context.getBean() val llmClient = getPrivateFieldValue(executor, "llmClient") @@ -332,9 +389,9 @@ class KoogAutoConfigurationTest { @Test fun `should supply OpenRouter executor bean with provided baseUrl`() { val configBaseUrl = "https://some-url.com" - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) + createApplicationContextRunner() .withPropertyValues( + "ai.koog.openrouter.enabled=true", "ai.koog.openrouter.api-key=some_api_key", "ai.koog.openrouter.base-url=$configBaseUrl", ) @@ -351,10 +408,12 @@ class KoogAutoConfigurationTest { @Test fun `should supply OpenRouter executor bean with retry client and default config`() { - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.openrouter.api-key=some_api_key") - .withPropertyValues("ai.koog.openrouter.retry.enabled=true") + createApplicationContextRunner() + .withPropertyValues( + "ai.koog.openrouter.enabled=true", + "ai.koog.openrouter.api-key=some_api_key", + "ai.koog.openrouter.retry.enabled=true" + ) .run { context -> val executor = context.getBean() val retryingClient = getPrivateFieldValue(executor, "llmClient") @@ -371,9 +430,10 @@ class KoogAutoConfigurationTest { @Test fun `should supply DeepSeek executor bean with provided apiKey and default baseUrl`() { val configApiKey = "some_api_key" - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.deepseek.api-key=$configApiKey") + createApplicationContextRunner() + .withPropertyValues( + "ai.koog.deepseek.api-key=$configApiKey" + ) .run { context -> val executor = context.getBean() val llmClient = getPrivateFieldValue(executor, "llmClient") @@ -392,12 +452,10 @@ class KoogAutoConfigurationTest { @Test fun `should supply DeepSeek executor bean with provided baseUrl`() { val configBaseUrl = "https://some-url.com" - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues( - "ai.koog.deepseek.api-key=some_api_key", - "ai.koog.deepseek.base-url=$configBaseUrl", - ) + createApplicationContextRunner().withPropertyValues( + "ai.koog.deepseek.api-key=some_api_key", + "ai.koog.deepseek.base-url=$configBaseUrl", + ) .run { context -> val executor = context.getBean() val llmClient = getPrivateFieldValue(executor, "llmClient") as DeepSeekLLMClient @@ -411,10 +469,11 @@ class KoogAutoConfigurationTest { @Test fun `should supply DeepSeek executor bean with retry client and default config`() { - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.deepseek.api-key=some_api_key") - .withPropertyValues("ai.koog.deepseek.retry.enabled=true") + createApplicationContextRunner() + .withPropertyValues( + "ai.koog.deepseek.api-key=some_api_key", + "ai.koog.deepseek.retry.enabled=true" + ) .run { context -> val executor = context.getBean() val retryingClient = getPrivateFieldValue(executor, "llmClient") @@ -431,9 +490,10 @@ class KoogAutoConfigurationTest { @Test fun `should supply Ollama executor bean with provided baseUrl`() { val configBaseUrl = "https://some-url.com" - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.ollama.base-url=$configBaseUrl") + createApplicationContextRunner().withPropertyValues( + "ai.koog.ollama.enabled=true", + "ai.koog.ollama.base-url=$configBaseUrl" + ) .run { context -> val executor = context.getBean() val llmClient = getPrivateFieldValue(executor, "llmClient") @@ -448,9 +508,12 @@ class KoogAutoConfigurationTest { @Test fun `should supply Ollama executor bean with retry client and default config`() { ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.ollama.base-url=https://some-url.com") - .withPropertyValues("ai.koog.ollama.retry.enabled=true") + .withConfiguration(AutoConfigurations.of(OllamaLLMAutoConfiguration::class.java)) + .withPropertyValues( + "ai.koog.ollama.enabled=true", + "ai.koog.ollama.base-url=https://some-url.com", + "ai.koog.ollama.retry.enabled=true" + ) .run { context -> val executor = context.getBean() val retryingClient = getPrivateFieldValue(executor, "llmClient") @@ -466,14 +529,13 @@ class KoogAutoConfigurationTest { @Test fun `should supply multiple executor beans`() { - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) + createApplicationContextRunner() .withPropertyValues( "ai.koog.openai.api-key=some_api_key", "ai.koog.anthropic.api-key=some_api_key", "ai.koog.google.api-key=some_api_key", "ai.koog.deepseek.api-key=some_api_key", - "ai.koog.ollama.base-url=http://localhist:8765", + "ai.koog.ollama.enabled=true", ) .run { context -> val beanNames = context.getBeanNamesForType() diff --git a/koog-spring-boot-starter/src/test/resources/it-application.properties b/koog-spring-boot-starter/src/test/resources/it-application.properties new file mode 100644 index 0000000000..2b7210053a --- /dev/null +++ b/koog-spring-boot-starter/src/test/resources/it-application.properties @@ -0,0 +1,6 @@ +ANTHROPIC_API_KEY=test_anthropic_api_key +DEEPSEEK_API_KEY=test_deepseek_api_key +GOOGLE_API_KEY=test_google_api_key +OPENAI_API_KEY=test_openai_api_key +OPENROUTER_API_KEY=test_openrouter_api_key +ai.koog.ollama.enabled=true diff --git a/koog-spring-boot-starter/src/test/resources/junit-platform.properties b/koog-spring-boot-starter/src/test/resources/junit-platform.properties new file mode 100644 index 0000000000..4ad8fb868b --- /dev/null +++ b/koog-spring-boot-starter/src/test/resources/junit-platform.properties @@ -0,0 +1,5 @@ +## https://docs.junit.org/5.3.0-M1/user-guide/index.html#writing-tests-parallel-execution +junit.jupiter.execution.parallel.enabled=true +junit.jupiter.execution.parallel.config.strategy=dynamic +junit.jupiter.execution.parallel.mode.classes.default=concurrent +junit.jupiter.execution.parallel.mode.default=concurrent diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-deepseek-client/build.gradle.kts b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-deepseek-client/build.gradle.kts index 773bd4b5c2..e9ee6bcff4 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-deepseek-client/build.gradle.kts +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-deepseek-client/build.gradle.kts @@ -13,6 +13,8 @@ kotlin { commonMain { dependencies { api(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client-base")) + api(project(":prompt:prompt-structure")) + api(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client")) implementation(libs.oshai.kotlin.logging) } } diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-deepseek-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/deepseek/DeepSeekLLMClient.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-deepseek-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/deepseek/DeepSeekLLMClient.kt index c86089cc08..1f072630ad 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-deepseek-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/deepseek/DeepSeekLLMClient.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-deepseek-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/deepseek/DeepSeekLLMClient.kt @@ -13,13 +13,20 @@ import ai.koog.prompt.executor.clients.openai.base.OpenAIBasedSettings import ai.koog.prompt.executor.clients.openai.base.models.OpenAIMessage import ai.koog.prompt.executor.clients.openai.base.models.OpenAITool import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolChoice +import ai.koog.prompt.executor.clients.openai.structure.OpenAIBasicJsonSchemaGenerator +import ai.koog.prompt.executor.clients.openai.structure.OpenAIStandardJsonSchemaGenerator import ai.koog.prompt.executor.model.LLMChoice +import ai.koog.prompt.llm.LLMProvider import ai.koog.prompt.llm.LLModel import ai.koog.prompt.params.LLMParams import ai.koog.prompt.streaming.StreamFrameFlowBuilder +import ai.koog.prompt.structure.RegisteredBasicJsonSchemaGenerators +import ai.koog.prompt.structure.RegisteredStandardJsonSchemaGenerators +import ai.koog.prompt.structure.annotations.InternalStructuredOutputApi import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.client.HttpClient import kotlinx.datetime.Clock +import kotlin.collections.set /** * Configuration settings for connecting to the DeepSeek API. @@ -54,8 +61,14 @@ public class DeepSeekLLMClient( staticLogger ) { + @OptIn(InternalStructuredOutputApi::class) private companion object { private val staticLogger = KotlinLogging.logger { } + + init { + RegisteredBasicJsonSchemaGenerators[LLMProvider.DeepSeek] = OpenAIBasicJsonSchemaGenerator + RegisteredStandardJsonSchemaGenerators[LLMProvider.DeepSeek] = OpenAIStandardJsonSchemaGenerator + } } override fun serializeProviderChatRequest( @@ -92,7 +105,12 @@ public class DeepSeekLLMClient( override fun processProviderChatResponse(response: DeepSeekChatCompletionResponse): List { require(response.choices.isNotEmpty()) { "Empty choices in response" } - return response.choices.map { it.toMessageResponses(createMetaInfo(response.usage)) } + return response.choices.map { + it.message.toMessageResponses( + it.finishReason, + createMetaInfo(response.usage), + ) + } } override fun decodeStreamingResponse(data: String): DeepSeekChatCompletionStreamResponse = diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClient.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClient.kt index 8517a54464..a25bf2120a 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClient.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClient.kt @@ -273,7 +273,7 @@ public open class GoogleLLMClient( * @param tools Tools to include in the request * @return A formatted GoogleAI request */ - private fun createGoogleRequest(prompt: Prompt, model: LLModel, tools: List): GoogleRequest { + internal fun createGoogleRequest(prompt: Prompt, model: LLModel, tools: List): GoogleRequest { val systemMessageParts = mutableListOf() val contents = mutableListOf() val pendingCalls = mutableListOf() @@ -388,7 +388,7 @@ public open class GoogleLLMClient( responseJsonSchema = responseFormat?.responseJsonSchema, temperature = if (model.capabilities.contains(LLMCapability.Temperature)) prompt.params.temperature else null, candidateCount = if (model.capabilities.contains(LLMCapability.MultipleChoices)) prompt.params.numberOfChoices else null, - maxOutputTokens = 2048, + maxOutputTokens = prompt.params.maxTokens, thinkingConfig = GoogleThinkingConfig( includeThoughts = prompt.params.includeThoughts.takeIf { it == true }, thinkingBudget = prompt.params.thinkingBudget diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClientTest.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClientTest.kt new file mode 100644 index 0000000000..d74f0cd053 --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClientTest.kt @@ -0,0 +1,40 @@ +package ai.koog.prompt.executor.clients.google + +import ai.koog.prompt.dsl.Prompt +import ai.koog.prompt.params.LLMParams +import kotlin.test.Test +import kotlin.test.assertEquals + +class GoogleLLMClientTest { + + @Test + fun `createGoogleRequest should use null maxTokens if unspecified`() { + val client = GoogleLLMClient(apiKey = "apiKey") + val model = GoogleModels.Gemini2_5Pro + val request = client.createGoogleRequest( + prompt = Prompt( + messages = emptyList(), + id = "id" + ), + model = model, + tools = emptyList() + ) + assertEquals(null, request.generationConfig!!.maxOutputTokens) + } + + @Test + fun `createGoogleRequest should use maxTokens from user specified parameters when available`() { + val client = GoogleLLMClient(apiKey = "apiKey") + val model = GoogleModels.Gemini2_5Pro + val request = client.createGoogleRequest( + prompt = Prompt( + messages = emptyList(), + id = "id", + params = LLMParams(maxTokens = 100) + ), + model = model, + tools = emptyList() + ) + assertEquals(100, request.generationConfig!!.maxOutputTokens) + } +} diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/build.gradle.kts b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/build.gradle.kts index 23dce9c48e..cbabf98aa8 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/build.gradle.kts +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/build.gradle.kts @@ -15,6 +15,7 @@ kotlin { api(project(":agents:agents-tools")) api(project(":prompt:prompt-llm")) api(project(":prompt:prompt-model")) + api(project(":prompt:prompt-tokenizer")) api(project(":agents:agents-tools")) api(project(":prompt:prompt-executor:prompt-executor-model")) api(project(":prompt:prompt-executor:prompt-executor-clients")) @@ -64,6 +65,9 @@ kotlin { dependencies { implementation(project(":test-utils")) implementation(project(":agents:agents-features:agents-features-event-handler")) + implementation(libs.kotlinx.coroutines.core) + implementation(libs.kotlinx.coroutines.test) + implementation(libs.ktor.client.mock) } } diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/ContextWindowStrategy.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/ContextWindowStrategy.kt new file mode 100644 index 0000000000..8d9c378916 --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/ContextWindowStrategy.kt @@ -0,0 +1,153 @@ +package ai.koog.prompt.executor.ollama.client + +import ai.koog.prompt.dsl.Prompt +import ai.koog.prompt.llm.LLModel +import ai.koog.prompt.tokenizer.PromptTokenizer +import io.github.oshai.kotlinlogging.KotlinLogging + +private val logger = KotlinLogging.logger { } + +/** + * Represents a strategy for computing the context window length for `OllamaClient`. + * Different implementations define specific approaches to computing the context window length. + * Based on the context window length computed by this strategy, Ollama will truncate the context window accordingly. + * + * To decide the context window length, Ollama proceeds as follows: + * - If a `num_ctx` parameter is specified in the chat request, the context window length is set to that value. + * - If the model definition contains a `num_ctx` parameter, the context window length is set to that value. + * - If an `OLLAMA_CONTEXT_LENGTH` environment variable is set, the context window length is set to that value. + * - Otherwise, the context window length is set to the default value of 2048. + * + * Effectively, this strategy allows you to specify what `num_ctx` value will be set in chat requests sent to Ollama, + * for a given prompt and model. + * + * Important: You will want to have a context window length that does not change often for a specific model. + * Indeed, Ollama will reload the model every time the context window length changes. + * + * Example implementations: + * - [ContextWindowStrategy.None] + * - [ContextWindowStrategy.Fixed] + * - [ContextWindowStrategy.FitPrompt] + */ +public interface ContextWindowStrategy { + + /** + * Computes the context length for a given prompt and language model. + * This may involve calculating the number of tokens used in the prompt + * and determining if it fits within the model's context length constraints. + * + * @param prompt The [Prompt] containing the list of messages, unique identifier, + * and language model parameters that describe the input for the LLM. + * @param model The [LLModel] representing the language model used to process the prompt, + * which includes its provider, identifier, capabilities, and context length. + * @return The context length as a [Long], indicating the number of tokens used + * in the prompt, or `null` if it cannot be calculated. + */ + public fun computeContextLength(prompt: Prompt, model: LLModel): Long? + + /** + * Provides companion object-related strategies for determining the context window length. + * It contains multiple strategies that are implemented as subtypes of [ContextWindowStrategy]. + */ + public companion object { + /** + * A strategy for letting the Ollama server decide the context window length. + * To decide the context window length, Ollama proceeds as follows: + * - If the model definition contains a `num_ctx` parameter, the context window length is set to that value. + * - If an `OLLAMA_CONTEXT_LENGTH` environment variable is set, the context window length is set to that value. + * - Otherwise, the context window length is set to the default value of 2048. + */ + public data object None : ContextWindowStrategy { + override fun computeContextLength(prompt: Prompt, model: LLModel): Long? = null + } + + /** + * A strategy for specifying a fixed context window length. + * If the given [contextLength] is more than the maximum context window length supported by the model, + * the context window length will be set to the maximum context window length supported by the model. + * + * @param contextLength The context window length to use. + */ + public data class Fixed(val contextLength: Long) : ContextWindowStrategy { + init { + require(contextLength > 0) { "Context length must be positive but was: $contextLength" } + } + + override fun computeContextLength(prompt: Prompt, model: LLModel): Long { + if (contextLength > model.contextLength) { + logger.warn { + "Context length $contextLength was more than what is supported by model '${model.id}'," + + " falling back to the model's maximum context length ${model.contextLength}" + } + return model.contextLength + } + return contextLength + } + } + + /** + * A strategy for computing the context window length based on the prompt length. + * + * @param promptTokenizer The [PromptTokenizer] to use for computing the prompt length, + * or null to use the last reported token usage. + * @param contextChunkSize The granularity to use for computing the context window length. Defaults to 2048. + * @param minimumChunkCount The minimum number of context chunks in the context. + * @param maximumChunkCount The maximum number of context chunks in the context. + * + * Example: contextChunkSize = 512, minimumChunkCount = 2, maximumChunkCount = 4, + * then [minimumContextLength] = 1024 and [maximumContextLength] = 2048 + */ + public data class FitPrompt( + val promptTokenizer: PromptTokenizer? = null, + val contextChunkSize: Long = 2048, + val minimumChunkCount: Long? = null, + val maximumChunkCount: Long? = null + ) : ContextWindowStrategy { + + private val minimumContextLength: Long? = minimumChunkCount?.let { cnt -> cnt * contextChunkSize } + private val maximumContextLength: Long? = maximumChunkCount?.let { cnt -> cnt * contextChunkSize } + + init { + require(contextChunkSize > 0) { "`contextChunkSize`` must be greater than 0" } + require(minimumChunkCount == null || minimumChunkCount > 0) { + "`minimumChunkCount` must be a positive number or `null`" + } + + if (minimumChunkCount != null && maximumChunkCount != null) { + require(minimumChunkCount <= maximumChunkCount) { + "`maximumChunkCount` ($maximumChunkCount) must be greater or equal" + + " to `minimumChunkCount` ($minimumChunkCount)" + } + } + } + + override fun computeContextLength(prompt: Prompt, model: LLModel): Long? { + val promptLength = when { + promptTokenizer != null -> promptTokenizer.tokenCountFor(prompt) + prompt.latestTokenUsage != 0 -> prompt.latestTokenUsage + else -> null + } + + if (promptLength == null) return minimumContextLength + + if (maximumContextLength != null && promptLength > maximumContextLength) { + logger.warn { + "Prompt length $promptLength was more than " + + "the maximum context length $maximumContextLength provideded" + } + return maximumContextLength + } + + if (promptLength > model.contextLength) { + logger.warn { + "Prompt length $promptLength was more than the maximum context length of model '${model.id}'," + + " falling back to the model's maximum context length ${model.contextLength}" + } + return model.contextLength + } + + return (promptLength / contextChunkSize + 1) * contextChunkSize + } + } + } +} diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/OllamaClient.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/OllamaClient.kt index cd3bde49e4..1565ba06a1 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/OllamaClient.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/OllamaClient.kt @@ -20,7 +20,6 @@ import ai.koog.prompt.executor.ollama.client.dto.OllamaPullModelResponseDTO import ai.koog.prompt.executor.ollama.client.dto.OllamaShowModelRequestDTO import ai.koog.prompt.executor.ollama.client.dto.OllamaShowModelResponseDTO import ai.koog.prompt.executor.ollama.client.dto.extractOllamaJsonFormat -import ai.koog.prompt.executor.ollama.client.dto.extractOllamaOptions import ai.koog.prompt.executor.ollama.client.dto.getToolCalls import ai.koog.prompt.executor.ollama.client.dto.toOllamaChatMessages import ai.koog.prompt.executor.ollama.client.dto.toOllamaModelCard @@ -57,19 +56,23 @@ import kotlinx.serialization.json.Json /** * Client for interacting with the Ollama API with comprehensive model support. * + * Implements: + * - [LLMClient] for executing prompts and streaming responses. + * - [LLMEmbeddingProvider] for generating embeddings from input text. + * * @param baseUrl The base URL of the Ollama server. Defaults to "http://localhost:11434". * @param baseClient The underlying HTTP client used for making requests. * @param timeoutConfig Configuration for connection, request, and socket timeouts. * @param clock Clock instance used for tracking response metadata timestamps. - * Implements: - * - LLMClient for executing prompts and streaming responses. - * - LLMEmbeddingProvider for generating embeddings from input text. + * @param contextWindowStrategy The [ContextWindowStrategy] to use for computing context window lengths. + * Defaults to [ContextWindowStrategy.None]. */ public class OllamaClient( public val baseUrl: String = "http://localhost:11434", baseClient: HttpClient = HttpClient(engineFactoryProvider()), timeoutConfig: ConnectionTimeoutConfig = ConnectionTimeoutConfig(), - private val clock: Clock = Clock.System + private val clock: Clock = Clock.System, + private val contextWindowStrategy: ContextWindowStrategy = ContextWindowStrategy.Companion.None, ) : LLMClient, LLMEmbeddingProvider { private companion object { @@ -159,7 +162,7 @@ public class OllamaClient( messages = prompt.toOllamaChatMessages(model), tools = if (tools.isNotEmpty()) tools.map { it.toOllamaTool() } else null, format = prompt.extractOllamaJsonFormat(), - options = prompt.extractOllamaOptions(), + options = extractOllamaOptions(prompt, model), stream = false, additionalProperties = prompt.params.additionalProperties ) @@ -239,7 +242,7 @@ public class OllamaClient( OllamaChatRequestDTO( model = model.id, messages = prompt.toOllamaChatMessages(model), - options = prompt.extractOllamaOptions(), + options = extractOllamaOptions(prompt, model), stream = true, additionalProperties = prompt.params.additionalProperties, ) @@ -274,6 +277,16 @@ public class OllamaClient( } } + /** + * Prepare Ollama chat request options from the given prompt and model. + */ + internal fun extractOllamaOptions(prompt: Prompt, model: LLModel): OllamaChatRequestDTO.Options { + return OllamaChatRequestDTO.Options( + temperature = prompt.params.temperature, + numCtx = contextWindowStrategy.computeContextLength(prompt, model), + ) + } + /** * Embeds the given text using the Ollama model. * diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/dto/OllamaConverters.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/dto/OllamaConverters.kt index ba4d390d23..65d9916f07 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/dto/OllamaConverters.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/dto/OllamaConverters.kt @@ -111,14 +111,6 @@ internal fun Prompt.extractOllamaJsonFormat(): JsonObject? { return if (schema is LLMParams.Schema.JSON) schema.schema else null } -/** - * Extracts options from the prompt, if temperature is defined. - */ -internal fun Prompt.extractOllamaOptions(): OllamaChatRequestDTO.Options? { - val temperature = params.temperature - return temperature?.let { OllamaChatRequestDTO.Options(temperature = temperature) } -} - /** * Extracts tool calls from a ChatMessage. * Returns the first tool call for compatibility, but logs if multiple calls exist. diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/dto/OllamaModels.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/dto/OllamaModels.kt index f9a7da7c73..8b29db836b 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/dto/OllamaModels.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/dto/OllamaModels.kt @@ -72,6 +72,7 @@ internal data class OllamaChatRequestDTO( @Serializable internal data class Options( val temperature: Double? = null, + @SerialName("num_ctx") val numCtx: Long? = null, ) } diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonTest/kotlin/ai/koog/prompt/executor/ollama/client/ContextWindowStrategyTest.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonTest/kotlin/ai/koog/prompt/executor/ollama/client/ContextWindowStrategyTest.kt new file mode 100644 index 0000000000..f56f7441f4 --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonTest/kotlin/ai/koog/prompt/executor/ollama/client/ContextWindowStrategyTest.kt @@ -0,0 +1,231 @@ +package ai.koog.prompt.executor.ollama.client + +import ai.koog.prompt.dsl.Prompt +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.executor.ollama.client.dto.OllamaChatMessageDTO +import ai.koog.prompt.executor.ollama.client.dto.OllamaChatRequestDTO +import ai.koog.prompt.executor.ollama.client.dto.OllamaChatResponseDTO +import ai.koog.prompt.llm.OllamaModels +import ai.koog.prompt.message.Message +import ai.koog.prompt.message.ResponseMetaInfo +import ai.koog.prompt.tokenizer.PromptTokenizer +import io.ktor.client.HttpClient +import io.ktor.client.engine.mock.MockEngine +import io.ktor.client.engine.mock.respond +import io.ktor.client.request.HttpRequestData +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.ktor.http.content.TextContent +import io.ktor.http.headersOf +import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Clock +import kotlinx.serialization.json.Json +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertNull + +class ContextWindowStrategyTest { + @Test + fun `test None strategy`() = runTest { + val mockServer = MockOllamaChatServer { request -> makeDummyResponse(request) } + + val ollamaClient = OllamaClient( + baseClient = HttpClient(mockServer.mockEngine), + contextWindowStrategy = ContextWindowStrategy.Companion.None, + ) + + ollamaClient.execute( + prompt = prompt("test-prompt") { }, + model = OllamaModels.Meta.LLAMA_3_2, + ) + + val requestHistory = mockServer.requestHistory + assertEquals(requestHistory.size, 1) + + val response = requestHistory.first() + assertNotNull(response.options) + assertNull(response.options.numCtx) + } + + @Test + fun `test Fixed strategy`() = runTest { + val mockServer = MockOllamaChatServer { request -> makeDummyResponse(request) } + + val ollamaClient = OllamaClient( + baseClient = HttpClient(mockServer.mockEngine), + contextWindowStrategy = ContextWindowStrategy.Companion.Fixed(42), + ) + + ollamaClient.execute( + prompt = prompt("test-prompt") { }, + model = OllamaModels.Meta.LLAMA_3_2, + ) + + val requestHistory = mockServer.requestHistory + assertEquals(requestHistory.size, 1) + + val response = requestHistory.first() + assertNotNull(response.options) + assertEquals(42, response.options.numCtx) + } + + @Test + fun `test FitPrompt strategy with tokenizer`() = runTest { + val mockServer = MockOllamaChatServer { request -> makeDummyResponse(request) } + + val ollamaClient = OllamaClient( + baseClient = HttpClient(mockServer.mockEngine), + contextWindowStrategy = ContextWindowStrategy.Companion.FitPrompt( + promptTokenizer = object : PromptTokenizer { + override fun tokenCountFor(message: Message): Int = error("Not needed") + override fun tokenCountFor(prompt: Prompt): Int = 3000 + }, + contextChunkSize = 1024, + minimumChunkCount = 2 + ), + ) + + ollamaClient.execute( + prompt = prompt("test-prompt") { }, + model = OllamaModels.Meta.LLAMA_3_2, + ) + + val requestHistory = mockServer.requestHistory + assertEquals(requestHistory.size, 1) + + val response = requestHistory.first() + assertNotNull(response.options) + assertEquals(3072, response.options.numCtx) + } + + @Test + fun `test FitPrompt strategy without tokenizer and no previous token usage`() = runTest { + val mockServer = MockOllamaChatServer { request -> makeDummyResponse(request) } + + val ollamaClient = OllamaClient( + baseClient = HttpClient(mockServer.mockEngine), + contextWindowStrategy = ContextWindowStrategy.Companion.FitPrompt( + promptTokenizer = null, + contextChunkSize = 1024, + minimumChunkCount = 2 + ), + ) + + ollamaClient.execute( + prompt = prompt("test-prompt") { }, + model = OllamaModels.Meta.LLAMA_3_2, + ) + + val requestHistory = mockServer.requestHistory + assertEquals(requestHistory.size, 1) + + val response = requestHistory.first() + assertNotNull(response.options) + assertEquals(2048, response.options.numCtx) + } + + @Test + fun `test FitPrompt strategy without tokenizer and existing token usage`() = runTest { + val mockServer = MockOllamaChatServer { request -> makeDummyResponse(request) } + + val ollamaClient = OllamaClient( + baseClient = HttpClient(mockServer.mockEngine), + contextWindowStrategy = ContextWindowStrategy.Companion.FitPrompt( + promptTokenizer = null, + contextChunkSize = 1024, + minimumChunkCount = 2 + ), + ) + + ollamaClient.execute( + prompt = prompt("test-prompt") { + message( + Message.Assistant( + "Dummy message", + metaInfo = ResponseMetaInfo( + timestamp = Clock.System.now(), + totalTokensCount = 5000, + ) + ) + ) + }, + model = OllamaModels.Meta.LLAMA_3_2, + ) + + val requestHistory = mockServer.requestHistory + assertEquals(requestHistory.size, 1) + + val response = requestHistory.first() + assertNotNull(response.options) + assertEquals(5120, response.options.numCtx) + } + + @Test + fun `test FitPrompt strategy with tokenizer and too long prompt`() = runTest { + val mockServer = MockOllamaChatServer { request -> makeDummyResponse(request) } + + val ollamaClient = OllamaClient( + baseClient = HttpClient(mockServer.mockEngine), + contextWindowStrategy = ContextWindowStrategy.Companion.FitPrompt( + promptTokenizer = object : PromptTokenizer { + override fun tokenCountFor(message: Message): Int = error("Not needed") + override fun tokenCountFor(prompt: Prompt): Int = 9000 + }, + contextChunkSize = 1024, + minimumChunkCount = 2 + ), + ) + + ollamaClient.execute( + prompt = prompt("test-prompt") { }, + model = OllamaModels.Meta.LLAMA_3_2.copy( + contextLength = 8192 + ), + ) + + val requestHistory = mockServer.requestHistory + assertEquals(requestHistory.size, 1) + + val response = requestHistory.first() + assertNotNull(response.options) + assertEquals(8192, response.options.numCtx) + } +} + +private fun makeDummyResponse( + request: OllamaChatRequestDTO, + content: String = "OK", + promptEvalCount: Int = 10, + evalCount: Int = 100, +): OllamaChatResponseDTO = OllamaChatResponseDTO( + model = request.model, + message = OllamaChatMessageDTO(role = "assistant", content = content), + done = true, + promptEvalCount = promptEvalCount, + evalCount = evalCount, +) + +private class MockOllamaChatServer( + private val handler: (OllamaChatRequestDTO) -> OllamaChatResponseDTO, +) { + val mockEngine = MockEngine { requestData -> + val request = requestData.extractChatRequest() + val response = handler(request) + respond( + content = Json.encodeToString(response), + status = HttpStatusCode.OK, + headers = headersOf(HttpHeaders.ContentType to listOf("application/json")), + ) + } + + val requestHistory: List + get() = mockEngine.requestHistory.map { it.extractChatRequest() } + + private fun HttpRequestData.extractChatRequest(): OllamaChatRequestDTO { + val requestContent = body as TextContent + val requestBody = requestContent.text + val request = Json.decodeFromString(requestBody) + return request + } +} diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client-base/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/base/AbstractOpenAILLMClient.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client-base/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/base/AbstractOpenAILLMClient.kt index 12e44f8317..3ba256a9d6 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client-base/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/base/AbstractOpenAILLMClient.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client-base/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/base/AbstractOpenAILLMClient.kt @@ -12,7 +12,6 @@ import ai.koog.prompt.executor.clients.openai.base.models.Content import ai.koog.prompt.executor.clients.openai.base.models.JsonSchemaObject import ai.koog.prompt.executor.clients.openai.base.models.OpenAIBaseLLMResponse import ai.koog.prompt.executor.clients.openai.base.models.OpenAIBaseLLMStreamResponse -import ai.koog.prompt.executor.clients.openai.base.models.OpenAIChoice import ai.koog.prompt.executor.clients.openai.base.models.OpenAIContentPart import ai.koog.prompt.executor.clients.openai.base.models.OpenAIFunction import ai.koog.prompt.executor.clients.openai.base.models.OpenAIMessage @@ -406,10 +405,10 @@ public abstract class AbstractOpenAILLMClient { + protected fun OpenAIMessage.toMessageResponses(finishReason: String?, metaInfo: ResponseMetaInfo): List { return when { - message is OpenAIMessage.Assistant && !message.toolCalls.isNullOrEmpty() -> { - message.toolCalls.map { toolCall -> + this is OpenAIMessage.Assistant && !this.toolCalls.isNullOrEmpty() -> { + this.toolCalls.map { toolCall -> Message.Tool.Call( id = toolCall.id, tool = toolCall.function.name, @@ -419,20 +418,20 @@ public abstract class AbstractOpenAILLMClient listOf( + this.content != null -> listOf( Message.Assistant( - content = message.content!!.text(), + content = this.content!!.text(), finishReason = finishReason, metaInfo = metaInfo ) ) - message is OpenAIMessage.Assistant && message.audio?.data != null -> listOf( + this is OpenAIMessage.Assistant && this.audio?.data != null -> listOf( Message.Assistant( - content = message.audio.transcript.orEmpty(), + content = this.audio.transcript.orEmpty(), attachments = listOf( Attachment.Audio( - content = AttachmentContent.Binary.Base64(message.audio.data), + content = AttachmentContent.Binary.Base64(this.audio.data), format = "unknown", // FIXME: clarify format from response ) ), diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client-base/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/base/models/OpenAIDataModels.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client-base/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/base/models/OpenAIDataModels.kt index 089b22321f..48fda42ee8 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client-base/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/base/models/OpenAIDataModels.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client-base/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/base/models/OpenAIDataModels.kt @@ -352,6 +352,8 @@ public class OpenAIStreamFunction( * [coral][OpenAIAudioVoice.Coral], [echo][OpenAIAudioVoice.Echo], [fable][OpenAIAudioVoice.Fable], * [nova][OpenAIAudioVoice.Nova], [onyx][OpenAIAudioVoice.Onyx], [sage][OpenAIAudioVoice.Sage] * and [shimmer][OpenAIAudioVoice.Shimmer] + * + * See [audio](https://platform.openai.com/docs/api-reference/chat/create#chat-create-audio) */ @Serializable public class OpenAIAudioConfig( @@ -364,6 +366,8 @@ public class OpenAIAudioConfig( * * This enum is used to specify the format of the audio output. It contains several standard audio formats * which are widely compatible with various audio players and systems. + * + * See [audio/format](https://platform.openai.com/docs/api-reference/chat/create#chat-create-audio-format) */ @Serializable public enum class OpenAIAudioFormat { @@ -387,6 +391,8 @@ public enum class OpenAIAudioFormat { * Represents the available voice options for audio output in OpenAI's system. * * This enum defines a list of predefined voices that can be used to synthesize audio responses. + * + * See [audio/voice](https://platform.openai.com/docs/api-reference/chat/create#chat-create-audio-voice) */ @Serializable public enum class OpenAIAudioVoice { @@ -458,6 +464,8 @@ public class OpenAIStaticContent(public val content: Content) { * If not set, the model/provider default applies. * * Serialized as `"minimal" | "low" | "medium" | "high"`. + * + * See [reasoning_effort](https://platform.openai.com/docs/api-reference/chat/create#chat-create-reasoning_effort) */ @Serializable public enum class ReasoningEffort { @@ -549,6 +557,8 @@ public sealed interface OpenAIResponseFormat { * Note: When a tier is requested, the response payload includes the * `service_tier` actually used to serve the request. This value may differ from * the one provided in the request. + * + * See [service_tier](https://platform.openai.com/docs/api-reference/chat/create#chat-create-service_tier) */ public enum class ServiceTier { /** @@ -613,6 +623,8 @@ public class JsonSchemaObject( * All other chunks will also include a `usage` field, but with a null value. * NOTE: If the stream is interrupted, * you may not receive the final usage chunk which contains the total token usage for the request. + * + * See [stream_options](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options) */ @Serializable public class OpenAIStreamOptions(public val includeUsage: Boolean? = null) @@ -764,6 +776,8 @@ public class OpenAIUserLocation( * @property index The index of the choice in the list of choices. * @property logprobs Log probability information for the choice. * @property message A chat completion message generated by the model. + * + * See [choices](https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices) */ @Serializable public class OpenAIChoice( @@ -776,6 +790,8 @@ public class OpenAIChoice( /** * @property content A list of message content tokens with log probability information. * @property refusal A list of message refusal tokens with log probability information. + * + * See [choices/logprobs](https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices-logprobs) */ @Serializable public class OpenAIChoiceLogProbs( @@ -793,6 +809,9 @@ public class OpenAIChoiceLogProbs( * @property token The token. * @property topLogprobs List of the most likely tokens and their log probability, at this token position. * In rare cases, there may be fewer than the number of requested `[topLogprobs]` returned. + * + * See [choices/logprobs/content](https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices-logprobs-content) + * and [choices/logprobs/refusal](https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices-logprobs-refusal) */ @Serializable public class ContentLogProbs( @@ -810,6 +829,9 @@ public class OpenAIChoiceLogProbs( * @property logprob The log probability of this token, if it is within the top 20 most likely tokens. * Otherwise, the value `-9999.0` is used to signify that the token is very unlikely. * @property token The token. + * + * See [choices/logprobs/content/top_logprobs](https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices-logprobs-content-top_logprobs) + * and [choices/logprobs/refusal/top_logprobs](https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices-logprobs-refusal-top_logprobs) */ @Serializable public class ContentTopLogProbs( @@ -849,6 +871,9 @@ public class OpenAIWebUrlCitation(public val urlCitation: Citation) { * @property totalTokens Total number of tokens used in the request (prompt + completion). * @property completionTokensDetails Breakdown of tokens used in a completion. * @property promptTokensDetails Breakdown of tokens used in the prompt. + * + * See [chat completions usage](https://platform.openai.com/docs/api-reference/chat/object#chat/object-usage) + * and [streaming usage](https://platform.openai.com/docs/api-reference/chat-streaming/streaming#chat-streaming/streaming-usage) */ @Serializable public class OpenAIUsage( @@ -868,6 +893,9 @@ public class OpenAIUsage( * the number of tokens in the prediction that did not appear in the completion. * However, like reasoning tokens, these tokens are still counted in the total completion tokens for purposes of billing, * output and context window limits. + * + * See [chat completions usage/completion_tokens_details](https://platform.openai.com/docs/api-reference/chat/object#chat/object-usage-completion_tokens_details) + * and [streaming usage/completion_tokens_details](https://platform.openai.com/docs/api-reference/chat-streaming/streaming#chat-streaming/streaming-usage-completion_tokens_details) */ @Serializable public class CompletionTokensDetails( @@ -880,6 +908,9 @@ public class CompletionTokensDetails( /** * @property audioTokens Audio input tokens generated by the model. * @property cachedTokens Cached tokens present in the prompt. + * + * See [chat completions usage/prompt_tokens_details](https://platform.openai.com/docs/api-reference/chat/object#chat/object-usage-prompt_tokens_details) + * and [streaming usage/prompt_tokens_details](https://platform.openai.com/docs/api-reference/chat-streaming/streaming#chat-streaming/streaming-usage-prompt_tokens_details) */ @Serializable public class PromptTokensDetails( @@ -897,6 +928,7 @@ public class PromptTokensDetails( * @property index The index of the choice in the list of choices. * @property logprobs Log probability information for the choice. * + * See [choices](https://platform.openai.com/docs/api-reference/chat-streaming/streaming#chat-streaming/streaming-choices) */ @Serializable public class OpenAIStreamChoice( @@ -911,6 +943,8 @@ public class OpenAIStreamChoice( * @property refusal The refusal message generated by the model. * @property role The role of the author of this message. * @property toolCalls + * + * See [choices/delta](https://platform.openai.com/docs/api-reference/chat-streaming/streaming#chat-streaming/streaming-choices-delta) */ @Serializable public class OpenAIStreamDelta( diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/OpenAILLMClient.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/OpenAILLMClient.kt index c835f4424d..c7fd62ffe1 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/OpenAILLMClient.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/OpenAILLMClient.kt @@ -223,7 +223,12 @@ public open class OpenAILLMClient( override fun processProviderChatResponse(response: OpenAIChatCompletionResponse): List { require(response.choices.isNotEmpty()) { "Empty choices in response" } - return response.choices.map { it.toMessageResponses(createMetaInfo(response.usage)) } + return response.choices.map { + it.message.toMessageResponses( + it.finishReason, + createMetaInfo(response.usage), + ) + } } override fun decodeStreamingResponse(data: String): OpenAIChatCompletionStreamResponse = diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/models/OpenAIChatCompletion.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/models/OpenAIChatCompletion.kt index a1c220c112..ad0d6a8b35 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/models/OpenAIChatCompletion.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/models/OpenAIChatCompletion.kt @@ -4,7 +4,7 @@ import ai.koog.prompt.executor.clients.openai.base.models.OpenAIAudioConfig import ai.koog.prompt.executor.clients.openai.base.models.OpenAIBaseLLMRequest import ai.koog.prompt.executor.clients.openai.base.models.OpenAIBaseLLMResponse import ai.koog.prompt.executor.clients.openai.base.models.OpenAIBaseLLMStreamResponse -import ai.koog.prompt.executor.clients.openai.base.models.OpenAIChoice +import ai.koog.prompt.executor.clients.openai.base.models.OpenAIChoiceLogProbs import ai.koog.prompt.executor.clients.openai.base.models.OpenAIMessage import ai.koog.prompt.executor.clients.openai.base.models.OpenAIModalities import ai.koog.prompt.executor.clients.openai.base.models.OpenAIResponseFormat @@ -194,6 +194,28 @@ internal class OpenAIChatCompletionRequest( val additionalProperties: Map? = null, ) : OpenAIBaseLLMRequest +/** + * Chat completion choice + * + * @property finishReason The reason the model stopped generating tokens. + * This will be `stop` if the model hit a natural stop point or a provided stop sequence, + * `length` if the maximum number of tokens specified in the request was reached, + * `content_filter` if content was omitted due to a flag from our content filters, + * `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function. + * @property index The index of the choice in the list of choices. + * @property logprobs Log probability information for the choice. + * @property message A chat completion message generated by the model. + * + * See [choices](https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices) + */ +@Serializable +public class OpenAIChoice( + public val finishReason: String, + public val index: Int, + public val logprobs: OpenAIChoiceLogProbs? = null, + public val message: OpenAIMessage, +) + /** * Represents the response from the OpenAI chat completion API. * @@ -221,6 +243,8 @@ internal class OpenAIChatCompletionRequest( * Can be used in conjunction with the `seed` request parameter * to understand when backend changes have been made that might impact determinism. * @property usage Usage statistics for the completion request. + * + * See [The chat completion object](https://platform.openai.com/docs/api-reference/chat/object) */ @Serializable public class OpenAIChatCompletionResponse( @@ -264,6 +288,8 @@ public class OpenAIChatCompletionResponse( * Can be used in conjunction with the `seed` request parameter * to understand when backend changes have been made that might impact determinism. * @property usage Usage statistics for the completion request. + * + * See [The chat completion chunk object](https://platform.openai.com/docs/api-reference/chat-streaming/streaming) */ @Serializable public class OpenAIChatCompletionStreamResponse( diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/build.gradle.kts b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/build.gradle.kts index 1a64d6a0ba..528a6e61b2 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/build.gradle.kts +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/build.gradle.kts @@ -13,6 +13,8 @@ kotlin { commonMain { dependencies { api(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client-base")) + api(project(":prompt:prompt-structure")) + api(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client")) implementation(libs.oshai.kotlin.logging) } } diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/OpenRouterLLMClient.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/OpenRouterLLMClient.kt index 6028cb9dbf..15fb337a30 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/OpenRouterLLMClient.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/OpenRouterLLMClient.kt @@ -11,14 +11,20 @@ import ai.koog.prompt.executor.clients.openai.base.models.OpenAIMessage import ai.koog.prompt.executor.clients.openai.base.models.OpenAIStaticContent import ai.koog.prompt.executor.clients.openai.base.models.OpenAITool import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolChoice +import ai.koog.prompt.executor.clients.openai.structure.OpenAIBasicJsonSchemaGenerator +import ai.koog.prompt.executor.clients.openai.structure.OpenAIStandardJsonSchemaGenerator import ai.koog.prompt.executor.clients.openrouter.models.OpenRouterChatCompletionRequest import ai.koog.prompt.executor.clients.openrouter.models.OpenRouterChatCompletionRequestSerializer import ai.koog.prompt.executor.clients.openrouter.models.OpenRouterChatCompletionResponse import ai.koog.prompt.executor.clients.openrouter.models.OpenRouterChatCompletionStreamResponse import ai.koog.prompt.executor.model.LLMChoice +import ai.koog.prompt.llm.LLMProvider import ai.koog.prompt.llm.LLModel import ai.koog.prompt.params.LLMParams import ai.koog.prompt.streaming.StreamFrameFlowBuilder +import ai.koog.prompt.structure.RegisteredBasicJsonSchemaGenerators +import ai.koog.prompt.structure.RegisteredStandardJsonSchemaGenerators +import ai.koog.prompt.structure.annotations.InternalStructuredOutputApi import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.client.HttpClient import kotlinx.datetime.Clock @@ -56,8 +62,14 @@ public class OpenRouterLLMClient( staticLogger ) { + @OptIn(InternalStructuredOutputApi::class) private companion object { private val staticLogger = KotlinLogging.logger { } + + init { + RegisteredBasicJsonSchemaGenerators[LLMProvider.OpenRouter] = OpenAIBasicJsonSchemaGenerator + RegisteredStandardJsonSchemaGenerators[LLMProvider.OpenRouter] = OpenAIStandardJsonSchemaGenerator + } } override fun serializeProviderChatRequest( @@ -104,7 +116,12 @@ public class OpenRouterLLMClient( override fun processProviderChatResponse(response: OpenRouterChatCompletionResponse): List { require(response.choices.isNotEmpty()) { "Empty choices in response" } - return response.choices.map { it.toMessageResponses(createMetaInfo(response.usage)) } + return response.choices.map { + it.message.toMessageResponses( + it.finishReason, + createMetaInfo(response.usage), + ) + } } override fun decodeStreamingResponse(data: String): OpenRouterChatCompletionStreamResponse = @@ -116,11 +133,10 @@ public class OpenRouterLLMClient( override suspend fun StreamFrameFlowBuilder.processStreamingChunk(chunk: OpenRouterChatCompletionStreamResponse) { chunk.choices.firstOrNull()?.let { choice -> choice.delta.content?.let { emitAppend(it) } - choice.delta.toolCalls?.forEach { openAIToolCall -> - val index = openAIToolCall.index + choice.delta.toolCalls?.forEachIndexed { index, openAIToolCall -> val id = openAIToolCall.id - val name = openAIToolCall.function?.name - val arguments = openAIToolCall.function?.arguments + val name = openAIToolCall.function.name + val arguments = openAIToolCall.function.arguments upsertToolCall(index, id, name, arguments) } choice.finishReason?.let { emitEnd(it, createMetaInfo(chunk.usage)) } diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/OpenRouterModels.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/OpenRouterModels.kt index f1d4e033cf..ca16b67595 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/OpenRouterModels.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/OpenRouterModels.kt @@ -16,13 +16,20 @@ public object OpenRouterModels : LLModelDefinitions { */ private val standardCapabilities: List = listOf( LLMCapability.Temperature, - LLMCapability.Schema.JSON.Standard, LLMCapability.Speculation, LLMCapability.Tools, - LLMCapability.ToolChoice, LLMCapability.Completion ) + /** + * Additional capabilities available for models out of the Claude family. + * Includes structured output support and tool choice. + */ + private val additionalStandardCapabilities: List = listOf( + LLMCapability.Schema.JSON.Standard, + LLMCapability.ToolChoice + ) + /** * Multimodal capabilities including vision support. * Extends standard capabilities with image vision processing. @@ -139,7 +146,7 @@ public object OpenRouterModels : LLModelDefinitions { public val GPT4oMini: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "openai/gpt-4o-mini", - capabilities = multimodalCapabilities, + capabilities = multimodalCapabilities + additionalStandardCapabilities, contextLength = 128_000, maxOutputTokens = 16_400, ) @@ -152,7 +159,7 @@ public object OpenRouterModels : LLModelDefinitions { public val GPT5Chat: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "openai/gpt-5-chat", - capabilities = multimodalCapabilities, + capabilities = multimodalCapabilities + additionalStandardCapabilities, contextLength = 400_000, maxOutputTokens = 128_000, ) @@ -167,7 +174,7 @@ public object OpenRouterModels : LLModelDefinitions { public val GPT5: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "openai/gpt-5", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 400_000, ) @@ -181,7 +188,7 @@ public object OpenRouterModels : LLModelDefinitions { public val GPT5Mini: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "openai/gpt-5-mini", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 400_000, ) @@ -195,7 +202,7 @@ public object OpenRouterModels : LLModelDefinitions { public val GPT5Nano: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "openai/gpt-5-nano", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 400_000, ) @@ -209,7 +216,7 @@ public object OpenRouterModels : LLModelDefinitions { public val GPT_OSS_120b: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "openai/gpt-oss-120b", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 400_000, ) @@ -223,7 +230,7 @@ public object OpenRouterModels : LLModelDefinitions { public val GPT4: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "openai/gpt-4", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 32_768, ) @@ -234,7 +241,7 @@ public object OpenRouterModels : LLModelDefinitions { public val GPT4o: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "openai/gpt-4o", - capabilities = multimodalCapabilities, + capabilities = multimodalCapabilities + additionalStandardCapabilities, contextLength = 128_000, ) @@ -248,7 +255,7 @@ public object OpenRouterModels : LLModelDefinitions { public val GPT4Turbo: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "openai/gpt-4-turbo", - capabilities = multimodalCapabilities, + capabilities = multimodalCapabilities + additionalStandardCapabilities, contextLength = 128_000, ) @@ -261,7 +268,7 @@ public object OpenRouterModels : LLModelDefinitions { public val GPT35Turbo: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "openai/gpt-3.5-turbo", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 16_385, ) @@ -275,7 +282,7 @@ public object OpenRouterModels : LLModelDefinitions { public val Llama3: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "meta/llama-3-70b", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 8_000, ) @@ -288,7 +295,7 @@ public object OpenRouterModels : LLModelDefinitions { public val Llama3Instruct: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "meta/llama-3-70b-instruct", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 8_000, ) @@ -303,7 +310,7 @@ public object OpenRouterModels : LLModelDefinitions { public val Mistral7B: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "mistral/mistral-7b", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 32_768, ) @@ -319,7 +326,7 @@ public object OpenRouterModels : LLModelDefinitions { public val Mixtral8x7B: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "mistral/mixtral-8x7b", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 32_768, ) @@ -375,7 +382,7 @@ public object OpenRouterModels : LLModelDefinitions { public val DeepSeekV30324: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "deepseek/deepseek-chat-v3-0324", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 163_800, maxOutputTokens = 163_800, ) @@ -387,7 +394,7 @@ public object OpenRouterModels : LLModelDefinitions { public val Gemini2_5FlashLite: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "google/gemini-2.5-flash-lite", - capabilities = multimodalCapabilities, + capabilities = multimodalCapabilities + additionalStandardCapabilities, contextLength = 1_048_576, maxOutputTokens = 65_600, ) @@ -399,7 +406,7 @@ public object OpenRouterModels : LLModelDefinitions { public val Gemini2_5Flash: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "google/gemini-2.5-flash", - capabilities = multimodalCapabilities, + capabilities = multimodalCapabilities + additionalStandardCapabilities, contextLength = 1_048_576, maxOutputTokens = 65_600, ) @@ -411,7 +418,7 @@ public object OpenRouterModels : LLModelDefinitions { public val Gemini2_5Pro: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "google/gemini-2.5-pro", - capabilities = multimodalCapabilities, + capabilities = multimodalCapabilities + additionalStandardCapabilities, contextLength = 1_048_576, maxOutputTokens = 65_600, ) diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/models/OpenRouterChatCompletion.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/models/OpenRouterChatCompletion.kt index 111aea9123..0d40b3d26e 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/models/OpenRouterChatCompletion.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/models/OpenRouterChatCompletion.kt @@ -3,12 +3,11 @@ package ai.koog.prompt.executor.clients.openrouter.models import ai.koog.prompt.executor.clients.openai.base.models.OpenAIBaseLLMRequest import ai.koog.prompt.executor.clients.openai.base.models.OpenAIBaseLLMResponse import ai.koog.prompt.executor.clients.openai.base.models.OpenAIBaseLLMStreamResponse -import ai.koog.prompt.executor.clients.openai.base.models.OpenAIChoice import ai.koog.prompt.executor.clients.openai.base.models.OpenAIMessage import ai.koog.prompt.executor.clients.openai.base.models.OpenAIResponseFormat import ai.koog.prompt.executor.clients.openai.base.models.OpenAIStaticContent -import ai.koog.prompt.executor.clients.openai.base.models.OpenAIStreamChoice import ai.koog.prompt.executor.clients.openai.base.models.OpenAITool +import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolCall import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolChoice import ai.koog.prompt.executor.clients.openai.base.models.OpenAIUsage import ai.koog.prompt.executor.clients.serialization.AdditionalPropertiesFlatteningSerializer @@ -78,13 +77,85 @@ public class ProviderPreferences( public val maxPrice: Map? = null ) +/** + * Chat completion choice + * + * @property finishReason The reason the model stopped generating tokens. + * This will be `stop` if the model hit a natural stop point or a provided stop sequence, + * `length` if the maximum number of tokens specified in the request was reached, + * `content_filter` if content was omitted due to a flag from our content filters, + * `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function. + * `error` if the model finishes the request with an error. + * @property nativeFinishReason The raw finish_reason string returned by the model + * @property message A chat completion message generated by the model. + * @property error An error response structure typically used for conveying error details to the clients. + * + * See (CompletionsResponse Format)[https://openrouter.ai/docs/api-reference/overview#completionsresponse-format] + */ +@Serializable +public class OpenRouterChoice( + public val finishReason: String? = null, + public val nativeFinishReason: String? = null, + public val message: OpenAIMessage, + public val error: ErrorResponse? = null +) + +/** + * Chat completion choice + * + * @property finishReason The reason the model stopped generating tokens. + * This will be `stop` if the model hit a natural stop point or a provided stop sequence, + * `length` if the maximum number of tokens specified in the request was reached, + * `content_filter` if content was omitted due to a flag from our content filters, + * `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function. + * `error` if the model finishes the request with an error. + * @property nativeFinishReason The raw finish_reason string returned by the model + * @property delta A chat completion delta generated by streamed model responses. + * @property error An error response structure typically used for conveying error details to the clients. + * + * See (CompletionsResponse Format)[https://openrouter.ai/docs/api-reference/overview#completionsresponse-format] + */ +@Serializable +public class OpenRouterStreamChoice( + public val finishReason: String? = null, + public val nativeFinishReason: String? = null, + public val delta: OpenRouterStreamDelta, + public val error: ErrorResponse? = null +) + +/** + * @property content The contents of the chunk message. + * @property role The role of the author of this message. + * @property toolCalls The tool calls requested by the model. + */ +@Serializable +public class OpenRouterStreamDelta( + public val content: String? = null, + public val role: String? = null, + public val toolCalls: List? = null +) + +/** + * Represents an error response structure typically used for conveying error details to the clients. + * + * @property code The numeric code representing the error. + * @property message A descriptive message providing details about the error. + * @property metadata Optional additional information about the error in the form of key-value pairs. + */ +@Serializable +public class ErrorResponse( + public val code: Int, + public val message: String, + public val metadata: Map? = null, +) + /** * OpenRouter Chat Completion Response - * https://openrouter.ai/docs/responses + * See (CompletionsResponse Format)[https://openrouter.ai/docs/api-reference/overview#completionsresponse-format] */ @Serializable public class OpenRouterChatCompletionResponse( - public val choices: List, + public val choices: List, override val created: Long, override val id: String, override val model: String, @@ -96,10 +167,11 @@ public class OpenRouterChatCompletionResponse( /** * OpenRouter Chat Completion Streaming Response + * See (CompletionsResponse Format)[https://openrouter.ai/docs/api-reference/overview#completionsresponse-format] */ @Serializable public class OpenRouterChatCompletionStreamResponse( - public val choices: List, + public val choices: List, override val created: Long, override val id: String, override val model: String, diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/openrouter/models/OpenRouterSerializationTest.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/openrouter/models/OpenRouterSerializationTest.kt index 17517b3f67..58c74f6036 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/openrouter/models/OpenRouterSerializationTest.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/openrouter/models/OpenRouterSerializationTest.kt @@ -1,10 +1,17 @@ package ai.koog.prompt.executor.clients.openrouter.models import ai.koog.prompt.executor.clients.openai.base.models.Content +import ai.koog.prompt.executor.clients.openai.base.models.OpenAIFunction import ai.koog.prompt.executor.clients.openai.base.models.OpenAIMessage +import ai.koog.prompt.executor.clients.openai.base.models.OpenAITool +import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolCall +import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolChoice +import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolFunction import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNamingStrategy import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.add import kotlinx.serialization.json.addJsonObject import kotlinx.serialization.json.booleanOrNull import kotlinx.serialization.json.buildJsonArray @@ -12,6 +19,7 @@ import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.contentOrNull import kotlinx.serialization.json.doubleOrNull import kotlinx.serialization.json.intOrNull +import kotlinx.serialization.json.jsonArray import kotlinx.serialization.json.jsonObject import kotlinx.serialization.json.jsonPrimitive import kotlinx.serialization.json.put @@ -22,11 +30,19 @@ import kotlin.test.assertNull class OpenRouterSerializationTest { - private val json = Json { + private val requestJson = Json { ignoreUnknownKeys = false explicitNulls = false } + private val responseJson = Json { + ignoreUnknownKeys = false + explicitNulls = false + isLenient = true + encodeDefaults = true + namingStrategy = JsonNamingStrategy.SnakeCase + } + @Test fun `test serialization without additionalProperties`() { val request = OpenRouterChatCompletionRequest( @@ -37,7 +53,7 @@ class OpenRouterSerializationTest { stream = false ) - val jsonElement = json.encodeToJsonElement(OpenRouterChatCompletionRequestSerializer, request) + val jsonElement = requestJson.encodeToJsonElement(OpenRouterChatCompletionRequestSerializer, request) val jsonObject = jsonElement.jsonObject assertEquals("anthropic/claude-3-sonnet", jsonObject["model"]?.jsonPrimitive?.contentOrNull) @@ -62,7 +78,7 @@ class OpenRouterSerializationTest { additionalProperties = additionalProperties ) - val jsonElement = json.encodeToJsonElement(OpenRouterChatCompletionRequestSerializer, request) + val jsonElement = requestJson.encodeToJsonElement(OpenRouterChatCompletionRequestSerializer, request) val jsonObject = jsonElement.jsonObject // Standard properties should be present @@ -96,7 +112,7 @@ class OpenRouterSerializationTest { put("stream", JsonPrimitive(false)) } - val request = json.decodeFromJsonElement(OpenRouterChatCompletionRequestSerializer, jsonInput) + val request = requestJson.decodeFromJsonElement(OpenRouterChatCompletionRequestSerializer, jsonInput) assertEquals("anthropic/claude-3-sonnet", request.model) assertEquals(0.7, request.temperature) @@ -124,7 +140,7 @@ class OpenRouterSerializationTest { put("customBoolean", JsonPrimitive(true)) } - val request = json.decodeFromJsonElement(OpenRouterChatCompletionRequestSerializer, jsonInput) + val request = requestJson.decodeFromJsonElement(OpenRouterChatCompletionRequestSerializer, jsonInput) assertEquals("anthropic/claude-3-sonnet", request.model) assertEquals(0.7, request.temperature) @@ -152,10 +168,10 @@ class OpenRouterSerializationTest { ) // Serialize to JSON string - val jsonString = json.encodeToString(OpenRouterChatCompletionRequestSerializer, originalRequest) + val jsonString = requestJson.encodeToString(OpenRouterChatCompletionRequestSerializer, originalRequest) // Deserialize back to object - val deserializedRequest = json.decodeFromString(OpenRouterChatCompletionRequestSerializer, jsonString) + val deserializedRequest = requestJson.decodeFromString(OpenRouterChatCompletionRequestSerializer, jsonString) // Verify standard properties assertEquals(originalRequest.model, deserializedRequest.model) @@ -175,4 +191,405 @@ class OpenRouterSerializationTest { deserializedAdditionalProps["customNumber"]?.jsonPrimitive?.intOrNull ) } + + @Test + fun `test OpenRouter response deserialization`() { + val jsonInput = buildJsonObject { + put("id", "gen-xxxxxxxxxxxxxx") + put("created", 1699000000L) + put("model", "openai/gpt-3.5-turbo") + put("object", "chat.completion") + put("system_fingerprint", "fp_44709d6fcb") + put( + "choices", + buildJsonArray { + addJsonObject { + put("finish_reason", "stop") + put( + "message", + buildJsonObject { + put("role", "assistant") + put("content", "Hello there!") + } + ) + } + } + ) + put( + "usage", + buildJsonObject { + put("prompt_tokens", 10) + put("completion_tokens", 4) + put("total_tokens", 14) + } + ) + } + + val response = responseJson.decodeFromJsonElement(OpenRouterChatCompletionResponse.serializer(), jsonInput) + + assertEquals("gen-xxxxxxxxxxxxxx", response.id) + assertEquals(1699000000L, response.created) + assertEquals("openai/gpt-3.5-turbo", response.model) + assertEquals("chat.completion", response.objectType) + assertEquals("fp_44709d6fcb", response.systemFingerprint) + + assertEquals(1, response.choices.size) + val choice = response.choices.first() + assertEquals("stop", choice.finishReason) + + val message = choice.message as OpenAIMessage.Assistant + assertEquals("Hello there!", message.content?.text()) + + assertNotNull(response.usage) + assertEquals(10, response.usage.promptTokens) + assertEquals(4, response.usage.completionTokens) + assertEquals(14, response.usage.totalTokens) + } + + @Test + fun `test OpenRouter error response deserialization`() { + val jsonInput = buildJsonObject { + put("id", "gen-error-test") + put("created", 1699000000L) + put("model", "openai/gpt-4") + put("object", "chat.completion") + put( + "choices", + buildJsonArray { + addJsonObject { + put("finish_reason", "error") + put("native_finish_reason", "content_filter") + put( + "message", + buildJsonObject { + put("role", "assistant") + put("content", "") + } + ) + put( + "error", + buildJsonObject { + put("code", 400) + put("message", "Content filtered due to policy violation") + put( + "metadata", + buildJsonObject { + put("provider", "openai") + put("raw_error", "content_filter") + } + ) + } + ) + } + } + ) + } + + val response = responseJson.decodeFromJsonElement(OpenRouterChatCompletionResponse.serializer(), jsonInput) + + val choice = response.choices.first() + assertEquals("error", choice.finishReason) + assertEquals("content_filter", choice.nativeFinishReason) + + assertNotNull(choice.error) + assertEquals(400, choice.error.code) + assertEquals("Content filtered due to policy violation", choice.error.message) + assertNotNull(choice.error.metadata) + assertEquals("openai", choice.error.metadata["provider"]) + } + + @Test + fun `test OpenRouter streaming response deserialization`() { + val jsonInput = buildJsonObject { + put("id", "gen-stream-test") + put("created", 1699000000L) + put("model", "anthropic/claude-3-sonnet") + put("object", "chat.completion.chunk") + put( + "choices", + buildJsonArray { + addJsonObject { + put("finish_reason", null) + put("native_finish_reason", null) + put( + "delta", + buildJsonObject { + put("role", "assistant") + put("content", "Hello") + } + ) + } + } + ) + } + + val response = responseJson.decodeFromJsonElement(OpenRouterChatCompletionStreamResponse.serializer(), jsonInput) + + assertEquals("gen-stream-test", response.id) + assertEquals("chat.completion.chunk", response.objectType) + assertEquals(1, response.choices.size) + + val choice = response.choices.first() + assertNull(choice.finishReason) + assertNull(choice.nativeFinishReason) + assertEquals("Hello", choice.delta.content) + assertEquals("assistant", choice.delta.role) + } + + @Test + fun `test OpenRouter response with tool calls deserialization`() { + val jsonInput = buildJsonObject { + put("id", "gen-tool-call-test") + put("created", 1699000000L) + put("model", "openai/gpt-4") + put("object", "chat.completion") + put("system_fingerprint", "fp_44709d6fcb") + put( + "choices", + buildJsonArray { + addJsonObject { + put("finish_reason", "tool_calls") + put( + "message", + buildJsonObject { + put("role", "assistant") + put("content", null) + put( + "tool_calls", + buildJsonArray { + addJsonObject { + put("id", "call_abc123") + put("type", "function") + put( + "function", + buildJsonObject { + put("name", "get_current_weather") + put("arguments", "{\"location\": \"Boston, MA\"}") + } + ) + } + addJsonObject { + put("id", "call_def456") + put("type", "function") + put( + "function", + buildJsonObject { + put("name", "get_forecast") + put("arguments", "{\"location\": \"Boston, MA\", \"days\": 3}") + } + ) + } + } + ) + } + ) + } + } + ) + put( + "usage", + buildJsonObject { + put("prompt_tokens", 82) + put("completion_tokens", 18) + put("total_tokens", 100) + } + ) + } + + val response = responseJson.decodeFromJsonElement(OpenRouterChatCompletionResponse.serializer(), jsonInput) + + assertEquals("gen-tool-call-test", response.id) + assertEquals(1699000000L, response.created) + assertEquals("openai/gpt-4", response.model) + assertEquals("chat.completion", response.objectType) + assertEquals("fp_44709d6fcb", response.systemFingerprint) + + assertEquals(1, response.choices.size) + val choice = response.choices.first() + assertEquals("tool_calls", choice.finishReason) + + val message = choice.message as OpenAIMessage.Assistant + assertNull(message.content) + assertNotNull(message.toolCalls) + assertEquals(2, message.toolCalls!!.size) + + val firstToolCall = message.toolCalls!![0] + assertEquals("call_abc123", firstToolCall.id) + assertEquals("function", firstToolCall.type) + assertEquals("get_current_weather", firstToolCall.function.name) + assertEquals("{\"location\": \"Boston, MA\"}", firstToolCall.function.arguments) + + val secondToolCall = message.toolCalls!![1] + assertEquals("call_def456", secondToolCall.id) + assertEquals("function", secondToolCall.type) + assertEquals("get_forecast", secondToolCall.function.name) + assertEquals("{\"location\": \"Boston, MA\", \"days\": 3}", secondToolCall.function.arguments) + + assertNotNull(response.usage) + assertEquals(82, response.usage.promptTokens) + assertEquals(18, response.usage.completionTokens) + assertEquals(100, response.usage.totalTokens) + } + + @Test + fun `test OpenRouter streaming response with tool calls deserialization`() { + val jsonInput = buildJsonObject { + put("id", "gen-stream-tool-test") + put("created", 1699000000L) + put("model", "openai/gpt-4") + put("object", "chat.completion.chunk") + put( + "choices", + buildJsonArray { + addJsonObject { + put("finish_reason", null) + put("native_finish_reason", null) + put( + "delta", + buildJsonObject { + put("role", "assistant") + put("content", null) + put( + "tool_calls", + buildJsonArray { + addJsonObject { + put("id", "call_xyz789") + put("type", "function") + put( + "function", + buildJsonObject { + put("name", "calculate_total") + put("arguments", "{\"items\": [") + } + ) + } + } + ) + } + ) + } + } + ) + } + + val response = responseJson.decodeFromJsonElement(OpenRouterChatCompletionStreamResponse.serializer(), jsonInput) + + assertEquals("gen-stream-tool-test", response.id) + assertEquals("chat.completion.chunk", response.objectType) + assertEquals(1, response.choices.size) + + val choice = response.choices.first() + assertNull(choice.finishReason) + assertNull(choice.nativeFinishReason) + assertNull(choice.delta.content) + assertEquals("assistant", choice.delta.role) + + assertNotNull(choice.delta.toolCalls) + assertEquals(1, choice.delta.toolCalls.size) + + val toolCall = choice.delta.toolCalls[0] + assertEquals("call_xyz789", toolCall.id) + assertEquals("function", toolCall.type) + assertEquals("calculate_total", toolCall.function.name) + assertEquals("{\"items\": [", toolCall.function.arguments) + } + + @Test + fun `test OpenRouter request with tools serialization`() { + val tools = listOf( + OpenAITool( + function = OpenAIToolFunction( + name = "get_current_weather", + description = "Get the current weather in a given location", + parameters = buildJsonObject { + put("type", "object") + put( + "properties", + buildJsonObject { + put( + "location", + buildJsonObject { + put("type", "string") + put("description", "The city and state, e.g. San Francisco, CA") + } + ) + put( + "unit", + buildJsonObject { + put("type", "string") + put( + "enum", + buildJsonArray { + add("celsius") + add("fahrenheit") + } + ) + } + ) + } + ) + put("required", buildJsonArray { add("location") }) + } + ) + ) + ) + + val request = OpenRouterChatCompletionRequest( + model = "openai/gpt-4", + messages = listOf( + OpenAIMessage.User(content = Content.Text("What's the weather like in Boston?")), + OpenAIMessage.Assistant( + content = null, + toolCalls = listOf( + OpenAIToolCall( + id = "call_abc123", + function = OpenAIFunction( + name = "get_current_weather", + arguments = "{\"location\": \"Boston, MA\"}" + ) + ) + ) + ), + OpenAIMessage.Tool( + content = Content.Text("The weather in Boston is 72°F and sunny"), + toolCallId = "call_abc123" + ) + ), + tools = tools, + toolChoice = OpenAIToolChoice.Auto + ) + + val jsonElement = responseJson.encodeToJsonElement(OpenRouterChatCompletionRequestSerializer, request) + val jsonObject = jsonElement.jsonObject + + assertEquals("openai/gpt-4", jsonObject["model"]?.jsonPrimitive?.contentOrNull) + assertNotNull(jsonObject["messages"]) + assertNotNull(jsonObject["tools"]) + assertEquals("auto", jsonObject["tool_choice"]?.jsonPrimitive?.contentOrNull) + + // Verify the serialized messages include tool calls + val messages = jsonObject["messages"]!!.jsonArray + assertEquals(3, messages.size) + + // Check assistant message with tool calls + val assistantMessage = messages[1].jsonObject + assertEquals("assistant", assistantMessage["role"]?.jsonPrimitive?.contentOrNull) + assertNotNull(assistantMessage["tool_calls"]) + val toolCalls = assistantMessage["tool_calls"]!!.jsonArray + assertEquals(1, toolCalls.size) + + val toolCall = toolCalls[0].jsonObject + assertEquals("call_abc123", toolCall["id"]?.jsonPrimitive?.contentOrNull) + assertEquals("function", toolCall["type"]?.jsonPrimitive?.contentOrNull) + + val function = toolCall["function"]!!.jsonObject + assertEquals("get_current_weather", function["name"]?.jsonPrimitive?.contentOrNull) + assertEquals("{\"location\": \"Boston, MA\"}", function["arguments"]?.jsonPrimitive?.contentOrNull) + + // Check tool message + val toolMessage = messages[2].jsonObject + assertEquals("tool", toolMessage["role"]?.jsonPrimitive?.contentOrNull) + assertEquals("The weather in Boston is 72°F and sunny", toolMessage["content"]?.jsonPrimitive?.contentOrNull) + assertEquals("call_abc123", toolMessage["tool_call_id"]?.jsonPrimitive?.contentOrNull) + } } diff --git a/prompt/prompt-executor/prompt-executor-clients/src/commonMain/kotlin/ai/koog/prompt/executor/clients/retry/RetryConfig.kt b/prompt/prompt-executor/prompt-executor-clients/src/commonMain/kotlin/ai/koog/prompt/executor/clients/retry/RetryConfig.kt index 1c57b9cffe..b865d5ea37 100644 --- a/prompt/prompt-executor/prompt-executor-clients/src/commonMain/kotlin/ai/koog/prompt/executor/clients/retry/RetryConfig.kt +++ b/prompt/prompt-executor/prompt-executor-clients/src/commonMain/kotlin/ai/koog/prompt/executor/clients/retry/RetryConfig.kt @@ -28,7 +28,9 @@ public data class RetryConfig( require(maxAttempts >= 1) { "maxAttempts must be at least 1" } require(backoffMultiplier >= 1.0) { "backoffMultiplier must be at least 1.0" } require(jitterFactor in 0.0..1.0) { "jitterFactor must be between 0.0 and 1.0" } - require(initialDelay <= maxDelay) { "initialDelay ($initialDelay) must not be greater than maxDelay ($maxDelay)" } + require(initialDelay <= maxDelay) { + "initialDelay ($initialDelay) must not be greater than maxDelay ($maxDelay)" + } } /** @@ -97,6 +99,13 @@ public data class RetryConfig( * No retry - effectively disables retry logic. */ public val DISABLED: RetryConfig = RetryConfig(maxAttempts = 1) + + /** + * The default retry configuration used by clients implementing retry logic. + * + * Suitable for general-purpose use cases where standard retry behavior is required. + */ + public val DEFAULT: RetryConfig = RetryConfig() } } diff --git a/prompt/prompt-executor/prompt-executor-clients/src/commonMain/kotlin/ai/koog/prompt/executor/clients/retry/RetryingLLMClient.kt b/prompt/prompt-executor/prompt-executor-clients/src/commonMain/kotlin/ai/koog/prompt/executor/clients/retry/RetryingLLMClient.kt index 621263e8bc..1056e4d14a 100644 --- a/prompt/prompt-executor/prompt-executor-clients/src/commonMain/kotlin/ai/koog/prompt/executor/clients/retry/RetryingLLMClient.kt +++ b/prompt/prompt-executor/prompt-executor-clients/src/commonMain/kotlin/ai/koog/prompt/executor/clients/retry/RetryingLLMClient.kt @@ -34,7 +34,7 @@ import kotlin.time.Duration.Companion.milliseconds */ public class RetryingLLMClient( private val delegate: LLMClient, - private val config: RetryConfig = RetryConfig() + internal val config: RetryConfig = RetryConfig() ) : LLMClient { private companion object { @@ -163,3 +163,17 @@ public class RetryingLLMClient( return finalMs.milliseconds } } + +/** + * Converts an instance of [LLMClient] into a retrying client with customizable retry behavior. + * + * @param retryConfig Configuration for retry behavior. Defaults to [RetryConfig.DEFAULT]. + * @return A new instance of [RetryingLLMClient] that adds retry logic to the provided client. + */ +public fun LLMClient.toRetryingClient( + retryConfig: RetryConfig = RetryConfig.DEFAULT +): RetryingLLMClient = + RetryingLLMClient( + delegate = this, + config = retryConfig + ) diff --git a/prompt/prompt-executor/prompt-executor-clients/src/commonTest/kotlin/ai/koog/prompt/executor/clients/retry/RetryingLLMClientTest.kt b/prompt/prompt-executor/prompt-executor-clients/src/commonTest/kotlin/ai/koog/prompt/executor/clients/retry/RetryingLLMClientTest.kt index 1c30949b1f..6f8c501af3 100644 --- a/prompt/prompt-executor/prompt-executor-clients/src/commonTest/kotlin/ai/koog/prompt/executor/clients/retry/RetryingLLMClientTest.kt +++ b/prompt/prompt-executor/prompt-executor-clients/src/commonTest/kotlin/ai/koog/prompt/executor/clients/retry/RetryingLLMClientTest.kt @@ -26,6 +26,7 @@ import kotlinx.datetime.Clock import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFailsWith +import kotlin.test.assertSame import kotlin.time.Duration.Companion.milliseconds class RetryingLLMClientTest { @@ -65,6 +66,28 @@ class RetryingLLMClientTest { assertEquals(1, mockClient.executeCalls) } + @Test + fun testConvertLLMClientToRetryingClientWithDefaultConfig() = runTest { + val mockClient = MockLLMClient() + // when + val retryingClient = mockClient.toRetryingClient() + + // then + assertSame(actual = retryingClient.config, expected = RetryConfig.DEFAULT) + } + + @Test + fun testConvertLLMClientToRetryingClientWithCustomConfig() = runTest { + // given + val mockClient = MockLLMClient() + val retryConfig = RetryConfig(maxAttempts = 100500) + // when + val retryingClient = mockClient.toRetryingClient(retryConfig) + + // then + assertSame(actual = retryingClient.config, expected = retryConfig) + } + @Test fun testRetryOnRateLimitError() = runTest { val mockClient = MockLLMClient( diff --git a/prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/message/Message.kt b/prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/message/Message.kt index ce944055d9..18f133c9f5 100644 --- a/prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/message/Message.kt +++ b/prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/message/Message.kt @@ -230,6 +230,12 @@ public sealed interface MessageMetaInfo { * to the current system time if not explicitly set. */ public val timestamp: Instant + + /** + * Free-form information associated with a message. + * Can be used to store custom metadata that doesn't fit into the standard fields. + */ + public val metadata: JsonObject? } /** @@ -243,7 +249,8 @@ public sealed interface MessageMetaInfo { */ @Serializable public data class RequestMetaInfo( - override val timestamp: Instant + override val timestamp: Instant, + override val metadata: JsonObject? = null ) : MessageMetaInfo { /** * Companion object for `RequestMetaInfo` that provides factory methods and utilities related to creating instances. @@ -291,7 +298,12 @@ public data class ResponseMetaInfo( public val totalTokensCount: Int? = null, public val inputTokensCount: Int? = null, public val outputTokensCount: Int? = null, + @Deprecated( + "additionalInfo is deprecated, use metadata instead", + ReplaceWith("metadata") + ) public val additionalInfo: Map = emptyMap(), + override val metadata: JsonObject? = null, ) : MessageMetaInfo { /** * Companion object for the ResponseMetaInfo class. @@ -302,7 +314,11 @@ public data class ResponseMetaInfo( * Creates a ResponseMetadata instance with a timestamp from the provided clock. * * @param clock The clock to use for generating the timestamp. - * @param tokensCount The number of tokens used in the response, or null if not available. + * @param totalTokensCount The total number of tokens involved in the response, including both input and output tokens. + * @param inputTokensCount The number of tokens used in the input. + * @param outputTokensCount The number of tokens generated in the output. + * @param additionalInfo Deprecated: use [metadata] instead. Additional metadata as a map of string keys to string values. + * @param metadata Additional metadata as a JSON object. * @return A new ResponseMetadata instance with the timestamp from the provided clock. */ public fun create( @@ -310,9 +326,10 @@ public data class ResponseMetaInfo( totalTokensCount: Int? = null, inputTokensCount: Int? = null, outputTokensCount: Int? = null, - additionalInfo: Map = emptyMap() + additionalInfo: Map = emptyMap(), + metadata: JsonObject? = null, ): ResponseMetaInfo = - ResponseMetaInfo(clock.now(), totalTokensCount, inputTokensCount, outputTokensCount, additionalInfo) + ResponseMetaInfo(clock.now(), totalTokensCount, inputTokensCount, outputTokensCount, additionalInfo, metadata) /** * An empty instance of the [ResponseMetaInfo] with the timestamp set to a distant past. diff --git a/prompt/prompt-structure/src/commonMain/kotlin/ai/koog/prompt/structure/PromptExecutorExtensions.kt b/prompt/prompt-structure/src/commonMain/kotlin/ai/koog/prompt/structure/PromptExecutorExtensions.kt index 55a4112e7a..b8296e34f7 100644 --- a/prompt/prompt-structure/src/commonMain/kotlin/ai/koog/prompt/structure/PromptExecutorExtensions.kt +++ b/prompt/prompt-structure/src/commonMain/kotlin/ai/koog/prompt/structure/PromptExecutorExtensions.kt @@ -48,7 +48,65 @@ public data class StructuredOutputConfig( public val default: StructuredOutput? = null, public val byProvider: Map> = emptyMap(), public val fixingParser: StructureFixingParser? = null -) +) { + /** + * Updates a given prompt to configure structured output using the specified large language model (LLM). + * Depending on the model's support for structured outputs, the prompt is updated either manually or natively. + * + * @param model The large language model (LLModel) used to determine the structured output configuration. + * @param prompt The original prompt to be updated with the structured output configuration. + * @return A new prompt reflecting the updated structured output configuration. + */ + public fun updatePrompt(model: LLModel, prompt: Prompt): Prompt { + return when (val mode = structuredOutput(model)) { + // Don't set schema parameter in prompt and coerce the model manually with user message to provide a structured response. + is StructuredOutput.Manual -> { + prompt(prompt) { + user { + markdown { + StructuredOutputPrompts.outputInstructionPrompt(this, mode.structure) + } + } + } + } + + // Rely on built-in model capabilities to provide structured response. + is StructuredOutput.Native -> { + prompt.withUpdatedParams { schema = mode.structure.schema } + } + } + } + + /** + * Retrieves the structured data configuration for a specific large language model (LLM). + * + * The method determines the appropriate structured data setup based on the given LLM + * instance, ensuring compatibility with the model's provider and capabilities. + * + * @param model The large language model (LLM) instance used to identify the structured data configuration. + * @return The structured data configuration represented as a `StructuredData` instance. + */ + public fun structure(model: LLModel): StructuredData { + return structuredOutput(model).structure + } + + /** + * Retrieves the structured output configuration for a specific large language model (LLM). + * + * The method determines the appropriate structured output instance based on the model's provider. + * If no specific configuration is found for the provider, it falls back to the default configuration. + * Throws an exception if no default configuration is available. + * + * @param model The large language model (LLM) used to identify the structured output configuration. + * @return An instance of `StructuredOutput` that represents the structured output configuration for the model. + * @throws IllegalArgumentException if no configuration is found for the provider and no default configuration is set. + */ + private fun structuredOutput(model: LLModel): StructuredOutput { + return byProvider[model.provider] + ?: default + ?: throw IllegalArgumentException("No structure found for provider ${model.provider}") + } +} /** * Defines how structured outputs should be generated. @@ -106,42 +164,13 @@ public suspend fun PromptExecutor.executeStructured( model: LLModel, config: StructuredOutputConfig, ): Result> { - val mode = config.byProvider[model.provider] - ?: config.default - ?: throw IllegalArgumentException("No structure found for provider ${model.provider}") - - val (structure: StructuredData, updatedPrompt: Prompt) = when (mode) { - // Don't set schema parameter in prompt and coerce the model manually with user message to provide a structured response. - is StructuredOutput.Manual -> { - mode.structure to prompt(prompt) { - user { - markdown { - StructuredOutputPrompts.outputInstructionPrompt(this, mode.structure) - } - } - } - } - - // Rely on built-in model capabilities to provide structured response. - is StructuredOutput.Native -> { - mode.structure to prompt.withUpdatedParams { schema = mode.structure.schema } - } - } - + val updatedPrompt = config.updatePrompt(model, prompt) val response = this.execute(prompt = updatedPrompt, model = model).single() return runCatching { require(response is Message.Assistant) { "Response for structured output must be an assistant message, got ${response::class.simpleName} instead" } - // Use fixingParser if provided, otherwise parse directly - val structureResponse = config.fixingParser - ?.parse(this, structure, response.content) - ?: structure.parse(response.content) - - StructuredResponse( - structure = structureResponse, - message = response - ) + parseResponseToStructuredResponse(response, config, model) } } @@ -269,3 +298,31 @@ public suspend inline fun PromptExecutor.executeStructured( fixingParser = fixingParser, ) } + +/** + * Parses a structured response from the assistant message using the provided structured output configuration + * and language model. If a fixing parser is specified in the configuration, it will be used; otherwise, + * the structure will be parsed directly. + * + * @param T The type of the structured output. + * @param response The assistant's response message to be parsed. + * @param config The structured output configuration defining how the response should be parsed. + * @param model The language model to be used for parsing the structured output. + * @return A `StructuredResponse` containing the parsed structure and the original assistant message. + */ +public suspend fun PromptExecutor.parseResponseToStructuredResponse( + response: Message.Assistant, + config: StructuredOutputConfig, + model: LLModel +): StructuredResponse { + // Use fixingParser if provided, otherwise parse directly + val structure = config.structure(model) + val structureResponse = config.fixingParser + ?.parse(this, structure, response.content) + ?: structure.parse(response.content) + + return StructuredResponse( + structure = structureResponse, + message = response + ) +} diff --git a/prompt/prompt-tokenizer/build.gradle.kts b/prompt/prompt-tokenizer/build.gradle.kts index c55e56f0cf..b033af0990 100644 --- a/prompt/prompt-tokenizer/build.gradle.kts +++ b/prompt/prompt-tokenizer/build.gradle.kts @@ -13,6 +13,7 @@ kotlin { commonMain { dependencies { api(project(":prompt:prompt-llm")) + api(project(":prompt:prompt-model")) api(libs.kotlinx.serialization.json) api(libs.kotlinx.datetime) } diff --git a/prompt/prompt-tokenizer/src/commonMain/kotlin/ai/koog/prompt/tokenizer/PromptTokenizer.kt b/prompt/prompt-tokenizer/src/commonMain/kotlin/ai/koog/prompt/tokenizer/PromptTokenizer.kt new file mode 100644 index 0000000000..0002dc7f80 --- /dev/null +++ b/prompt/prompt-tokenizer/src/commonMain/kotlin/ai/koog/prompt/tokenizer/PromptTokenizer.kt @@ -0,0 +1,110 @@ +package ai.koog.prompt.tokenizer + +import ai.koog.prompt.dsl.Prompt +import ai.koog.prompt.message.Message +import kotlin.collections.sumOf + +/** + * An interface that provides utilities for tokenizing and calculating token usage in messages and prompts. + */ +public interface PromptTokenizer { + /** + * Calculates the number of tokens required for a given message. + * + * @param message The message for which the token count should be determined. + * @return The number of tokens required to encode the message. + */ + public fun tokenCountFor(message: Message): Int + + /** + * Calculates the total number of tokens spent in a given prompt. + * + * @param prompt The prompt for which the total tokens spent need to be calculated. + * @return The total number of tokens spent as an integer. + */ + public fun tokenCountFor(prompt: Prompt): Int +} + +/** + * An implementation of the [PromptTokenizer] interface that delegates token counting + * to an instance of the [Tokenizer] interface. The class provides methods to estimate + * the token count for individual messages and for the entirety of a prompt. + * + * This is useful in contexts where token-based costs or limitations are significant, + * such as when interacting with large language models (LLMs). + * + * @property tokenizer The [Tokenizer] instance used for token counting. + */ +public class OnDemandTokenizer(private val tokenizer: Tokenizer) : PromptTokenizer { + + /** + * Computes the number of tokens in a given message. + * + * @param message The message for which the token count needs to be calculated. + * The content of the message is analyzed to estimate the token count. + * @return The estimated number of tokens in the message content. + */ + public override fun tokenCountFor(message: Message): Int = tokenizer.countTokens(message.content) + + /** + * Calculates the total number of tokens spent for the given prompt based on its messages. + * + * @param prompt The `Prompt` instance containing the list of messages for which the total token count will be calculated. + * @return The total number of tokens across all messages in the prompt. + */ + public override fun tokenCountFor(prompt: Prompt): Int = prompt.messages.sumOf(::tokenCountFor) +} + +/** + * A caching implementation of the `PromptTokenizer` interface that optimizes token counting + * by storing previously computed token counts for messages. This reduces redundant computations + * when the same message is processed multiple times. + * + * @constructor Creates an instance of `CachingTokenizer` with a provided `Tokenizer` instance + * that performs the actual token counting. + * @property tokenizer The underlying `Tokenizer` used for counting tokens in the message content. + */ +public class CachingTokenizer(private val tokenizer: Tokenizer) : PromptTokenizer { + /** + * A cache that maps a `Message` to its corresponding token count. + * + * This is used to store the results of token computations for reuse, optimizing performance + * by avoiding repeated invocations of the token counting process on the same message content. + * + * Token counts are computed lazily and stored in the cache when requested via the `tokensFor` + * method. This cache can be cleared using the `clearCache` method. + */ + internal val cache = mutableMapOf() + + /** + * Retrieves the number of tokens contained in the content of the given message. + * This method utilizes caching to improve performance, storing previously + * computed token counts and reusing them for identical messages. + * + * @param message The message whose content's token count is to be retrieved + * @return The number of tokens in the content of the message + */ + public override fun tokenCountFor(message: Message): Int = cache.getOrPut(message) { + tokenizer.countTokens(message.content) + } + + /** + * Calculates the total number of tokens spent on the given prompt by summing the token usage + * of all messages associated with the prompt. + * + * @param prompt The prompt containing the list of messages whose token usage will be calculated. + * @return The total number of tokens spent across all messages in the provided prompt. + */ + public override fun tokenCountFor(prompt: Prompt): Int = prompt.messages.sumOf(::tokenCountFor) + + /** + * Clears all cached token counts from the internal cache. + * + * This method is useful when the state of the cached data becomes invalid + * or needs resetting. After calling this, any subsequent token count + * calculations will be recomputed rather than retrieved from the cache. + */ + public fun clearCache() { + cache.clear() + } +} diff --git a/prompt/prompt-tokenizer/src/commonTest/kotlin/ai/koog/prompt/tokenizer/PromptTokenizerTest.kt b/prompt/prompt-tokenizer/src/commonTest/kotlin/ai/koog/prompt/tokenizer/PromptTokenizerTest.kt new file mode 100644 index 0000000000..0f68ab6359 --- /dev/null +++ b/prompt/prompt-tokenizer/src/commonTest/kotlin/ai/koog/prompt/tokenizer/PromptTokenizerTest.kt @@ -0,0 +1,133 @@ +package ai.koog.prompt.tokenizer + +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.message.Message +import ai.koog.prompt.message.RequestMetaInfo +import ai.koog.prompt.message.ResponseMetaInfo +import kotlinx.datetime.Clock +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +/** + * Test for the PromptTokenizer implementations. + */ +class PromptTokenizerTest { + + /** + * A mock tokenizer that tracks the total tokens counted. + * + * This implementation counts tokens by simply counting characters and dividing by 4, + * which is a very rough approximation but sufficient for testing purposes. + * It also keeps track of the total tokens counted across all calls. + */ + class MockTokenizer : Tokenizer { + private var _totalTokens = 0 + + /** + * The total number of tokens counted across all calls to countTokens. + */ + val totalTokens: Int + get() = _totalTokens + + /** + * Counts tokens by simply counting characters and dividing by 4. + * Also adds to the running total of tokens counted. + * + * @param text The text to tokenize + * @return The estimated number of tokens in the text + */ + override fun countTokens(text: String): Int { + // Simple approximation: 1 token ≈ 4 characters + println("countTokens: $text") + val tokens = (text.length / 4) + 1 + _totalTokens += tokens + return tokens + } + + /** + * Resets the total tokens counter to 0. + */ + fun reset() { + _totalTokens = 0 + } + } + + @Test + fun testPromptTokenizer() { + // Create a mock tokenizer to track token usage + val mockTokenizer = MockTokenizer() + + // Create a prompt tokenizer with our mock tokenizer + val promptTokenizer = OnDemandTokenizer(mockTokenizer) + + // Create a prompt with some messages + val testPrompt = prompt("test-prompt") { + system("You are a helpful assistant.") + user("What is the capital of France?") + assistant("Paris is the capital of France.") + } + + // Count tokens in the prompt + val totalTokens = promptTokenizer.tokenCountFor(testPrompt) + + // Verify that tokens were counted + assertTrue(totalTokens > 0, "Total tokens should be greater than 0") + + // Verify that the tokenizer was used and counted tokens + assertTrue(mockTokenizer.totalTokens > 0, "Tokenizer should have counted tokens") + + // Verify that the total tokens match what we expect + assertEquals(totalTokens, mockTokenizer.totalTokens, "Total tokens should match the tokenizer's count") + + // Print the total tokens spent + println("[DEBUG_LOG] Total tokens spent: ${mockTokenizer.totalTokens}") + + val requestMetainfo = RequestMetaInfo.create(Clock.System) + val responseMetainfo = ResponseMetaInfo.create(Clock.System) + // Count tokens for individual messages + val systemTokens = promptTokenizer.tokenCountFor( + Message.System("You are a helpful assistant.", requestMetainfo) + ) + val userTokens = promptTokenizer.tokenCountFor(Message.User("What is the capital of France?", requestMetainfo)) + val assistantTokens = promptTokenizer.tokenCountFor( + Message.Assistant("Paris is the capital of France.", responseMetainfo) + ) + + // Print token counts for each message + println("[DEBUG_LOG] System message tokens: $systemTokens") + println("[DEBUG_LOG] User message tokens: $userTokens") + println("[DEBUG_LOG] Assistant message tokens: $assistantTokens") + + // Verify that the sum of individual message tokens equals the total + val sumOfMessageTokens = systemTokens + userTokens + assistantTokens + assertEquals(sumOfMessageTokens, totalTokens, "Sum of message tokens should equal total tokens") + } + + @Test + fun testCachingPromptTokenizer() { + // Create a mock tokenizer to track token usage + val mockTokenizer = MockTokenizer() + + // Create a prompt tokenizer with our mock tokenizer + val promptTokenizer = CachingTokenizer(mockTokenizer) + + // Create a prompt with some messages + val testPrompt = prompt("test-prompt") { + system("You are a helpful assistant.") + user("What is the capital of France?") + assistant("Paris is the capital of France.") + } + + assertEquals(0, promptTokenizer.cache.size) + promptTokenizer.tokenCountFor(testPrompt) + assertEquals(3, promptTokenizer.cache.size) + promptTokenizer.clearCache() + assertEquals(0, promptTokenizer.cache.size) + promptTokenizer.tokenCountFor(testPrompt.messages[1]) + promptTokenizer.tokenCountFor(testPrompt.messages[2]) + assertEquals(2, promptTokenizer.cache.size) + promptTokenizer.tokenCountFor(testPrompt) + assertEquals(3, promptTokenizer.cache.size) + } +} diff --git a/qodana.sarif.json b/qodana.sarif.json index a3ace14508..a3bb2eb138 100644 --- a/qodana.sarif.json +++ b/qodana.sarif.json @@ -76874,8 +76874,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'getOnToolCallFailure' coverage is below the threshold 50%", - "markdown": "Method `getOnToolCallFailure` coverage is below the threshold 50%" + "text": "Method 'getOnToolCallFailed' coverage is below the threshold 50%", + "markdown": "Method `getOnToolCallFailed` coverage is below the threshold 50%" }, "locations": [ { @@ -76890,7 +76890,7 @@ "charOffset": 8585, "charLength": 17, "snippet": { - "text": "onToolCallFailure" + "text": "OnToolCallFailed" }, "sourceLanguage": "kotlin" }, @@ -76900,7 +76900,7 @@ "charOffset": 8570, "charLength": 180, "snippet": { - "text": " public var onToolCallFailure: suspend (tool: Tool<*, *>, toolArgs: Tool.Args, throwable: Throwable) -> Unit = { tool: Tool<*, *>, toolArgs: Tool.Args, throwable: Throwable -> }" + "text": " public var onToolCallFailed: suspend (tool: Tool<*, *>, toolArgs: Tool.Args, throwable: Throwable) -> Unit = { tool: Tool<*, *>, toolArgs: Tool.Args, throwable: Throwable -> }" }, "sourceLanguage": "kotlin" } @@ -77048,8 +77048,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'setOnToolValidationError' coverage is below the threshold 50%", - "markdown": "Method `setOnToolValidationError` coverage is below the threshold 50%" + "text": "Method 'setOnToolValidationFailed' coverage is below the threshold 50%", + "markdown": "Method `setOnToolValidationFailed` coverage is below the threshold 50%" }, "locations": [ { @@ -77064,7 +77064,7 @@ "charOffset": 8397, "charLength": 46, "snippet": { - "text": "set(value) = this.onToolValidationError(value)" + "text": "set(value) = this.onToolValidationFailed(value)" }, "sourceLanguage": "kotlin" }, @@ -77074,7 +77074,7 @@ "charOffset": 8389, "charLength": 54, "snippet": { - "text": " set(value) = this.onToolValidationError(value)" + "text": " set(value) = this.onToolValidationFailed(value)" }, "sourceLanguage": "kotlin" } @@ -77280,8 +77280,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'setOnStrategyFinished' coverage is below the threshold 50%", - "markdown": "Method `setOnStrategyFinished` coverage is below the threshold 50%" + "text": "Method 'setOnStrategyCompleted' coverage is below the threshold 50%", + "markdown": "Method `setOnStrategyCompleted` coverage is below the threshold 50%" }, "locations": [ { @@ -77296,7 +77296,7 @@ "charOffset": 5912, "charLength": 43, "snippet": { - "text": "set(value) = this.onStrategyFinished(value)" + "text": "set(value) = this.onStrategyCompleted(value)" }, "sourceLanguage": "kotlin" }, @@ -77306,7 +77306,7 @@ "charOffset": 5904, "charLength": 51, "snippet": { - "text": " set(value) = this.onStrategyFinished(value)" + "text": " set(value) = this.onStrategyCompleted(value)" }, "sourceLanguage": "kotlin" } @@ -79542,8 +79542,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'onToolCall' coverage is below the threshold 50%", - "markdown": "Method `onToolCall` coverage is below the threshold 50%" + "text": "Method 'onToolCallStrarting' coverage is below the threshold 50%", + "markdown": "Method `onToolCallStarting` coverage is below the threshold 50%" }, "locations": [ { @@ -79558,7 +79558,7 @@ "charOffset": 3518, "charLength": 10, "snippet": { - "text": "onToolCall" + "text": "OnToolCallStarting" }, "sourceLanguage": "kotlin" }, @@ -80644,8 +80644,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'getOnBeforeNode' coverage is below the threshold 50%", - "markdown": "Method `getOnBeforeNode` coverage is below the threshold 50%" + "text": "Method 'getOnNodeExecutionStarting' coverage is below the threshold 50%", + "markdown": "Method `getOnNodeExecutionStarting` coverage is below the threshold 50%" }, "locations": [ { @@ -80660,7 +80660,7 @@ "charOffset": 6172, "charLength": 12, "snippet": { - "text": "onBeforeNode" + "text": "OnNodeExecutionStarting" }, "sourceLanguage": "kotlin" }, @@ -80670,7 +80670,7 @@ "charOffset": 6157, "charLength": 195, "snippet": { - "text": " public var onBeforeNode: suspend (node: AIAgentNodeBase<*, *>, context: AIAgentContextBase, input: Any?) -> Unit = { node: AIAgentNodeBase<*, *>, context: AIAgentContextBase, input: Any? -> }" + "text": " public var onNodeExecutionStarting: suspend (node: AIAgentNodeBase<*, *>, context: AIAgentContextBase, input: Any?) -> Unit = { node: AIAgentNodeBase<*, *>, context: AIAgentContextBase, input: Any? -> }" }, "sourceLanguage": "kotlin" } @@ -85242,7 +85242,7 @@ "charOffset": 4202, "charLength": 10, "snippet": { - "text": "onToolCall" + "text": "OnToolCallStarting" }, "sourceLanguage": "kotlin" }, @@ -85284,8 +85284,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'getOnBeforeAgentStarted' coverage is below the threshold 50%", - "markdown": "Method `getOnBeforeAgentStarted` coverage is below the threshold 50%" + "text": "Method 'getOnAgentStarting' coverage is below the threshold 50%", + "markdown": "Method `getOnAgentStarting` coverage is below the threshold 50%" }, "locations": [ { @@ -85300,7 +85300,7 @@ "charOffset": 4405, "charLength": 20, "snippet": { - "text": "onBeforeAgentStarted" + "text": "OnAgentStarting" }, "sourceLanguage": "kotlin" }, @@ -85310,7 +85310,7 @@ "charOffset": 4390, "charLength": 147, "snippet": { - "text": " public var onBeforeAgentStarted: suspend (strategy: AIAgentStrategy, agent: AIAgent) -> Unit = { strategy: AIAgentStrategy, agent: AIAgent -> }" + "text": " public var onAgentStarting: suspend (strategy: AIAgentStrategy, agent: AIAgent) -> Unit = { strategy: AIAgentStrategy, agent: AIAgent -> }" }, "sourceLanguage": "kotlin" } @@ -85342,8 +85342,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'getOnToolValidationError' coverage is below the threshold 50%", - "markdown": "Method `getOnToolValidationError` coverage is below the threshold 50%" + "text": "Method 'getOnToolValidationFailed' coverage is below the threshold 50%", + "markdown": "Method `getOnToolValidationFailed` coverage is below the threshold 50%" }, "locations": [ { @@ -85358,7 +85358,7 @@ "charOffset": 8233, "charLength": 21, "snippet": { - "text": "onToolValidationError" + "text": "onToolValidationFailed" }, "sourceLanguage": "kotlin" }, @@ -85368,7 +85368,7 @@ "charOffset": 8218, "charLength": 170, "snippet": { - "text": " public var onToolValidationError: suspend (tool: Tool<*, *>, toolArgs: Tool.Args, value: String) -> Unit = { tool: Tool<*, *>, toolArgs: Tool.Args, value: String -> }" + "text": " public var onToolValidationFailed: suspend (tool: Tool<*, *>, toolArgs: Tool.Args, value: String) -> Unit = { tool: Tool<*, *>, toolArgs: Tool.Args, value: String -> }" }, "sourceLanguage": "kotlin" } @@ -86444,8 +86444,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'setOnToolCallFailure' coverage is below the threshold 50%", - "markdown": "Method `setOnToolCallFailure` coverage is below the threshold 50%" + "text": "Method 'setOnToolCallFailed' coverage is below the threshold 50%", + "markdown": "Method `setOnToolCallFailed` coverage is below the threshold 50%" }, "locations": [ { @@ -86460,7 +86460,7 @@ "charOffset": 8759, "charLength": 42, "snippet": { - "text": "set(value) = this.onToolCallFailure(value)" + "text": "set(value) = this.onToolCallFailed(value)" }, "sourceLanguage": "kotlin" }, @@ -86470,7 +86470,7 @@ "charOffset": 8751, "charLength": 50, "snippet": { - "text": " set(value) = this.onToolCallFailure(value)" + "text": " set(value) = this.onToolCallFailed(value)" }, "sourceLanguage": "kotlin" } @@ -87314,8 +87314,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'getOnAfterNode' coverage is below the threshold 50%", - "markdown": "Method `getOnAfterNode` coverage is below the threshold 50%" + "text": "Method 'getOnNodeExecutionCompleted' coverage is below the threshold 50%", + "markdown": "Method `getOnNodeExecutionCompleted` coverage is below the threshold 50%" }, "locations": [ { @@ -87330,7 +87330,7 @@ "charOffset": 6528, "charLength": 11, "snippet": { - "text": "onAfterNode" + "text": "OnNodeExecutionCompleted" }, "sourceLanguage": "kotlin" }, @@ -87340,7 +87340,7 @@ "charOffset": 6513, "charLength": 222, "snippet": { - "text": " public var onAfterNode: suspend (node: AIAgentNodeBase<*, *>, context: AIAgentContextBase, input: Any?, output: Any?) -> Unit = { node: AIAgentNodeBase<*, *>, context: AIAgentContextBase, input: Any?, output: Any? -> }" + "text": " public var onNodeExecutionCompleted: suspend (node: AIAgentNodeBase<*, *>, context: AIAgentContextBase, input: Any?, output: Any?) -> Unit = { node: AIAgentNodeBase<*, *>, context: AIAgentContextBase, input: Any?, output: Any? -> }" }, "sourceLanguage": "kotlin" } @@ -87430,8 +87430,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'getOnStrategyFinished' coverage is below the threshold 50%", - "markdown": "Method `getOnStrategyFinished` coverage is below the threshold 50%" + "text": "Method 'getOnStrategyCompleted' coverage is below the threshold 50%", + "markdown": "Method `getOnStrategyCompleted` coverage is below the threshold 50%" }, "locations": [ { @@ -87446,7 +87446,7 @@ "charOffset": 5773, "charLength": 18, "snippet": { - "text": "onStrategyFinished" + "text": "OnStrategyCompleted" }, "sourceLanguage": "kotlin" }, @@ -87456,7 +87456,7 @@ "charOffset": 5758, "charLength": 145, "snippet": { - "text": " public var onStrategyFinished: suspend (strategy: AIAgentStrategy, result: String) -> Unit = { strategy: AIAgentStrategy, result: String -> }" + "text": " public var onStrategyCompleted: suspend (strategy: AIAgentStrategy, result: String) -> Unit = { strategy: AIAgentStrategy, result: String -> }" }, "sourceLanguage": "kotlin" } @@ -90794,8 +90794,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'setOnToolCall' coverage is below the threshold 50%", - "markdown": "Method `setOnToolCall` coverage is below the threshold 50%" + "text": "Method 'setOnToolCallStarting' coverage is below the threshold 50%", + "markdown": "Method `setOnToolCallStarting` coverage is below the threshold 50%" }, "locations": [ { @@ -90810,7 +90810,7 @@ "charOffset": 8048, "charLength": 35, "snippet": { - "text": "set(value) = this.onToolCall(value)" + "text": "set(value) = this.onToolCallStarting(value)" }, "sourceLanguage": "kotlin" }, @@ -90820,7 +90820,7 @@ "charOffset": 8040, "charLength": 43, "snippet": { - "text": " set(value) = this.onToolCall(value)" + "text": " set(value) = this.onToolCallStarting(value)" }, "sourceLanguage": "kotlin" } @@ -91664,8 +91664,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'setOnBeforeNode' coverage is below the threshold 50%", - "markdown": "Method `setOnBeforeNode` coverage is below the threshold 50%" + "text": "Method 'setOnNodeExecutionStarting' coverage is below the threshold 50%", + "markdown": "Method `setOnNodeExecutionStarting` coverage is below the threshold 50%" }, "locations": [ { @@ -91680,7 +91680,7 @@ "charOffset": 6361, "charLength": 37, "snippet": { - "text": "set(value) = this.onBeforeNode(value)" + "text": "set(value) = this.onNodeExecutionStarting(value)" }, "sourceLanguage": "kotlin" }, @@ -91690,7 +91690,7 @@ "charOffset": 6353, "charLength": 45, "snippet": { - "text": " set(value) = this.onBeforeNode(value)" + "text": " set(value) = this.onNodeExecutionStarting(value)" }, "sourceLanguage": "kotlin" } @@ -93752,8 +93752,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'setOnAfterNode' coverage is below the threshold 50%", - "markdown": "Method `setOnAfterNode` coverage is below the threshold 50%" + "text": "Method 'setOnNodeExecutionCompleted' coverage is below the threshold 50%", + "markdown": "Method `setOnNodeExecutionCompleted` coverage is below the threshold 50%" }, "locations": [ { @@ -93768,7 +93768,7 @@ "charOffset": 6744, "charLength": 36, "snippet": { - "text": "set(value) = this.onAfterNode(value)" + "text": "set(value) = this.onNodeExecutionCompleted(value)" }, "sourceLanguage": "kotlin" }, @@ -93778,7 +93778,7 @@ "charOffset": 6736, "charLength": 44, "snippet": { - "text": " set(value) = this.onAfterNode(value)" + "text": " set(value) = this.onNodeExecutionCompleted(value)" }, "sourceLanguage": "kotlin" } @@ -99610,8 +99610,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'getOnBeforeLLMCall' coverage is below the threshold 50%", - "markdown": "Method `getOnBeforeLLMCall` coverage is below the threshold 50%" + "text": "Method 'getOnLLMCallStarting' coverage is below the threshold 50%", + "markdown": "Method `getOnLLMCallStarting` coverage is below the threshold 50%" }, "locations": [ { @@ -99626,7 +99626,7 @@ "charOffset": 7003, "charLength": 15, "snippet": { - "text": "onBeforeLLMCall" + "text": "onLLMCallStarting" }, "sourceLanguage": "kotlin" }, @@ -99636,7 +99636,7 @@ "charOffset": 6988, "charLength": 216, "snippet": { - "text": " public var onBeforeLLMCall: suspend (prompt: Prompt, tools: List, model: LLModel, sessionId: String) -> Unit = { prompt: Prompt, tools: List, model: LLModel, sessionId: String -> }" + "text": " public var onLLMCallStarting: suspend (prompt: Prompt, tools: List, model: LLModel, sessionId: String) -> Unit = { prompt: Prompt, tools: List, model: LLModel, sessionId: String -> }" }, "sourceLanguage": "kotlin" } @@ -100016,8 +100016,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'setOnStrategyStarted' coverage is below the threshold 50%", - "markdown": "Method `setOnStrategyStarted` coverage is below the threshold 50%" + "text": "Method 'setOnStrategyStarting' coverage is below the threshold 50%", + "markdown": "Method `setOnStrategyStarting` coverage is below the threshold 50%" }, "locations": [ { @@ -100032,7 +100032,7 @@ "charOffset": 5587, "charLength": 42, "snippet": { - "text": "set(value) = this.onStrategyStarted(value)" + "text": "set(value) = this.onStrategyStarting(value)" }, "sourceLanguage": "kotlin" }, @@ -100042,7 +100042,7 @@ "charOffset": 5579, "charLength": 50, "snippet": { - "text": " set(value) = this.onStrategyStarted(value)" + "text": " set(value) = this.onStrategyStarting(value)" }, "sourceLanguage": "kotlin" } @@ -102626,8 +102626,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'setOnToolCallResult' coverage is below the threshold 50%", - "markdown": "Method `setOnToolCallResult` coverage is below the threshold 50%" + "text": "Method 'setOnToolCallCompleted' coverage is below the threshold 50%", + "markdown": "Method `setOnToolCallCompleted` coverage is below the threshold 50%" }, "locations": [ { @@ -102642,7 +102642,7 @@ "charOffset": 9112, "charLength": 41, "snippet": { - "text": "set(value) = this.onToolCallResult(value)" + "text": "set(value) = this.onToolCallCompleted(value)" }, "sourceLanguage": "kotlin" }, @@ -102652,7 +102652,7 @@ "charOffset": 9104, "charLength": 49, "snippet": { - "text": " set(value) = this.onToolCallResult(value)" + "text": " set(value) = this.onToolCallCompleted(value)" }, "sourceLanguage": "kotlin" } @@ -102684,8 +102684,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'getOnStrategyStarted' coverage is below the threshold 50%", - "markdown": "Method `getOnStrategyStarted` coverage is below the threshold 50%" + "text": "Method 'getOnStrategyStarting' coverage is below the threshold 50%", + "markdown": "Method `getOnStrategyStarting` coverage is below the threshold 50%" }, "locations": [ { @@ -102700,7 +102700,7 @@ "charOffset": 5481, "charLength": 17, "snippet": { - "text": "onStrategyStarted" + "text": "OnStrategyStarting" }, "sourceLanguage": "kotlin" }, @@ -102710,7 +102710,7 @@ "charOffset": 5466, "charLength": 112, "snippet": { - "text": " public var onStrategyStarted: suspend (strategy: AIAgentStrategy) -> Unit = { strategy: AIAgentStrategy -> }" + "text": " public var onStrategyStarting: suspend (strategy: AIAgentStrategy) -> Unit = { strategy: AIAgentStrategy -> }" }, "sourceLanguage": "kotlin" } @@ -102916,8 +102916,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'invokeOnToolCallFailure$koog_agents_agents_agents_features_agents_features_event_handler_commonMain' coverage is below the threshold 50%", - "markdown": "Method `invokeOnToolCallFailure$koog_agents_agents_agents_features_agents_features_event_handler_commonMain` coverage is below the threshold 50%" + "text": "Method 'invokeOnToolCallFailed$koog_agents_agents_agents_features_agents_features_event_handler_commonMain' coverage is below the threshold 50%", + "markdown": "Method `invokeOnToolCallFailed$koog_agents_agents_agents_features_agents_features_event_handler_commonMain` coverage is below the threshold 50%" }, "locations": [ { @@ -102932,7 +102932,7 @@ "charOffset": 19182, "charLength": 23, "snippet": { - "text": "invokeOnToolCallFailure" + "text": "invokeOnToolCallFailed" }, "sourceLanguage": "kotlin" }, @@ -102942,7 +102942,7 @@ "charOffset": 19157, "charLength": 111, "snippet": { - "text": " internal suspend fun invokeOnToolCallFailure(tool: Tool<*, *>, toolArgs: Tool.Args, throwable: Throwable) {" + "text": " internal suspend fun invokeOnToolCallFailed(tool: Tool<*, *>, toolArgs: Tool.Args, throwable: Throwable) {" }, "sourceLanguage": "kotlin" } @@ -104946,8 +104946,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'setOnAgentFinished' coverage is below the threshold 50%", - "markdown": "Method `setOnAgentFinished` coverage is below the threshold 50%" + "text": "Method 'setOnAgentCompleted' coverage is below the threshold 50%", + "markdown": "Method `setOnAgentCompleted` coverage is below the threshold 50%" }, "locations": [ { @@ -104962,7 +104962,7 @@ "charOffset": 4857, "charLength": 40, "snippet": { - "text": "set(value) = this.onAgentFinished(value)" + "text": "set(value) = this.onAgentCompleted(value)" }, "sourceLanguage": "kotlin" }, @@ -104972,7 +104972,7 @@ "charOffset": 4849, "charLength": 48, "snippet": { - "text": " set(value) = this.onAgentFinished(value)" + "text": " set(value) = this.onAgentCompleted(value)" }, "sourceLanguage": "kotlin" } @@ -110340,8 +110340,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'getOnToolCallResult' coverage is below the threshold 50%", - "markdown": "Method `getOnToolCallResult` coverage is below the threshold 50%" + "text": "Method 'getOnToolCallCompleted' coverage is below the threshold 50%", + "markdown": "Method `getOnToolCallCompleted` coverage is below the threshold 50%" }, "locations": [ { @@ -110356,7 +110356,7 @@ "charOffset": 8941, "charLength": 16, "snippet": { - "text": "onToolCallResult" + "text": "OnToolCallCompleted" }, "sourceLanguage": "kotlin" }, @@ -110366,7 +110366,7 @@ "charOffset": 8926, "charLength": 177, "snippet": { - "text": " public var onToolCallResult: suspend (tool: Tool<*, *>, toolArgs: Tool.Args, result: ToolResult?) -> Unit = { tool: Tool<*, *>, toolArgs: Tool.Args, result: ToolResult? -> }" + "text": " public var onToolCallCompleted: suspend (tool: Tool<*, *>, toolArgs: Tool.Args, result: ToolResult?) -> Unit = { tool: Tool<*, *>, toolArgs: Tool.Args, result: ToolResult? -> }" }, "sourceLanguage": "kotlin" } @@ -112428,8 +112428,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'getOnAfterLLMCall' coverage is below the threshold 50%", - "markdown": "Method `getOnAfterLLMCall` coverage is below the threshold 50%" + "text": "Method 'getLLMCallCompleted' coverage is below the threshold 50%", + "markdown": "Method `getOnLLMCallCompleted` coverage is below the threshold 50%" }, "locations": [ { @@ -112444,7 +112444,7 @@ "charOffset": 7389, "charLength": 14, "snippet": { - "text": "onAfterLLMCall" + "text": "onLLMCallCompleted" }, "sourceLanguage": "kotlin" }, @@ -112454,7 +112454,7 @@ "charOffset": 7374, "charLength": 285, "snippet": { - "text": " public var onAfterLLMCall: suspend (prompt: Prompt, tools: List, model: LLModel, responses: List, sessionId: String) -> Unit = { prompt: Prompt, tools: List, model: LLModel, responses: List, sessionId: String -> }" + "text": " public var onLLMCallCompleted: suspend (prompt: Prompt, tools: List, model: LLModel, responses: List, sessionId: String) -> Unit = { prompt: Prompt, tools: List, model: LLModel, responses: List, sessionId: String -> }" }, "sourceLanguage": "kotlin" } @@ -115328,8 +115328,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'getOnAgentRunError' coverage is below the threshold 50%", - "markdown": "Method `getOnAgentRunError` coverage is below the threshold 50%" + "text": "Method 'getOnAgentExecutionFailed' coverage is below the threshold 50%", + "markdown": "Method `getOnAgentExecutionFailed` coverage is below the threshold 50%" }, "locations": [ { @@ -115344,7 +115344,7 @@ "charOffset": 5035, "charLength": 15, "snippet": { - "text": "onAgentRunError" + "text": "OnAgentExecutionFailed" }, "sourceLanguage": "kotlin" }, @@ -115354,7 +115354,7 @@ "charOffset": 5020, "charLength": 184, "snippet": { - "text": " public var onAgentRunError: suspend (strategyName: String, sessionId: String?, throwable: Throwable) -> Unit = { strategyName: String, sessionId: String?, throwable: Throwable -> }" + "text": " public var onAgentExecutionFailed: suspend (strategyName: String, sessionId: String?, throwable: Throwable) -> Unit = { strategyName: String, sessionId: String?, throwable: Throwable -> }" }, "sourceLanguage": "kotlin" } @@ -117126,8 +117126,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'setOnBeforeAgentStarted' coverage is below the threshold 50%", - "markdown": "Method `setOnBeforeAgentStarted` coverage is below the threshold 50%" + "text": "Method 'setOnAgentStarting' coverage is below the threshold 50%", + "markdown": "Method `setOnAgentStarting` coverage is below the threshold 50%" }, "locations": [ { @@ -117142,7 +117142,7 @@ "charOffset": 4546, "charLength": 45, "snippet": { - "text": "set(value) = this.onBeforeAgentStarted(value)" + "text": "set(value) = this.onAgentStarting(value)" }, "sourceLanguage": "kotlin" }, @@ -117152,7 +117152,7 @@ "charOffset": 4538, "charLength": 53, "snippet": { - "text": " set(value) = this.onBeforeAgentStarted(value)" + "text": " set(value) = this.onAgentStarting(value)" }, "sourceLanguage": "kotlin" } @@ -120838,8 +120838,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'setOnAfterLLMCall' coverage is below the threshold 50%", - "markdown": "Method `setOnAfterLLMCall` coverage is below the threshold 50%" + "text": "Method 'setOnLLMCallCompleted' coverage is below the threshold 50%", + "markdown": "Method `setOnLLMCallCompleted` coverage is below the threshold 50%" }, "locations": [ { @@ -120854,7 +120854,7 @@ "charOffset": 7668, "charLength": 39, "snippet": { - "text": "set(value) = this.onAfterLLMCall(value)" + "text": "set(value) = this.onLLMCallCompleted(value)" }, "sourceLanguage": "kotlin" }, @@ -120864,7 +120864,7 @@ "charOffset": 7660, "charLength": 47, "snippet": { - "text": " set(value) = this.onAfterLLMCall(value)" + "text": " set(value) = this.onLLMCallCompleted(value)" }, "sourceLanguage": "kotlin" } @@ -121360,8 +121360,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'getOnToolCall' coverage is below the threshold 50%", - "markdown": "Method `getOnToolCall` coverage is below the threshold 50%" + "text": "Method 'getOnToolCallStarting' coverage is below the threshold 50%", + "markdown": "Method `getOnToolCallStarting` coverage is below the threshold 50%" }, "locations": [ { @@ -121376,7 +121376,7 @@ "charOffset": 7925, "charLength": 10, "snippet": { - "text": "onToolCall" + "text": "OnToolCallStarting" }, "sourceLanguage": "kotlin" }, @@ -121386,7 +121386,7 @@ "charOffset": 7910, "charLength": 129, "snippet": { - "text": " public var onToolCall: suspend (tool: Tool<*, *>, toolArgs: Tool.Args) -> Unit = { tool: Tool<*, *>, toolArgs: Tool.Args -> }" + "text": " public var onToolCallStarting: suspend (tool: Tool<*, *>, toolArgs: Tool.Args) -> Unit = { tool: Tool<*, *>, toolArgs: Tool.Args -> }" }, "sourceLanguage": "kotlin" } @@ -121882,8 +121882,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'setOnBeforeLLMCall' coverage is below the threshold 50%", - "markdown": "Method `setOnBeforeLLMCall` coverage is below the threshold 50%" + "text": "Method 'setOnLLMCallStrarting' coverage is below the threshold 50%", + "markdown": "Method `setOnLLMCallStarting` coverage is below the threshold 50%" }, "locations": [ { @@ -121898,7 +121898,7 @@ "charOffset": 7213, "charLength": 40, "snippet": { - "text": "set(value) = this.onBeforeLLMCall(value)" + "text": "set(value) = this.onLLMCallStarting(value)" }, "sourceLanguage": "kotlin" }, @@ -121908,7 +121908,7 @@ "charOffset": 7205, "charLength": 48, "snippet": { - "text": " set(value) = this.onBeforeLLMCall(value)" + "text": " set(value) = this.onLLMCallStarting(value)" }, "sourceLanguage": "kotlin" } @@ -123100,8 +123100,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'getOnAgentFinished' coverage is below the threshold 50%", - "markdown": "Method `getOnAgentFinished` coverage is below the threshold 50%" + "text": "Method 'getOnAgentCompleted' coverage is below the threshold 50%", + "markdown": "Method `getOnAgentCompleted` coverage is below the threshold 50%" }, "locations": [ { @@ -123116,7 +123116,7 @@ "charOffset": 4729, "charLength": 15, "snippet": { - "text": "onAgentFinished" + "text": "OnAgentCompleted" }, "sourceLanguage": "kotlin" }, @@ -123126,7 +123126,7 @@ "charOffset": 4714, "charLength": 134, "snippet": { - "text": " public var onAgentFinished: suspend (strategyName: String, result: String?) -> Unit = { strategyName: String, result: String? -> }" + "text": " public var onAgentCompleted: suspend (strategyName: String, result: String?) -> Unit = { strategyName: String, result: String? -> }" }, "sourceLanguage": "kotlin" } @@ -125304,8 +125304,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'setOnAgentRunError' coverage is below the threshold 50%", - "markdown": "Method `setOnAgentRunError` coverage is below the threshold 50%" + "text": "Method 'setOnAgentExecutionFailed' coverage is below the threshold 50%", + "markdown": "Method `setOnAgentExecutionFailed` coverage is below the threshold 50%" }, "locations": [ { @@ -125320,7 +125320,7 @@ "charOffset": 5213, "charLength": 40, "snippet": { - "text": "set(value) = this.onAgentRunError(value)" + "text": "set(value) = this.onAgentExecutionFailed(value)" }, "sourceLanguage": "kotlin" }, @@ -125330,7 +125330,7 @@ "charOffset": 5205, "charLength": 48, "snippet": { - "text": " set(value) = this.onAgentRunError(value)" + "text": " set(value) = this.onAgentExecutionFailed(value)" }, "sourceLanguage": "kotlin" } @@ -126116,8 +126116,8 @@ "kind": "fail", "level": "warning", "message": { - "text": "Method 'onToolCallFailure' coverage is below the threshold 50%", - "markdown": "Method `onToolCallFailure` coverage is below the threshold 50%" + "text": "Method 'onToolCallFailed' coverage is below the threshold 50%", + "markdown": "Method `onToolCallFailed` coverage is below the threshold 50%" }, "locations": [ { @@ -126132,7 +126132,7 @@ "charOffset": 14208, "charLength": 17, "snippet": { - "text": "onToolCallFailure" + "text": "OnToolCallFailed" }, "sourceLanguage": "kotlin" }, @@ -126142,7 +126142,7 @@ "charOffset": 14185, "charLength": 103, "snippet": { - "text": " public suspend fun onToolCallFailure(tool: Tool<*, *>, toolArgs: Tool.Args, throwable: Throwable) {" + "text": " public suspend fun onToolCallFailed(tool: Tool<*, *>, toolArgs: Tool.Args, throwable: Throwable) {" }, "sourceLanguage": "kotlin" } diff --git a/settings.gradle.kts b/settings.gradle.kts index 54f3402496..9e1fb87b82 100755 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -17,6 +17,9 @@ include(":agents:agents-features:agents-features-sql") include(":agents:agents-features:agents-features-trace") include(":agents:agents-features:agents-features-tokenizer") include(":agents:agents-features:agents-features-snapshot") +include(":agents:agents-features:agents-features-a2a-core") +include(":agents:agents-features:agents-features-a2a-server") +include(":agents:agents-features:agents-features-a2a-client") include(":agents:agents-mcp") include(":agents:agents-mcp-server") @@ -61,6 +64,15 @@ include(":embeddings:embeddings-llm") include(":rag:rag-base") include(":rag:vector-storage") +include(":a2a:a2a-core") +include(":a2a:a2a-server") +include(":a2a:a2a-client") +include(":a2a:a2a-test") +include(":a2a:a2a-transport:a2a-transport-core-jsonrpc") +include(":a2a:a2a-transport:a2a-transport-server-jsonrpc-http") +include(":a2a:a2a-transport:a2a-transport-client-jsonrpc-http") +include(":a2a:test-tck:a2a-test-server-tck") + include(":koog-spring-boot-starter") include(":koog-ktor") diff --git a/test-utils/build.gradle.kts b/test-utils/build.gradle.kts index 688ab4e94b..79ea57e818 100644 --- a/test-utils/build.gradle.kts +++ b/test-utils/build.gradle.kts @@ -22,6 +22,7 @@ kotlin { dependencies { api(kotlin("test-junit5")) api(libs.junit.jupiter.params) + api(libs.testcontainers) runtimeOnly(libs.slf4j.simple) } } diff --git a/test-utils/src/jvmMain/kotlin/ai/koog/test/utils/DockerAvailableCondition.kt b/test-utils/src/jvmMain/kotlin/ai/koog/test/utils/DockerAvailableCondition.kt new file mode 100644 index 0000000000..35fe37e977 --- /dev/null +++ b/test-utils/src/jvmMain/kotlin/ai/koog/test/utils/DockerAvailableCondition.kt @@ -0,0 +1,20 @@ +package ai.koog.test.utils + +import org.junit.jupiter.api.extension.ConditionEvaluationResult +import org.junit.jupiter.api.extension.ExecutionCondition +import org.junit.jupiter.api.extension.ExtensionContext +import org.testcontainers.DockerClientFactory + +/** + * Helper test condition method to skip test suite if Docker is not available. + */ +public class DockerAvailableCondition : ExecutionCondition { + override fun evaluateExecutionCondition(context: ExtensionContext): ConditionEvaluationResult { + return try { + DockerClientFactory.instance().client() + ConditionEvaluationResult.enabled("Docker is available") + } catch (e: Exception) { + ConditionEvaluationResult.disabled("Docker is not available, skipping this test") + } + } +} diff --git a/utils/src/commonMain/kotlin/ai/koog/utils/lang/StringExtensions.kt b/utils/src/commonMain/kotlin/ai/koog/utils/lang/StringExtensions.kt new file mode 100644 index 0000000000..7ce21b654d --- /dev/null +++ b/utils/src/commonMain/kotlin/ai/koog/utils/lang/StringExtensions.kt @@ -0,0 +1,25 @@ +package ai.koog.utils.lang + +/** + * Returns string with masked symbols using the strictest security level. + * + * This method is designed for masking sensitive secrets and provides maximum security + * by returning a consistent pattern that reveals no information about the original content. + * + * Examples: + * - `null` -> `null` + * - `""` -> `null` + * - `" "` -> `null` + * - `"I"` -> `"***HIDDEN***"` + * - `"Hi"` -> `"***HIDDEN***"` + * - `"Hello"` -> `"***HIDDEN***"` + * - `"VeryLongSecretKey123"` -> `"***HIDDEN***"` + */ +public fun String?.masked( + maskChar: Char = '*', +): String? { + if (this.isNullOrBlank()) return null + + // Strictest security level: consistent pattern regardless of input + return "${maskChar}${maskChar}${maskChar}HIDDEN${maskChar}${maskChar}$maskChar" +} diff --git a/utils/src/commonTest/kotlin/ai/koog/utils/lang/StringExtensionsTest.kt b/utils/src/commonTest/kotlin/ai/koog/utils/lang/StringExtensionsTest.kt new file mode 100644 index 0000000000..3c69fb02e9 --- /dev/null +++ b/utils/src/commonTest/kotlin/ai/koog/utils/lang/StringExtensionsTest.kt @@ -0,0 +1,45 @@ +package ai.koog.utils.lang + +import kotlin.test.Test +import kotlin.test.assertEquals + +class StringExtensionsTest { + + @Test + fun `String masked should hide string contents`() { + // Test null input + assertEquals(null, null.masked()) + + // Test empty string + assertEquals(null, "".masked()) + + // Test blank string + assertEquals(null, " ".masked()) + + // Test single character - strict security returns consistent pattern + assertEquals("***HIDDEN***", "I".masked()) + + // Test two characters - strict security returns consistent pattern + assertEquals("***HIDDEN***", "Hi".masked()) + + // Test three characters - strict security returns consistent pattern + assertEquals("***HIDDEN***", "Hey".masked()) + + // Test longer string - strict security returns consistent pattern + assertEquals("***HIDDEN***", "Hello".masked()) + + // Test very long string - strict security returns consistent pattern + assertEquals("***HIDDEN***", "cccccclulbkucigkbggivvitngjdbfuhkevedrdukvcr".masked()) + + // Test custom mask character - should use custom char in pattern + assertEquals("---HIDDEN---", "Hello".masked(maskChar = '-')) + + // Test string with whitespace - should still return consistent pattern + assertEquals("***HIDDEN***", " Hello ".masked()) + + // Test sensitive data - all return same pattern for maximum security + assertEquals("***HIDDEN***", "password123".masked()) + assertEquals("***HIDDEN***", "api-key-secret".masked()) + assertEquals("***HIDDEN***", "x".masked()) + } +}