From a474c7e51f89f30d30b3c85c9b804ddca3313d53 Mon Sep 17 00:00:00 2001 From: Briliantov Vadim Date: Tue, 30 Sep 2025 21:08:30 +0200 Subject: [PATCH 01/52] Allow adjusting context window sizes for Ollama dynamically (#883) --- .../tokenizer/feature/MessageTokenizer.kt | 110 +-------- .../tokenizer/feature/MessageTokenizerTest.kt | 84 ------- .../build.gradle.kts | 4 + .../ollama/client/ContextWindowStrategy.kt | 153 ++++++++++++ .../executor/ollama/client/OllamaClient.kt | 27 +- .../ollama/client/dto/OllamaConverters.kt | 8 - .../ollama/client/dto/OllamaModels.kt | 1 + .../client/ContextWindowStrategyTest.kt | 231 ++++++++++++++++++ prompt/prompt-tokenizer/build.gradle.kts | 1 + .../koog/prompt/tokenizer/PromptTokenizer.kt | 110 +++++++++ .../prompt/tokenizer/PromptTokenizerTest.kt | 133 ++++++++++ 11 files changed, 656 insertions(+), 206 deletions(-) create mode 100644 prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/ContextWindowStrategy.kt create mode 100644 prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonTest/kotlin/ai/koog/prompt/executor/ollama/client/ContextWindowStrategyTest.kt create mode 100644 prompt/prompt-tokenizer/src/commonMain/kotlin/ai/koog/prompt/tokenizer/PromptTokenizer.kt create mode 100644 prompt/prompt-tokenizer/src/commonTest/kotlin/ai/koog/prompt/tokenizer/PromptTokenizerTest.kt diff --git a/agents/agents-features/agents-features-tokenizer/src/commonMain/kotlin/ai/koog/agents/features/tokenizer/feature/MessageTokenizer.kt b/agents/agents-features/agents-features-tokenizer/src/commonMain/kotlin/ai/koog/agents/features/tokenizer/feature/MessageTokenizer.kt index 222a2c6cea..ab47e6ecea 100644 --- a/agents/agents-features/agents-features-tokenizer/src/commonMain/kotlin/ai/koog/agents/features/tokenizer/feature/MessageTokenizer.kt +++ b/agents/agents-features/agents-features-tokenizer/src/commonMain/kotlin/ai/koog/agents/features/tokenizer/feature/MessageTokenizer.kt @@ -8,9 +8,10 @@ import ai.koog.agents.core.feature.AIAgentGraphPipeline import ai.koog.agents.core.feature.AIAgentNonGraphFeature import ai.koog.agents.core.feature.AIAgentNonGraphPipeline import ai.koog.agents.core.feature.config.FeatureConfig -import ai.koog.prompt.dsl.Prompt -import ai.koog.prompt.message.Message +import ai.koog.prompt.tokenizer.CachingTokenizer import ai.koog.prompt.tokenizer.NoTokenizer +import ai.koog.prompt.tokenizer.OnDemandTokenizer +import ai.koog.prompt.tokenizer.PromptTokenizer import ai.koog.prompt.tokenizer.Tokenizer /** @@ -49,111 +50,6 @@ public class MessageTokenizerConfig : FeatureConfig() { public var enableCaching: Boolean = true } -/** - * An interface that provides utilities for tokenizing and calculating token usage in messages and prompts. - */ -public interface PromptTokenizer { - /** - * Calculates the number of tokens required for a given message. - * - * @param message The message for which the token count should be determined. - * @return The number of tokens required to encode the message. - */ - public fun tokenCountFor(message: Message): Int - - /** - * Calculates the total number of tokens spent in a given prompt. - * - * @param prompt The prompt for which the total tokens spent need to be calculated. - * @return The total number of tokens spent as an integer. - */ - public fun tokenCountFor(prompt: Prompt): Int -} - -/** - * An implementation of the [PromptTokenizer] interface that delegates token counting - * to an instance of the [Tokenizer] interface. The class provides methods to estimate - * the token count for individual messages and for the entirety of a prompt. - * - * This is useful in contexts where token-based costs or limitations are significant, - * such as when interacting with large language models (LLMs). - * - * @property tokenizer The [Tokenizer] instance used for token counting. - */ -public class OnDemandTokenizer(private val tokenizer: Tokenizer) : PromptTokenizer { - - /** - * Computes the number of tokens in a given message. - * - * @param message The message for which the token count needs to be calculated. - * The content of the message is analyzed to estimate the token count. - * @return The estimated number of tokens in the message content. - */ - public override fun tokenCountFor(message: Message): Int = tokenizer.countTokens(message.content) - - /** - * Calculates the total number of tokens spent for the given prompt based on its messages. - * - * @param prompt The `Prompt` instance containing the list of messages for which the total token count will be calculated. - * @return The total number of tokens across all messages in the prompt. - */ - public override fun tokenCountFor(prompt: Prompt): Int = prompt.messages.sumOf(::tokenCountFor) -} - -/** - * A caching implementation of the `PromptTokenizer` interface that optimizes token counting - * by storing previously computed token counts for messages. This reduces redundant computations - * when the same message is processed multiple times. - * - * @constructor Creates an instance of `CachingTokenizer` with a provided `Tokenizer` instance - * that performs the actual token counting. - * @property tokenizer The underlying `Tokenizer` used for counting tokens in the message content. - */ -public class CachingTokenizer(private val tokenizer: Tokenizer) : PromptTokenizer { - /** - * A cache that maps a `Message` to its corresponding token count. - * - * This is used to store the results of token computations for reuse, optimizing performance - * by avoiding repeated invocations of the token counting process on the same message content. - * - * Token counts are computed lazily and stored in the cache when requested via the `tokensFor` - * method. This cache can be cleared using the `clearCache` method. - */ - internal val cache = mutableMapOf() - - /** - * Retrieves the number of tokens contained in the content of the given message. - * This method utilizes caching to improve performance, storing previously - * computed token counts and reusing them for identical messages. - * - * @param message The message whose content's token count is to be retrieved - * @return The number of tokens in the content of the message - */ - public override fun tokenCountFor(message: Message): Int = cache.getOrPut(message) { - tokenizer.countTokens(message.content) - } - - /** - * Calculates the total number of tokens spent on the given prompt by summing the token usage - * of all messages associated with the prompt. - * - * @param prompt The prompt containing the list of messages whose token usage will be calculated. - * @return The total number of tokens spent across all messages in the provided prompt. - */ - public override fun tokenCountFor(prompt: Prompt): Int = prompt.messages.sumOf(::tokenCountFor) - - /** - * Clears all cached token counts from the internal cache. - * - * This method is useful when the state of the cached data becomes invalid - * or needs resetting. After calling this, any subsequent token count - * calculations will be recomputed rather than retrieved from the cache. - */ - public fun clearCache() { - cache.clear() - } -} - /** * The [MessageTokenizer] feature is responsible for handling tokenization of messages using a provided [Tokenizer] * implementation. It serves as a feature that can be installed into an `AIAgentPipeline`. The tokenizer behavior can be configured diff --git a/agents/agents-features/agents-features-tokenizer/src/jvmTest/kotlin/ai/koog/agents/features/tokenizer/feature/MessageTokenizerTest.kt b/agents/agents-features/agents-features-tokenizer/src/jvmTest/kotlin/ai/koog/agents/features/tokenizer/feature/MessageTokenizerTest.kt index 8fa8f7f2e6..037452f5f1 100644 --- a/agents/agents-features/agents-features-tokenizer/src/jvmTest/kotlin/ai/koog/agents/features/tokenizer/feature/MessageTokenizerTest.kt +++ b/agents/agents-features/agents-features-tokenizer/src/jvmTest/kotlin/ai/koog/agents/features/tokenizer/feature/MessageTokenizerTest.kt @@ -15,16 +15,10 @@ import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.agents.testing.tools.mockLLMAnswer import ai.koog.prompt.dsl.prompt import ai.koog.prompt.llm.OllamaModels -import ai.koog.prompt.message.Message -import ai.koog.prompt.message.RequestMetaInfo -import ai.koog.prompt.message.ResponseMetaInfo import ai.koog.prompt.tokenizer.Tokenizer import kotlinx.coroutines.runBlocking -import kotlinx.coroutines.test.runTest -import kotlinx.datetime.Clock import kotlin.test.Test import kotlin.test.assertEquals -import kotlin.test.assertTrue /** * Test for the MessageTokenizer feature. @@ -72,84 +66,6 @@ class MessageTokenizerTest { } } - @Test - fun testPromptTokenizer() = runTest { - // Create a mock tokenizer to track token usage - val mockTokenizer = MockTokenizer() - - // Create a prompt tokenizer with our mock tokenizer - val promptTokenizer = OnDemandTokenizer(mockTokenizer) - - // Create a prompt with some messages - val testPrompt = prompt("test-prompt") { - system("You are a helpful assistant.") - user("What is the capital of France?") - assistant("Paris is the capital of France.") - } - - // Count tokens in the prompt - val totalTokens = promptTokenizer.tokenCountFor(testPrompt) - - // Verify that tokens were counted - assertTrue(totalTokens > 0, "Total tokens should be greater than 0") - - // Verify that the tokenizer was used and counted tokens - assertTrue(mockTokenizer.totalTokens > 0, "Tokenizer should have counted tokens") - - // Verify that the total tokens match what we expect - assertEquals(totalTokens, mockTokenizer.totalTokens, "Total tokens should match the tokenizer's count") - - // Print the total tokens spent - println("[DEBUG_LOG] Total tokens spent: ${mockTokenizer.totalTokens}") - - val requestMetainfo = RequestMetaInfo.create(Clock.System) - val responseMetainfo = ResponseMetaInfo.create(Clock.System) - // Count tokens for individual messages - val systemTokens = promptTokenizer.tokenCountFor( - Message.System("You are a helpful assistant.", requestMetainfo) - ) - val userTokens = promptTokenizer.tokenCountFor(Message.User("What is the capital of France?", requestMetainfo)) - val assistantTokens = promptTokenizer.tokenCountFor( - Message.Assistant("Paris is the capital of France.", responseMetainfo) - ) - - // Print token counts for each message - println("[DEBUG_LOG] System message tokens: $systemTokens") - println("[DEBUG_LOG] User message tokens: $userTokens") - println("[DEBUG_LOG] Assistant message tokens: $assistantTokens") - - // Verify that the sum of individual message tokens equals the total - val sumOfMessageTokens = systemTokens + userTokens + assistantTokens - assertEquals(sumOfMessageTokens, totalTokens, "Sum of message tokens should equal total tokens") - } - - @Test - fun testCachingPromptTokenizer() = runTest { - // Create a mock tokenizer to track token usage - val mockTokenizer = MockTokenizer() - - // Create a prompt tokenizer with our mock tokenizer - val promptTokenizer = CachingTokenizer(mockTokenizer) - - // Create a prompt with some messages - val testPrompt = prompt("test-prompt") { - system("You are a helpful assistant.") - user("What is the capital of France?") - assistant("Paris is the capital of France.") - } - - assertEquals(0, promptTokenizer.cache.size) - promptTokenizer.tokenCountFor(testPrompt) - assertEquals(3, promptTokenizer.cache.size) - promptTokenizer.clearCache() - assertEquals(0, promptTokenizer.cache.size) - promptTokenizer.tokenCountFor(testPrompt.messages[1]) - promptTokenizer.tokenCountFor(testPrompt.messages[2]) - assertEquals(2, promptTokenizer.cache.size) - promptTokenizer.tokenCountFor(testPrompt) - assertEquals(3, promptTokenizer.cache.size) - } - @Test fun testTokenizerInAgents() { val testToolRegistry = ToolRegistry { diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/build.gradle.kts b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/build.gradle.kts index 23dce9c48e..cbabf98aa8 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/build.gradle.kts +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/build.gradle.kts @@ -15,6 +15,7 @@ kotlin { api(project(":agents:agents-tools")) api(project(":prompt:prompt-llm")) api(project(":prompt:prompt-model")) + api(project(":prompt:prompt-tokenizer")) api(project(":agents:agents-tools")) api(project(":prompt:prompt-executor:prompt-executor-model")) api(project(":prompt:prompt-executor:prompt-executor-clients")) @@ -64,6 +65,9 @@ kotlin { dependencies { implementation(project(":test-utils")) implementation(project(":agents:agents-features:agents-features-event-handler")) + implementation(libs.kotlinx.coroutines.core) + implementation(libs.kotlinx.coroutines.test) + implementation(libs.ktor.client.mock) } } diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/ContextWindowStrategy.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/ContextWindowStrategy.kt new file mode 100644 index 0000000000..8d9c378916 --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/ContextWindowStrategy.kt @@ -0,0 +1,153 @@ +package ai.koog.prompt.executor.ollama.client + +import ai.koog.prompt.dsl.Prompt +import ai.koog.prompt.llm.LLModel +import ai.koog.prompt.tokenizer.PromptTokenizer +import io.github.oshai.kotlinlogging.KotlinLogging + +private val logger = KotlinLogging.logger { } + +/** + * Represents a strategy for computing the context window length for `OllamaClient`. + * Different implementations define specific approaches to computing the context window length. + * Based on the context window length computed by this strategy, Ollama will truncate the context window accordingly. + * + * To decide the context window length, Ollama proceeds as follows: + * - If a `num_ctx` parameter is specified in the chat request, the context window length is set to that value. + * - If the model definition contains a `num_ctx` parameter, the context window length is set to that value. + * - If an `OLLAMA_CONTEXT_LENGTH` environment variable is set, the context window length is set to that value. + * - Otherwise, the context window length is set to the default value of 2048. + * + * Effectively, this strategy allows you to specify what `num_ctx` value will be set in chat requests sent to Ollama, + * for a given prompt and model. + * + * Important: You will want to have a context window length that does not change often for a specific model. + * Indeed, Ollama will reload the model every time the context window length changes. + * + * Example implementations: + * - [ContextWindowStrategy.None] + * - [ContextWindowStrategy.Fixed] + * - [ContextWindowStrategy.FitPrompt] + */ +public interface ContextWindowStrategy { + + /** + * Computes the context length for a given prompt and language model. + * This may involve calculating the number of tokens used in the prompt + * and determining if it fits within the model's context length constraints. + * + * @param prompt The [Prompt] containing the list of messages, unique identifier, + * and language model parameters that describe the input for the LLM. + * @param model The [LLModel] representing the language model used to process the prompt, + * which includes its provider, identifier, capabilities, and context length. + * @return The context length as a [Long], indicating the number of tokens used + * in the prompt, or `null` if it cannot be calculated. + */ + public fun computeContextLength(prompt: Prompt, model: LLModel): Long? + + /** + * Provides companion object-related strategies for determining the context window length. + * It contains multiple strategies that are implemented as subtypes of [ContextWindowStrategy]. + */ + public companion object { + /** + * A strategy for letting the Ollama server decide the context window length. + * To decide the context window length, Ollama proceeds as follows: + * - If the model definition contains a `num_ctx` parameter, the context window length is set to that value. + * - If an `OLLAMA_CONTEXT_LENGTH` environment variable is set, the context window length is set to that value. + * - Otherwise, the context window length is set to the default value of 2048. + */ + public data object None : ContextWindowStrategy { + override fun computeContextLength(prompt: Prompt, model: LLModel): Long? = null + } + + /** + * A strategy for specifying a fixed context window length. + * If the given [contextLength] is more than the maximum context window length supported by the model, + * the context window length will be set to the maximum context window length supported by the model. + * + * @param contextLength The context window length to use. + */ + public data class Fixed(val contextLength: Long) : ContextWindowStrategy { + init { + require(contextLength > 0) { "Context length must be positive but was: $contextLength" } + } + + override fun computeContextLength(prompt: Prompt, model: LLModel): Long { + if (contextLength > model.contextLength) { + logger.warn { + "Context length $contextLength was more than what is supported by model '${model.id}'," + + " falling back to the model's maximum context length ${model.contextLength}" + } + return model.contextLength + } + return contextLength + } + } + + /** + * A strategy for computing the context window length based on the prompt length. + * + * @param promptTokenizer The [PromptTokenizer] to use for computing the prompt length, + * or null to use the last reported token usage. + * @param contextChunkSize The granularity to use for computing the context window length. Defaults to 2048. + * @param minimumChunkCount The minimum number of context chunks in the context. + * @param maximumChunkCount The maximum number of context chunks in the context. + * + * Example: contextChunkSize = 512, minimumChunkCount = 2, maximumChunkCount = 4, + * then [minimumContextLength] = 1024 and [maximumContextLength] = 2048 + */ + public data class FitPrompt( + val promptTokenizer: PromptTokenizer? = null, + val contextChunkSize: Long = 2048, + val minimumChunkCount: Long? = null, + val maximumChunkCount: Long? = null + ) : ContextWindowStrategy { + + private val minimumContextLength: Long? = minimumChunkCount?.let { cnt -> cnt * contextChunkSize } + private val maximumContextLength: Long? = maximumChunkCount?.let { cnt -> cnt * contextChunkSize } + + init { + require(contextChunkSize > 0) { "`contextChunkSize`` must be greater than 0" } + require(minimumChunkCount == null || minimumChunkCount > 0) { + "`minimumChunkCount` must be a positive number or `null`" + } + + if (minimumChunkCount != null && maximumChunkCount != null) { + require(minimumChunkCount <= maximumChunkCount) { + "`maximumChunkCount` ($maximumChunkCount) must be greater or equal" + + " to `minimumChunkCount` ($minimumChunkCount)" + } + } + } + + override fun computeContextLength(prompt: Prompt, model: LLModel): Long? { + val promptLength = when { + promptTokenizer != null -> promptTokenizer.tokenCountFor(prompt) + prompt.latestTokenUsage != 0 -> prompt.latestTokenUsage + else -> null + } + + if (promptLength == null) return minimumContextLength + + if (maximumContextLength != null && promptLength > maximumContextLength) { + logger.warn { + "Prompt length $promptLength was more than " + + "the maximum context length $maximumContextLength provideded" + } + return maximumContextLength + } + + if (promptLength > model.contextLength) { + logger.warn { + "Prompt length $promptLength was more than the maximum context length of model '${model.id}'," + + " falling back to the model's maximum context length ${model.contextLength}" + } + return model.contextLength + } + + return (promptLength / contextChunkSize + 1) * contextChunkSize + } + } + } +} diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/OllamaClient.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/OllamaClient.kt index cd3bde49e4..1565ba06a1 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/OllamaClient.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/OllamaClient.kt @@ -20,7 +20,6 @@ import ai.koog.prompt.executor.ollama.client.dto.OllamaPullModelResponseDTO import ai.koog.prompt.executor.ollama.client.dto.OllamaShowModelRequestDTO import ai.koog.prompt.executor.ollama.client.dto.OllamaShowModelResponseDTO import ai.koog.prompt.executor.ollama.client.dto.extractOllamaJsonFormat -import ai.koog.prompt.executor.ollama.client.dto.extractOllamaOptions import ai.koog.prompt.executor.ollama.client.dto.getToolCalls import ai.koog.prompt.executor.ollama.client.dto.toOllamaChatMessages import ai.koog.prompt.executor.ollama.client.dto.toOllamaModelCard @@ -57,19 +56,23 @@ import kotlinx.serialization.json.Json /** * Client for interacting with the Ollama API with comprehensive model support. * + * Implements: + * - [LLMClient] for executing prompts and streaming responses. + * - [LLMEmbeddingProvider] for generating embeddings from input text. + * * @param baseUrl The base URL of the Ollama server. Defaults to "http://localhost:11434". * @param baseClient The underlying HTTP client used for making requests. * @param timeoutConfig Configuration for connection, request, and socket timeouts. * @param clock Clock instance used for tracking response metadata timestamps. - * Implements: - * - LLMClient for executing prompts and streaming responses. - * - LLMEmbeddingProvider for generating embeddings from input text. + * @param contextWindowStrategy The [ContextWindowStrategy] to use for computing context window lengths. + * Defaults to [ContextWindowStrategy.None]. */ public class OllamaClient( public val baseUrl: String = "http://localhost:11434", baseClient: HttpClient = HttpClient(engineFactoryProvider()), timeoutConfig: ConnectionTimeoutConfig = ConnectionTimeoutConfig(), - private val clock: Clock = Clock.System + private val clock: Clock = Clock.System, + private val contextWindowStrategy: ContextWindowStrategy = ContextWindowStrategy.Companion.None, ) : LLMClient, LLMEmbeddingProvider { private companion object { @@ -159,7 +162,7 @@ public class OllamaClient( messages = prompt.toOllamaChatMessages(model), tools = if (tools.isNotEmpty()) tools.map { it.toOllamaTool() } else null, format = prompt.extractOllamaJsonFormat(), - options = prompt.extractOllamaOptions(), + options = extractOllamaOptions(prompt, model), stream = false, additionalProperties = prompt.params.additionalProperties ) @@ -239,7 +242,7 @@ public class OllamaClient( OllamaChatRequestDTO( model = model.id, messages = prompt.toOllamaChatMessages(model), - options = prompt.extractOllamaOptions(), + options = extractOllamaOptions(prompt, model), stream = true, additionalProperties = prompt.params.additionalProperties, ) @@ -274,6 +277,16 @@ public class OllamaClient( } } + /** + * Prepare Ollama chat request options from the given prompt and model. + */ + internal fun extractOllamaOptions(prompt: Prompt, model: LLModel): OllamaChatRequestDTO.Options { + return OllamaChatRequestDTO.Options( + temperature = prompt.params.temperature, + numCtx = contextWindowStrategy.computeContextLength(prompt, model), + ) + } + /** * Embeds the given text using the Ollama model. * diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/dto/OllamaConverters.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/dto/OllamaConverters.kt index ba4d390d23..65d9916f07 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/dto/OllamaConverters.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/dto/OllamaConverters.kt @@ -111,14 +111,6 @@ internal fun Prompt.extractOllamaJsonFormat(): JsonObject? { return if (schema is LLMParams.Schema.JSON) schema.schema else null } -/** - * Extracts options from the prompt, if temperature is defined. - */ -internal fun Prompt.extractOllamaOptions(): OllamaChatRequestDTO.Options? { - val temperature = params.temperature - return temperature?.let { OllamaChatRequestDTO.Options(temperature = temperature) } -} - /** * Extracts tool calls from a ChatMessage. * Returns the first tool call for compatibility, but logs if multiple calls exist. diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/dto/OllamaModels.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/dto/OllamaModels.kt index f9a7da7c73..8b29db836b 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/dto/OllamaModels.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonMain/kotlin/ai/koog/prompt/executor/ollama/client/dto/OllamaModels.kt @@ -72,6 +72,7 @@ internal data class OllamaChatRequestDTO( @Serializable internal data class Options( val temperature: Double? = null, + @SerialName("num_ctx") val numCtx: Long? = null, ) } diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonTest/kotlin/ai/koog/prompt/executor/ollama/client/ContextWindowStrategyTest.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonTest/kotlin/ai/koog/prompt/executor/ollama/client/ContextWindowStrategyTest.kt new file mode 100644 index 0000000000..f56f7441f4 --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-ollama-client/src/commonTest/kotlin/ai/koog/prompt/executor/ollama/client/ContextWindowStrategyTest.kt @@ -0,0 +1,231 @@ +package ai.koog.prompt.executor.ollama.client + +import ai.koog.prompt.dsl.Prompt +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.executor.ollama.client.dto.OllamaChatMessageDTO +import ai.koog.prompt.executor.ollama.client.dto.OllamaChatRequestDTO +import ai.koog.prompt.executor.ollama.client.dto.OllamaChatResponseDTO +import ai.koog.prompt.llm.OllamaModels +import ai.koog.prompt.message.Message +import ai.koog.prompt.message.ResponseMetaInfo +import ai.koog.prompt.tokenizer.PromptTokenizer +import io.ktor.client.HttpClient +import io.ktor.client.engine.mock.MockEngine +import io.ktor.client.engine.mock.respond +import io.ktor.client.request.HttpRequestData +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.ktor.http.content.TextContent +import io.ktor.http.headersOf +import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Clock +import kotlinx.serialization.json.Json +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertNull + +class ContextWindowStrategyTest { + @Test + fun `test None strategy`() = runTest { + val mockServer = MockOllamaChatServer { request -> makeDummyResponse(request) } + + val ollamaClient = OllamaClient( + baseClient = HttpClient(mockServer.mockEngine), + contextWindowStrategy = ContextWindowStrategy.Companion.None, + ) + + ollamaClient.execute( + prompt = prompt("test-prompt") { }, + model = OllamaModels.Meta.LLAMA_3_2, + ) + + val requestHistory = mockServer.requestHistory + assertEquals(requestHistory.size, 1) + + val response = requestHistory.first() + assertNotNull(response.options) + assertNull(response.options.numCtx) + } + + @Test + fun `test Fixed strategy`() = runTest { + val mockServer = MockOllamaChatServer { request -> makeDummyResponse(request) } + + val ollamaClient = OllamaClient( + baseClient = HttpClient(mockServer.mockEngine), + contextWindowStrategy = ContextWindowStrategy.Companion.Fixed(42), + ) + + ollamaClient.execute( + prompt = prompt("test-prompt") { }, + model = OllamaModels.Meta.LLAMA_3_2, + ) + + val requestHistory = mockServer.requestHistory + assertEquals(requestHistory.size, 1) + + val response = requestHistory.first() + assertNotNull(response.options) + assertEquals(42, response.options.numCtx) + } + + @Test + fun `test FitPrompt strategy with tokenizer`() = runTest { + val mockServer = MockOllamaChatServer { request -> makeDummyResponse(request) } + + val ollamaClient = OllamaClient( + baseClient = HttpClient(mockServer.mockEngine), + contextWindowStrategy = ContextWindowStrategy.Companion.FitPrompt( + promptTokenizer = object : PromptTokenizer { + override fun tokenCountFor(message: Message): Int = error("Not needed") + override fun tokenCountFor(prompt: Prompt): Int = 3000 + }, + contextChunkSize = 1024, + minimumChunkCount = 2 + ), + ) + + ollamaClient.execute( + prompt = prompt("test-prompt") { }, + model = OllamaModels.Meta.LLAMA_3_2, + ) + + val requestHistory = mockServer.requestHistory + assertEquals(requestHistory.size, 1) + + val response = requestHistory.first() + assertNotNull(response.options) + assertEquals(3072, response.options.numCtx) + } + + @Test + fun `test FitPrompt strategy without tokenizer and no previous token usage`() = runTest { + val mockServer = MockOllamaChatServer { request -> makeDummyResponse(request) } + + val ollamaClient = OllamaClient( + baseClient = HttpClient(mockServer.mockEngine), + contextWindowStrategy = ContextWindowStrategy.Companion.FitPrompt( + promptTokenizer = null, + contextChunkSize = 1024, + minimumChunkCount = 2 + ), + ) + + ollamaClient.execute( + prompt = prompt("test-prompt") { }, + model = OllamaModels.Meta.LLAMA_3_2, + ) + + val requestHistory = mockServer.requestHistory + assertEquals(requestHistory.size, 1) + + val response = requestHistory.first() + assertNotNull(response.options) + assertEquals(2048, response.options.numCtx) + } + + @Test + fun `test FitPrompt strategy without tokenizer and existing token usage`() = runTest { + val mockServer = MockOllamaChatServer { request -> makeDummyResponse(request) } + + val ollamaClient = OllamaClient( + baseClient = HttpClient(mockServer.mockEngine), + contextWindowStrategy = ContextWindowStrategy.Companion.FitPrompt( + promptTokenizer = null, + contextChunkSize = 1024, + minimumChunkCount = 2 + ), + ) + + ollamaClient.execute( + prompt = prompt("test-prompt") { + message( + Message.Assistant( + "Dummy message", + metaInfo = ResponseMetaInfo( + timestamp = Clock.System.now(), + totalTokensCount = 5000, + ) + ) + ) + }, + model = OllamaModels.Meta.LLAMA_3_2, + ) + + val requestHistory = mockServer.requestHistory + assertEquals(requestHistory.size, 1) + + val response = requestHistory.first() + assertNotNull(response.options) + assertEquals(5120, response.options.numCtx) + } + + @Test + fun `test FitPrompt strategy with tokenizer and too long prompt`() = runTest { + val mockServer = MockOllamaChatServer { request -> makeDummyResponse(request) } + + val ollamaClient = OllamaClient( + baseClient = HttpClient(mockServer.mockEngine), + contextWindowStrategy = ContextWindowStrategy.Companion.FitPrompt( + promptTokenizer = object : PromptTokenizer { + override fun tokenCountFor(message: Message): Int = error("Not needed") + override fun tokenCountFor(prompt: Prompt): Int = 9000 + }, + contextChunkSize = 1024, + minimumChunkCount = 2 + ), + ) + + ollamaClient.execute( + prompt = prompt("test-prompt") { }, + model = OllamaModels.Meta.LLAMA_3_2.copy( + contextLength = 8192 + ), + ) + + val requestHistory = mockServer.requestHistory + assertEquals(requestHistory.size, 1) + + val response = requestHistory.first() + assertNotNull(response.options) + assertEquals(8192, response.options.numCtx) + } +} + +private fun makeDummyResponse( + request: OllamaChatRequestDTO, + content: String = "OK", + promptEvalCount: Int = 10, + evalCount: Int = 100, +): OllamaChatResponseDTO = OllamaChatResponseDTO( + model = request.model, + message = OllamaChatMessageDTO(role = "assistant", content = content), + done = true, + promptEvalCount = promptEvalCount, + evalCount = evalCount, +) + +private class MockOllamaChatServer( + private val handler: (OllamaChatRequestDTO) -> OllamaChatResponseDTO, +) { + val mockEngine = MockEngine { requestData -> + val request = requestData.extractChatRequest() + val response = handler(request) + respond( + content = Json.encodeToString(response), + status = HttpStatusCode.OK, + headers = headersOf(HttpHeaders.ContentType to listOf("application/json")), + ) + } + + val requestHistory: List + get() = mockEngine.requestHistory.map { it.extractChatRequest() } + + private fun HttpRequestData.extractChatRequest(): OllamaChatRequestDTO { + val requestContent = body as TextContent + val requestBody = requestContent.text + val request = Json.decodeFromString(requestBody) + return request + } +} diff --git a/prompt/prompt-tokenizer/build.gradle.kts b/prompt/prompt-tokenizer/build.gradle.kts index c55e56f0cf..b033af0990 100644 --- a/prompt/prompt-tokenizer/build.gradle.kts +++ b/prompt/prompt-tokenizer/build.gradle.kts @@ -13,6 +13,7 @@ kotlin { commonMain { dependencies { api(project(":prompt:prompt-llm")) + api(project(":prompt:prompt-model")) api(libs.kotlinx.serialization.json) api(libs.kotlinx.datetime) } diff --git a/prompt/prompt-tokenizer/src/commonMain/kotlin/ai/koog/prompt/tokenizer/PromptTokenizer.kt b/prompt/prompt-tokenizer/src/commonMain/kotlin/ai/koog/prompt/tokenizer/PromptTokenizer.kt new file mode 100644 index 0000000000..0002dc7f80 --- /dev/null +++ b/prompt/prompt-tokenizer/src/commonMain/kotlin/ai/koog/prompt/tokenizer/PromptTokenizer.kt @@ -0,0 +1,110 @@ +package ai.koog.prompt.tokenizer + +import ai.koog.prompt.dsl.Prompt +import ai.koog.prompt.message.Message +import kotlin.collections.sumOf + +/** + * An interface that provides utilities for tokenizing and calculating token usage in messages and prompts. + */ +public interface PromptTokenizer { + /** + * Calculates the number of tokens required for a given message. + * + * @param message The message for which the token count should be determined. + * @return The number of tokens required to encode the message. + */ + public fun tokenCountFor(message: Message): Int + + /** + * Calculates the total number of tokens spent in a given prompt. + * + * @param prompt The prompt for which the total tokens spent need to be calculated. + * @return The total number of tokens spent as an integer. + */ + public fun tokenCountFor(prompt: Prompt): Int +} + +/** + * An implementation of the [PromptTokenizer] interface that delegates token counting + * to an instance of the [Tokenizer] interface. The class provides methods to estimate + * the token count for individual messages and for the entirety of a prompt. + * + * This is useful in contexts where token-based costs or limitations are significant, + * such as when interacting with large language models (LLMs). + * + * @property tokenizer The [Tokenizer] instance used for token counting. + */ +public class OnDemandTokenizer(private val tokenizer: Tokenizer) : PromptTokenizer { + + /** + * Computes the number of tokens in a given message. + * + * @param message The message for which the token count needs to be calculated. + * The content of the message is analyzed to estimate the token count. + * @return The estimated number of tokens in the message content. + */ + public override fun tokenCountFor(message: Message): Int = tokenizer.countTokens(message.content) + + /** + * Calculates the total number of tokens spent for the given prompt based on its messages. + * + * @param prompt The `Prompt` instance containing the list of messages for which the total token count will be calculated. + * @return The total number of tokens across all messages in the prompt. + */ + public override fun tokenCountFor(prompt: Prompt): Int = prompt.messages.sumOf(::tokenCountFor) +} + +/** + * A caching implementation of the `PromptTokenizer` interface that optimizes token counting + * by storing previously computed token counts for messages. This reduces redundant computations + * when the same message is processed multiple times. + * + * @constructor Creates an instance of `CachingTokenizer` with a provided `Tokenizer` instance + * that performs the actual token counting. + * @property tokenizer The underlying `Tokenizer` used for counting tokens in the message content. + */ +public class CachingTokenizer(private val tokenizer: Tokenizer) : PromptTokenizer { + /** + * A cache that maps a `Message` to its corresponding token count. + * + * This is used to store the results of token computations for reuse, optimizing performance + * by avoiding repeated invocations of the token counting process on the same message content. + * + * Token counts are computed lazily and stored in the cache when requested via the `tokensFor` + * method. This cache can be cleared using the `clearCache` method. + */ + internal val cache = mutableMapOf() + + /** + * Retrieves the number of tokens contained in the content of the given message. + * This method utilizes caching to improve performance, storing previously + * computed token counts and reusing them for identical messages. + * + * @param message The message whose content's token count is to be retrieved + * @return The number of tokens in the content of the message + */ + public override fun tokenCountFor(message: Message): Int = cache.getOrPut(message) { + tokenizer.countTokens(message.content) + } + + /** + * Calculates the total number of tokens spent on the given prompt by summing the token usage + * of all messages associated with the prompt. + * + * @param prompt The prompt containing the list of messages whose token usage will be calculated. + * @return The total number of tokens spent across all messages in the provided prompt. + */ + public override fun tokenCountFor(prompt: Prompt): Int = prompt.messages.sumOf(::tokenCountFor) + + /** + * Clears all cached token counts from the internal cache. + * + * This method is useful when the state of the cached data becomes invalid + * or needs resetting. After calling this, any subsequent token count + * calculations will be recomputed rather than retrieved from the cache. + */ + public fun clearCache() { + cache.clear() + } +} diff --git a/prompt/prompt-tokenizer/src/commonTest/kotlin/ai/koog/prompt/tokenizer/PromptTokenizerTest.kt b/prompt/prompt-tokenizer/src/commonTest/kotlin/ai/koog/prompt/tokenizer/PromptTokenizerTest.kt new file mode 100644 index 0000000000..0f68ab6359 --- /dev/null +++ b/prompt/prompt-tokenizer/src/commonTest/kotlin/ai/koog/prompt/tokenizer/PromptTokenizerTest.kt @@ -0,0 +1,133 @@ +package ai.koog.prompt.tokenizer + +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.message.Message +import ai.koog.prompt.message.RequestMetaInfo +import ai.koog.prompt.message.ResponseMetaInfo +import kotlinx.datetime.Clock +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +/** + * Test for the PromptTokenizer implementations. + */ +class PromptTokenizerTest { + + /** + * A mock tokenizer that tracks the total tokens counted. + * + * This implementation counts tokens by simply counting characters and dividing by 4, + * which is a very rough approximation but sufficient for testing purposes. + * It also keeps track of the total tokens counted across all calls. + */ + class MockTokenizer : Tokenizer { + private var _totalTokens = 0 + + /** + * The total number of tokens counted across all calls to countTokens. + */ + val totalTokens: Int + get() = _totalTokens + + /** + * Counts tokens by simply counting characters and dividing by 4. + * Also adds to the running total of tokens counted. + * + * @param text The text to tokenize + * @return The estimated number of tokens in the text + */ + override fun countTokens(text: String): Int { + // Simple approximation: 1 token ≈ 4 characters + println("countTokens: $text") + val tokens = (text.length / 4) + 1 + _totalTokens += tokens + return tokens + } + + /** + * Resets the total tokens counter to 0. + */ + fun reset() { + _totalTokens = 0 + } + } + + @Test + fun testPromptTokenizer() { + // Create a mock tokenizer to track token usage + val mockTokenizer = MockTokenizer() + + // Create a prompt tokenizer with our mock tokenizer + val promptTokenizer = OnDemandTokenizer(mockTokenizer) + + // Create a prompt with some messages + val testPrompt = prompt("test-prompt") { + system("You are a helpful assistant.") + user("What is the capital of France?") + assistant("Paris is the capital of France.") + } + + // Count tokens in the prompt + val totalTokens = promptTokenizer.tokenCountFor(testPrompt) + + // Verify that tokens were counted + assertTrue(totalTokens > 0, "Total tokens should be greater than 0") + + // Verify that the tokenizer was used and counted tokens + assertTrue(mockTokenizer.totalTokens > 0, "Tokenizer should have counted tokens") + + // Verify that the total tokens match what we expect + assertEquals(totalTokens, mockTokenizer.totalTokens, "Total tokens should match the tokenizer's count") + + // Print the total tokens spent + println("[DEBUG_LOG] Total tokens spent: ${mockTokenizer.totalTokens}") + + val requestMetainfo = RequestMetaInfo.create(Clock.System) + val responseMetainfo = ResponseMetaInfo.create(Clock.System) + // Count tokens for individual messages + val systemTokens = promptTokenizer.tokenCountFor( + Message.System("You are a helpful assistant.", requestMetainfo) + ) + val userTokens = promptTokenizer.tokenCountFor(Message.User("What is the capital of France?", requestMetainfo)) + val assistantTokens = promptTokenizer.tokenCountFor( + Message.Assistant("Paris is the capital of France.", responseMetainfo) + ) + + // Print token counts for each message + println("[DEBUG_LOG] System message tokens: $systemTokens") + println("[DEBUG_LOG] User message tokens: $userTokens") + println("[DEBUG_LOG] Assistant message tokens: $assistantTokens") + + // Verify that the sum of individual message tokens equals the total + val sumOfMessageTokens = systemTokens + userTokens + assistantTokens + assertEquals(sumOfMessageTokens, totalTokens, "Sum of message tokens should equal total tokens") + } + + @Test + fun testCachingPromptTokenizer() { + // Create a mock tokenizer to track token usage + val mockTokenizer = MockTokenizer() + + // Create a prompt tokenizer with our mock tokenizer + val promptTokenizer = CachingTokenizer(mockTokenizer) + + // Create a prompt with some messages + val testPrompt = prompt("test-prompt") { + system("You are a helpful assistant.") + user("What is the capital of France?") + assistant("Paris is the capital of France.") + } + + assertEquals(0, promptTokenizer.cache.size) + promptTokenizer.tokenCountFor(testPrompt) + assertEquals(3, promptTokenizer.cache.size) + promptTokenizer.clearCache() + assertEquals(0, promptTokenizer.cache.size) + promptTokenizer.tokenCountFor(testPrompt.messages[1]) + promptTokenizer.tokenCountFor(testPrompt.messages[2]) + assertEquals(2, promptTokenizer.cache.size) + promptTokenizer.tokenCountFor(testPrompt) + assertEquals(3, promptTokenizer.cache.size) + } +} From 5a05fed68f753549af28c768e02b4debeed7cfb3 Mon Sep 17 00:00:00 2001 From: Konstantin Pavlov <1517853+kpavlov@users.noreply.github.com> Date: Tue, 30 Sep 2025 22:24:42 +0300 Subject: [PATCH 02/52] Refactor spring boot starters (#886) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refactor LLM client auto-configuration structure - Removed `KoogAutoConfiguration` and related legacy configuration files. - Introduced modular `LLM` provider-specific auto-configuration classes (e.g., `OpenAILLMAutoConfiguration`). - Updated test setup to align with the new modular configuration. - Enhanced property validation with `enabled` flag support for finer-grained control. - Introduced a `String.masked` extension to obfuscate sensitive string parts. Added test coverage for various input scenarios, including custom mask characters and edge cases. - Updated `KoogAutoConfigurationTest` to include scenarios where providers are explicitly disabled. - Refactored parameterized tests to use `textBlock = PROVIDERS` with `@CsvSource`. - Adjusted `ApplicationContextRunner` configuration logic for clean test scenarios. - Introduced `ignoreUnknownFields` to all `ConfigurationProperties` to improve flexibility with unmapped fields. - Expanded property documentation for all LLM clients (e.g., OpenAI, Google, Anthropic) to include detailed descriptions, usage examples, and retry behavior. - Improved autowiring safety with `@ConditionalOnBean` for LLM-specific executors. - Updated `spring-boot.md` to detail `enabled` flags and their default behavior for LLM providers. - Improved clarity by reformatting the supported provider list. - Added explanations for default base URLs and explicit activation for non-API-key providers (e.g., Ollama). A blog post about writing Spring Boot starters is [here](https://kpavlov.me/blog/spring-boot-starters) **NB! Tests need a split and some love, but this PR is already big enough. I will refactor them separately.** ## Motivation and Context - The Spring Boot configuration in general is not very Spring Boot idiomatic - The current Spring Boot starter does not expose LLM Clients. In order to get the LLM client in the application, developers have to use reflection to a get a field from the executor 🤯 - All configuration is encapsulated in one file, which makes extensibility difficult (SRP is violated). - Defaults are hardcoded in classes instead of being provided in the resources - Ollama is enabled by the presence of the base URL, which is interesting. - There is no explicit way to disable an LLM provider: no `enabled=true` property ## Breaking Changes - AutoConfiguration classes moved to their respective packages - Ollama provider needs `ai.koog.ollama.enabled=true` to be set explicitly in order to activate. --- #### Type of the changes - [ ] New feature (non-breaking change which adds functionality) - [ ] Bug fix (non-breaking change which fixes an issue) - [x] Breaking change (fix or feature that would cause existing functionality to change) - [x] Documentation update - [x] Tests improvement - [x] Refactoring #### Checklist - [x] The pull request has a description of the proposed change - [ ] I read the [Contributing Guidelines](https://github.com/JetBrains/koog/blob/main/CONTRIBUTING.md) before opening the pull request - [x] The pull request uses **`develop`** as the base branch - [ ] Tests for the changes have been added - [ ] All new and existing tests passed ##### Additional steps for pull requests adding a new feature - [ ] An issue describing the proposed change exists - [ ] The pull request includes a link to the issue - [ ] The change was discussed and approved in the issue - [x] Docs have been added / updated --- docs/docs/spring-boot.md | 30 ++- .../ai/koog/spring/AnthropicKoogProperties.kt | 31 --- .../ai/koog/spring/DeepSeekKoogProperties.kt | 31 --- .../ai/koog/spring/GoogleKoogProperties.kt | 31 --- .../kotlin/ai/koog/spring/KoogAutoConfig.kt | 167 -------------- .../ai/koog/spring/OllamaKoogProperties.kt | 29 --- .../ai/koog/spring/OpenAIKoogProperties.kt | 31 --- .../koog/spring/OpenRouterKoogProperties.kt | 31 --- .../koog/spring/RetryConfigKoogProperties.kt | 19 -- .../clients/KoogLlmClientProperties.kt | 19 ++ .../clients/RetryConfigKoogProperties.kt | 29 +++ .../anthropic/AnthropicKoogProperties.kt | 49 ++++ .../AnthropicLLMAutoConfiguration.kt | 76 +++++++ .../deepseek/DeepSeekKoogProperties.kt | 55 +++++ .../deepseek/DeepSeekLLMAutoConfiguration.kt | 75 +++++++ .../clients/google/GoogleKoogProperties.kt | 83 +++++++ .../google/GoogleLLMAutoConfiguration.kt | 75 +++++++ .../clients/ollama/OllamaKoogProperties.kt | 53 +++++ .../ollama/OllamaLLMAutoConfiguration.kt | 71 ++++++ .../clients/openai/OpenAIKoogProperties.kt | 51 +++++ .../openai/OpenAILLMAutoConfiguration.kt | 73 ++++++ .../openrouter/OpenRouterKoogProperties.kt | 56 +++++ .../OpenRouterLLMAutoConfiguration.kt | 67 ++++++ .../spring/prompt/executor/clients/utils.kt | 27 +++ .../config/koog/anthropic-llm.properties | 2 + .../config/koog/deepseek-llm.properties | 2 + .../config/koog/google-llm.properties | 2 + .../config/koog/ollama-llm.properties | 2 + .../config/koog/openai-llm.properties | 2 + .../config/koog/openrouter-llm.properties | 2 + ...ot.autoconfigure.AutoConfiguration.imports | 7 +- .../koog/spring/KoogAutoConfigurationTest.kt | 210 +++++++++++++----- .../test/resources/junit-platform.properties | 5 + .../ai/koog/utils/lang/StringExtensions.kt | 25 +++ .../koog/utils/lang/StringExtensionsTest.kt | 45 ++++ 35 files changed, 1134 insertions(+), 429 deletions(-) delete mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/AnthropicKoogProperties.kt delete mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/DeepSeekKoogProperties.kt delete mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/GoogleKoogProperties.kt delete mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/KoogAutoConfig.kt delete mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OllamaKoogProperties.kt delete mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OpenAIKoogProperties.kt delete mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OpenRouterKoogProperties.kt delete mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/RetryConfigKoogProperties.kt create mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/KoogLlmClientProperties.kt create mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/RetryConfigKoogProperties.kt create mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/anthropic/AnthropicKoogProperties.kt create mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/anthropic/AnthropicLLMAutoConfiguration.kt create mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/deepseek/DeepSeekKoogProperties.kt create mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/deepseek/DeepSeekLLMAutoConfiguration.kt create mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/google/GoogleKoogProperties.kt create mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/google/GoogleLLMAutoConfiguration.kt create mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/ollama/OllamaKoogProperties.kt create mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/ollama/OllamaLLMAutoConfiguration.kt create mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openai/OpenAIKoogProperties.kt create mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openai/OpenAILLMAutoConfiguration.kt create mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openrouter/OpenRouterKoogProperties.kt create mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openrouter/OpenRouterLLMAutoConfiguration.kt create mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/utils.kt create mode 100644 koog-spring-boot-starter/src/main/resources/META-INF/config/koog/anthropic-llm.properties create mode 100644 koog-spring-boot-starter/src/main/resources/META-INF/config/koog/deepseek-llm.properties create mode 100644 koog-spring-boot-starter/src/main/resources/META-INF/config/koog/google-llm.properties create mode 100644 koog-spring-boot-starter/src/main/resources/META-INF/config/koog/ollama-llm.properties create mode 100644 koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openai-llm.properties create mode 100644 koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openrouter-llm.properties create mode 100644 koog-spring-boot-starter/src/test/resources/junit-platform.properties create mode 100644 utils/src/commonMain/kotlin/ai/koog/utils/lang/StringExtensions.kt create mode 100644 utils/src/commonTest/kotlin/ai/koog/utils/lang/StringExtensionsTest.kt diff --git a/docs/docs/spring-boot.md b/docs/docs/spring-boot.md index 91644653bc..2490d88543 100644 --- a/docs/docs/spring-boot.md +++ b/docs/docs/spring-boot.md @@ -6,8 +6,13 @@ agents into your Spring Boot applications with minimal setup. ## Overview The `koog-spring-boot-starter` automatically configures LLM clients based on your application properties and provides -ready-to-use beans for dependency injection. It supports all major LLM providers including OpenAI, Anthropic, Google, -OpenRouter, DeepSeek, and Ollama. +ready-to-use beans for dependency injection. It supports all major LLM providers including: +- OpenAI +- Anthropic +- Google +- OpenRouter +- DeepSeek +- Ollama ## Getting Started @@ -27,21 +32,27 @@ Configure your preferred LLM providers in `application.properties`: ```properties # OpenAI Configuration +ai.koog.openai.enabled=true ai.koog.openai.api-key=${OPENAI_API_KEY} ai.koog.openai.base-url=https://api.openai.com # Anthropic Configuration +ai.koog.anthropic.enabled=true ai.koog.anthropic.api-key=${ANTHROPIC_API_KEY} ai.koog.anthropic.base-url=https://api.anthropic.com # Google Configuration +ai.koog.google.enabled=true ai.koog.google.api-key=${GOOGLE_API_KEY} ai.koog.google.base-url=https://generativelanguage.googleapis.com # OpenRouter Configuration +ai.koog.openrouter.enabled=true ai.koog.openrouter.api-key=${OPENROUTER_API_KEY} ai.koog.openrouter.base-url=https://openrouter.ai # DeepSeek Configuration +ai.koog.deepseek.enabled=true ai.koog.deepseek.api-key=${DEEPSEEK_API_KEY} ai.koog.deepseek.base-url=https://api.deepseek.com # Ollama Configuration (local - no API key required) +ai.koog.ollama.enabled=true ai.koog.ollama.base-url=http://localhost:11434 ``` @@ -51,24 +62,39 @@ Or using YAML format (`application.yml`): ai: koog: openai: + enabled: true api-key: ${OPENAI_API_KEY} base-url: https://api.openai.com anthropic: + enabled: true api-key: ${ANTHROPIC_API_KEY} base-url: https://api.anthropic.com google: + enabled: true api-key: ${GOOGLE_API_KEY} base-url: https://generativelanguage.googleapis.com openrouter: + enabled: true api-key: ${OPENROUTER_API_KEY} base-url: https://openrouter.ai deepseek: + enabled: true api-key: ${DEEPSEEK_API_KEY} base-url: https://api.deepseek.com ollama: + enabled: true # Set it to `true` explicitly to activate !!! base-url: http://localhost:11434 ``` +Both `ai.koog.PROVIDER.api-key` and `ai.koog.PROVIDER.enabled` properties are used to activate the provider. + +If the provider supports the API Key (like OpenAI, Anthropic, Google), then `ai.koog.PROVIDER.enabled` is set to `true` by default. + +If the provider does not support the API Key, like Ollama, `ai.koog.PROVIDER.enabled` is set to `false` by default, +and provider should be enabled explicitly in the application configuration. + +Provider's base urls are set to their default values in the Spring Boot starter, but you may override it in your application. + !!! tip "Environment Variables" It's recommended to use environment variables for API keys to keep them secure and out of version control. diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/AnthropicKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/AnthropicKoogProperties.kt deleted file mode 100644 index 1ae9c00d25..0000000000 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/AnthropicKoogProperties.kt +++ /dev/null @@ -1,31 +0,0 @@ -package ai.koog.spring - -import org.springframework.boot.context.properties.ConfigurationProperties - -/** - * Configuration properties for the Koog library used for integrating with Anthropic LLM provider. - * These properties are used in conjunction with the [KoogAutoConfiguration] auto-configuration class to initialize and - * configure respective client implementations. - * - * Configuration prefix: `ai.koog.anthropic` - * - * @param apiKey The API key used to authenticate requests to the provider's service - * @param baseUrl The base URL of the provider's API endpoint. By default, it is set to `https://api.anthropic.com` - */ -@ConfigurationProperties(prefix = AnthropicKoogProperties.PREFIX) -public class AnthropicKoogProperties( - public val apiKey: String = "", - public val baseUrl: String = "https://api.anthropic.com", - public val retry: RetryConfigKoogProperties? = null -) { - /** - * Companion object for the AnthropicKoogProperties class, providing constant values and - * utilities associated with the configuration of Anthropic-related properties. - */ - public companion object Companion { - /** - * Prefix constant used for configuration Anthropic-related properties in the Koog framework. - */ - public const val PREFIX: String = "ai.koog.anthropic" - } -} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/DeepSeekKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/DeepSeekKoogProperties.kt deleted file mode 100644 index b39a1f39bd..0000000000 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/DeepSeekKoogProperties.kt +++ /dev/null @@ -1,31 +0,0 @@ -package ai.koog.spring - -import org.springframework.boot.context.properties.ConfigurationProperties - -/** - * Configuration properties for the Koog library used for integrating with DeepSeek LLM provider. - * These properties are used in conjunction with the [KoogAutoConfiguration] auto-configuration class to initialize and - * configure respective client implementations. - * - * Configuration prefix: `ai.koog.deepseek` - * - * @param apiKey The API key used to authenticate requests to the provider's service - * @param baseUrl The base URL of the provider's API endpoint. By default, it is set to `https://api.deepseek.com` - */ -@ConfigurationProperties(prefix = DeepSeekKoogProperties.PREFIX) -public class DeepSeekKoogProperties( - public val apiKey: String = "", - public val baseUrl: String = "https://api.deepseek.com", - public val retry: RetryConfigKoogProperties? = null -) { - /** - * Companion object for the DeepSeekKoogProperties class, providing constant values and - * utilities associated with the configuration of DeepSeek-related properties. - */ - public companion object Companion { - /** - * Prefix constant used for configuration DeepSeek-related properties in the Koog framework. - */ - public const val PREFIX: String = "ai.koog.deepseek" - } -} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/GoogleKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/GoogleKoogProperties.kt deleted file mode 100644 index 12805c67f6..0000000000 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/GoogleKoogProperties.kt +++ /dev/null @@ -1,31 +0,0 @@ -package ai.koog.spring - -import org.springframework.boot.context.properties.ConfigurationProperties - -/** - * Configuration properties for the Koog library used for integrating with Google LLM provider. - * These properties are used in conjunction with the [KoogAutoConfiguration] auto-configuration class to initialize and - * configure respective client implementations. - * - * Configuration prefix: `ai.koog.google` - * - * @param apiKey The API key used to authenticate requests to the provider's service - * @param baseUrl The base URL of the provider's API endpoint. By default, it is set to `https://generativelanguage.googleapis.com` - */ -@ConfigurationProperties(prefix = GoogleKoogProperties.PREFIX) -public class GoogleKoogProperties( - public val apiKey: String = "", - public val baseUrl: String = "https://generativelanguage.googleapis.com", - public val retry: RetryConfigKoogProperties? = null -) { - /** - * Companion object for the GoogleKoogProperties class, providing constant values and - * utilities associated with the configuration of Google-related properties. - */ - public companion object Companion { - /** - * Prefix constant used for configuration Google-related properties in the Koog framework. - */ - public const val PREFIX: String = "ai.koog.google" - } -} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/KoogAutoConfig.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/KoogAutoConfig.kt deleted file mode 100644 index dbae6b8bd1..0000000000 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/KoogAutoConfig.kt +++ /dev/null @@ -1,167 +0,0 @@ -package ai.koog.spring - -import ai.koog.prompt.executor.clients.LLMClient -import ai.koog.prompt.executor.clients.anthropic.AnthropicClientSettings -import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient -import ai.koog.prompt.executor.clients.deepseek.DeepSeekClientSettings -import ai.koog.prompt.executor.clients.deepseek.DeepSeekLLMClient -import ai.koog.prompt.executor.clients.google.GoogleClientSettings -import ai.koog.prompt.executor.clients.google.GoogleLLMClient -import ai.koog.prompt.executor.clients.openai.OpenAIClientSettings -import ai.koog.prompt.executor.clients.openai.OpenAILLMClient -import ai.koog.prompt.executor.clients.openrouter.OpenRouterClientSettings -import ai.koog.prompt.executor.clients.openrouter.OpenRouterLLMClient -import ai.koog.prompt.executor.clients.retry.RetryConfig -import ai.koog.prompt.executor.clients.retry.RetryingLLMClient -import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor -import ai.koog.prompt.executor.ollama.client.OllamaClient -import org.springframework.boot.autoconfigure.AutoConfiguration -import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty -import org.springframework.boot.context.properties.EnableConfigurationProperties -import org.springframework.context.annotation.Bean -import kotlin.time.toKotlinDuration - -/** - * [KoogAutoConfiguration] is a Spring Boot auto-configuration class that configures and provides beans - * for various LLM (Large Language Model) provider clients. It ensures that the beans are only - * created if the corresponding properties are defined in the application's configuration. - * - * This configuration includes support for Anthropic, Google, Ollama, OpenAI, DeepSeek, and OpenRouter providers. - * Each provider is configured with specific settings and logic encapsulated within a - * [SingleLLMPromptExecutor] instance backed by a respective client implementation. - */ -@AutoConfiguration -@EnableConfigurationProperties( - OpenAIKoogProperties::class, - AnthropicKoogProperties::class, - GoogleKoogProperties::class, - OllamaKoogProperties::class, - DeepSeekKoogProperties::class, - OpenRouterKoogProperties::class -) -public class KoogAutoConfiguration { - - /** - * Creates and configures a [SingleLLMPromptExecutor] using an [AnthropicLLMClient]. - * This is conditioned on the presence of an API key in the application properties. - * - * @param properties The configuration properties containing settings for the Anthropic client. - * @return An instance of [SingleLLMPromptExecutor] configured with [AnthropicLLMClient]. - */ - @Bean - @ConditionalOnProperty(prefix = AnthropicKoogProperties.PREFIX, name = ["api-key"]) - public fun anthropicExecutor(properties: AnthropicKoogProperties): SingleLLMPromptExecutor { - val client = AnthropicLLMClient( - apiKey = properties.apiKey, - settings = AnthropicClientSettings(baseUrl = properties.baseUrl) - ) - return SingleLLMPromptExecutor(getRetryingClientOrDefault(client, properties.retry)) - } - - /** - * Provides a [SingleLLMPromptExecutor] bean configured with a [GoogleLLMClient] using the settings - * from the given `KoogProperties`. The bean is only created if the `google.api-key` property is set. - * - * @param properties The configuration properties containing the `googleClientProperties` needed to create the client. - * @return A [SingleLLMPromptExecutor] instance configured with a [GoogleLLMClient]. - */ - @Bean - @ConditionalOnProperty(prefix = GoogleKoogProperties.PREFIX, name = ["api-key"]) - public fun googleExecutor(properties: GoogleKoogProperties): SingleLLMPromptExecutor { - val client = GoogleLLMClient( - apiKey = properties.apiKey, - settings = GoogleClientSettings(baseUrl = properties.baseUrl) - ) - return SingleLLMPromptExecutor(getRetryingClientOrDefault(client, properties.retry)) - } - - /** - * Creates and configures a [SingleLLMPromptExecutor] instance using Ollama properties. - * - * The method initializes an [OllamaClient] with the base URL derived from the provided [OllamaKoogProperties] - * and uses it to construct the [SingleLLMPromptExecutor]. - * - * @param properties the configuration properties containing Ollama client settings such as the base URL. - * @return a [SingleLLMPromptExecutor] configured to use the Ollama client. - */ - @Bean - @ConditionalOnProperty(prefix = OllamaKoogProperties.PREFIX, name = ["base-url"]) - public fun ollamaExecutor(properties: OllamaKoogProperties): SingleLLMPromptExecutor { - val client = OllamaClient(baseUrl = properties.baseUrl) - return SingleLLMPromptExecutor(getRetryingClientOrDefault(client, properties.retry)) - } - - /** - * Provides a bean of type [SingleLLMPromptExecutor] configured for OpenAI interaction. - * The bean will only be instantiated if the property `ai.koog.openai.api-key` is defined in the application properties. - * - * @param properties The configuration properties containing OpenAI-specific client settings such as API key and base URL. - * @return An instance of [SingleLLMPromptExecutor] initialized with the OpenAI client. - */ - @Bean - @ConditionalOnProperty(prefix = OpenAIKoogProperties.PREFIX, name = ["api-key"]) - public fun openAIExecutor(properties: OpenAIKoogProperties): SingleLLMPromptExecutor { - val client = OpenAILLMClient( - apiKey = properties.apiKey, - settings = OpenAIClientSettings(baseUrl = properties.baseUrl) - ) - return SingleLLMPromptExecutor(getRetryingClientOrDefault(client, properties.retry)) - } - - /** - * Creates a [SingleLLMPromptExecutor] bean configured to use the OpenRouter LLM client. - * - * This method is only executed if the `openrouter.api-key` property is defined in the application's configuration. - * It initializes the OpenRouter client using the provided API key and base URL from the application's properties. - * - * @param properties The configuration properties for the application, including the OpenRouter client settings. - * @return A [SingleLLMPromptExecutor] initialized with an OpenRouter LLM client. - */ - @Bean - @ConditionalOnProperty(prefix = OpenRouterKoogProperties.PREFIX, name = ["api-key"]) - public fun openRouterExecutor(properties: OpenRouterKoogProperties): SingleLLMPromptExecutor { - val client = OpenRouterLLMClient( - apiKey = properties.apiKey, - settings = OpenRouterClientSettings(baseUrl = properties.baseUrl) - ) - return SingleLLMPromptExecutor(getRetryingClientOrDefault(client, properties.retry)) - } - - /** - * Creates a [SingleLLMPromptExecutor] bean configured to use the DeepSeek LLM client. - * - * This method is only executed if the `deepseek.api-key` property is defined in the application's configuration. - * It initializes the DeepSeek client using the provided API key and base URL from the application's properties. - * - * @param properties The configuration properties for the application, including the DeepSeek client settings. - * @return A [SingleLLMPromptExecutor] initialized with an DeepSeek LLM client. - */ - @Bean - @ConditionalOnProperty(prefix = DeepSeekKoogProperties.PREFIX, name = ["api-key"]) - public fun deepSeekExecutor(properties: DeepSeekKoogProperties): SingleLLMPromptExecutor { - val client = DeepSeekLLMClient( - apiKey = properties.apiKey, - settings = DeepSeekClientSettings(baseUrl = properties.baseUrl) - ) - return SingleLLMPromptExecutor(getRetryingClientOrDefault(client, properties.retry)) - } - - private fun getRetryingClientOrDefault(client: LLMClient, properties: RetryConfigKoogProperties?): LLMClient { - return if (properties?.enabled == true) { - val defaultConfig = RetryConfig() - val retryConfig = RetryConfig( - maxAttempts = properties.maxAttempts ?: defaultConfig.maxAttempts, - initialDelay = properties.initialDelay?.toKotlinDuration() ?: defaultConfig.initialDelay, - maxDelay = properties.maxDelay?.toKotlinDuration() ?: defaultConfig.maxDelay, - backoffMultiplier = properties.backoffMultiplier ?: defaultConfig.backoffMultiplier, - jitterFactor = properties.jitterFactor ?: defaultConfig.jitterFactor - ) - RetryingLLMClient( - delegate = client, - config = retryConfig - ) - } else { - client - } - } -} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OllamaKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OllamaKoogProperties.kt deleted file mode 100644 index 18d23dfe59..0000000000 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OllamaKoogProperties.kt +++ /dev/null @@ -1,29 +0,0 @@ -package ai.koog.spring - -import org.springframework.boot.context.properties.ConfigurationProperties - -/** - * Configuration properties for the Koog library used for integrating with Ollama LLM provider. - * These properties are used in conjunction with the [KoogAutoConfiguration] auto-configuration class to initialize and - * configure respective client implementations. - * - * Configuration prefix: `ai.koog.ollama` - * - * @param baseUrl The base URL of the provider's API endpoint. By default, it is set to `http://localhost:11434` - */ -@ConfigurationProperties(prefix = OllamaKoogProperties.PREFIX) -public class OllamaKoogProperties( - public val baseUrl: String = "http://localhost:11434", - public val retry: RetryConfigKoogProperties? = null -) { - /** - * Companion object for the OllamaKoogProperties class, providing constant values and - * utilities associated with the configuration of Ollama-related properties. - */ - public companion object Companion { - /** - * Prefix constant used for configuration Ollama-related properties in the Koog framework. - */ - public const val PREFIX: String = "ai.koog.ollama" - } -} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OpenAIKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OpenAIKoogProperties.kt deleted file mode 100644 index 8b5196ad6a..0000000000 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OpenAIKoogProperties.kt +++ /dev/null @@ -1,31 +0,0 @@ -package ai.koog.spring - -import org.springframework.boot.context.properties.ConfigurationProperties - -/** - * Configuration properties for the Koog library used for integrating with OpenAI LLM provider. - * These properties are used in conjunction with the [KoogAutoConfiguration] auto-configuration class to initialize and - * configure respective client implementations. - * - * Configuration prefix: `ai.koog.openai` - * - * @param apiKey The API key used to authenticate requests to the provider's service - * @param baseUrl The base URL of the provider's API endpoint. By default, it is set to `https://api.openai.com` - */ -@ConfigurationProperties(prefix = OpenAIKoogProperties.PREFIX) -public class OpenAIKoogProperties( - public val apiKey: String = "", - public val baseUrl: String = "https://api.openai.com", - public val retry: RetryConfigKoogProperties? = null -) { - /** - * Companion object for the OpenAIKoogProperties class, providing constant values and - * utilities associated with the configuration of OpenAI-related properties. - */ - public companion object { - /** - * Prefix constant used for configuration OpenAI-related properties in the Koog framework. - */ - public const val PREFIX: String = "ai.koog.openai" - } -} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OpenRouterKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OpenRouterKoogProperties.kt deleted file mode 100644 index b2b5640ff8..0000000000 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/OpenRouterKoogProperties.kt +++ /dev/null @@ -1,31 +0,0 @@ -package ai.koog.spring - -import org.springframework.boot.context.properties.ConfigurationProperties - -/** - * Configuration properties for the Koog library used for integrating with OpenRouter LLM provider. - * These properties are used in conjunction with the [KoogAutoConfiguration] auto-configuration class to initialize and - * configure respective client implementations. - * - * Configuration prefix: `ai.koog.openrouter` - * - * @param apiKey The API key used to authenticate requests to the provider's service - * @param baseUrl The base URL of the provider's API endpoint. By default, it is set to `https://openrouter.ai` - */ -@ConfigurationProperties(prefix = OpenRouterKoogProperties.PREFIX) -public class OpenRouterKoogProperties( - public val apiKey: String = "", - public val baseUrl: String = "https://openrouter.ai", - public val retry: RetryConfigKoogProperties? = null -) { - /** - * Companion object for the OpenRouterKoogProperties class, providing constant values and - * utilities associated with the configuration of OpenRouter-related properties. - */ - public companion object Companion { - /** - * Prefix constant used for configuration OpenRouter-related properties in the Koog framework. - */ - public const val PREFIX: String = "ai.koog.openrouter" - } -} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/RetryConfigKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/RetryConfigKoogProperties.kt deleted file mode 100644 index abaac3b678..0000000000 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/RetryConfigKoogProperties.kt +++ /dev/null @@ -1,19 +0,0 @@ -package ai.koog.spring - -import org.springframework.boot.convert.DurationUnit -import java.time.Duration -import java.time.temporal.ChronoUnit - -/** - * Configuration properties for the Koog library used for LLM clients retry configuration. - */ -public class RetryConfigKoogProperties( - public val enabled: Boolean = false, - public val maxAttempts: Int? = null, - @param:DurationUnit(ChronoUnit.SECONDS) - public val initialDelay: Duration? = null, - @param:DurationUnit(ChronoUnit.SECONDS) - public val maxDelay: Duration? = null, - public val backoffMultiplier: Double? = null, - public val jitterFactor: Double? = null -) diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/KoogLlmClientProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/KoogLlmClientProperties.kt new file mode 100644 index 0000000000..a48069c4b4 --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/KoogLlmClientProperties.kt @@ -0,0 +1,19 @@ +package ai.koog.spring.prompt.executor.clients + +import ai.koog.spring.RetryConfigKoogProperties + +/** + * Interface representing configuration properties for a Koog LLM Client. + * + * This interface is intended to provide the necessary configuration required to set up a LLM Client. + * It includes options for enabling the client, specifying the base URL, and defining retry configurations. + * + * @param enabled Indicates whether the LLM client is enabled. + * @param baseUrl Specifies the base URL for the LLM client. + * @param retry Defines the retry configuration for the LLM client using [RetryConfigKoogProperties]. + */ +public interface KoogLlmClientProperties { + public val enabled: Boolean + public val baseUrl: String + public val retry: RetryConfigKoogProperties? +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/RetryConfigKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/RetryConfigKoogProperties.kt new file mode 100644 index 0000000000..09ff05dcca --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/RetryConfigKoogProperties.kt @@ -0,0 +1,29 @@ +package ai.koog.spring + +import org.springframework.boot.convert.DurationUnit +import java.time.Duration +import java.time.temporal.ChronoUnit + +/** + * Represents configuration properties for retry mechanisms associated with clients. + * + * This class is used to define retry behavior, including specifications such as the number of retry + * attempts, delay between attempts, and mechanisms to control backoff strategy and randomness in delays. + * + * @property enabled Indicates whether retries are enabled. + * @property maxAttempts Specifies the maximum number of retry attempts. + * @property initialDelay Specifies the initial delay before the first retry attempt, in seconds. + * @property maxDelay Specifies the maximum delay allowed between retry attempts, in seconds. + * @property backoffMultiplier Defines the multiplier to apply to the delay for each subsequent retry attempt. + * @property jitterFactor Specifies the factor to introduce randomness in the delay calculations to avoid symmetric loads. + */ +public class RetryConfigKoogProperties( + public val enabled: Boolean = false, + public val maxAttempts: Int? = null, + @param:DurationUnit(ChronoUnit.SECONDS) + public val initialDelay: Duration? = null, + @param:DurationUnit(ChronoUnit.SECONDS) + public val maxDelay: Duration? = null, + public val backoffMultiplier: Double? = null, + public val jitterFactor: Double? = null +) diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/anthropic/AnthropicKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/anthropic/AnthropicKoogProperties.kt new file mode 100644 index 0000000000..07ad9d507f --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/anthropic/AnthropicKoogProperties.kt @@ -0,0 +1,49 @@ +package ai.koog.spring.prompt.executor.clients.anthropic + +import ai.koog.spring.RetryConfigKoogProperties +import ai.koog.spring.prompt.executor.clients.KoogLlmClientProperties +import org.springframework.boot.context.properties.ConfigurationProperties + +/** + * Configuration properties for configuring Anthropic-related clients in the Koog framework. + * + * This class allows defining settings necessary for integrating with the Anthropic LLM (Large Language Model) + * client. It implements [KoogLlmClientProperties] and includes common LLM client configurations such as `enabled`, + * `baseUrl`, and retry options. Additionally, it includes the `apiKey` property specific to the Anthropic client. + * + * The properties are bound to the configuration prefix defined by [AnthropicKoogProperties.PREFIX], which is + * `ai.koog.anthropic`. This allows configuring the client via property files in a Spring Boot application. + * + * @property enabled Indicates whether the Anthropic client is enabled. If `false`, the client will not be configured. + * @property apiKey The API key used to authenticate requests to the Anthropic API. + * @property baseUrl The base URL of the Anthropic API for sending requests. + * @property retry Retry configuration for the client in case of failed or timeout requests. This is optional. + */ +@ConfigurationProperties(prefix = AnthropicKoogProperties.PREFIX, ignoreUnknownFields = true) +public class AnthropicKoogProperties( + public override val enabled: Boolean, + public val apiKey: String, + public override val baseUrl: String, + public override val retry: RetryConfigKoogProperties? = null +) : KoogLlmClientProperties { + /** + * Companion object providing constant value associated with the configuration of Anthropic-related properties. + */ + public companion object Companion { + /** + * Prefix constant used for configuration Anthropic-related properties in the Koog framework. + */ + public const val PREFIX: String = "ai.koog.anthropic" + } + + /** + * Returns a string representation of the `AnthropicKoogProperties` object. + * + * The representation includes the state of its properties: `enabled`, `apiKey`(masked), `baseUrl`, and `retry`. + * + * @return A string containing the property values of the `AnthropicKoogProperties` object. + */ + override fun toString(): String { + return "AnthropicKoogProperties(enabled=$enabled, apiKey='$apiKey', baseUrl='$baseUrl', retry=$retry)" + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/anthropic/AnthropicLLMAutoConfiguration.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/anthropic/AnthropicLLMAutoConfiguration.kt new file mode 100644 index 0000000000..baef3d28a9 --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/anthropic/AnthropicLLMAutoConfiguration.kt @@ -0,0 +1,76 @@ +package ai.koog.spring.prompt.executor.clients.anthropic + +import ai.koog.prompt.executor.clients.anthropic.AnthropicClientSettings +import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor +import ai.koog.spring.prompt.executor.clients.toRetryingClient +import org.slf4j.LoggerFactory +import org.springframework.boot.autoconfigure.AutoConfiguration +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty +import org.springframework.boot.context.properties.EnableConfigurationProperties +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.PropertySource + +/** + * Auto-configuration class for Anthropic LLM integration in a Spring Boot application. + * + * This class automatically configures the required beans for interacting with the Anthropic LLM + * when the appropriate configuration properties are set in the application. It specifically checks + * for the presence of the `ai.koog.anthropic.enabled` and `ai.koog.anthropic.api-key` properties. + * + * Beans provided by this configuration: + * - [AnthropicLLMClient]: Configured client for interacting with the Anthropic API. + * - [SingleLLMPromptExecutor]: Prompt executor that utilizes the configured Anthropic client. + * + * To enable this configuration, the `ai.koog.anthropic.enabled` property must be set to `true` and a valid `api-key` + * must be provided in the application's property files. + * + * This configuration reads additional properties from the `classpath:META-INF/config/koog/anthropic-llm.properties` + * and binds them to the [AnthropicKoogProperties]. + * + * @property properties Anthropic-specific configuration properties, automatically injected by Spring's + * configuration properties mechanism. + */ +@AutoConfiguration +@PropertySource("classpath:/META-INF/config/koog/anthropic-llm.properties") +@EnableConfigurationProperties( + AnthropicKoogProperties::class, +) +@ConditionalOnProperty(prefix = AnthropicKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") +@ConditionalOnProperty(prefix = AnthropicKoogProperties.PREFIX, name = ["api-key"]) +public class AnthropicLLMAutoConfiguration( + private val properties: AnthropicKoogProperties +) { + + private val logger = LoggerFactory.getLogger(AnthropicLLMAutoConfiguration::class.java) + + /** + * Creates and initializes an instance of [AnthropicLLMClient] with the specified API key and settings from the + * application properties. The client is configured to interact with the Anthropic LLM API using the provided + * base URL and credentials. + * + * @return An instance of [AnthropicLLMClient] configured for communication with the Anthropic API. + */ + @Bean + public fun anthropicLLMClient(): AnthropicLLMClient { + logger.info("Initializing AnthropicLLMClient with: $properties") + return AnthropicLLMClient( + apiKey = properties.apiKey, + settings = AnthropicClientSettings(baseUrl = properties.baseUrl) + ) + } + + /** + * Creates and initializes a [SingleLLMPromptExecutor] instance using an [AnthropicLLMClient]. + * The executor is configured with a retrying client derived from the provided AnthropicLLMClient. + * + * @param client An instance of [AnthropicLLMClient] used to communicate with the Anthropic LLM API. + * @return An instance of [SingleLLMPromptExecutor] for sending prompts to the Anthropic LLM API. + */ + @Bean + @ConditionalOnBean(AnthropicLLMClient::class) + public fun anthropicExecutor(client: AnthropicLLMClient): SingleLLMPromptExecutor { + return SingleLLMPromptExecutor(client.toRetryingClient(properties.retry)) + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/deepseek/DeepSeekKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/deepseek/DeepSeekKoogProperties.kt new file mode 100644 index 0000000000..0e9060bdc4 --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/deepseek/DeepSeekKoogProperties.kt @@ -0,0 +1,55 @@ +package ai.koog.spring.prompt.executor.clients.deepseek + +import ai.koog.spring.RetryConfigKoogProperties +import ai.koog.spring.prompt.executor.clients.KoogLlmClientProperties +import ai.koog.spring.prompt.executor.clients.deepseek.DeepSeekKoogProperties.Companion.PREFIX +import ai.koog.utils.lang.masked +import org.springframework.boot.context.properties.ConfigurationProperties + +/** + * Configuration properties class for DeepSeek LLM provider integration within the Koog framework. + * + * This class is used to define and manage application-level configuration parameters for connecting + * to the DeepSeek provider. It includes properties such as API key, base URL, and optional retry settings. + * + * The properties are auto-configured via Spring Boot's `@ConfigurationProperties` using the `ai.koog.deepseek` prefix. + * + * Implements the [KoogLlmClientProperties] interface, which provides base attributes for all LLM client property configurations. + * + * Properties from this class are typically consumed by auto-configuration classes, such as [DeepSeekLLMAutoConfiguration], + * to initialize and configure the necessary beans for working with the DeepSeek API. + * + * @property enabled Indicates whether DeepSeek API integration is enabled (true or false). + * @property apiKey An API key string required to authenticate requests to the DeepSeek external service. + * @property baseUrl The base URL endpoint for DeepSeek API calls. + * @property retry Optional retry configuration for API requests, represented by [RetryConfigKoogProperties]. + */ +@ConfigurationProperties(prefix = PREFIX, ignoreUnknownFields = true) +public class DeepSeekKoogProperties( + public override val enabled: Boolean, + public val apiKey: String, + public override val baseUrl: String, + public override val retry: RetryConfigKoogProperties? = null +) : KoogLlmClientProperties { + /** + * Companion object for the DeepSeekKoogProperties class, providing constant values and + * utilities associated with the configuration of DeepSeek-related properties. + */ + public companion object Companion { + /** + * Prefix constant used for configuration DeepSeek-related properties in the Koog framework. + */ + public const val PREFIX: String = "ai.koog.deepseek" + } + + /** + * Returns a string representation of the DeepSeekKoogProperties object. + * The string includes information about the `enabled` status, a masked representation of the `apiKey`, + * the `baseUrl`, and the `retry` configuration. + * + * @return A string summarizing the current configuration properties. + */ + override fun toString(): String { + return "DeepSeekKoogProperties(enabled=$enabled, apiKey='$${apiKey.masked()}', baseUrl='$baseUrl', retry=$retry)" + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/deepseek/DeepSeekLLMAutoConfiguration.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/deepseek/DeepSeekLLMAutoConfiguration.kt new file mode 100644 index 0000000000..60f40ee54b --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/deepseek/DeepSeekLLMAutoConfiguration.kt @@ -0,0 +1,75 @@ +package ai.koog.spring.prompt.executor.clients.deepseek + +import ai.koog.prompt.executor.clients.deepseek.DeepSeekClientSettings +import ai.koog.prompt.executor.clients.deepseek.DeepSeekLLMClient +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor +import ai.koog.spring.prompt.executor.clients.toRetryingClient +import org.springframework.boot.autoconfigure.AutoConfiguration +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty +import org.springframework.boot.context.properties.EnableConfigurationProperties +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.PropertySource + +/** + * Auto-configuration class for integrating with the DeepSeek LLM provider within a Spring Boot application. + * + * This configuration enables the auto-wiring of required beans when the appropriate application + * properties are set. The configuration ensures that the [DeepSeekLLMClient] is properly initialized + * and available for usage in the application. + * + * The following conditions must be met for this configuration to be activated: + * - The property `ai.koog.deepseek.api-key` must be defined in the application configuration. + * - The property `ai.koog.deepseek.enabled` must have a value of `true`. + * + * Properties used: + * - `ai.koog.deepseek.api-key`: API key required to authenticate requests to the DeepSeek API. + * - `ai.koog.deepseek.base-url`: Base URL of the DeepSeek API, with a default value of `https://api.deepseek.com`. + * - `ai.koog.deepseek.retry`: Retry configuration settings for failed requests. + * + * Beans provided: + * - [DeepSeekLLMClient]: A client for interacting with the DeepSeek API. + * - [SingleLLMPromptExecutor]: A bean for executing single-step LLM prompts using the DeepSeek client. + * + * @property properties [DeepSeekKoogProperties] containing the configuration properties for the DeepSeek client. + * + * @see DeepSeekKoogProperties + * @see DeepSeekLLMClient + * @see SingleLLMPromptExecutor + */ +@AutoConfiguration +@PropertySource("classpath:/META-INF/config/koog/deepseek-llm.properties") +@EnableConfigurationProperties( + DeepSeekKoogProperties::class, +) +@ConditionalOnProperty(prefix = DeepSeekKoogProperties.PREFIX, name = ["api-key"]) +@ConditionalOnProperty(prefix = DeepSeekKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") +public class DeepSeekLLMAutoConfiguration( + private val properties: DeepSeekKoogProperties +) { + + @Bean + public fun deepSeekLLMClient(): DeepSeekLLMClient { + return DeepSeekLLMClient( + apiKey = properties.apiKey, + settings = DeepSeekClientSettings(baseUrl = properties.baseUrl) + ) + } + + /** + * Creates a [SingleLLMPromptExecutor] bean configured to use the DeepSeek LLM client. + * + * This method is only executed if the `deepseek.api-key` property is defined in the application's configuration. + * It initializes the DeepSeek client using the provided API key and base URL from the application's properties. + * + * @param properties The configuration properties for the application, including the DeepSeek client settings. + * @return A [SingleLLMPromptExecutor] initialized with an DeepSeek LLM client. + */ + @Bean + public fun deepSeekExecutor(client: DeepSeekLLMClient): SingleLLMPromptExecutor { + val client = DeepSeekLLMClient( + apiKey = properties.apiKey, + settings = DeepSeekClientSettings(baseUrl = properties.baseUrl) + ) + return SingleLLMPromptExecutor(client.toRetryingClient(properties.retry)) + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/google/GoogleKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/google/GoogleKoogProperties.kt new file mode 100644 index 0000000000..cf8610b7fc --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/google/GoogleKoogProperties.kt @@ -0,0 +1,83 @@ +package ai.koog.spring.prompt.executor.clients.google + +import ai.koog.spring.RetryConfigKoogProperties +import ai.koog.spring.prompt.executor.clients.KoogLlmClientProperties +import ai.koog.utils.lang.masked +import org.springframework.boot.context.properties.ConfigurationProperties + +/** + * Configuration properties for integrating with Google's LLM services in the Koog framework. + * + * This class provides the necessary settings for enabling and configuring access + * to Google's LLM API, including authentication and retry behavior. + * The configuration is mapped using the prefix `ai.koog.google`. + * + * Parameters: + * @param enabled Indicates whether the Google LLM integration is enabled. + * @param apiKey The API key used for authenticating requests to Google's services. + * @param baseUrl The base URL of the Google LLM API. + * @param retry Optional configuration for retrying failed API calls. + * + * Usage: + * These properties are automatically bound to the Spring environment when specified + * in application configuration (e.g., `application.yml` or `application.properties`). + * + * Example configuration snippet in `application.yml` or `application.properties`: + * ```properties + * ai.koog.google.enabled=true + * ai.koog.google.api-key=your-google-api-key + * ai.koog.google.base-url=https://api.google.com/llm + * ai.koog.google.retry.enabled=true + * ai.koog.google.retry.max-attempts=3 + * ai.koog.google.retry.initial-delay=2s + * ai.koog.google.retry.max-delay=10s + * ai.koog.google.retry.backoff-multiplier=2.0 + * ai.koog.google.retry.jitter-factor=0.5 + * ``` + * + * Advanced Features: + * - The retry configuration supports customizable retries for handling transient failures. + * - Dedicated masking utility ensures that sensitive information, such as the API key, is + * not exposed when serialized or logged. + * + * This class is primarily used in conjunction with the `GoogleLLMAutoConfiguration` auto-configuration + * class to initialize and configure the necessary beans for interacting with Google's LLM API. + * + * For more details on retry behavior, refer to the `RetryConfigKoogProperties` class. + * For shared configuration attributes, see the `KoogLlmClientProperties` interface. + * + * @property enabled Enables or disables the Google LLM integration. + * @property apiKey The key required for authenticating with the API. + * @property baseUrl URL endpoint for the Google LLM API. + * @property retry Defines retry behavior, such as maximum attempts and delays between retries. + */ +@ConfigurationProperties(prefix = GoogleKoogProperties.PREFIX, ignoreUnknownFields = true) +public class GoogleKoogProperties( + public override val enabled: Boolean, + public val apiKey: String, + public override val baseUrl: String, + public override val retry: RetryConfigKoogProperties? = null +) : KoogLlmClientProperties { + /** + * Companion object for the GoogleKoogProperties class, providing constant values and + * utilities associated with the configuration of Google-related properties. + */ + public companion object Companion { + /** + * Prefix constant used for configuration Google-related properties in the Koog framework. + */ + public const val PREFIX: String = "ai.koog.google" + } + + /** + * Returns a string representation of the GoogleKoogProperties object. + * + * The resulting string includes details about the object's properties such as + * `enabled`, `apiKey` (with sensitive information masked), `baseUrl`, and `retry`. + * + * @return A string representation of the GoogleKoogProperties object. + */ + override fun toString(): String { + return "GoogleKoogProperties(enabled=$enabled, apiKey='${apiKey.masked()}', baseUrl='$baseUrl', retry=$retry)" + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/google/GoogleLLMAutoConfiguration.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/google/GoogleLLMAutoConfiguration.kt new file mode 100644 index 0000000000..914186ae74 --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/google/GoogleLLMAutoConfiguration.kt @@ -0,0 +1,75 @@ +package ai.koog.spring.prompt.executor.clients.google + +import ai.koog.prompt.executor.clients.google.GoogleClientSettings +import ai.koog.prompt.executor.clients.google.GoogleLLMClient +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor +import ai.koog.spring.prompt.executor.clients.ollama.OllamaKoogProperties +import ai.koog.spring.prompt.executor.clients.toRetryingClient +import org.springframework.boot.autoconfigure.AutoConfiguration +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty +import org.springframework.boot.context.properties.EnableConfigurationProperties +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.PropertySource + +/** + * Provides the auto-configuration for integrating with Google LLM via the Koog framework. + * This class is responsible for initializing and configuring the necessary beans for interacting + * with Google's APIs, based on the configurations supplied via `GoogleKoogProperties`. + * + * The configuration is activated only when the property `ai.koog.google.enabled` is set to `true`, + * and an `api-key` is provided. + * + * Beans configured by this class: + * - [GoogleLLMClient]: A client for interacting with Google's LLM API, using the specified API key and settings. + * - [SingleLLMPromptExecutor]: An executor capable of handling and retrying LLM prompts, using the initialized client. + * + * An external configuration file at `classpath:/META-INF/config/koog/google-llm.properties` is leveraged + * for managing default settings. + * + * @property properties [GoogleKoogProperties] to define key settings such as API key, base URL, and retry configurations. + * + * @see GoogleKoogProperties + * @see GoogleLLMClient + * @see SingleLLMPromptExecutor + */ +@AutoConfiguration +@PropertySource("classpath:/META-INF/config/koog/google-llm.properties") +@EnableConfigurationProperties( + GoogleKoogProperties::class, +) +@ConditionalOnProperty(prefix = GoogleKoogProperties.PREFIX, name = ["api-key"]) +@ConditionalOnProperty(prefix = GoogleKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") +public class GoogleLLMAutoConfiguration( + private val properties: GoogleKoogProperties +) { + + /** + * Provides a [GoogleLLMClient] bean configured with the API key and base URL + * specified in the application's properties. + * + * @return A configured instance of [GoogleLLMClient]. + */ + @Bean + public fun googleLLMClient(): GoogleLLMClient { + return GoogleLLMClient( + apiKey = properties.apiKey, + settings = GoogleClientSettings(baseUrl = properties.baseUrl) + ) + } + + /** + * Creates and configures a [SingleLLMPromptExecutor] instance using [GoogleLLMClient] properties. + * + * The method initializes an [GoogleLLMClient] with the base URL derived from the provided [OllamaKoogProperties] + * and uses it to construct the [SingleLLMPromptExecutor]. + * + * @param properties the configuration properties containing Ollama client settings such as the base URL. + * @return a [SingleLLMPromptExecutor] configured to use the Ollama client. + */ + @Bean + @ConditionalOnBean(GoogleLLMClient::class) + public fun googleExecutor(client: GoogleLLMClient): SingleLLMPromptExecutor { + return SingleLLMPromptExecutor(client.toRetryingClient(properties.retry)) + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/ollama/OllamaKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/ollama/OllamaKoogProperties.kt new file mode 100644 index 0000000000..4e862900c9 --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/ollama/OllamaKoogProperties.kt @@ -0,0 +1,53 @@ +package ai.koog.spring.prompt.executor.clients.ollama + +import ai.koog.spring.RetryConfigKoogProperties +import ai.koog.spring.prompt.executor.clients.KoogLlmClientProperties +import org.springframework.boot.context.properties.ConfigurationProperties + +/** + * Configuration properties for the Ollama integration in the Koog framework. + * + * This class defines properties that control the connection and behavior when interacting + * with the Ollama Large Language Model (LLM) service. + * + * These properties are typically configured in the application properties or YAML file. + * + * The configuration prefix for these properties is defined as `ai.koog.ollama`. + * It is used to map these properties in the application configuration file. + * + * This class is designed to work along with the `OllamaLLMAutoConfiguration` class to + * automatically initialize and configure the required beans for the Ollama client and executor. + * + * @property enabled Indicates whether the Ollama integration is enabled. + * @property baseUrl The URL of the API endpoint for Ollama service. + * @property retry The retry settings for handling request failures. + */ +@ConfigurationProperties(prefix = OllamaKoogProperties.PREFIX, ignoreUnknownFields = true) +public class OllamaKoogProperties( + public override val enabled: Boolean, + public override val baseUrl: String, + public override val retry: RetryConfigKoogProperties? = null +) : KoogLlmClientProperties { + /** + * Companion object for the OllamaKoogProperties class, providing constant values and + * utilities associated with the configuration of Ollama-related properties. + */ + public companion object Companion { + /** + * Prefix constant used for configuration Ollama-related properties in the Koog framework. + */ + public const val PREFIX: String = "ai.koog.ollama" + } + + /** + * Returns a string representation of the `OllamaKoogProperties` object. + * + * The string includes values for the `enabled`, `baseUrl`, and `retry` properties + * to provide a comprehensive overview of the configuration state of the object. + * + * @return a string describing the current state of the `OllamaKoogProperties` instance. + */ + override fun toString(): String { + return "OllamaKoogProperties(enabled=$enabled, baseUrl='$baseUrl', retry=$retry)" + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/ollama/OllamaLLMAutoConfiguration.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/ollama/OllamaLLMAutoConfiguration.kt new file mode 100644 index 0000000000..798dc4bd14 --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/ollama/OllamaLLMAutoConfiguration.kt @@ -0,0 +1,71 @@ +package ai.koog.spring.prompt.executor.clients.ollama + +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor +import ai.koog.prompt.executor.ollama.client.OllamaClient +import ai.koog.spring.prompt.executor.clients.toRetryingClient +import org.springframework.boot.autoconfigure.AutoConfiguration +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty +import org.springframework.boot.context.properties.EnableConfigurationProperties +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.PropertySource + +/** + * Auto-configuration class for integrating the Ollama Large Language Model (LLM) service into applications. + * + * This configuration initializes and provides the necessary beans to enable interaction with the Ollama LLM API. + * It relies on properties defined in the [OllamaKoogProperties] class to set up the service. + * + * The configuration is conditional and will only be initialized if: + * - [OllamaKoogProperties.enabled] is set to `true`. + * - The required [OllamaKoogProperties] are provided in the application configuration. + * + * Initializes the following beans: + * - [OllamaClient]: A client for interacting with the Ollama LLM service. + * - [SingleLLMPromptExecutor]: Executes single-prompt interactions with Ollama, utilizing the client. + * + * This configuration allows seamless integration with the Ollama API while enabling properties-based customization. + * + * @property properties [OllamaKoogProperties] to define key settings such as API key, base URL, and retry configurations. + * @see OllamaKoogProperties + * @see OllamaClient + * @see SingleLLMPromptExecutor + */ +@AutoConfiguration +@PropertySource("classpath:/META-INF/config/koog/ollama-llm.properties") +@EnableConfigurationProperties( + OllamaKoogProperties::class, +) +@ConditionalOnProperty(prefix = OllamaKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") +public class OllamaLLMAutoConfiguration( + private val properties: OllamaKoogProperties +) { + + /** + * Creates and configures an instance of [OllamaClient] using the base URL from the provided properties. + * + * This client is used to communicate with the Ollama LLM service and is a prerequisite + * for executing prompts and other interactions with the service. + * + * @return an [OllamaClient] configured with the base URL extracted from the application's properties. + */ + @Bean + public fun ollamaLLMClient(): OllamaClient { + return OllamaClient( + baseUrl = properties.baseUrl, + ) + } + + /** + * Creates and configures an instance of [SingleLLMPromptExecutor] that wraps the provided [OllamaClient]. + * The configured executor includes retry capabilities based on the application's properties. + * + * @param client the [OllamaClient] instance used for communicating with the Ollama LLM service. + * @return a [SingleLLMPromptExecutor] configured to execute LLM prompts with the provided client. + */ + @Bean + @ConditionalOnBean(OllamaClient::class) + public fun ollamaExecutor(client: OllamaClient): SingleLLMPromptExecutor { + return SingleLLMPromptExecutor(client.toRetryingClient(properties.retry)) + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openai/OpenAIKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openai/OpenAIKoogProperties.kt new file mode 100644 index 0000000000..0d9bbe7069 --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openai/OpenAIKoogProperties.kt @@ -0,0 +1,51 @@ +package ai.koog.spring.prompt.executor.clients.openai + +import ai.koog.spring.RetryConfigKoogProperties +import ai.koog.spring.prompt.executor.clients.KoogLlmClientProperties +import ai.koog.utils.lang.masked +import org.springframework.boot.context.properties.ConfigurationProperties + +/** + * Configuration class for OpenAI settings in the Koog framework. + * This class defines properties required to configure and use OpenAI-related services, such as API keys, + * base URLs for the services, and optional retry configurations. + * + * The class is annotated with `@ConfigurationProperties` to bind its fields to configuration file properties + * prefixed with `ai.koog.openai`. + * + * @property enabled Determines if the OpenAI client is enabled. + * @property apiKey The API key required to authenticate requests to the OpenAI API. + * @property baseUrl The base URL for accessing the OpenAI API. + * @property retry Optional retry configuration settings, such as maximum attempts, delays, and backoff strategies. + * + * This configuration is used in `OpenAILLMAutoConfiguration` to set up the OpenAI client and related beans. + */ +@ConfigurationProperties(prefix = OpenAIKoogProperties.PREFIX, ignoreUnknownFields = true) +public class OpenAIKoogProperties( + public override val enabled: Boolean, + public val apiKey: String, + public override val baseUrl: String, + public override val retry: RetryConfigKoogProperties? = null +) : KoogLlmClientProperties { + /** + * Companion object for the OpenAIKoogProperties class, providing constant values and + * utilities associated with the configuration of OpenAI-related properties. + */ + public companion object { + /** + * Prefix constant used for configuration OpenAI-related properties in the Koog framework. + */ + public const val PREFIX: String = "ai.koog.openai" + } + + /** + * Returns a string representation of the `OpenAIKoogProperties` object. + * The string includes details about the `enabled` status, masked `apiKey`, + * `baseUrl`, and `retry` configuration. + * + * @return A string summarizing the `OpenAIKoogProperties` object's state. + */ + override fun toString(): String { + return "OpenAIKoogProperties(enabled=$enabled, apiKey='$${apiKey.masked()}', baseUrl='$baseUrl', retry=$retry)" + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openai/OpenAILLMAutoConfiguration.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openai/OpenAILLMAutoConfiguration.kt new file mode 100644 index 0000000000..93ad1a0713 --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openai/OpenAILLMAutoConfiguration.kt @@ -0,0 +1,73 @@ +package ai.koog.spring.prompt.executor.clients.openai + +import ai.koog.prompt.executor.clients.openai.OpenAIClientSettings +import ai.koog.prompt.executor.clients.openai.OpenAILLMClient +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor +import ai.koog.spring.prompt.executor.clients.toRetryingClient +import org.springframework.boot.autoconfigure.AutoConfiguration +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty +import org.springframework.boot.context.properties.EnableConfigurationProperties +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.PropertySource + +/** + * Auto-configuration class for setting up OpenAI LLM client and related beans. + * This class utilizes the properties defined in [OpenAIKoogProperties] to configure and initialize OpenAI-related components, + * including the client and executor. + * + * The configuration is conditionally applied if the property `ai.koog.openai.api-key` is set to `true`. + * It reads additional configuration from the properties file located at `classpath:/META-INF/config/koog/openai-llm.properties`. + * + * Key Features: + * - Sets up the [OpenAILLMClient] bean with API key and base URL from the provided properties. + * - Configures a [SingleLLMPromptExecutor] bean using the configured OpenAI client with retry capabilities. + * + * Usage Notes: + * - To activate, ensure the `ai.koog.openai.api-key` property is defined in your application configuration. + * - Customize behavior and settings using the `ai.koog.openai.*` configuration properties. + * + * @property properties [OpenAIKoogProperties] to define key settings such as API key, base URL, and retry configurations. + * + * @see OpenAIKoogProperties + * @see OpenAILLMClient + * @see SingleLLMPromptExecutor + */ +@AutoConfiguration +@PropertySource("classpath:/META-INF/config/koog/openai-llm.properties") +@EnableConfigurationProperties( + OpenAIKoogProperties::class, +) +@ConditionalOnProperty(prefix = OpenAIKoogProperties.PREFIX, name = ["api-key"]) +@ConditionalOnProperty(prefix = OpenAIKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") +public class OpenAILLMAutoConfiguration( + private val properties: OpenAIKoogProperties +) { + + /** + * Creates and provides an instance of [OpenAILLMClient] as a Spring bean for use in the application context. + * The [OpenAILLMClient] is configured using API key and base URL from the associated properties. + * + * @return a configured instance of [OpenAILLMClient]. + */ + @Bean + public fun openAILLMClient(): OpenAILLMClient { + return OpenAILLMClient( + apiKey = properties.apiKey, + settings = OpenAIClientSettings(baseUrl = properties.baseUrl) + ) + } + + /** + * Creates and returns a [SingleLLMPromptExecutor] bean configured with the given [OpenAILLMClient]. + * This bean is conditionally initialized only when an [OpenAILLMClient] bean is present in the application context. + * + * @param client the [OpenAILLMClient] used to create a retry-capable client for executing LLM prompts. + * @return a configured instance of [SingleLLMPromptExecutor]. + */ + @Bean + @ConditionalOnBean(OpenAILLMClient::class) + public fun openAIExecutor(client: OpenAILLMClient): SingleLLMPromptExecutor { + return SingleLLMPromptExecutor(client.toRetryingClient(properties.retry)) + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openrouter/OpenRouterKoogProperties.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openrouter/OpenRouterKoogProperties.kt new file mode 100644 index 0000000000..95b3bc2c85 --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openrouter/OpenRouterKoogProperties.kt @@ -0,0 +1,56 @@ +package ai.koog.spring.prompt.executor.clients.openrouter + +import ai.koog.spring.RetryConfigKoogProperties +import ai.koog.spring.prompt.executor.clients.KoogLlmClientProperties +import ai.koog.utils.lang.masked +import org.springframework.boot.context.properties.ConfigurationProperties + +/** + * Configuration properties class for OpenRouter integration within the Koog framework. + * + * This class defines configuration options required for connecting to the OpenRouter service + * via the Koog framework. It includes parameters such as API key, base URL, enabling or disabling + * the integration, and retry configuration for handling API requests. + * + * When properly configured in the application properties using the defined prefix, this class + * allows seamless integration with OpenRouter's LLM services. + * + * Configuration prefix: `ai.koog.openrouter` + * + * @property enabled Specifies whether the OpenRouter integration is enabled. This can be toggled + * via the property `ai.koog.openrouter.enabled`. + * @property apiKey The API key used for authenticating requests to the OpenRouter service. + * This must be provided through the property `ai.koog.openrouter.api-key`. + * @property baseUrl The base URL of the OpenRouter API endpoint, configurable via the + * property `ai.koog.openrouter.base-url`. Defaults to the service's official API URL. + * @property retry An optional retry configuration for handling failed API requests. + * This can be set using sub-properties under `ai.koog.openrouter.retry`. + */ +@ConfigurationProperties(prefix = OpenRouterKoogProperties.PREFIX, ignoreUnknownFields = true) +public class OpenRouterKoogProperties( + public override val enabled: Boolean, + public val apiKey: String, + public override val baseUrl: String, + public override val retry: RetryConfigKoogProperties? = null +) : KoogLlmClientProperties { + /** + * Companion object for the [OpenRouterKoogProperties] class, providing constant values and + * utilities associated with the configuration of OpenRouter-related properties. + */ + public companion object Companion { + /** + * Prefix constant used for configuration OpenRouter-related properties in the Koog framework. + */ + public const val PREFIX: String = "ai.koog.openrouter" + } + + /** + * Converts the [OpenRouterKoogProperties] instance to its string representation. + * Sensitive information, such as the API key, is masked to ensure security. + * + * @return A string representation of the [OpenRouterKoogProperties] object. + */ + override fun toString(): String { + return "OpenRouterKoogProperties(enabled=$enabled, apiKey='${apiKey.masked()}', baseUrl='$baseUrl', retry=$retry)" + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openrouter/OpenRouterLLMAutoConfiguration.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openrouter/OpenRouterLLMAutoConfiguration.kt new file mode 100644 index 0000000000..6e8928adaa --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openrouter/OpenRouterLLMAutoConfiguration.kt @@ -0,0 +1,67 @@ +package ai.koog.spring.prompt.executor.clients.openrouter + +import ai.koog.prompt.executor.clients.openrouter.OpenRouterClientSettings +import ai.koog.prompt.executor.clients.openrouter.OpenRouterLLMClient +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor +import ai.koog.spring.prompt.executor.clients.toRetryingClient +import org.springframework.boot.autoconfigure.AutoConfiguration +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty +import org.springframework.boot.context.properties.EnableConfigurationProperties +import org.springframework.context.annotation.Bean +import org.springframework.context.annotation.PropertySource + +/** + * Auto-configuration class for integrating OpenRouter with Koog framework. + * + * This class enables the automatic configuration of beans and properties to work with OpenRouter's LLM services, + * provided the application properties have been set with the required prefix and fields. + * + * The configuration is activated only when both `ai.koog.openrouter.enabled` is set to `true` + * and `ai.koog.openrouter.api-key` is provided in the application properties. + * + * @property properties [OpenRouterKoogProperties] to define key settings such as API key, base URL, and retry configurations. + * @see OpenRouterKoogProperties + * @see OpenRouterLLMClient + * @see SingleLLMPromptExecutor + */ +@AutoConfiguration +@PropertySource("classpath:/META-INF/config/koog/openrouter-llm.properties") +@EnableConfigurationProperties( + OpenRouterKoogProperties::class, +) +@ConditionalOnProperty(prefix = OpenRouterKoogProperties.PREFIX, name = ["api-key"]) +@ConditionalOnProperty(prefix = OpenRouterKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") +public class OpenRouterLLMAutoConfiguration( + private val properties: OpenRouterKoogProperties +) { + + /** + * Creates and configures an instance of [OpenRouterLLMClient] as a Spring Bean. + * The client is initialized with the API key and settings (such as base URL) + * obtained from the provided `properties` configuration. + * + * @return An instance of [OpenRouterLLMClient] configured with the given properties. + */ + @Bean + public fun openRouterLLMClient(): OpenRouterLLMClient { + return OpenRouterLLMClient( + apiKey = properties.apiKey, + settings = OpenRouterClientSettings(baseUrl = properties.baseUrl) + ) + } + + /** + * Provides a [SingleLLMPromptExecutor] bean configured with an [OpenRouterLLMClient]. + * + * The method uses the provided [OpenRouterLLMClient] to create a retrying client instance + * based on the configuration in the `properties.retry` parameter. + * + * @param client The [OpenRouterLLMClient] instance used to configure the [SingleLLMPromptExecutor] + * */ + @Bean + @ConditionalOnBean(OpenRouterLLMClient::class) + public fun openRouterExecutor(client: OpenRouterLLMClient): SingleLLMPromptExecutor { + return SingleLLMPromptExecutor(client.toRetryingClient(properties.retry)) + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/utils.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/utils.kt new file mode 100644 index 0000000000..efebf0447a --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/utils.kt @@ -0,0 +1,27 @@ +package ai.koog.spring.prompt.executor.clients + +import ai.koog.prompt.executor.clients.LLMClient +import ai.koog.prompt.executor.clients.retry.RetryConfig +import ai.koog.prompt.executor.clients.retry.RetryingLLMClient +import ai.koog.spring.RetryConfigKoogProperties +import kotlin.time.toKotlinDuration + +internal fun LLMClient.toRetryingClient(properties: RetryConfigKoogProperties?): LLMClient { + val self = this + return if (properties?.enabled == true) { + val defaultConfig = RetryConfig() + val retryConfig = RetryConfig( + maxAttempts = properties.maxAttempts ?: defaultConfig.maxAttempts, + initialDelay = properties.initialDelay?.toKotlinDuration() ?: defaultConfig.initialDelay, + maxDelay = properties.maxDelay?.toKotlinDuration() ?: defaultConfig.maxDelay, + backoffMultiplier = properties.backoffMultiplier ?: defaultConfig.backoffMultiplier, + jitterFactor = properties.jitterFactor ?: defaultConfig.jitterFactor + ) + RetryingLLMClient( + delegate = self, + config = retryConfig + ) + } else { + self + } +} diff --git a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/anthropic-llm.properties b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/anthropic-llm.properties new file mode 100644 index 0000000000..fa078f9422 --- /dev/null +++ b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/anthropic-llm.properties @@ -0,0 +1,2 @@ +ai.koog.anthropic.enabled=true +ai.koog.anthropic.base-url=https://api.anthropic.com diff --git a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/deepseek-llm.properties b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/deepseek-llm.properties new file mode 100644 index 0000000000..a924235966 --- /dev/null +++ b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/deepseek-llm.properties @@ -0,0 +1,2 @@ +ai.koog.deepseek.enabled=true +ai.koog.deepseek.base-url=https://api.deepseek.com diff --git a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/google-llm.properties b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/google-llm.properties new file mode 100644 index 0000000000..e6c0935123 --- /dev/null +++ b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/google-llm.properties @@ -0,0 +1,2 @@ +ai.koog.google.enabled=true +ai.koog.google.base-url=https://generativelanguage.googleapis.com diff --git a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/ollama-llm.properties b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/ollama-llm.properties new file mode 100644 index 0000000000..9f93516484 --- /dev/null +++ b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/ollama-llm.properties @@ -0,0 +1,2 @@ +ai.koog.ollama.enabled=false +ai.koog.ollama.base-url=http://localhost:11434 diff --git a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openai-llm.properties b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openai-llm.properties new file mode 100644 index 0000000000..1ec885474d --- /dev/null +++ b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openai-llm.properties @@ -0,0 +1,2 @@ +ai.koog.openai.enabled=true +ai.koog.openai.base-url=https://api.openai.com diff --git a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openrouter-llm.properties b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openrouter-llm.properties new file mode 100644 index 0000000000..54c3588a7a --- /dev/null +++ b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openrouter-llm.properties @@ -0,0 +1,2 @@ +ai.koog.openrouter.enabled=true +ai.koog.openrouter.base-url=https://openrouter.ai diff --git a/koog-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports b/koog-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports index 398557bf69..d9859c91bb 100644 --- a/koog-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports +++ b/koog-spring-boot-starter/src/main/resources/META-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.imports @@ -1 +1,6 @@ -ai.koog.spring.KoogAutoConfiguration \ No newline at end of file +ai.koog.spring.prompt.executor.clients.anthropic.AnthropicLLMAutoConfiguration +ai.koog.spring.prompt.executor.clients.deepseek.DeepSeekLLMAutoConfiguration +ai.koog.spring.prompt.executor.clients.google.GoogleLLMAutoConfiguration +ai.koog.spring.prompt.executor.clients.ollama.OllamaLLMAutoConfiguration +ai.koog.spring.prompt.executor.clients.openai.OpenAILLMAutoConfiguration +ai.koog.spring.prompt.executor.clients.openrouter.OpenRouterLLMAutoConfiguration diff --git a/koog-spring-boot-starter/src/test/kotlin/ai/koog/spring/KoogAutoConfigurationTest.kt b/koog-spring-boot-starter/src/test/kotlin/ai/koog/spring/KoogAutoConfigurationTest.kt index 30f15c8279..c986991b49 100644 --- a/koog-spring-boot-starter/src/test/kotlin/ai/koog/spring/KoogAutoConfigurationTest.kt +++ b/koog-spring-boot-starter/src/test/kotlin/ai/koog/spring/KoogAutoConfigurationTest.kt @@ -14,6 +14,12 @@ import ai.koog.prompt.executor.clients.retry.RetryConfig import ai.koog.prompt.executor.clients.retry.RetryingLLMClient import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor import ai.koog.prompt.executor.ollama.client.OllamaClient +import ai.koog.spring.prompt.executor.clients.anthropic.AnthropicLLMAutoConfiguration +import ai.koog.spring.prompt.executor.clients.deepseek.DeepSeekLLMAutoConfiguration +import ai.koog.spring.prompt.executor.clients.google.GoogleLLMAutoConfiguration +import ai.koog.spring.prompt.executor.clients.ollama.OllamaLLMAutoConfiguration +import ai.koog.spring.prompt.executor.clients.openai.OpenAILLMAutoConfiguration +import ai.koog.spring.prompt.executor.clients.openrouter.OpenRouterLLMAutoConfiguration import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Assertions.assertTrue import org.junit.jupiter.api.Test @@ -41,10 +47,19 @@ private const val PROVIDERS = """ class KoogAutoConfigurationTest { private val defaultRetryConfig = RetryConfig() + private val allProvidersAutoConfigurations = AutoConfigurations.of( + AnthropicLLMAutoConfiguration::class.java, + GoogleLLMAutoConfiguration::class.java, + DeepSeekLLMAutoConfiguration::class.java, + OllamaLLMAutoConfiguration::class.java, + OpenAILLMAutoConfiguration::class.java, + OpenRouterLLMAutoConfiguration::class.java, + ) + @Test fun `should not supply executor beans if no apiKey is provided`() { ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) + .withConfiguration(allProvidersAutoConfigurations) .run { context -> assertThrows { context.getBean() } } @@ -54,8 +69,11 @@ class KoogAutoConfigurationTest { fun `should supply OpenAI executor bean with provided apiKey and default baseUrl`() { val configApiKey = "some_api_key" ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.openai.api-key=$configApiKey") + .withConfiguration(allProvidersAutoConfigurations) + .withPropertyValues( + "ai.koog.openai.enabled=true", + "ai.koog.openai.api-key=$configApiKey" + ) .run { context -> val executor = context.getBean() val llmClient = getPrivateFieldValue(executor, "llmClient") @@ -75,8 +93,9 @@ class KoogAutoConfigurationTest { fun `should supply OpenAI executor bean with provided baseUrl`() { val configBaseUrl = "https://some-url.com" ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) + .withConfiguration(AutoConfigurations.of(OpenAILLMAutoConfiguration::class.java)) .withPropertyValues( + "ai.koog.openai.enabled=true", "ai.koog.openai.api-key=some_api_key", "ai.koog.openai.base-url=$configBaseUrl", ) @@ -92,16 +111,19 @@ class KoogAutoConfigurationTest { } @ParameterizedTest - @CsvSource(PROVIDERS) + @CsvSource(textBlock = PROVIDERS) fun `should supply OpenAI executor bean with retry client and default config`( provider: String, clazz: Class ) { ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.$provider.api-key=some_api_key") - .withPropertyValues("ai.koog.$provider.retry.enabled=true") - .withPropertyValues("ai.koog.$provider.base-url=http://localhost:9876") + .withConfiguration(allProvidersAutoConfigurations) + .withPropertyValues( + "ai.koog.$provider.enabled=true", + "ai.koog.$provider.api-key=some_api_key", + "ai.koog.$provider.retry.enabled=true", + "ai.koog.$provider.base-url=http://localhost:9876" + ) .run { context -> val executor = context.getBean() val retryingClient = getPrivateFieldValue(executor, "llmClient") @@ -120,7 +142,7 @@ class KoogAutoConfigurationTest { } @ParameterizedTest - @CsvSource(PROVIDERS) + @CsvSource(textBlock = PROVIDERS) fun `should supply executor bean with retry client and full custom config`( provider: String, clazz: Class @@ -131,15 +153,18 @@ class KoogAutoConfigurationTest { val backoffMultiplier = 5.0 val jitterFactor = 0.5 ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.$provider.api-key=some_api_key") - .withPropertyValues("ai.koog.$provider.base-url=http://localhost:9876") - .withPropertyValues("ai.koog.$provider.retry.enabled=true") - .withPropertyValues("ai.koog.$provider.retry.max-attempts=$maxAttempts") - .withPropertyValues("ai.koog.$provider.retry.initial-delay=$initialDelay") - .withPropertyValues("ai.koog.$provider.retry.max-delay=$maxDelay") - .withPropertyValues("ai.koog.$provider.retry.backoff-multiplier=$backoffMultiplier") - .withPropertyValues("ai.koog.$provider.retry.jitter-factor=$jitterFactor") + .withConfiguration(allProvidersAutoConfigurations) + .withPropertyValues( + "ai.koog.$provider.enabled=true", + "ai.koog.$provider.api-key=some_api_key", + "ai.koog.$provider.base-url=http://localhost:9876", + "ai.koog.$provider.retry.enabled=true", + "ai.koog.$provider.retry.max-attempts=$maxAttempts", + "ai.koog.$provider.retry.initial-delay=$initialDelay", + "ai.koog.$provider.retry.max-delay=$maxDelay", + "ai.koog.$provider.retry.backoff-multiplier=$backoffMultiplier", + "ai.koog.$provider.retry.jitter-factor=$jitterFactor" + ) .run { context -> val executor = context.getBean() val retryingClient = getPrivateFieldValue(executor, "llmClient") @@ -158,7 +183,37 @@ class KoogAutoConfigurationTest { } @ParameterizedTest - @CsvSource(PROVIDERS) + @CsvSource(textBlock = PROVIDERS) + fun `Should not create beans when provider is DISABLED`( + provider: String, + ) { + val maxAttempts = 5 + val initialDelay = 10 + val maxDelay = 60 + val backoffMultiplier = 5.0 + val jitterFactor = 0.5 + ApplicationContextRunner() + .withConfiguration(allProvidersAutoConfigurations) + .withPropertyValues( + "ai.koog.$provider.enabled=false", + "ai.koog.$provider.api-key=some_api_key", + "ai.koog.$provider.base-url=http://localhost:9876", + "ai.koog.$provider.retry.enabled=true", + "ai.koog.$provider.retry.max-attempts=$maxAttempts", + "ai.koog.$provider.retry.initial-delay=$initialDelay", + "ai.koog.$provider.retry.max-delay=$maxDelay", + "ai.koog.$provider.retry.backoff-multiplier=$backoffMultiplier", + "ai.koog.$provider.retry.jitter-factor=$jitterFactor" + ) + .run { context -> + assertTrue { context.getBeansOfType(SingleLLMPromptExecutor::class.java).isEmpty() } + assertTrue { context.getBeansOfType(RetryingLLMClient::class.java).isEmpty() } + assertTrue { context.getBeansOfType(LLMClient::class.java).isEmpty() } + } + } + + @ParameterizedTest + @CsvSource(textBlock = PROVIDERS) fun `should supply executor bean with retry client and partial custom config`( provider: String, clazz: Class @@ -166,11 +221,14 @@ class KoogAutoConfigurationTest { val maxAttempts = 5 val initialDelay = 10 ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.$provider.api-key=some_api_key") - .withPropertyValues("ai.koog.$provider.retry.enabled=true") - .withPropertyValues("ai.koog.$provider.retry.max-attempts=$maxAttempts") - .withPropertyValues("ai.koog.$provider.retry.initial-delay=$initialDelay") + .withConfiguration(allProvidersAutoConfigurations) + .withPropertyValues( + "ai.koog.$provider.enabled=true", + "ai.koog.$provider.api-key=some_api_key", + "ai.koog.$provider.retry.enabled=true", + "ai.koog.$provider.retry.max-attempts=$maxAttempts", + "ai.koog.$provider.retry.initial-delay=$initialDelay" + ) .run { context -> val executor = context.getBean() val retryingClient = getPrivateFieldValue(executor, "llmClient") @@ -192,8 +250,11 @@ class KoogAutoConfigurationTest { fun `should supply Anthropic executor bean with provided apiKey and default baseUrl`() { val configApiKey = "some_api_key" ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.anthropic.api-key=$configApiKey") + .withConfiguration(allProvidersAutoConfigurations) + .withPropertyValues( + "ai.koog.anthropic.enabled=true", + "ai.koog.anthropic.api-key=$configApiKey" + ) .run { context -> val executor = context.getBean() val llmClient = getPrivateFieldValue(executor, "llmClient") @@ -212,9 +273,12 @@ class KoogAutoConfigurationTest { @Test fun `should supply Anthropic executor bean with retry client and default config`() { ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.anthropic.api-key=some_api_key") - .withPropertyValues("ai.koog.anthropic.retry.enabled=true") + .withConfiguration(AutoConfigurations.of(AnthropicLLMAutoConfiguration::class.java)) + .withPropertyValues( + "ai.koog.anthropic.enabled=true", + "ai.koog.anthropic.api-key=some_api_key", + "ai.koog.anthropic.retry.enabled=true" + ) .run { context -> val executor = context.getBean() val retryingClient = getPrivateFieldValue(executor, "llmClient") @@ -232,8 +296,9 @@ class KoogAutoConfigurationTest { fun `should supply Anthropic executor bean with provided baseUrl`() { val configBaseUrl = "https://some-url.com" ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) + .withConfiguration(AutoConfigurations.of(AnthropicLLMAutoConfiguration::class.java)) .withPropertyValues( + "ai.koog.anthropic.enabled=true", "ai.koog.anthropic.api-key=some_api_key", "ai.koog.anthropic.base-url=$configBaseUrl", ) @@ -252,8 +317,11 @@ class KoogAutoConfigurationTest { fun `should supply Google executor bean with provided apiKey and default baseUrl`() { val configApiKey = "some_api_key" ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.google.api-key=$configApiKey") + .withConfiguration(AutoConfigurations.of(GoogleLLMAutoConfiguration::class.java)) + .withPropertyValues( + "ai.koog.google.enabled=true", + "ai.koog.google.api-key=$configApiKey" + ) .run { context -> val executor = context.getBean() val llmClient = getPrivateFieldValue(executor, "llmClient") @@ -273,8 +341,9 @@ class KoogAutoConfigurationTest { fun `should supply Google executor bean with provided baseUrl`() { val configBaseUrl = "https://some-url.com" ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) + .withConfiguration(AutoConfigurations.of(GoogleLLMAutoConfiguration::class.java)) .withPropertyValues( + "ai.koog.google.enabled=true", "ai.koog.google.api-key=some_api_key", "ai.koog.google.base-url=$configBaseUrl", ) @@ -292,9 +361,12 @@ class KoogAutoConfigurationTest { @Test fun `should supply Google executor bean with retry client and default config`() { ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.google.api-key=some_api_key") - .withPropertyValues("ai.koog.google.retry.enabled=true") + .withConfiguration(AutoConfigurations.of(GoogleLLMAutoConfiguration::class.java)) + .withPropertyValues( + "ai.koog.google.enabled=true", + "ai.koog.google.api-key=some_api_key", + "ai.koog.google.retry.enabled=true" + ) .run { context -> val executor = context.getBean() val retryingClient = getPrivateFieldValue(executor, "llmClient") @@ -312,8 +384,11 @@ class KoogAutoConfigurationTest { fun `should supply OpenRouter executor bean with provided apiKey and default baseUrl`() { val configApiKey = "some_api_key" ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.openrouter.api-key=$configApiKey") + .withConfiguration(AutoConfigurations.of(OpenRouterLLMAutoConfiguration::class.java)) + .withPropertyValues( + "ai.koog.openrouter.enabled=true", + "ai.koog.openrouter.api-key=$configApiKey" + ) .run { context -> val executor = context.getBean() val llmClient = getPrivateFieldValue(executor, "llmClient") @@ -333,8 +408,9 @@ class KoogAutoConfigurationTest { fun `should supply OpenRouter executor bean with provided baseUrl`() { val configBaseUrl = "https://some-url.com" ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) + .withConfiguration(AutoConfigurations.of(OpenRouterLLMAutoConfiguration::class.java)) .withPropertyValues( + "ai.koog.openrouter.enabled=true", "ai.koog.openrouter.api-key=some_api_key", "ai.koog.openrouter.base-url=$configBaseUrl", ) @@ -352,9 +428,12 @@ class KoogAutoConfigurationTest { @Test fun `should supply OpenRouter executor bean with retry client and default config`() { ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.openrouter.api-key=some_api_key") - .withPropertyValues("ai.koog.openrouter.retry.enabled=true") + .withConfiguration(AutoConfigurations.of(OpenRouterLLMAutoConfiguration::class.java)) + .withPropertyValues( + "ai.koog.openrouter.enabled=true", + "ai.koog.openrouter.api-key=some_api_key", + "ai.koog.openrouter.retry.enabled=true" + ) .run { context -> val executor = context.getBean() val retryingClient = getPrivateFieldValue(executor, "llmClient") @@ -372,8 +451,11 @@ class KoogAutoConfigurationTest { fun `should supply DeepSeek executor bean with provided apiKey and default baseUrl`() { val configApiKey = "some_api_key" ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.deepseek.api-key=$configApiKey") + .withConfiguration(AutoConfigurations.of(DeepSeekLLMAutoConfiguration::class.java)) + .withPropertyValues( + "ai.koog.deepseek.enabled=true", + "ai.koog.deepseek.api-key=$configApiKey" + ) .run { context -> val executor = context.getBean() val llmClient = getPrivateFieldValue(executor, "llmClient") @@ -393,8 +475,9 @@ class KoogAutoConfigurationTest { fun `should supply DeepSeek executor bean with provided baseUrl`() { val configBaseUrl = "https://some-url.com" ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) + .withConfiguration(AutoConfigurations.of(DeepSeekLLMAutoConfiguration::class.java)) .withPropertyValues( + "ai.koog.deepseek.enabled=true", "ai.koog.deepseek.api-key=some_api_key", "ai.koog.deepseek.base-url=$configBaseUrl", ) @@ -412,9 +495,12 @@ class KoogAutoConfigurationTest { @Test fun `should supply DeepSeek executor bean with retry client and default config`() { ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.deepseek.api-key=some_api_key") - .withPropertyValues("ai.koog.deepseek.retry.enabled=true") + .withConfiguration(AutoConfigurations.of(DeepSeekLLMAutoConfiguration::class.java)) + .withPropertyValues( + "ai.koog.deepseek.enabled=true", + "ai.koog.deepseek.api-key=some_api_key", + "ai.koog.deepseek.retry.enabled=true" + ) .run { context -> val executor = context.getBean() val retryingClient = getPrivateFieldValue(executor, "llmClient") @@ -432,8 +518,11 @@ class KoogAutoConfigurationTest { fun `should supply Ollama executor bean with provided baseUrl`() { val configBaseUrl = "https://some-url.com" ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.ollama.base-url=$configBaseUrl") + .withConfiguration(AutoConfigurations.of(OllamaLLMAutoConfiguration::class.java)) + .withPropertyValues( + "ai.koog.ollama.enabled=true", + "ai.koog.ollama.base-url=$configBaseUrl" + ) .run { context -> val executor = context.getBean() val llmClient = getPrivateFieldValue(executor, "llmClient") @@ -448,9 +537,12 @@ class KoogAutoConfigurationTest { @Test fun `should supply Ollama executor bean with retry client and default config`() { ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) - .withPropertyValues("ai.koog.ollama.base-url=https://some-url.com") - .withPropertyValues("ai.koog.ollama.retry.enabled=true") + .withConfiguration(AutoConfigurations.of(OllamaLLMAutoConfiguration::class.java)) + .withPropertyValues( + "ai.koog.ollama.enabled=true", + "ai.koog.ollama.base-url=https://some-url.com", + "ai.koog.ollama.retry.enabled=true" + ) .run { context -> val executor = context.getBean() val retryingClient = getPrivateFieldValue(executor, "llmClient") @@ -467,13 +559,19 @@ class KoogAutoConfigurationTest { @Test fun `should supply multiple executor beans`() { ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(KoogAutoConfiguration::class.java)) + .withConfiguration( + allProvidersAutoConfigurations + ) .withPropertyValues( + "ai.koog.openai.enabled=true", "ai.koog.openai.api-key=some_api_key", + "ai.koog.anthropic.enabled=true", "ai.koog.anthropic.api-key=some_api_key", + "ai.koog.google.enabled=true", "ai.koog.google.api-key=some_api_key", + "ai.koog.deepseek.enabled=true", "ai.koog.deepseek.api-key=some_api_key", - "ai.koog.ollama.base-url=http://localhist:8765", + "ai.koog.ollama.enabled=true", ) .run { context -> val beanNames = context.getBeanNamesForType() diff --git a/koog-spring-boot-starter/src/test/resources/junit-platform.properties b/koog-spring-boot-starter/src/test/resources/junit-platform.properties new file mode 100644 index 0000000000..4ad8fb868b --- /dev/null +++ b/koog-spring-boot-starter/src/test/resources/junit-platform.properties @@ -0,0 +1,5 @@ +## https://docs.junit.org/5.3.0-M1/user-guide/index.html#writing-tests-parallel-execution +junit.jupiter.execution.parallel.enabled=true +junit.jupiter.execution.parallel.config.strategy=dynamic +junit.jupiter.execution.parallel.mode.classes.default=concurrent +junit.jupiter.execution.parallel.mode.default=concurrent diff --git a/utils/src/commonMain/kotlin/ai/koog/utils/lang/StringExtensions.kt b/utils/src/commonMain/kotlin/ai/koog/utils/lang/StringExtensions.kt new file mode 100644 index 0000000000..7ce21b654d --- /dev/null +++ b/utils/src/commonMain/kotlin/ai/koog/utils/lang/StringExtensions.kt @@ -0,0 +1,25 @@ +package ai.koog.utils.lang + +/** + * Returns string with masked symbols using the strictest security level. + * + * This method is designed for masking sensitive secrets and provides maximum security + * by returning a consistent pattern that reveals no information about the original content. + * + * Examples: + * - `null` -> `null` + * - `""` -> `null` + * - `" "` -> `null` + * - `"I"` -> `"***HIDDEN***"` + * - `"Hi"` -> `"***HIDDEN***"` + * - `"Hello"` -> `"***HIDDEN***"` + * - `"VeryLongSecretKey123"` -> `"***HIDDEN***"` + */ +public fun String?.masked( + maskChar: Char = '*', +): String? { + if (this.isNullOrBlank()) return null + + // Strictest security level: consistent pattern regardless of input + return "${maskChar}${maskChar}${maskChar}HIDDEN${maskChar}${maskChar}$maskChar" +} diff --git a/utils/src/commonTest/kotlin/ai/koog/utils/lang/StringExtensionsTest.kt b/utils/src/commonTest/kotlin/ai/koog/utils/lang/StringExtensionsTest.kt new file mode 100644 index 0000000000..3c69fb02e9 --- /dev/null +++ b/utils/src/commonTest/kotlin/ai/koog/utils/lang/StringExtensionsTest.kt @@ -0,0 +1,45 @@ +package ai.koog.utils.lang + +import kotlin.test.Test +import kotlin.test.assertEquals + +class StringExtensionsTest { + + @Test + fun `String masked should hide string contents`() { + // Test null input + assertEquals(null, null.masked()) + + // Test empty string + assertEquals(null, "".masked()) + + // Test blank string + assertEquals(null, " ".masked()) + + // Test single character - strict security returns consistent pattern + assertEquals("***HIDDEN***", "I".masked()) + + // Test two characters - strict security returns consistent pattern + assertEquals("***HIDDEN***", "Hi".masked()) + + // Test three characters - strict security returns consistent pattern + assertEquals("***HIDDEN***", "Hey".masked()) + + // Test longer string - strict security returns consistent pattern + assertEquals("***HIDDEN***", "Hello".masked()) + + // Test very long string - strict security returns consistent pattern + assertEquals("***HIDDEN***", "cccccclulbkucigkbggivvitngjdbfuhkevedrdukvcr".masked()) + + // Test custom mask character - should use custom char in pattern + assertEquals("---HIDDEN---", "Hello".masked(maskChar = '-')) + + // Test string with whitespace - should still return consistent pattern + assertEquals("***HIDDEN***", " Hello ".masked()) + + // Test sensitive data - all return same pattern for maximum security + assertEquals("***HIDDEN***", "password123".masked()) + assertEquals("***HIDDEN***", "api-key-secret".masked()) + assertEquals("***HIDDEN***", "x".masked()) + } +} From 45012a3b4ffbfadd23e2fa04daa400a81e1c81df Mon Sep 17 00:00:00 2001 From: Nicolas Frenay Date: Tue, 30 Sep 2025 20:37:46 -0300 Subject: [PATCH 03/52] fix: Google: use maxTokens from params (#734) --- .../clients/google/GoogleLLMClient.kt | 4 +- .../clients/google/GoogleLLMClientTest.kt | 40 +++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) create mode 100644 prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClientTest.kt diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClient.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClient.kt index 8517a54464..a25bf2120a 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClient.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClient.kt @@ -273,7 +273,7 @@ public open class GoogleLLMClient( * @param tools Tools to include in the request * @return A formatted GoogleAI request */ - private fun createGoogleRequest(prompt: Prompt, model: LLModel, tools: List): GoogleRequest { + internal fun createGoogleRequest(prompt: Prompt, model: LLModel, tools: List): GoogleRequest { val systemMessageParts = mutableListOf() val contents = mutableListOf() val pendingCalls = mutableListOf() @@ -388,7 +388,7 @@ public open class GoogleLLMClient( responseJsonSchema = responseFormat?.responseJsonSchema, temperature = if (model.capabilities.contains(LLMCapability.Temperature)) prompt.params.temperature else null, candidateCount = if (model.capabilities.contains(LLMCapability.MultipleChoices)) prompt.params.numberOfChoices else null, - maxOutputTokens = 2048, + maxOutputTokens = prompt.params.maxTokens, thinkingConfig = GoogleThinkingConfig( includeThoughts = prompt.params.includeThoughts.takeIf { it == true }, thinkingBudget = prompt.params.thinkingBudget diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClientTest.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClientTest.kt new file mode 100644 index 0000000000..d74f0cd053 --- /dev/null +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-google-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/google/GoogleLLMClientTest.kt @@ -0,0 +1,40 @@ +package ai.koog.prompt.executor.clients.google + +import ai.koog.prompt.dsl.Prompt +import ai.koog.prompt.params.LLMParams +import kotlin.test.Test +import kotlin.test.assertEquals + +class GoogleLLMClientTest { + + @Test + fun `createGoogleRequest should use null maxTokens if unspecified`() { + val client = GoogleLLMClient(apiKey = "apiKey") + val model = GoogleModels.Gemini2_5Pro + val request = client.createGoogleRequest( + prompt = Prompt( + messages = emptyList(), + id = "id" + ), + model = model, + tools = emptyList() + ) + assertEquals(null, request.generationConfig!!.maxOutputTokens) + } + + @Test + fun `createGoogleRequest should use maxTokens from user specified parameters when available`() { + val client = GoogleLLMClient(apiKey = "apiKey") + val model = GoogleModels.Gemini2_5Pro + val request = client.createGoogleRequest( + prompt = Prompt( + messages = emptyList(), + id = "id", + params = LLMParams(maxTokens = 100) + ), + model = model, + tools = emptyList() + ) + assertEquals(100, request.generationConfig!!.maxOutputTokens) + } +} From 9a50bf0da34491267c8c84d2cec6b8c157bd7cd1 Mon Sep 17 00:00:00 2001 From: Inna Teteniuk Date: Wed, 1 Oct 2025 10:25:29 +0200 Subject: [PATCH 04/52] Updated documentation on functional agents (#881) Updated documentation on functional agents --- #### Type of the changes - [ ] New feature (non-breaking change which adds functionality) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [x] Documentation update - [ ] Tests improvement - [ ] Refactoring #### Checklist - [ ] The pull request has a description of the proposed change - [ ] I read the [Contributing Guidelines](https://github.com/JetBrains/koog/blob/main/CONTRIBUTING.md) before opening the pull request - [ ] The pull request uses **`develop`** as the base branch - [ ] Tests for the changes have been added - [ ] All new and existing tests passed ##### Additional steps for pull requests adding a new feature - [ ] An issue describing the proposed change exists - [ ] The pull request includes a link to the issue - [ ] The change was discussed and approved in the issue - [ ] Docs have been added / updated --- docs/docs/act-ai-agent.md | 202 ------------------- docs/docs/complex-workflow-agents.md | 11 +- docs/docs/functional-agents.md | 280 +++++++++++++++++++++++++++ docs/docs/prompt-api.md | 4 +- docs/mkdocs.yml | 4 +- 5 files changed, 292 insertions(+), 209 deletions(-) delete mode 100644 docs/docs/act-ai-agent.md create mode 100644 docs/docs/functional-agents.md diff --git a/docs/docs/act-ai-agent.md b/docs/docs/act-ai-agent.md deleted file mode 100644 index 664fed43db..0000000000 --- a/docs/docs/act-ai-agent.md +++ /dev/null @@ -1,202 +0,0 @@ -# FunctionalAIAgent: How to build a single‑run agent step by step - -FunctionalAIAgent is a lightweight, non‑graph agent that you control with a simple loop. Use it when you want to: -- Call an LLM once or a few times in a custom loop; -- Optionally call tools between LLM turns; -- Return a final value (string, data class, etc.) without building a full strategy graph. - -What you’ll do in this guide: -1) Create a “Hello, World” FunctionalAIAgent. -2) Add a tool and let the agent call it. -3) Add a feature (event handler) to observe behavior. -4) Keep context under control with history compression. -5) Learn common recipes, pitfalls, and FAQs. - -## 1) Prerequisites -You need a PromptExecutor (the object that actually talks to your LLM). For local experimenting, you can use the Ollama executor: - -```kotlin -val exec = simpleOllamaAIExecutor() -``` - -You also need to pick a model, for example: - -```kotlin -val model = OllamaModels.Meta.LLAMA_3_2 -``` - -That’s it — we’ll inject both into the agent factory. - - -## 2) Your first agent (Hello, World) -Goal: Send the user’s text to the LLM and return a single assistant message as a string. - -```kotlin -val agent = functionalAIAgent( - prompt = "You are a helpful assistant.", - promptExecutor = exec, - model = model -) { input -> - val responses = requestLLMMultiple(input) - responses.single().asAssistantMessage().content -} - -val result = agent.run("Say hi in one sentence") -println(result) -``` - -What happens? -- requestLLMMultiple(input) sends the user input and receives one or more assistant messages. -- We return the only message’s content (typical one‑shot flow). - -Tip: If you want to return structured data, parse the content or use the Structured Data API. - - -## 3) Add tools (how the agent calls your functions) -Goal: Let the model operate a tiny device via tools. - -```kotlin -class Switch { - private var on = false - fun on() { on = true } - fun off() { on = false } - fun isOn() = on -} - -class SwitchTools(private val sw: Switch) { - fun turn_on() = run { sw.on(); "ok" } - fun turn_off() = run { sw.off(); "ok" } - fun state() = if (sw.isOn()) "on" else "off" -} - -val sw = Switch() -val tools = ToolRegistry { tools(SwitchTools(sw).asTools()) } - -val toolAgent = functionalAIAgent( - prompt = "You're responsible for running a Switch device and perform operations on it by request.", - promptExecutor = exec, - model = model, - toolRegistry = tools -) { input -> - var responses = requestLLMMultiple(input) - - while (responses.containsToolCalls()) { - val pending = extractToolCalls(responses) - val results = executeMultipleTools(pending) - responses = sendMultipleToolResults(results) - } - - responses.single().asAssistantMessage().content -} - -val out = toolAgent.run("Turn switch on") -println(out) -println("Switch is ${if (sw.isOn()) "on" else "off"}") -``` - -How it works -- containsToolCalls() detects tool call messages from the LLM. -- extractToolCalls(...) reads which tools to run and with what args. -- executeMultipleTools(...) runs them against your ToolRegistry. -- sendMultipleToolResults(...) sends results back to the LLM and gets the next response. - - -## 4) Observe behavior with features (EventHandler) -Goal: Print every tool call to the console. - -```kotlin -val observed = functionalAIAgent( - prompt = "...", - promptExecutor = exec, - model = model, - toolRegistry = tools, - featureContext = { - install(EventHandler) { - onToolCall { e -> println("Tool called: ${'$'}{e.tool.name}, args: ${'$'}{e.toolArgs}") } - } - } -) { input -> - var responses = requestLLMMultiple(input) - while (responses.containsToolCalls()) { - val pending = extractToolCalls(responses) - val results = executeMultipleTools(pending) - responses = sendMultipleToolResults(results) - } - responses.single().asAssistantMessage().content -} -``` - -Other features you can install this way include streaming tokens and tracing; see the related docs in the sidebar. - - -## 5) Keep context under control (history compression) -Long conversations can exceed the model’s context window. Use the token usage to decide when to compress history: - -```kotlin -var responses = requestLLMMultiple(input) - -while (responses.containsToolCalls()) { - if (latestTokenUsage() > 100_000) { - compressHistory() - } - val pending = extractToolCalls(responses) - val results = executeMultipleTools(pending) - responses = sendMultipleToolResults(results) -} -``` - -Use a threshold appropriate for your model and prompt size. - - -## Common recipes -- Return structured output - - Ask the LLM to format JSON and parse it; or use Structured Data API. -- Validate tool inputs - - Perform validation in tool functions and return clear error messages. -- One agent instance per request - - Each agent instance is single‑run at a time. Create new instances if you need concurrency. -- Custom Output type - - Change functionalAIAgent and return a data class from the loop. - - -## Troubleshooting & pitfalls -- “Agent is already running” - - FunctionalAIAgent prevents concurrent runs on the same instance. Don’t share one instance across parallel coroutines; create a fresh agent per run or await completion. -- Empty or unexpected model output - - Check your system prompt. Print intermediate responses. Consider adding few‑shot examples. -- Loop never ends - - Ensure you break when there are no tool calls; add guards/timeouts for safety. -- Context overflows - - Watch latestTokenUsage() and call compressHistory(). - - -## Reference (quick) -Constructors - -```kotlin -fun functionalAIAgent( - promptExecutor: PromptExecutor, - agentConfig: AIAgentConfigBase, - toolRegistry: ToolRegistry = ToolRegistry.EMPTY, - loop: suspend AIAgentFunctionalContext.(input: Input) -> Output -): AIAgent - -fun functionalAIAgent( - promptExecutor: PromptExecutor, - toolRegistry: ToolRegistry = ToolRegistry.EMPTY, - prompt: String = "", - model: LLModel = OpenAIModels.Chat.GPT4o, - featureContext: FeatureContext.() -> Unit = {}, - func: suspend AIAgentFunctionalContext.(input: Input) -> Output, -): AIAgent -``` - -Important types -- FunctionalAIAgent -- AIAgentFunctionalContext -- AIAgentConfig / AIAgentConfigBase -- PromptExecutor -- ToolRegistry -- FeatureContext and feature interfaces - -See source: agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/FunctionalAIAgent.kt diff --git a/docs/docs/complex-workflow-agents.md b/docs/docs/complex-workflow-agents.md index 8fd84409e8..787c81a456 100644 --- a/docs/docs/complex-workflow-agents.md +++ b/docs/docs/complex-workflow-agents.md @@ -3,6 +3,9 @@ In addition to single-run agents, the `AIAgent` class lets you build agents that handle complex workflows by defining custom strategies, tools, configurations, and custom input/output types. +!!! tip + If you are new to Koog and want to create the simplest agent, start with [Single-run agents](single-run-agents.md). + The process of creating and configuring such an agent typically includes the following steps: 1. Provide a prompt executor to communicate with the LLM. @@ -20,7 +23,7 @@ The process of creating and configuring such an agent typically includes the fol Use environment variables or a secure configuration management system to store your API keys. Avoid hardcoding API keys directly in your source code. -## Creating a single-run agent +## Creating a complex workflow agent ### 1. Add dependencies @@ -39,7 +42,7 @@ For all available installation methods, see [Installation](index.md#installation Prompt executors manage and run prompts. You can choose a prompt executor based on the LLM provider you plan to use. Also, you can create a custom prompt executor using one of the available LLM clients. -To learn more, see [Prompt executors](prompt-api.md#prompt-executors). +To learn more, see [Prompt executors](prompt-api.md#running-prompts-with-prompt-executors). For example, to provide the OpenAI prompt executor, you need to call the `simpleOpenAIExecutor` function and provide it with the API key required for authentication with the OpenAI service: @@ -115,6 +118,7 @@ val processNode by node { input -> } ``` + !!! tip There are also pre-defined nodes that you can use in your agent strategy. To learn more, see [Predefined nodes and components](nodes-and-components.md). @@ -165,6 +169,7 @@ edge(sourceNode forwardTo targetNode transformed { output -> edge(sourceNode forwardTo targetNode onCondition { it.isNotEmpty() } transformed { it.uppercase() }) ``` + #### 3.2. Implement the strategy To implement the agent strategy, call the `strategy` function and define nodes and edges. For example: @@ -391,7 +396,7 @@ fun main() { ## Working with structured data -The `AIAgent` can process structured data from LLM outputs. For more details, see [Structured data processing](structured-data.md). +The `AIAgent` can process structured data from LLM outputs. For more details, see [Structured data processing](structured-output.md). ## Using parallel tool calls diff --git a/docs/docs/functional-agents.md b/docs/docs/functional-agents.md new file mode 100644 index 0000000000..cd1605f342 --- /dev/null +++ b/docs/docs/functional-agents.md @@ -0,0 +1,280 @@ +# Functional agents + +Functional agents are lightweight AI agents that operate without building complex strategy graphs. +Instead, the agent logic is implemented as a lambda function that handles user input, interacts with an LLM, +optionally calls tools, and produces a final output. It can perform a single LLM call, process multiple LLM calls in sequence, or loop based on user input, as well as LLM and tool outputs. + +!!! tip + - If you already have a simple [single-run agent](single-run-agents.md) as your first MVP, but run into task-specific limitations, use a functional agent to prototype custom logic. You can implement custom control flows in plain Kotlin while still using most Koog features, including history compression and automatic state management. + - For production-grade needs, refactor your functional agent into a [complex workflow agent](complex-workflow-agents.md) with strategy graphs. This provides persistence with controllable rollbacks for fault-tolerance and advanced OpenTelemetry tracing with nested graph events. + +This page guides you through the steps necessary to create a minimal functional agent and extend it with tools. + +## Prerequisites + +Before you start, make sure that you have the following: + +- A working Kotlin/JVM project with Gradle. +- Java 17+ installed. +- A valid API key from the LLM provider used to implement an AI agent. For a list of all available providers, refer to [Overview](index.md). +- (Optional) Ollama installed and running locally if you use this provider. + +!!! tip + Use environment variables or a secure configuration management system to store your API keys. + Avoid hardcoding API keys directly in your source code. + +## Add dependencies + +The `AIAgent` class is the main class for creating agents in Koog. +Include the following dependency in your build configuration to use the class functionality: + +``` +dependencies { + implementation("ai.koog:koog-agents:VERSION") +} +``` +For all available installation methods, see [Installation](index.md#installation). + +## Create a minimal functional agent + +To create a minimal functional agent, do the following: + +1. Choose the input and output types that the agent handles and create a corresponding `AIAgent` instance. + In this guide, we use `AIAgent`, which means the agent receives and returns `String`. +2. Provide the required parameters, including a system prompt, prompt executor, and LLM. +3. Define the agent logic with a lambda function wrapped into the `functionalStrategy {...}` DSL method. + +Here is an example of a minimal functional agent that sends user text to a specified LLM and returns a single assistant message. + + + +```kotlin +// Create an AIAgent instance and provide a system prompt, prompt executor, and LLM +val mathAgent = AIAgent( + systemPrompt = "You are a precise math assistant.", + promptExecutor = simpleOllamaAIExecutor(), + llmModel = OllamaModels.Meta.LLAMA_3_2, + strategy = functionalStrategy { input -> // Define the agent logic + // Make one LLM call + val response = requestLLM(input) + // Extract and return the assistant message content from the response + response.asAssistantMessage().content + } +) + +// Run the agent with a user input and print the result +val result = mathAgent.run("What is 12 × 9?") +println(result) +``` + + +The agent can produce the following output: + +``` +The answer to 12 × 9 is 108. +``` + +This agent makes a single LLM call and returns the assistant message content. +You can extend the agent logic to handle multiple sequential LLM calls. For example: + + + +```kotlin +// Create an AIAgent instance and provide a system prompt, prompt executor, and LLM +val mathAgent = AIAgent( + systemPrompt = "You are a precise math assistant.", + promptExecutor = simpleOllamaAIExecutor(), + llmModel = OllamaModels.Meta.LLAMA_3_2, + strategy = functionalStrategy { input -> // Define the agent logic + // The first LLM call to produce an initial draft based on the user input + val draft = requestLLM("Draft: $input").asAssistantMessage().content + // The second LLM call to improve the draft by prompting the LLM again with the draft content + val improved = requestLLM("Improve and clarify: $draft").asAssistantMessage().content + // The final LLM call to format the improved text and return the final formatted result + requestLLM("Format the result as bold: $improved").asAssistantMessage().content + } +) + +// Run the agent with a user input and print the result +val result = mathAgent.run("What is 12 × 9?") +println(result) +``` + + +The agent can produce the following output: + +``` +When multiplying 12 by 9, we can break it down as follows: + +**12 (tens) × 9 = 108** + +Alternatively, we can also use the distributive property to calculate this: + +**(10 + 2) × 9** += **10 × 9 + 2 × 9** += **90 + 18** += **108** +``` + +## Add tools + +In many cases, a functional agent needs to complete specific tasks, such as reading and writing data or calling APIs. +In Koog, you expose such capabilities as tools and let the LLM call them in the agent logic. + +This chapter takes the minimal functional agent created above and demonstrates how to extend the agent logic using tools. + + +1) Create an annotation-based tool. For more details, see [Annotation-based tools](annotation-based-tools.md). + + +```kotlin +@LLMDescription("Simple multiplier") +class MathTools : ToolSet { + @Tool + @LLMDescription("Multiplies two numbers and returns the result") + fun multiply(a: Int, b: Int): Int { + val result = a * b + return result + } +} +``` + + +To learn more about available tools, refer to the [Tool overview](tools-overview.md). + +2) Register the tool to make it available to the agent. + + + +```kotlin +val toolRegistry = ToolRegistry { + tools(MathTools()) +} +``` + + +3) Pass the tool registry to the agent to enable the LLM to request and use the available tools. + +4) Extend the agent logic to identify tool calls, execute the requested tools, send their results back to the LLM, and repeat the process until no tool calls remain. + +!!! note + Use a loop only if the LLM continues to issue tool calls. + + + +```kotlin +val mathWithTools = AIAgent( + systemPrompt = "You are a precise math assistant. When multiplication is needed, use the multiplication tool.", + promptExecutor = simpleOllamaAIExecutor(), + llmModel = OllamaModels.Meta.LLAMA_3_2, + toolRegistry = toolRegistry, + strategy = functionalStrategy { input -> // Define the agent logic extended with tool calls + // Send the user input to the LLM + var responses = requestLLMMultiple(input) + + // Only loop while the LLM requests tools + while (responses.containsToolCalls()) { + // Extract tool calls from the response + val pendingCalls = extractToolCalls(responses) + // Execute the tools and return the results + val results = executeMultipleTools(pendingCalls) + // Send the tool results back to the LLM. The LLM may call more tools or return a final output + responses = sendMultipleToolResults(results) + } + + // When no tool calls remain, extract and return the assistant message content from the response + responses.single().asAssistantMessage().content + } +) + +// Run the agent with a user input and print the result +val reply = mathWithTools.run("Please multiply 12.5 and 4, then add 10 to the result.") +println(reply) +``` + + +The agent can produce the following output: + +``` +Here is the step-by-step solution: + +1. Multiply 12.5 and 4: + 12.5 × 4 = 50 + +2. Add 10 to the result: + 50 + 10 = 60 +``` + +## What's next + +- Learn how to return structured data using the [Structured output API](structured-output.md). +- Experiment with adding more [tools](tools-overview.md) to the agent. +- Improve observability with the [EventHandler](agent-events.md) feature. +- Learn how to handle long-running conversations with [History compression](history-compression.md). diff --git a/docs/docs/prompt-api.md b/docs/docs/prompt-api.md index 297c87118f..a500a19b46 100644 --- a/docs/docs/prompt-api.md +++ b/docs/docs/prompt-api.md @@ -222,7 +222,7 @@ To choose between clients and executors, consider the following factors: - Use LLM clients directly if you work with a single LLM provider and do not require advanced lifecycle management. To learm more, see [Running prompts with LLM clients](#running-prompts-with-llm-clients). - Use prompt executors if you need a higher level of abstraction for managing LLMs and their lifecycle, or if you want to run prompts with a consistent API across multiple providers and dynamically switch between them. - To learn more, see [Runnning prompts with prompt executors](#running-prompts-with-executors). + To learn more, see [Runnning prompts with prompt executors](#running-prompts-with-prompt-executors). !!!note Both the LLM clients and prompt executors let you stream responses, generate multiple choices, and run content moderation. @@ -408,7 +408,7 @@ For faster setup, Koog provides the following ready-to-use executor implementati - `simpleAnthropicExecutor` for executing prompts with Anthropic models. - `simpleGoogleAIExecutor` for executing prompts with Google models. - `simpleOpenRouterExecutor` for executing prompts with OpenRouter. - - `simpleOllamaExecutor` for executing prompts with Ollama. + - `simpleOllamaAIExecutor` for executing prompts with Ollama. - Multi-provider executor: - `DefaultMultiLLMPromptExecutor` which is an implementation of `MultiLLMPromptExecutor` that supports OpenAI, Anthropic, and Google providers. diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 99140cfad1..7593012fa0 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -6,8 +6,8 @@ nav: - Key concepts: key-concepts.md - Getting started: - Single-run agents: single-run-agents.md - - Act Agent API: act-ai-agent.md - Complex workflow agents: complex-workflow-agents.md + - Functional agents: functional-agents.md - Prompt API: prompt-api.md - Tools: - Overview: tools-overview.md @@ -98,7 +98,7 @@ plugins: - key-concepts.md: This page provides explanations of key terms and concepts related to Koog and agentic development. Getting started: - single-run-agents.md: This guide lets you quickly build a single-run agent with minimum required configuration. - - act-ai-agent.md: This guide shows you how to build a lightweight, non‑graph agent that you control with a simple loop. + - functional-agents.md: This guide shows you how to build a lightweight, non‑graph agent that you control with a simple loop. - complex-workflow-agents.md: This page explains how you can create agents that handle complex workflows by defining custom strategies, tools, configurations, and custom input and output types. Prompt API: - prompt-api.md: This page includes detailed instructions about the use of Prompt API, which provides a comprehensive toolkit for interacting with Large Language Models in production applications. From 58ba445e9348d0de4fbfb2fbb6ca0d4781cdb22b Mon Sep 17 00:00:00 2001 From: Anastasiia Zarechneva <49490937+aozherelyeva@users.noreply.github.com> Date: Wed, 1 Oct 2025 10:49:30 +0200 Subject: [PATCH 05/52] Add integration tests for executing tools with primitive types combinations (#889) ## Motivation and Context As a follow-up to our last week discussion, it's good to test different input/output combinations in tool descriptors after the merge of #791. ## Breaking Changes None. --- #### Type of the changes - [ ] New feature (non-breaking change which adds functionality) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Tests improvement - [ ] Refactoring #### Checklist - [x] The pull request has a description of the proposed change - [x] I read the [Contributing Guidelines](https://github.com/JetBrains/koog/blob/main/CONTRIBUTING.md) before opening the pull request - [x] The pull request uses **`develop`** as the base branch - [x] Tests for the changes have been added - [x] All new and existing tests passed ##### Additional steps for pull requests adding a new feature - [x] An issue describing the proposed change exists - [x] The pull request includes a link to the issue - [x] The change was discussed and approved in the issue - [ ] Docs have been added / updated --- .../executor/ToolDescriptorIntegrationTest.kt | 351 ++++++++++++++++++ 1 file changed, 351 insertions(+) create mode 100644 integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/ToolDescriptorIntegrationTest.kt diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/ToolDescriptorIntegrationTest.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/ToolDescriptorIntegrationTest.kt new file mode 100644 index 0000000000..d3926a908c --- /dev/null +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/ToolDescriptorIntegrationTest.kt @@ -0,0 +1,351 @@ +package ai.koog.integration.tests.executor + +import ai.koog.agents.core.tools.Tool +import ai.koog.integration.tests.utils.Models +import ai.koog.integration.tests.utils.RetryUtils.withRetry +import ai.koog.integration.tests.utils.TestUtils +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient +import ai.koog.prompt.executor.clients.anthropic.AnthropicModels +import ai.koog.prompt.executor.clients.bedrock.BedrockModels +import ai.koog.prompt.executor.clients.google.GoogleLLMClient +import ai.koog.prompt.executor.clients.google.GoogleModels +import ai.koog.prompt.executor.clients.openai.OpenAILLMClient +import ai.koog.prompt.executor.clients.openai.OpenAIModels +import ai.koog.prompt.executor.clients.openrouter.OpenRouterModels +import ai.koog.prompt.llm.LLMCapability +import ai.koog.prompt.llm.LLMProvider +import ai.koog.prompt.llm.LLModel +import ai.koog.prompt.message.Message +import ai.koog.prompt.params.LLMParams +import ai.koog.prompt.params.LLMParams.ToolChoice +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.KSerializer +import kotlinx.serialization.builtins.serializer +import org.junit.jupiter.api.Assumptions.assumeTrue +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.Arguments +import org.junit.jupiter.params.provider.MethodSource +import java.util.stream.Stream +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.seconds + +class ToolDescriptorIntegrationTest { + + enum class ToolName(val value: String, val displayName: String, val testUserMessage: String) { + INT_TO_STRING( + "int_to_string", + "Tool", + "Convert the number 42 to its string representation using the tool." + ), + STRING_TO_INT("string_to_int", "Tool", "Get the length of the string 'hello' using the tool."), + INT_TO_INT("int_to_int", "Tool", "Double the number 21 using the tool."), + STRING_TO_STRING( + "string_to_string", + "Tool", + "Convert 'hello world' to uppercase using the tool." + ), + BOOLEAN_TO_STRING( + "boolean_to_string", + "Tool", + "Convert the boolean value true to its string representation using the tool." + ), + STRING_TO_BOOLEAN( + "string_to_boolean", + "Tool", + "Convert the string 'true' to a boolean using the tool." + ), + DOUBLE_TO_INT( + "double_to_int", + "Tool", + "Convert the double value 3.7 to an integer using the tool." + ), + INT_TO_DOUBLE("int_to_double", "Tool", "Convert the integer value 42 to a double using the tool."), + LONG_TO_DOUBLE( + "long_to_double", + "Tool", + "Convert the long value 100 to a double with decimal places using the tool." + ), + DOUBLE_TO_LONG( + "double_to_long", + "Tool", + "Convert the double value 15.8 to a long using the tool." + ), + FLOAT_TO_BOOLEAN( + "float_to_boolean", + "Tool", + "Convert the float value 2.5 to a boolean using the tool." + ), + BOOLEAN_TO_FLOAT( + "boolean_to_float", + "Tool", + "Convert the boolean value true to a float using the tool." + ), + LONG_TO_INT("long_to_int", "Tool", "Convert the long value 12345 to an integer using the tool."), + INT_TO_LONG("int_to_long", "Tool", "Convert the integer value 789 to a long using the tool."), + FLOAT_TO_STRING( + "float_to_string", + "Tool", + "Convert the float value 3.14 to its string representation using the tool." + ), + STRING_TO_FLOAT( + "string_to_float", + "Tool", + "Convert the string 'hello' to a float based on its length using the tool." + ), + DOUBLE_TO_STRING( + "double_to_string", + "Tool", + "Convert the double value 2.718 to its string representation using the tool." + ), + STRING_TO_DOUBLE( + "string_to_double", + "Tool", + "Convert the string 'world' to a double based on its length using the tool." + ); + + override fun toString(): String = displayName + } + + companion object { + @JvmStatic + fun allModels(): Stream { + return Stream.of( + OpenAIModels.CostOptimized.GPT4_1Mini, + AnthropicModels.Sonnet_3_7, + GoogleModels.Gemini2_5Flash, + BedrockModels.AnthropicClaude35Haiku, + OpenRouterModels.Mistral7B, + ) + } + + @JvmStatic + fun primitiveToolAndModelCombinations(): Stream { + val primitiveTools = listOf( + IntToStringTool(), + StringToIntTool(), + IntToIntTool(), + StringToStringTool(), + BooleanToStringTool(), + StringToBooleanTool(), + DoubleToIntTool(), + IntToDoubleTool(), + LongToDoubleTool(), + DoubleToLongTool(), + FloatToBooleanTool(), + BooleanToFloatTool(), + LongToIntTool(), + IntToLongTool(), + FloatToStringTool(), + StringToFloatTool(), + DoubleToStringTool(), + StringToDoubleTool() + ) + + return allModels().flatMap { model -> + primitiveTools.map { tool -> + Arguments.arguments(tool, model) + }.stream() + } + } + } + + abstract class TestTool : Tool() { + abstract val toolName: ToolName + override val name: String get() = toolName.value + override fun toString(): String = toolName.displayName + } + + class IntToStringTool : TestTool() { + override val toolName = ToolName.INT_TO_STRING + override val argsSerializer: KSerializer = Int.serializer() + override val resultSerializer: KSerializer = String.serializer() + override val description: String = "Converts an integer to its string representation" + + override suspend fun execute(args: Int): String = "Number: $args" + } + + class StringToIntTool : TestTool() { + override val toolName = ToolName.STRING_TO_INT + override val argsSerializer: KSerializer = String.serializer() + override val resultSerializer: KSerializer = Int.serializer() + override val description: String = "Converts a string to an integer" + + override suspend fun execute(args: String): Int = args.length + } + + class IntToIntTool : TestTool() { + override val toolName = ToolName.INT_TO_INT + override val argsSerializer: KSerializer = Int.serializer() + override val resultSerializer: KSerializer = Int.serializer() + override val description: String = "Doubles an integer value" + + override suspend fun execute(args: Int): Int = args * 2 + } + + class StringToStringTool : TestTool() { + override val toolName = ToolName.STRING_TO_STRING + override val argsSerializer: KSerializer = String.serializer() + override val resultSerializer: KSerializer = String.serializer() + override val description: String = "Converts string to uppercase" + + override suspend fun execute(args: String): String = args.uppercase() + } + + class BooleanToStringTool : TestTool() { + override val toolName = ToolName.BOOLEAN_TO_STRING + override val argsSerializer: KSerializer = Boolean.serializer() + override val resultSerializer: KSerializer = String.serializer() + override val description: String = "Converts boolean to descriptive string" + + override suspend fun execute(args: Boolean): String = if (args) "TRUE_VALUE" else "FALSE_VALUE" + } + + class DoubleToIntTool : TestTool() { + override val toolName = ToolName.DOUBLE_TO_INT + override val argsSerializer: KSerializer = Double.serializer() + override val resultSerializer: KSerializer = Int.serializer() + override val description: String = "Converts double to integer by rounding" + + override suspend fun execute(args: Double): Int = args.toInt() + } + + class LongToDoubleTool : TestTool() { + override val toolName = ToolName.LONG_TO_DOUBLE + override val argsSerializer: KSerializer = Long.serializer() + override val resultSerializer: KSerializer = Double.serializer() + override val description: String = "Converts long to double with decimal places" + + override suspend fun execute(args: Long): Double = args + 0.5 + } + + class FloatToBooleanTool : TestTool() { + override val toolName = ToolName.FLOAT_TO_BOOLEAN + override val argsSerializer: KSerializer = Float.serializer() + override val resultSerializer: KSerializer = Boolean.serializer() + override val description: String = "Converts float to boolean (positive = true)" + + override suspend fun execute(args: Float): Boolean = args > 0f + } + + class StringToBooleanTool : TestTool() { + override val toolName = ToolName.STRING_TO_BOOLEAN + override val argsSerializer: KSerializer = String.serializer() + override val resultSerializer: KSerializer = Boolean.serializer() + override val description: String = "Converts string to boolean ('true' = true, others = false)" + + override suspend fun execute(args: String): Boolean = args.equals("true", ignoreCase = true) + } + + class IntToDoubleTool : TestTool() { + override val toolName = ToolName.INT_TO_DOUBLE + override val argsSerializer: KSerializer = Int.serializer() + override val resultSerializer: KSerializer = Double.serializer() + override val description: String = "Converts integer to double" + + override suspend fun execute(args: Int): Double = args.toDouble() + } + + class DoubleToLongTool : TestTool() { + override val toolName = ToolName.DOUBLE_TO_LONG + override val argsSerializer: KSerializer = Double.serializer() + override val resultSerializer: KSerializer = Long.serializer() + override val description: String = "Converts double to long by rounding" + + override suspend fun execute(args: Double): Long = args.toLong() + } + + class BooleanToFloatTool : TestTool() { + override val toolName = ToolName.BOOLEAN_TO_FLOAT + override val argsSerializer: KSerializer = Boolean.serializer() + override val resultSerializer: KSerializer = Float.serializer() + override val description: String = "Converts boolean to float (true = 1.0f, false = 0.0f)" + + override suspend fun execute(args: Boolean): Float = if (args) 1.0f else 0.0f + } + + class LongToIntTool : TestTool() { + override val toolName = ToolName.LONG_TO_INT + override val argsSerializer: KSerializer = Long.serializer() + override val resultSerializer: KSerializer = Int.serializer() + override val description: String = "Converts long to integer" + + override suspend fun execute(args: Long): Int = args.toInt() + } + + class IntToLongTool : TestTool() { + override val toolName = ToolName.INT_TO_LONG + override val argsSerializer: KSerializer = Int.serializer() + override val resultSerializer: KSerializer = Long.serializer() + override val description: String = "Converts integer to long" + + override suspend fun execute(args: Int): Long = args.toLong() + } + + class FloatToStringTool : TestTool() { + override val toolName = ToolName.FLOAT_TO_STRING + override val argsSerializer: KSerializer = Float.serializer() + override val resultSerializer: KSerializer = String.serializer() + override val description: String = "Converts float to string" + + override suspend fun execute(args: Float): String = "Float: $args" + } + + class StringToFloatTool : TestTool() { + override val toolName = ToolName.STRING_TO_FLOAT + override val argsSerializer: KSerializer = String.serializer() + override val resultSerializer: KSerializer = Float.serializer() + override val description: String = "Converts string length to float" + + override suspend fun execute(args: String): Float = args.length.toFloat() + } + + class DoubleToStringTool : TestTool() { + override val toolName = ToolName.DOUBLE_TO_STRING + override val argsSerializer: KSerializer = Double.serializer() + override val resultSerializer: KSerializer = String.serializer() + override val description: String = "Converts double to string" + + override suspend fun execute(args: Double): String = "Double: $args" + } + + class StringToDoubleTool : TestTool() { + override val toolName = ToolName.STRING_TO_DOUBLE + override val argsSerializer: KSerializer = String.serializer() + override val resultSerializer: KSerializer = Double.serializer() + override val description: String = "Converts string length to double" + + override suspend fun execute(args: String): Double = args.length.toDouble() + } + + @ParameterizedTest(name = "{0} with {1}") + @MethodSource("primitiveToolAndModelCombinations") + fun integration_testPrimitiveTools(tool: Tool<*, *>, model: LLModel) = runTest(timeout = 300.seconds) { + Models.assumeAvailable(model.provider) + assumeTrue(model.capabilities.contains(LLMCapability.Tools), "Model $model does not support tools") + + val client = when (model.provider) { + is LLMProvider.Anthropic -> AnthropicLLMClient(TestUtils.readTestAnthropicKeyFromEnv()) + is LLMProvider.Google -> GoogleLLMClient(TestUtils.readTestGoogleAIKeyFromEnv()) + else -> OpenAILLMClient(TestUtils.readTestOpenAIKeyFromEnv()) + } + + val testTool = tool as TestTool<*, *> + val prompt = prompt(testTool.toolName.value, params = LLMParams(toolChoice = ToolChoice.Required)) { + system("You are a helpful assistant with access to tools. ALWAYS use the available tool.") + user(testTool.toolName.testUserMessage) + } + + withRetry { + val response = client.execute(prompt, model, listOf(tool.descriptor)) + assertTrue(response.isNotEmpty(), "Response should not be empty for tool ${tool.name} with model $model") + val hasToolCall = response.any { message -> + message is Message.Tool.Call && message.tool == tool.name + } + assertTrue( + hasToolCall, + "Response should contain a Tool.Call message for tool '${tool.name}' with model $model." + ) + } + } +} From 5394f6a2acf55140b17cc40c07bbc58903155fd3 Mon Sep 17 00:00:00 2001 From: Ruben Cagnie Date: Wed, 1 Oct 2025 05:39:08 -0400 Subject: [PATCH 06/52] Support tool calling strategy in structured output (#829) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation and Context Add new strategy `structuredOutputWithToolsStrategy` that combines structured output capabilities with tool execution, enabling agents to generate structured responses while leveraging external tools. This implementation includes: - New `nodeSetStructuredOutput` node for configuring structured output in agent graphs - `parseResponseToStructuredResponse` method in AIAgentLLMSession for parsing structured responses - Support for both sequential and parallel tool execution - Example demonstrating weather forecast generation with tools and structured output The strategy follows a pipeline: set structured output → execute input → call LLM → handle tool calls or transform structured response → return final output. Fixes #400 ## Breaking Changes N/A --- #### Type of the changes - [x] New feature (non-breaking change which adds functionality) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Tests improvement - [ ] Refactoring #### Checklist - [x] The pull request has a description of the proposed change - [x] I read the [Contributing Guidelines](https://github.com/JetBrains/koog/blob/main/CONTRIBUTING.md) before opening the pull request - [x] The pull request uses **`develop`** as the base branch - [x] Tests for the changes have been added - [x] All new and existing tests passed ##### Additional steps for pull requests adding a new feature - [x] An issue describing the proposed change exists - [x] The pull request includes a link to the issue - [ ] The change was discussed and approved in the issue - [x] Docs have been added / updated --------- Co-authored-by: Ruben Cagnie --- .../core/agent/session/AIAgentLLMSession.kt | 19 ++ .../agents/core/dsl/extension/AIAgentNodes.kt | 23 ++ .../AIAgentLLMSessionStructuredOutputTest.kt | 222 ++++++++++++ .../core/dsl/extension/AIAgentNodesTest.kt | 98 ++++++ .../agents/ext/agent/AIAgentStrategies.kt | 79 +++++ .../agents/ext/agent/AIAgentStrategiesTest.kt | 153 +++++++++ ...tructuredOutputWithToolsIntegrationTest.kt | 323 ++++++++++++++++++ .../AdvancedWithStandardSchema.kt | 199 +---------- .../AdvancedWithStandardSchemaAndTools.kt | 135 ++++++++ .../models/FullWeatherForecast.kt | 195 +++++++++++ .../structuredoutput/tools/WeatherTools.kt | 26 ++ .../structure/PromptExecutorExtensions.kt | 121 +++++-- 12 files changed, 1367 insertions(+), 226 deletions(-) create mode 100644 agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMSessionStructuredOutputTest.kt create mode 100644 agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/StructuredOutputWithToolsIntegrationTest.kt create mode 100644 examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/AdvancedWithStandardSchemaAndTools.kt create mode 100644 examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/models/FullWeatherForecast.kt create mode 100644 examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/tools/WeatherTools.kt diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMSession.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMSession.kt index 296b2df3e6..c86535b79e 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMSession.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMSession.kt @@ -15,6 +15,7 @@ import ai.koog.prompt.structure.StructureFixingParser import ai.koog.prompt.structure.StructuredOutputConfig import ai.koog.prompt.structure.StructuredResponse import ai.koog.prompt.structure.executeStructured +import ai.koog.prompt.structure.parseResponseToStructuredResponse import kotlinx.serialization.KSerializer import kotlinx.serialization.serializer @@ -312,6 +313,24 @@ public sealed class AIAgentLLMSession( fixingParser = fixingParser, ) + /** + * Parses a structured response from the language model using the specified configuration. + * + * This function takes a response message and a structured output configuration, + * parses the response content based on the defined structure, and returns + * a structured response containing the parsed data and the original message. + * + * @param response The response message from the language model that contains the content to be parsed. + * The message is expected to match the defined structured output. + * @param config The configuration defining the expected structure and additional parsing behavior. + * It includes options such as structure definitions and optional parsers for error handling. + * @return A structured response containing the parsed data of type `T` along with the original message. + */ + public suspend fun parseResponseToStructuredResponse( + response: Message.Assistant, + config: StructuredOutputConfig + ): StructuredResponse = executor.parseResponseToStructuredResponse(response, config, model) + /** * Sends a request to the language model, potentially receiving multiple choices, * and returns a list of choices from the model. diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodes.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodes.kt index 89ff0437f4..4bc9bf6561 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodes.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodes.kt @@ -507,3 +507,26 @@ public inline fun AIAgentSubgraphBuilderBase< toolResult } } + +/** + * Creates a node that sets up a structured output for an AI agent subgraph. + * + * The method defines a new node with a configurable structured output schema + * that will be applied during the AI agent's message processing. The schema + * is determined by the given configuration. + * + * @param name An optional name for the node. If null, a default name will be assigned. + * @param config The configuration that defines the structured output format and schema. + * @return An instance of [AIAgentNodeDelegate] representing the constructed node. + */ +@AIAgentBuilderDslMarker +public inline fun AIAgentSubgraphBuilderBase<*, *>.nodeSetStructuredOutput( + name: String? = null, + config: StructuredOutputConfig +): AIAgentNodeDelegate = + node(name) { message -> + llm.writeSession { + prompt = config.updatePrompt(model, prompt) + message + } + } diff --git a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMSessionStructuredOutputTest.kt b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMSessionStructuredOutputTest.kt new file mode 100644 index 0000000000..dd24d3f6a3 --- /dev/null +++ b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/session/AIAgentLLMSessionStructuredOutputTest.kt @@ -0,0 +1,222 @@ +package ai.koog.agents.core.agent.session + +import ai.koog.agents.core.CalculatorChatExecutor.testClock +import ai.koog.agents.core.agent.context.AIAgentLLMContext +import ai.koog.agents.core.agent.context.AgentTestBase +import ai.koog.agents.core.tools.annotations.LLMDescription +import ai.koog.agents.testing.tools.getMockExecutor +import ai.koog.agents.testing.tools.mockLLMAnswer +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.executor.clients.openai.OpenAIModels +import ai.koog.prompt.message.Message +import ai.koog.prompt.message.ResponseMetaInfo +import ai.koog.prompt.structure.StructuredOutput +import ai.koog.prompt.structure.StructuredOutputConfig +import ai.koog.prompt.structure.json.JsonStructuredData +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.Serializable +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull + +class AIAgentLLMSessionStructuredOutputTest : AgentTestBase() { + + @Serializable + data class TestStructure( + @property:LLMDescription("The name field") + val name: String, + @property:LLMDescription("The age field") + val age: Int, + @property:LLMDescription("Optional description field") + val description: String? = null + ) + + @Test + fun testParseResponseToStructuredResponse() = runTest { + val structure = JsonStructuredData.createJsonStructure() + val config = StructuredOutputConfig( + default = StructuredOutput.Manual(structure) + ) + + val mockExecutor = getMockExecutor { + mockLLMAnswer("Test response").asDefaultResponse + } + + val prompt = prompt("test") { + system("Test system message") + } + + val llmContext = AIAgentLLMContext( + tools = emptyList(), + prompt = prompt, + model = OpenAIModels.CostOptimized.GPT4oMini, + promptExecutor = mockExecutor, + environment = createTestEnvironment(), + config = createTestConfig(), + clock = testClock + ) + + val context = createTestContext( + llmContext = llmContext + ) + + val jsonResponse = """ + { + "name": "John Doe", + "age": 30, + "description": "A test person" + } + """.trimIndent() + + val assistantMessage = Message.Assistant( + content = jsonResponse, + metaInfo = ResponseMetaInfo.create(testClock) + ) + + val result = context.llm.writeSession { + parseResponseToStructuredResponse(assistantMessage, config) + } + + assertNotNull(result) + assertEquals("John Doe", result.structure.name) + assertEquals(30, result.structure.age) + assertEquals("A test person", result.structure.description) + assertEquals(assistantMessage, result.message) + } + + @Test + fun testParseResponseToStructuredResponseWithNullableField() = runTest { + val structure = JsonStructuredData.createJsonStructure() + val config = StructuredOutputConfig( + default = StructuredOutput.Manual(structure) + ) + + val mockExecutor = getMockExecutor { + mockLLMAnswer("Test response").asDefaultResponse + } + + val prompt = prompt("test") { + system("Test system message") + } + + val llmContext = AIAgentLLMContext( + tools = emptyList(), + prompt = prompt, + model = OpenAIModels.CostOptimized.GPT4oMini, + promptExecutor = mockExecutor, + environment = createTestEnvironment(), + config = createTestConfig(), + clock = testClock + ) + + val context = createTestContext( + llmContext = llmContext + ) + + val jsonResponse = """ + { + "name": "Jane Doe", + "age": 25 + } + """.trimIndent() + + val assistantMessage = Message.Assistant( + content = jsonResponse, + metaInfo = ResponseMetaInfo.create(testClock) + ) + + val result = context.llm.writeSession { + parseResponseToStructuredResponse(assistantMessage, config) + } + + assertNotNull(result) + assertEquals("Jane Doe", result.structure.name) + assertEquals(25, result.structure.age) + assertEquals(null, result.structure.description) + assertEquals(assistantMessage, result.message) + } + + @Test + fun testParseResponseToStructuredResponseComplexStructure() = runTest { + @Serializable + data class Address( + @property:LLMDescription("Street name") + val street: String, + @property:LLMDescription("City name") + val city: String + ) + + @Serializable + data class ComplexStructure( + @property:LLMDescription("User identifier") + val id: Int, + @property:LLMDescription("List of addresses") + val addresses: List
, + @property:LLMDescription("Tags associated with the user") + val tags: Set + ) + + val structure = JsonStructuredData.createJsonStructure() + val config = StructuredOutputConfig( + default = StructuredOutput.Manual(structure) + ) + + val mockExecutor = getMockExecutor { + mockLLMAnswer("Test response").asDefaultResponse + } + + val prompt = prompt("test") { + system("Test system message") + } + + val llmContext = AIAgentLLMContext( + tools = emptyList(), + prompt = prompt, + model = OpenAIModels.CostOptimized.GPT4oMini, + promptExecutor = mockExecutor, + environment = createTestEnvironment(), + config = createTestConfig(), + clock = testClock + ) + + val context = createTestContext( + llmContext = llmContext + ) + + val jsonResponse = """ + { + "id": 123, + "addresses": [ + { + "street": "123 Main St", + "city": "New York" + }, + { + "street": "456 Oak Ave", + "city": "Los Angeles" + } + ], + "tags": ["vip", "premium", "verified"] + } + """.trimIndent() + + val assistantMessage = Message.Assistant( + content = jsonResponse, + metaInfo = ResponseMetaInfo.create(testClock) + ) + + val result = context.llm.writeSession { + parseResponseToStructuredResponse(assistantMessage, config) + } + + assertNotNull(result) + assertEquals(123, result.structure.id) + assertEquals(2, result.structure.addresses.size) + assertEquals("123 Main St", result.structure.addresses[0].street) + assertEquals("New York", result.structure.addresses[0].city) + assertEquals("456 Oak Ave", result.structure.addresses[1].street) + assertEquals("Los Angeles", result.structure.addresses[1].city) + assertEquals(setOf("vip", "premium", "verified"), result.structure.tags) + assertEquals(assistantMessage, result.message) + } +} diff --git a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodesTest.kt b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodesTest.kt index 461118a47e..eecf5baf70 100644 --- a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodesTest.kt +++ b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodesTest.kt @@ -9,12 +9,18 @@ import ai.koog.agents.features.eventHandler.feature.EventHandler import ai.koog.agents.testing.tools.DummyTool import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.agents.testing.tools.mockLLMAnswer +import ai.koog.prompt.dsl.Prompt import ai.koog.prompt.dsl.prompt import ai.koog.prompt.executor.clients.openai.OpenAIModels import ai.koog.prompt.llm.OllamaModels +import ai.koog.prompt.structure.StructuredOutput +import ai.koog.prompt.structure.StructuredOutputConfig +import ai.koog.prompt.structure.json.JsonStructuredData import kotlinx.coroutines.test.runTest +import kotlinx.serialization.Serializable import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertNotNull import kotlin.test.assertTrue class AIAgentNodesTest { @@ -137,4 +143,96 @@ class AIAgentNodesTest { "Should have at least 3 execution events (agent finished, node transitions)" ) } + + @Test + fun testNodeSetStructuredOutput() = runTest { + @Serializable + data class TestOutput( + val message: String, + val code: Int + ) + + // Test Manual mode + val manualStructure = JsonStructuredData.createJsonStructure() + val manualConfig = StructuredOutputConfig( + default = StructuredOutput.Manual(manualStructure) + ) + + var capturedPrompt: Prompt? = null + + val manualStrategy = strategy("test-manual") { + val setStructuredOutput by nodeSetStructuredOutput(config = manualConfig) + val checkPrompt by node { input -> + capturedPrompt = llm.prompt + input + } + + edge(nodeStart forwardTo setStructuredOutput) + edge(setStructuredOutput forwardTo checkPrompt) + edge(checkPrompt forwardTo nodeFinish) + } + + val testExecutor = getMockExecutor { + mockLLMAnswer("Test").asDefaultResponse + } + + val agentConfig = AIAgentConfig( + prompt = prompt("test") {}, + model = OpenAIModels.CostOptimized.GPT4oMini, + maxAgentIterations = 5 + ) + + val manualAgent = AIAgent( + promptExecutor = testExecutor, + strategy = manualStrategy, + agentConfig = agentConfig, + toolRegistry = ToolRegistry { } + ) + + manualAgent.run("Test input") + + // Manual mode: schema should not be set, user message should be added + assertNotNull(capturedPrompt, "Prompt should be captured") + assertEquals(null, capturedPrompt!!.params.schema, "Schema should not be set for Manual config") + assertTrue( + capturedPrompt!!.messages.any { it is ai.koog.prompt.message.Message.User }, + "Should have user message with instructions for Manual config" + ) + + // Test Native mode + val nativeStructure = JsonStructuredData.createJsonStructure() + val nativeConfig = StructuredOutputConfig( + default = StructuredOutput.Native(nativeStructure) + ) + + val nativeStrategy = strategy("test-native") { + val setStructuredOutput by nodeSetStructuredOutput(config = nativeConfig) + val checkPrompt by node { input -> + capturedPrompt = llm.prompt + input + } + + edge(nodeStart forwardTo setStructuredOutput) + edge(setStructuredOutput forwardTo checkPrompt) + edge(checkPrompt forwardTo nodeFinish) + } + + val nativeAgent = AIAgent( + promptExecutor = testExecutor, + strategy = nativeStrategy, + agentConfig = AIAgentConfig( + prompt = prompt("test") {}, + model = OpenAIModels.CostOptimized.GPT4oMini, + maxAgentIterations = 5 + ), + toolRegistry = ToolRegistry { } + ) + + nativeAgent.run("Test input") + + // Native mode: schema should be set + assertNotNull(capturedPrompt, "Prompt should be captured") + assertNotNull(capturedPrompt!!.params.schema, "Schema should be set for Native config") + assertEquals(nativeStructure.schema, capturedPrompt!!.params.schema, "Schema should match structure's schema") + } } diff --git a/agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentStrategies.kt b/agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentStrategies.kt index 6952929553..70bf252071 100644 --- a/agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentStrategies.kt +++ b/agents/agents-ext/src/commonMain/kotlin/ai/koog/agents/ext/agent/AIAgentStrategies.kt @@ -1,17 +1,25 @@ package ai.koog.agents.ext.agent +import ai.koog.agents.core.agent.context.AIAgentGraphContextBase import ai.koog.agents.core.agent.entity.AIAgentGraphStrategy import ai.koog.agents.core.agent.entity.createStorageKey import ai.koog.agents.core.dsl.builder.forwardTo import ai.koog.agents.core.dsl.builder.strategy +import ai.koog.agents.core.dsl.extension.nodeExecuteMultipleTools import ai.koog.agents.core.dsl.extension.nodeExecuteTool import ai.koog.agents.core.dsl.extension.nodeLLMRequest +import ai.koog.agents.core.dsl.extension.nodeLLMRequestMultiple +import ai.koog.agents.core.dsl.extension.nodeLLMSendMultipleToolResults import ai.koog.agents.core.dsl.extension.nodeLLMSendToolResult +import ai.koog.agents.core.dsl.extension.nodeSetStructuredOutput import ai.koog.agents.core.dsl.extension.onAssistantMessage +import ai.koog.agents.core.dsl.extension.onMultipleAssistantMessages +import ai.koog.agents.core.dsl.extension.onMultipleToolCalls import ai.koog.agents.core.dsl.extension.onToolCall import ai.koog.agents.core.environment.ReceivedToolResult import ai.koog.agents.core.environment.result import ai.koog.prompt.message.Message +import ai.koog.prompt.structure.StructuredOutputConfig // FIXME improve this strategy to use Message.Assistant to chat, it works better than tools @@ -164,3 +172,74 @@ public fun reActStrategy( edge(nodeExecuteTool forwardTo nodeCallLLMReason) edge(nodeCallLLMReason forwardTo nodeCallLLM) } + +/** + * Defines a strategy for handling structured output with tools integration using specified configuration and execution logic. + * + * This strategy facilitates a structured pipeline for generating outputs using tools and large language models (LLMs), + * enabling transformations between input, intermediate results, and structured output based on the provided configuration and execution behavior. + * + * @param Output The type of the structured output generated by the strategy. + * @param config The configuration for structured output processing, specifying schema, providers, and optional error handling mechanisms. + */ +public inline fun structuredOutputWithToolsStrategy( + config: StructuredOutputConfig, + parallelTools: Boolean = false +): AIAgentGraphStrategy = structuredOutputWithToolsStrategy( + config, + parallelTools +) { it } + +/** + * Defines a strategy for handling structured output with tools integration using specified configuration and execution logic. + * + * This strategy facilitates a structured pipeline for generating outputs using tools and large language models (LLMs), + * enabling transformations between input, intermediate results, and structured output based on the provided configuration and execution behavior. + * + * @param Input The type of the input to be processed by the strategy. + * @param Output The type of the structured output generated by the strategy. + * @param config The configuration for structured output processing, specifying schema, providers, and optional error handling mechanisms. + * @param transform A suspendable function that accepts the input of type `Input` and produces a string output + * that serves as the input for further processing in the structured output pipeline. + */ +public inline fun structuredOutputWithToolsStrategy( + config: StructuredOutputConfig, + parallelTools: Boolean = false, + noinline transform: suspend AIAgentGraphContextBase.(input: Input) -> String +): AIAgentGraphStrategy = strategy("structured_output_with_tools_strategy") { + val setStructuredOutput by nodeSetStructuredOutput(config = config) + val transformInput by node { transform(it) } + val callLLM by nodeLLMRequestMultiple() + val executeTools by nodeExecuteMultipleTools(parallelTools = parallelTools) + val sendToolResult by nodeLLMSendMultipleToolResults() + val transformToStructuredOutput by node { response -> + llm.writeSession { + parseResponseToStructuredResponse(response, config).structure + } + } + + // Set the structured output, get the input and then call the llm + nodeStart then setStructuredOutput then transformInput then callLLM + + // On tools + edge(callLLM forwardTo executeTools onMultipleToolCalls { true }) + edge(executeTools forwardTo sendToolResult) + + // On assistant messages + edge( + callLLM forwardTo transformToStructuredOutput + onMultipleAssistantMessages { true } + transformed { it.single() } + ) + + // Post tool result + edge(sendToolResult forwardTo executeTools onMultipleToolCalls { true }) + edge( + sendToolResult forwardTo transformToStructuredOutput + onMultipleAssistantMessages { true } + transformed { it.first() } + ) + + // Finish + transformToStructuredOutput then nodeFinish +} diff --git a/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/AIAgentStrategiesTest.kt b/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/AIAgentStrategiesTest.kt index fc2a0cce9d..6bbfd73103 100644 --- a/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/AIAgentStrategiesTest.kt +++ b/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/AIAgentStrategiesTest.kt @@ -1,9 +1,15 @@ package ai.koog.agents.ext.agent +import ai.koog.agents.core.tools.annotations.LLMDescription +import ai.koog.prompt.structure.StructuredOutput +import ai.koog.prompt.structure.StructuredOutputConfig +import ai.koog.prompt.structure.json.JsonStructuredData import kotlinx.coroutines.test.runTest +import kotlinx.serialization.Serializable import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFailsWith +import kotlin.test.assertNotNull class AIAgentStrategiesTest { private val defaultName = "re_act" @@ -43,4 +49,151 @@ class AIAgentStrategiesTest { reActStrategy(reasoningInterval = -1) } } + + @Test + fun testStructuredOutputWithToolsStrategyDefaultName() = runTest { + @Serializable + data class TestOutput( + @property:LLMDescription("Test field") + val field: String + ) + + val structure = JsonStructuredData.createJsonStructure() + val config = StructuredOutputConfig( + default = StructuredOutput.Manual(structure) + ) + + val strategy = structuredOutputWithToolsStrategy(config) { input -> + "Processed: $input" + } + + assertEquals("structured_output_with_tools_strategy", strategy.name) + } + + @Test + fun testStructuredOutputWithToolsStrategyWithParallelTools() = runTest { + @Serializable + data class TestResult( + @property:LLMDescription("Result message") + val message: String, + @property:LLMDescription("Success status") + val success: Boolean + ) + + val structure = JsonStructuredData.createJsonStructure() + val config = StructuredOutputConfig( + default = StructuredOutput.Manual(structure) + ) + + val strategyWithParallel = structuredOutputWithToolsStrategy( + config = config, + parallelTools = true + ) { input -> + "Processing with parallel tools: $input" + } + + assertNotNull(strategyWithParallel) + assertEquals("structured_output_with_tools_strategy", strategyWithParallel.name) + + val strategyWithoutParallel = structuredOutputWithToolsStrategy( + config = config, + parallelTools = false + ) { input -> + "Processing without parallel tools: $input" + } + + assertNotNull(strategyWithoutParallel) + assertEquals("structured_output_with_tools_strategy", strategyWithoutParallel.name) + } + + @Test + fun testStructuredOutputWithToolsStrategyComplexTypes() = runTest { + @Serializable + data class Address( + @property:LLMDescription("Street address") + val street: String, + @property:LLMDescription("City") + val city: String, + @property:LLMDescription("ZIP code") + val zipCode: String + ) + + @Serializable + data class ComplexOutput( + @property:LLMDescription("User ID") + val id: Int, + @property:LLMDescription("User name") + val name: String, + @property:LLMDescription("User addresses") + val addresses: List
, + @property:LLMDescription("User preferences") + val preferences: Map? = null + ) + + @Serializable + data class ComplexInput( + val userId: String, + val requestType: String + ) + + val structure = JsonStructuredData.createJsonStructure() + val config = StructuredOutputConfig( + default = StructuredOutput.Manual(structure) + ) + + val strategy = structuredOutputWithToolsStrategy(config) { input -> + "Fetch user data for ID: ${input.userId}, type: ${input.requestType}" + } + + assertNotNull(strategy) + assertEquals("structured_output_with_tools_strategy", strategy.name) + } + + @Test + fun testStructuredOutputWithToolsStrategyDifferentConfigs() = runTest { + @Serializable + data class SimpleOutput( + @property:LLMDescription("Value") + val value: String + ) + + val manualStructure = JsonStructuredData.createJsonStructure() + val nativeStructure = JsonStructuredData.createJsonStructure() + + // Test with manual mode + val manualConfig = StructuredOutputConfig( + default = StructuredOutput.Manual(manualStructure) + ) + + val manualStrategy = structuredOutputWithToolsStrategy(manualConfig) { input -> + "Manual mode: $input" + } + + assertNotNull(manualStrategy) + + // Test with native mode + val nativeConfig = StructuredOutputConfig( + default = StructuredOutput.Native(nativeStructure) + ) + + val nativeStrategy = structuredOutputWithToolsStrategy(nativeConfig) { input -> + "Native mode: $input" + } + + assertNotNull(nativeStrategy) + + // Test with both modes in config + val mixedConfig = StructuredOutputConfig( + default = StructuredOutput.Manual(manualStructure), + byProvider = mapOf( + ai.koog.prompt.llm.LLMProvider.OpenAI to StructuredOutput.Native(nativeStructure) + ) + ) + + val mixedStrategy = structuredOutputWithToolsStrategy(mixedConfig) { input -> + "Mixed mode: $input" + } + + assertNotNull(mixedStrategy) + } } diff --git a/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/StructuredOutputWithToolsIntegrationTest.kt b/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/StructuredOutputWithToolsIntegrationTest.kt new file mode 100644 index 0000000000..c007ece354 --- /dev/null +++ b/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/StructuredOutputWithToolsIntegrationTest.kt @@ -0,0 +1,323 @@ +package ai.koog.agents.ext.agent + +import ai.koog.agents.core.agent.AIAgent +import ai.koog.agents.core.agent.config.AIAgentConfig +import ai.koog.agents.core.tools.SimpleTool +import ai.koog.agents.core.tools.ToolArgs +import ai.koog.agents.core.tools.ToolDescriptor +import ai.koog.agents.core.tools.ToolParameterDescriptor +import ai.koog.agents.core.tools.ToolParameterType +import ai.koog.agents.core.tools.ToolRegistry +import ai.koog.agents.core.tools.annotations.LLMDescription +import ai.koog.agents.features.eventHandler.feature.EventHandler +import ai.koog.agents.testing.tools.getMockExecutor +import ai.koog.agents.testing.tools.mockLLMAnswer +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.executor.clients.openai.OpenAIModels +import ai.koog.prompt.structure.StructuredOutput +import ai.koog.prompt.structure.StructuredOutputConfig +import ai.koog.prompt.structure.json.JsonStructuredData +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.KSerializer +import kotlinx.serialization.Serializable +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotNull + +class StructuredOutputWithToolsIntegrationTest { + + @Serializable + data class WeatherRequest( + val city: String, + val country: String + ) + + @Serializable + data class WeatherResponse( + @property:LLMDescription("Temperature in Celsius") + val temperature: Int, + @property:LLMDescription("Weather conditions") + val conditions: String, + @property:LLMDescription("Wind speed in km/h") + val windSpeed: Double, + @property:LLMDescription("Humidity percentage") + val humidity: Int + ) + + object GetTemperatureTool : SimpleTool() { + @Serializable + data class Args(val city: String, val country: String) : ToolArgs + + override val argsSerializer: KSerializer = Args.serializer() + + override val descriptor: ToolDescriptor = ToolDescriptor( + name = "get_temperature", + description = "Get current temperature for a city", + requiredParameters = listOf( + ToolParameterDescriptor("city", "City name", ToolParameterType.String), + ToolParameterDescriptor("country", "Country name", ToolParameterType.String) + ) + ) + + override suspend fun doExecute(args: Args): String = + "Temperature in ${args.city}, ${args.country}: 22°C" + } + + object GetWeatherConditionsTool : SimpleTool() { + @Serializable + data class Args(val city: String, val country: String) : ToolArgs + + override val argsSerializer: KSerializer = Args.serializer() + + override val descriptor: ToolDescriptor = ToolDescriptor( + name = "get_weather_conditions", + description = "Get current weather conditions for a city", + requiredParameters = listOf( + ToolParameterDescriptor("city", "City name", ToolParameterType.String), + ToolParameterDescriptor("country", "Country name", ToolParameterType.String) + ) + ) + + override suspend fun doExecute(args: Args): String = + "Weather conditions in ${args.city}, ${args.country}: Partly Cloudy" + } + + object GetWindSpeedTool : SimpleTool() { + @Serializable + data class Args(val city: String, val country: String) : ToolArgs + + override val argsSerializer: KSerializer = Args.serializer() + + override val descriptor: ToolDescriptor = ToolDescriptor( + name = "get_wind_speed", + description = "Get current wind speed for a city", + requiredParameters = listOf( + ToolParameterDescriptor("city", "City name", ToolParameterType.String), + ToolParameterDescriptor("country", "Country name", ToolParameterType.String) + ) + ) + + override suspend fun doExecute(args: Args): String = + "Wind speed in ${args.city}, ${args.country}: 15.5 km/h" + } + + object GetHumidityTool : SimpleTool() { + @Serializable + data class Args(val city: String, val country: String) : ToolArgs + + override val argsSerializer: KSerializer = Args.serializer() + + override val descriptor: ToolDescriptor = ToolDescriptor( + name = "get_humidity", + description = "Get current humidity for a city", + requiredParameters = listOf( + ToolParameterDescriptor("city", "City name", ToolParameterType.String), + ToolParameterDescriptor("country", "Country name", ToolParameterType.String) + ) + ) + + override suspend fun doExecute(args: Args): String = + "Humidity in ${args.city}, ${args.country}: 65%" + } + + @Test + fun testStructuredOutputWithToolsIntegration() = runTest { + val structure = JsonStructuredData.createJsonStructure() + val config = StructuredOutputConfig( + default = StructuredOutput.Manual(structure) + ) + + val strategy = structuredOutputWithToolsStrategy( + config = config, + parallelTools = false + ) { request -> + "Get complete weather data for ${request.city}, ${request.country}" + } + + val toolCallEvents = mutableListOf() + val results = mutableListOf() + + // For common tests, we need to use a simpler mock setup + val mockExecutor = getMockExecutor { + // Simply return the structured output directly + mockLLMAnswer( + """ + { + "temperature": 22, + "conditions": "Partly Cloudy", + "windSpeed": 15.5, + "humidity": 65 + } + """.trimIndent() + ).asDefaultResponse + } + + val agentConfig = AIAgentConfig( + prompt = prompt("weather-agent") { + system( + """ + You are a weather assistant. Use the available tools to gather weather data + and return a complete weather report in the specified JSON format. + """.trimIndent() + ) + }, + model = OpenAIModels.CostOptimized.GPT4oMini, + maxAgentIterations = 10 + ) + + val agent = AIAgent( + promptExecutor = mockExecutor, + strategy = strategy, + agentConfig = agentConfig, + toolRegistry = ToolRegistry { + tool(GetTemperatureTool) + tool(GetWeatherConditionsTool) + tool(GetWindSpeedTool) + tool(GetHumidityTool) + } + ) { + install(EventHandler) { + onToolCall { eventContext -> + toolCallEvents.add(eventContext.tool.name) + } + onAgentFinished { eventContext -> + eventContext.result?.let { results.add(it as WeatherResponse) } + } + } + } + + val request = WeatherRequest(city = "New York", country = "USA") + val result = agent.run(request) + + assertNotNull(result) + assertEquals(22, result.temperature) + assertEquals("Partly Cloudy", result.conditions) + assertEquals(15.5, result.windSpeed) + assertEquals(65, result.humidity) + } + + @Test + fun testStructuredOutputWithToolsParallelExecution() = runTest { + val structure = JsonStructuredData.createJsonStructure() + val config = StructuredOutputConfig( + default = StructuredOutput.Manual(structure) + ) + + val strategy = structuredOutputWithToolsStrategy( + config = config, + parallelTools = true // Enable parallel tool execution + ) { request -> + "Get all weather metrics simultaneously for ${request.city}, ${request.country}" + } + + val toolCallTimestamps = mutableMapOf() + val currentTime = kotlinx.datetime.Clock.System.now().toEpochMilliseconds() + + val mockExecutor = getMockExecutor { + // Return structured output + mockLLMAnswer( + """ + { + "temperature": 18, + "conditions": "Rainy", + "windSpeed": 20.0, + "humidity": 80 + } + """.trimIndent() + ).asDefaultResponse + } + + val agentConfig = AIAgentConfig( + prompt = prompt("weather-agent-parallel") { + system( + """ + You are a weather assistant. Gather all weather metrics in parallel + and return a complete weather report. + """.trimIndent() + ) + }, + model = OpenAIModels.CostOptimized.GPT4oMini, + maxAgentIterations = 10 + ) + + val agent = AIAgent( + promptExecutor = mockExecutor, + strategy = strategy, + agentConfig = agentConfig, + toolRegistry = ToolRegistry { + tool(GetTemperatureTool) + tool(GetWeatherConditionsTool) + tool(GetWindSpeedTool) + tool(GetHumidityTool) + } + ) { + install(EventHandler) { + onToolCall { eventContext -> + toolCallTimestamps[eventContext.tool.name] = currentTime + } + } + } + + val request = WeatherRequest(city = "London", country = "UK") + val result = agent.run(request) + + assertNotNull(result) + assertEquals(18, result.temperature) + assertEquals("Rainy", result.conditions) + assertEquals(20.0, result.windSpeed) + assertEquals(80, result.humidity) + } + + @Test + fun testStructuredOutputWithNoTools() = runTest { + val structure = JsonStructuredData.createJsonStructure() + val config = StructuredOutputConfig( + default = StructuredOutput.Manual(structure) + ) + + val strategy = structuredOutputWithToolsStrategy( + config = config + ) { request -> + "Generate mock weather data for ${request.city}, ${request.country}" + } + + val mockExecutor = getMockExecutor { + // LLM directly returns structured output without calling tools + // Set as default response to match any request + mockLLMAnswer( + """ + { + "temperature": 25, + "conditions": "Sunny", + "windSpeed": 10.0, + "humidity": 50 + } + """.trimIndent() + ).asDefaultResponse + } + + val agentConfig = AIAgentConfig( + prompt = prompt("weather-agent-no-tools") { + system("Generate weather data without using tools.") + }, + model = OpenAIModels.CostOptimized.GPT4oMini, + maxAgentIterations = 10 + ) + + val agent = AIAgent( + promptExecutor = mockExecutor, + strategy = strategy, + agentConfig = agentConfig, + toolRegistry = ToolRegistry { } // No tools registered + ) + + val request = WeatherRequest(city = "Paris", country = "France") + val result = agent.run(request) + + assertNotNull(result) + assertEquals(25, result.temperature) + assertEquals("Sunny", result.conditions) + assertEquals(10.0, result.windSpeed) + assertEquals(50, result.humidity) + } +} diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/AdvancedWithStandardSchema.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/AdvancedWithStandardSchema.kt index e3b47998fb..b56be37acc 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/AdvancedWithStandardSchema.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/AdvancedWithStandardSchema.kt @@ -5,8 +5,9 @@ import ai.koog.agents.core.agent.config.AIAgentConfig import ai.koog.agents.core.dsl.builder.forwardTo import ai.koog.agents.core.dsl.builder.strategy import ai.koog.agents.core.dsl.extension.nodeLLMRequestStructured -import ai.koog.agents.core.tools.annotations.LLMDescription import ai.koog.agents.example.ApiKeyService +import ai.koog.agents.example.structuredoutput.models.FullWeatherForecast +import ai.koog.agents.example.structuredoutput.models.FullWeatherForecastRequest import ai.koog.agents.features.eventHandler.feature.handleEvents import ai.koog.prompt.dsl.prompt import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient @@ -25,203 +26,13 @@ import ai.koog.prompt.structure.json.JsonStructuredData import ai.koog.prompt.structure.json.generator.StandardJsonSchemaGenerator import ai.koog.prompt.text.text import kotlinx.coroutines.runBlocking -import kotlinx.serialization.SerialName -import kotlinx.serialization.Serializable import kotlinx.serialization.json.Json -/** - * This is a more advanced example showing how to configure various parameters of structured output manually, to fine-tune - * it for your needs when necessary. - * - * Structured output that uses "full" JSON schema. - * More advanced features are supported, e.g. polymorphism and recursive type references, and schemas can be more complex. - */ - -@Serializable -@SerialName("FullWeatherForecast") -@LLMDescription("Weather forecast for a given location") -data class FullWeatherForecast( - @property:LLMDescription("Temperature in Celsius") - val temperature: Int, - // properties with default values - @property:LLMDescription("Weather conditions (e.g., sunny, cloudy, rainy)") - val conditions: String = "sunny", - // nullable properties - @property:LLMDescription("Chance of precipitation in percentage") - val precipitation: Int?, - // nested classes - @property:LLMDescription("Coordinates of the location") - val latLon: LatLon, - // enums - val pollution: Pollution, - // polymorphism - val alert: WeatherAlert, - // lists - @property:LLMDescription("List of news articles") - val news: List, -// // maps (string keys only, some providers don't support maps at all) -// @property:LLMDescription("Map of weather sources") -// val sources: Map -) { - // Nested classes - @Serializable - @SerialName("LatLon") - data class LatLon( - @property:LLMDescription("Latitude of the location") - val lat: Double, - @property:LLMDescription("Longitude of the location") - val lon: Double - ) - - // Nested classes in lists... - @Serializable - @SerialName("WeatherNews") - data class WeatherNews( - @property:LLMDescription("Title of the news article") - val title: String, - @property:LLMDescription("Link to the news article") - val link: String - ) - - // ... and maps (but only with string keys!) - @Suppress("unused") - @Serializable - @SerialName("WeatherSource") - data class WeatherSource( - @property:LLMDescription("Name of the weather station") - val stationName: String, - @property:LLMDescription("Authority of the weather station") - val stationAuthority: String - ) - - // Enums - @Suppress("unused") - @SerialName("Pollution") - @Serializable - enum class Pollution { - @SerialName("None") - None, - - @SerialName("LOW") - Low, - - @SerialName("MEDIUM") - Medium, - - @SerialName("HIGH") - High - } - - /* - Polymorphism: - 1. Closed with sealed classes, - 2. Open: non-sealed classes with subclasses registered in json config - https://github.com/Kotlin/kotlinx.serialization/blob/master/docs/polymorphism.md#registered-subclasses - */ - @Suppress("unused") - @Serializable - @SerialName("WeatherAlert") - sealed class WeatherAlert { - abstract val severity: Severity - abstract val message: String - - @Serializable - @SerialName("Severity") - enum class Severity { Low, Moderate, Severe, Extreme } - - @Serializable - @SerialName("StormAlert") - data class StormAlert( - override val severity: Severity, - override val message: String, - @property:LLMDescription("Wind speed in km/h") - val windSpeed: Double - ) : WeatherAlert() - - @Serializable - @SerialName("FloodAlert") - data class FloodAlert( - override val severity: Severity, - override val message: String, - @property:LLMDescription("Expected rainfall in mm") - val expectedRainfall: Double - ) : WeatherAlert() - - @Serializable - @SerialName("TemperatureAlert") - data class TemperatureAlert( - override val severity: Severity, - override val message: String, - @property:LLMDescription("Temperature threshold in Celsius") - val threshold: Int, // in Celsius - @property:LLMDescription("Whether the alert is a heat warning") - val isHeatWarning: Boolean - ) : WeatherAlert() - } -} - -data class FullWeatherForecastRequest( - val city: String, - val country: String -) - private val json = Json { prettyPrint = true } fun main(): Unit = runBlocking { - // Optional examples, to help LLM understand the format better in manual mode - val exampleForecasts = listOf( - FullWeatherForecast( - temperature = 18, - conditions = "Cloudy", - precipitation = 30, - latLon = FullWeatherForecast.LatLon(lat = 34.0522, lon = -118.2437), - pollution = FullWeatherForecast.Pollution.Medium, - alert = FullWeatherForecast.WeatherAlert.StormAlert( - severity = FullWeatherForecast.WeatherAlert.Severity.Moderate, - message = "Possible thunderstorms in the evening", - windSpeed = 45.5 - ), - news = listOf( - FullWeatherForecast.WeatherNews(title = "Local news", link = "https://example.com/news"), - FullWeatherForecast.WeatherNews(title = "Global news", link = "https://example.com/global-news") - ), -// sources = mapOf( -// "MeteorologicalWatch" to FullWeatherForecast.WeatherSource( -// stationName = "MeteorologicalWatch", -// stationAuthority = "US Department of Agriculture" -// ), -// "MeteorologicalWatch2" to FullWeatherForecast.WeatherSource( -// stationName = "MeteorologicalWatch2", -// stationAuthority = "US Department of Agriculture" -// ) -// ) - ), - FullWeatherForecast( - temperature = 10, - conditions = "Rainy", - precipitation = null, - latLon = FullWeatherForecast.LatLon(lat = 37.7739, lon = -122.4194), - pollution = FullWeatherForecast.Pollution.Low, - alert = FullWeatherForecast.WeatherAlert.FloodAlert( - severity = FullWeatherForecast.WeatherAlert.Severity.Severe, - message = "Heavy rainfall may cause local flooding", - expectedRainfall = 75.2 - ), - news = listOf( - FullWeatherForecast.WeatherNews(title = "Local news", link = "https://example.com/news"), - FullWeatherForecast.WeatherNews(title = "Global news", link = "https://example.com/global-news") - ), -// sources = mapOf( -// "MeteorologicalWatch" to WeatherForecast.WeatherSource( -// stationName = "MeteorologicalWatch", -// stationAuthority = "US Department of Agriculture" -// ), -// ) - ) - ) - /* This structure has a generic schema that is suitable for manual structured output mode. But to use native structured output support in different LLM providers you might need to use custom JSON schema generators @@ -230,7 +41,7 @@ fun main(): Unit = runBlocking { val genericWeatherStructure = JsonStructuredData.createJsonStructure( // Some models might not work well with json schema, so you may try simple, but it has more limitations (no polymorphism!) schemaGenerator = StandardJsonSchemaGenerator, - examples = exampleForecasts, + examples = FullWeatherForecast.exampleForecasts, ) println("Generated generic JSON schema:\n${json.encodeToString(genericWeatherStructure.schema.schema)}") @@ -241,12 +52,12 @@ fun main(): Unit = runBlocking { val openAiWeatherStructure = JsonStructuredData.createJsonStructure( schemaGenerator = OpenAIStandardJsonSchemaGenerator, - examples = exampleForecasts, + examples = FullWeatherForecast.exampleForecasts, ) val googleWeatherStructure = JsonStructuredData.createJsonStructure( schemaGenerator = GoogleStandardJsonSchemaGenerator, - examples = exampleForecasts, + examples = FullWeatherForecast.exampleForecasts, ) val agentStrategy = strategy("advanced-full-weather-forecast") { diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/AdvancedWithStandardSchemaAndTools.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/AdvancedWithStandardSchemaAndTools.kt new file mode 100644 index 0000000000..df9d30c5c9 --- /dev/null +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/AdvancedWithStandardSchemaAndTools.kt @@ -0,0 +1,135 @@ +package ai.koog.agents.example.structuredoutput + +import ai.koog.agents.core.agent.AIAgent +import ai.koog.agents.core.agent.config.AIAgentConfig +import ai.koog.agents.core.tools.ToolRegistry +import ai.koog.agents.core.tools.reflect.asTools +import ai.koog.agents.example.ApiKeyService +import ai.koog.agents.example.structuredoutput.models.FullWeatherForecast +import ai.koog.agents.example.structuredoutput.models.FullWeatherForecastRequest +import ai.koog.agents.example.structuredoutput.tools.WeatherTools +import ai.koog.agents.ext.agent.structuredOutputWithToolsStrategy +import ai.koog.agents.features.eventHandler.feature.handleEvents +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient +import ai.koog.prompt.executor.clients.anthropic.AnthropicModels +import ai.koog.prompt.executor.clients.google.GoogleLLMClient +import ai.koog.prompt.executor.clients.google.structure.GoogleStandardJsonSchemaGenerator +import ai.koog.prompt.executor.clients.openai.OpenAILLMClient +import ai.koog.prompt.executor.clients.openai.OpenAIModels +import ai.koog.prompt.executor.clients.openai.structure.OpenAIStandardJsonSchemaGenerator +import ai.koog.prompt.executor.llms.MultiLLMPromptExecutor +import ai.koog.prompt.llm.LLMProvider +import ai.koog.prompt.structure.StructureFixingParser +import ai.koog.prompt.structure.StructuredOutput +import ai.koog.prompt.structure.StructuredOutputConfig +import ai.koog.prompt.structure.json.JsonStructuredData +import ai.koog.prompt.structure.json.generator.StandardJsonSchemaGenerator +import ai.koog.prompt.text.text +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.Json + +private val json = Json { + prettyPrint = true +} + +fun main(): Unit = runBlocking { + /* + This structure has a generic schema that is suitable for manual structured output mode. + But to use native structured output support in different LLM providers you might need to use custom JSON schema generators + that would produce the schema these providers expect. + */ + val genericWeatherStructure = JsonStructuredData.createJsonStructure( + // Some models might not work well with json schema, so you may try simple, but it has more limitations (no polymorphism!) + schemaGenerator = StandardJsonSchemaGenerator, + examples = FullWeatherForecast.exampleForecasts, + ) + + println("Generated generic JSON schema:\n${json.encodeToString(genericWeatherStructure.schema.schema)}") + /* + These are specific structure definitions with schemas in format that particular LLM providers understand in their native + structured output. + */ + + val openAiWeatherStructure = JsonStructuredData.createJsonStructure( + schemaGenerator = OpenAIStandardJsonSchemaGenerator, + examples = FullWeatherForecast.exampleForecasts, + ) + + val googleWeatherStructure = JsonStructuredData.createJsonStructure( + schemaGenerator = GoogleStandardJsonSchemaGenerator, + examples = FullWeatherForecast.exampleForecasts, + ) + + val config = StructuredOutputConfig( + byProvider = mapOf( + // Native modes leveraging native structured output support in models, with custom definitions for LLM providers that might have different format. + LLMProvider.OpenAI to StructuredOutput.Native(openAiWeatherStructure), + LLMProvider.Google to StructuredOutput.Native(googleWeatherStructure), + // Anthropic does not support native structured output yet. + LLMProvider.Anthropic to StructuredOutput.Manual(genericWeatherStructure), + ), + + // Fallback manual structured output mode, via explicit prompting with additional message, not native model support + default = StructuredOutput.Manual(genericWeatherStructure), + + // Helper parser to attempt a fix if a malformed output is produced. + fixingParser = StructureFixingParser( + fixingModel = AnthropicModels.Haiku_3_5, + retries = 2, + ), + ) + + val agentStrategy = structuredOutputWithToolsStrategy( + config + ) { request -> + text { + +"Requesting forecast for" + +"City: ${request.city}" + +"Country: ${request.country}" + } + } + + val agentConfig = AIAgentConfig( + prompt = prompt("weather-forecast-with-tools") { + system( + """ + You are a weather forecasting assistant. + When asked for a weather forecast, use the weather tools to get the weather forecast for the specified city and country. + """.trimIndent() + ) + }, + model = OpenAIModels.Chat.GPT4_1, + maxAgentIterations = 10 + ) + + val agent = AIAgent( + promptExecutor = MultiLLMPromptExecutor( + LLMProvider.OpenAI to OpenAILLMClient(ApiKeyService.openAIApiKey), + LLMProvider.Anthropic to AnthropicLLMClient(ApiKeyService.anthropicApiKey), + LLMProvider.Google to GoogleLLMClient(ApiKeyService.googleApiKey), + ), + strategy = agentStrategy, + agentConfig = agentConfig, + toolRegistry = ToolRegistry { + tools(WeatherTools().asTools()) + } + ) { + handleEvents { + onAgentRunError { eventContext -> + println("An error occurred: ${eventContext.throwable.message}\n${eventContext.throwable.stackTraceToString()}") + } + } + } + + println( + """ + === Full Weather Forecast Example === + This example demonstrates how to use structured output with full schema support + to get properly structured output from the LLM. + """.trimIndent() + ) + + val result: FullWeatherForecast = agent.run(FullWeatherForecastRequest(city = "New York", country = "USA")) + println("Agent run result: $result") +} diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/models/FullWeatherForecast.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/models/FullWeatherForecast.kt new file mode 100644 index 0000000000..92f72761b8 --- /dev/null +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/models/FullWeatherForecast.kt @@ -0,0 +1,195 @@ +package ai.koog.agents.example.structuredoutput.models + +import ai.koog.agents.core.tools.annotations.LLMDescription +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable + +/** + * This is a more advanced example showing how to configure various parameters of structured output manually, to fine-tune + * it for your needs when necessary. + * + * Structured output that uses "full" JSON schema. + * More advanced features are supported, e.g. polymorphism and recursive type references, and schemas can be more complex. + */ + +@Serializable +@SerialName("FullWeatherForecast") +@LLMDescription("Weather forecast for a given location") +data class FullWeatherForecast( + @property:LLMDescription("Temperature in Celsius") + val temperature: Int, + // properties with default values + @property:LLMDescription("Weather conditions (e.g., sunny, cloudy, rainy)") + val conditions: String = "sunny", + // nullable properties + @property:LLMDescription("Chance of precipitation in percentage") + val precipitation: Int?, + // nested classes + @property:LLMDescription("Coordinates of the location") + val latLon: LatLon, + // enums + val pollution: Pollution, + // polymorphism + val alert: WeatherAlert, + // lists + @property:LLMDescription("List of news articles") + val news: List, +// // maps (string keys only, some providers don't support maps at all) +// @property:LLMDescription("Map of weather sources") +// val sources: Map +) { + companion object { + // Optional examples, to help LLM understand the format better in manual mode + val exampleForecasts = listOf( + FullWeatherForecast( + temperature = 18, + conditions = "Cloudy", + precipitation = 30, + latLon = FullWeatherForecast.LatLon(lat = 34.0522, lon = -118.2437), + pollution = FullWeatherForecast.Pollution.Medium, + alert = FullWeatherForecast.WeatherAlert.StormAlert( + severity = FullWeatherForecast.WeatherAlert.Severity.Moderate, + message = "Possible thunderstorms in the evening", + windSpeed = 45.5 + ), + news = listOf( + FullWeatherForecast.WeatherNews(title = "Local news", link = "https://example.com/news"), + FullWeatherForecast.WeatherNews(title = "Global news", link = "https://example.com/global-news") + ), +// sources = mapOf( +// "MeteorologicalWatch" to FullWeatherForecast.WeatherSource( +// stationName = "MeteorologicalWatch", +// stationAuthority = "US Department of Agriculture" +// ), +// "MeteorologicalWatch2" to FullWeatherForecast.WeatherSource( +// stationName = "MeteorologicalWatch2", +// stationAuthority = "US Department of Agriculture" +// ) +// ) + ), + FullWeatherForecast( + temperature = 10, + conditions = "Rainy", + precipitation = null, + latLon = FullWeatherForecast.LatLon(lat = 37.7739, lon = -122.4194), + pollution = FullWeatherForecast.Pollution.Low, + alert = FullWeatherForecast.WeatherAlert.FloodAlert( + severity = FullWeatherForecast.WeatherAlert.Severity.Severe, + message = "Heavy rainfall may cause local flooding", + expectedRainfall = 75.2 + ), + news = listOf( + FullWeatherForecast.WeatherNews(title = "Local news", link = "https://example.com/news"), + FullWeatherForecast.WeatherNews(title = "Global news", link = "https://example.com/global-news") + ), +// sources = mapOf( +// "MeteorologicalWatch" to WeatherForecast.WeatherSource( +// stationName = "MeteorologicalWatch", +// stationAuthority = "US Department of Agriculture" +// ), +// ) + ) + ) + } + + // Nested classes + @Serializable + @SerialName("LatLon") + data class LatLon( + @property:LLMDescription("Latitude of the location") + val lat: Double, + @property:LLMDescription("Longitude of the location") + val lon: Double + ) + + // Nested classes in lists... + @Serializable + @SerialName("WeatherNews") + data class WeatherNews( + @property:LLMDescription("Title of the news article") + val title: String, + @property:LLMDescription("Link to the news article") + val link: String + ) + + // ... and maps (but only with string keys!) + @Suppress("unused") + @Serializable + @SerialName("WeatherSource") + data class WeatherSource( + @property:LLMDescription("Name of the weather station") + val stationName: String, + @property:LLMDescription("Authority of the weather station") + val stationAuthority: String + ) + + // Enums + @Suppress("unused") + @SerialName("Pollution") + @Serializable + enum class Pollution { + @SerialName("None") + None, + + @SerialName("LOW") + Low, + + @SerialName("MEDIUM") + Medium, + + @SerialName("HIGH") + High + } + + /* + Polymorphism: + 1. Closed with sealed classes, + 2. Open: non-sealed classes with subclasses registered in json config + https://github.com/Kotlin/kotlinx.serialization/blob/master/docs/polymorphism.md#registered-subclasses + */ + @Suppress("unused") + @Serializable + @SerialName("WeatherAlert") + sealed class WeatherAlert { + abstract val severity: Severity + abstract val message: String + + @Serializable + @SerialName("Severity") + enum class Severity { Low, Moderate, Severe, Extreme } + + @Serializable + @SerialName("StormAlert") + data class StormAlert( + override val severity: Severity, + override val message: String, + @property:LLMDescription("Wind speed in km/h") + val windSpeed: Double + ) : WeatherAlert() + + @Serializable + @SerialName("FloodAlert") + data class FloodAlert( + override val severity: Severity, + override val message: String, + @property:LLMDescription("Expected rainfall in mm") + val expectedRainfall: Double + ) : WeatherAlert() + + @Serializable + @SerialName("TemperatureAlert") + data class TemperatureAlert( + override val severity: Severity, + override val message: String, + @property:LLMDescription("Temperature threshold in Celsius") + val threshold: Int, // in Celsius + @property:LLMDescription("Whether the alert is a heat warning") + val isHeatWarning: Boolean + ) : WeatherAlert() + } +} + +data class FullWeatherForecastRequest( + val city: String, + val country: String +) diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/tools/WeatherTools.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/tools/WeatherTools.kt new file mode 100644 index 0000000000..24603e83c6 --- /dev/null +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/structuredoutput/tools/WeatherTools.kt @@ -0,0 +1,26 @@ +package ai.koog.agents.example.structuredoutput.tools + +import ai.koog.agents.core.tools.annotations.LLMDescription +import ai.koog.agents.core.tools.annotations.Tool +import ai.koog.agents.core.tools.reflect.ToolSet +import ai.koog.prompt.text.text + +class WeatherTools : ToolSet { + @Tool + @LLMDescription("Get the weather forecast for the specified city and country") + fun getWeatherForecast( + @LLMDescription("The city for which to get the weather forecast") + city: String, + ) = text { + +"The weather forecast for " + +city + +" is " + +"Cloudy" + +"temperature = 18" + +"precipitation = 30" + +"lat = 34.0522, lon = -118.2437" + +"pollution = medium" + +"alert = Moderate. Possible thunderstorms in the evening, windSpeed = 45.5" + +"news = Local news: https://example.com/news, Global news: https://example.com/global-news" + } +} diff --git a/prompt/prompt-structure/src/commonMain/kotlin/ai/koog/prompt/structure/PromptExecutorExtensions.kt b/prompt/prompt-structure/src/commonMain/kotlin/ai/koog/prompt/structure/PromptExecutorExtensions.kt index 55a4112e7a..b8296e34f7 100644 --- a/prompt/prompt-structure/src/commonMain/kotlin/ai/koog/prompt/structure/PromptExecutorExtensions.kt +++ b/prompt/prompt-structure/src/commonMain/kotlin/ai/koog/prompt/structure/PromptExecutorExtensions.kt @@ -48,7 +48,65 @@ public data class StructuredOutputConfig( public val default: StructuredOutput? = null, public val byProvider: Map> = emptyMap(), public val fixingParser: StructureFixingParser? = null -) +) { + /** + * Updates a given prompt to configure structured output using the specified large language model (LLM). + * Depending on the model's support for structured outputs, the prompt is updated either manually or natively. + * + * @param model The large language model (LLModel) used to determine the structured output configuration. + * @param prompt The original prompt to be updated with the structured output configuration. + * @return A new prompt reflecting the updated structured output configuration. + */ + public fun updatePrompt(model: LLModel, prompt: Prompt): Prompt { + return when (val mode = structuredOutput(model)) { + // Don't set schema parameter in prompt and coerce the model manually with user message to provide a structured response. + is StructuredOutput.Manual -> { + prompt(prompt) { + user { + markdown { + StructuredOutputPrompts.outputInstructionPrompt(this, mode.structure) + } + } + } + } + + // Rely on built-in model capabilities to provide structured response. + is StructuredOutput.Native -> { + prompt.withUpdatedParams { schema = mode.structure.schema } + } + } + } + + /** + * Retrieves the structured data configuration for a specific large language model (LLM). + * + * The method determines the appropriate structured data setup based on the given LLM + * instance, ensuring compatibility with the model's provider and capabilities. + * + * @param model The large language model (LLM) instance used to identify the structured data configuration. + * @return The structured data configuration represented as a `StructuredData` instance. + */ + public fun structure(model: LLModel): StructuredData { + return structuredOutput(model).structure + } + + /** + * Retrieves the structured output configuration for a specific large language model (LLM). + * + * The method determines the appropriate structured output instance based on the model's provider. + * If no specific configuration is found for the provider, it falls back to the default configuration. + * Throws an exception if no default configuration is available. + * + * @param model The large language model (LLM) used to identify the structured output configuration. + * @return An instance of `StructuredOutput` that represents the structured output configuration for the model. + * @throws IllegalArgumentException if no configuration is found for the provider and no default configuration is set. + */ + private fun structuredOutput(model: LLModel): StructuredOutput { + return byProvider[model.provider] + ?: default + ?: throw IllegalArgumentException("No structure found for provider ${model.provider}") + } +} /** * Defines how structured outputs should be generated. @@ -106,42 +164,13 @@ public suspend fun PromptExecutor.executeStructured( model: LLModel, config: StructuredOutputConfig, ): Result> { - val mode = config.byProvider[model.provider] - ?: config.default - ?: throw IllegalArgumentException("No structure found for provider ${model.provider}") - - val (structure: StructuredData, updatedPrompt: Prompt) = when (mode) { - // Don't set schema parameter in prompt and coerce the model manually with user message to provide a structured response. - is StructuredOutput.Manual -> { - mode.structure to prompt(prompt) { - user { - markdown { - StructuredOutputPrompts.outputInstructionPrompt(this, mode.structure) - } - } - } - } - - // Rely on built-in model capabilities to provide structured response. - is StructuredOutput.Native -> { - mode.structure to prompt.withUpdatedParams { schema = mode.structure.schema } - } - } - + val updatedPrompt = config.updatePrompt(model, prompt) val response = this.execute(prompt = updatedPrompt, model = model).single() return runCatching { require(response is Message.Assistant) { "Response for structured output must be an assistant message, got ${response::class.simpleName} instead" } - // Use fixingParser if provided, otherwise parse directly - val structureResponse = config.fixingParser - ?.parse(this, structure, response.content) - ?: structure.parse(response.content) - - StructuredResponse( - structure = structureResponse, - message = response - ) + parseResponseToStructuredResponse(response, config, model) } } @@ -269,3 +298,31 @@ public suspend inline fun PromptExecutor.executeStructured( fixingParser = fixingParser, ) } + +/** + * Parses a structured response from the assistant message using the provided structured output configuration + * and language model. If a fixing parser is specified in the configuration, it will be used; otherwise, + * the structure will be parsed directly. + * + * @param T The type of the structured output. + * @param response The assistant's response message to be parsed. + * @param config The structured output configuration defining how the response should be parsed. + * @param model The language model to be used for parsing the structured output. + * @return A `StructuredResponse` containing the parsed structure and the original assistant message. + */ +public suspend fun PromptExecutor.parseResponseToStructuredResponse( + response: Message.Assistant, + config: StructuredOutputConfig, + model: LLModel +): StructuredResponse { + // Use fixingParser if provided, otherwise parse directly + val structure = config.structure(model) + val structureResponse = config.fixingParser + ?.parse(this, structure, response.content) + ?: structure.parse(response.content) + + return StructuredResponse( + structure = structureResponse, + message = response + ) +} From ce9d463dca5411be7c38e46f0d8d49a666836b0f Mon Sep 17 00:00:00 2001 From: Maria Tigina <31625351+tiginamaria@users.noreply.github.com> Date: Wed, 1 Oct 2025 13:35:34 +0200 Subject: [PATCH 07/52] Fix finishReason nullability (#771) Fixes https://github.com/JetBrains/koog/issues/758 --- gradle.properties | 1 - .../SingleLLMPromptExecutorIntegrationTest.kt | 32 +- .../ai/koog/integration/tests/utils/Models.kt | 3 +- .../ai/koog/ktor/utils/LLMModelParser.kt | 6 +- .../build.gradle.kts | 2 + .../clients/deepseek/DeepSeekLLMClient.kt | 20 +- .../openai/base/AbstractOpenAILLMClient.kt | 17 +- .../openai/base/models/OpenAIDataModels.kt | 34 ++ .../clients/openai/OpenAILLMClient.kt | 7 +- .../openai/models/OpenAIChatCompletion.kt | 28 +- .../build.gradle.kts | 2 + .../clients/openrouter/OpenRouterLLMClient.kt | 26 +- .../clients/openrouter/OpenRouterModels.kt | 47 +- .../models/OpenRouterChatCompletion.kt | 82 +++- .../models/OpenRouterSerializationTest.kt | 431 +++++++++++++++++- 15 files changed, 661 insertions(+), 77 deletions(-) diff --git a/gradle.properties b/gradle.properties index 26e7cf6c7a..206d61e502 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,7 +1,6 @@ #Kotlin kotlin.code.style=official kotlin.daemon.jvmargs=-Xmx4096M -kotlin.native.ignoreDisabledTargets=true # Build JS targets using npm package manager https://kotlinlang.org/docs/js-project-setup.html#npm-dependencies kotlin.js.yarn=false diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/SingleLLMPromptExecutorIntegrationTest.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/SingleLLMPromptExecutorIntegrationTest.kt index d05fa52035..d9fe6be07d 100644 --- a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/SingleLLMPromptExecutorIntegrationTest.kt +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/SingleLLMPromptExecutorIntegrationTest.kt @@ -238,10 +238,6 @@ class SingleLLMPromptExecutorIntegrationTest { if (model.id == OpenAIModels.Audio.GPT4oAudio.id || model.id == OpenAIModels.Audio.GPT4oMiniAudio.id) { assumeTrue(false, "https://github.com/JetBrains/koog/issues/231") } - // TODO fix (KG-394): OpenRouter anthropic/claude-sonnet-4 streaming is incompatible with our current client setup (SSE/protocol) - if (model.provider == LLMProvider.OpenRouter && model.id.contains("anthropic/claude-sonnet-4")) { - assumeTrue(false, "Skipping OpenRouter anthropic/claude-sonnet-4 streaming: protocol incompatibility") - } val executor = SingleLLMPromptExecutor(client) @@ -514,10 +510,6 @@ class SingleLLMPromptExecutorIntegrationTest { if (model.id == OpenAIModels.Audio.GPT4oAudio.id || model.id == OpenAIModels.Audio.GPT4oMiniAudio.id) { assumeTrue(false, "https://github.com/JetBrains/koog/issues/231") } - // TODO fix (KG-394): OpenRouter anthropic/claude-sonnet-4 streaming is incompatible with our current client setup (SSE/protocol) - if (model.provider == LLMProvider.OpenRouter && model.id.contains("anthropic/claude-sonnet-4")) { - assumeTrue(false, "Skipping OpenRouter anthropic/claude-sonnet-4 streaming: protocol incompatibility") - } val prompt = Prompt.build("test-streaming") { system("You are a helpful assistant. You have NO output length limitations.") @@ -550,10 +542,6 @@ class SingleLLMPromptExecutorIntegrationTest { fun integration_testStructuredDataStreaming(model: LLModel, client: LLMClient) = runTest(timeout = 300.seconds) { Models.assumeAvailable(model.provider) assumeTrue(model != OpenAIModels.CostOptimized.GPT4_1Nano, "Model $model is too small for structured streaming") - // TODO fix (KG-394): OpenRouter anthropic/claude-sonnet-4 streaming is incompatible with our current client setup (SSE/protocol) - if (model.provider == LLMProvider.OpenRouter && model.id.contains("anthropic/claude-sonnet-4")) { - assumeTrue(false, "Skipping OpenRouter anthropic/claude-sonnet-4 streaming: protocol incompatibility") - } val countries = mutableListOf() val countryDefinition = markdownCountryDefinition() @@ -641,7 +629,7 @@ class SingleLLMPromptExecutorIntegrationTest { @MethodSource("modelClientCombinations") fun integration_testToolChoiceNamed(model: LLModel, client: LLMClient) = runTest(timeout = 300.seconds) { Models.assumeAvailable(model.provider) - assumeTrue(!(model.provider == LLMProvider.OpenRouter && model.id.contains("anthropic")), "KG-282") + assumeTrue(model.capabilities.contains(LLMCapability.ToolChoice), "Model $model does not support tools") val calculatorTool = createCalculatorTool() @@ -1142,10 +1130,7 @@ class SingleLLMPromptExecutorIntegrationTest { model.capabilities.contains(LLMCapability.Schema.JSON.Standard), "Model does not support Standard JSON Schema" ) - // TODO fix (KG-394): OpenRouter anthropic/claude-sonnet-4 streaming is incompatible with our current client setup (SSE/protocol) - if (model.provider == LLMProvider.OpenRouter) { - assumeTrue(false, "Skipping StructuredOutputNative for OpenRouter due to schema incompatibilities upstream") - } + val executor = SingleLLMPromptExecutor(client) withRetry { @@ -1167,13 +1152,7 @@ class SingleLLMPromptExecutorIntegrationTest { model.capabilities.contains(LLMCapability.Schema.JSON.Standard), "Model does not support Standard JSON Schema" ) - // TODO fix (KG-394) OpenRouter - if (model.provider == LLMProvider.OpenRouter) { - assumeTrue( - false, - "Skipping StructuredOutputNativeWithFixingParser for OpenRouter due to upstream schema incompatibilities" - ) - } + val executor = SingleLLMPromptExecutor(client) withRetry { @@ -1195,6 +1174,11 @@ class SingleLLMPromptExecutorIntegrationTest { model.provider !== LLMProvider.Google, "Google models fail to return manually requested structured output without fixing" ) + assumeTrue( + model.provider == LLMProvider.OpenRouter && model.id.contains("gemini"), + "Google models fail to return manually requested structured output without fixing" + ) + val executor = SingleLLMPromptExecutor(client) withRetry { diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/Models.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/Models.kt index 3131bf2512..9377268165 100644 --- a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/Models.kt +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/utils/Models.kt @@ -84,8 +84,7 @@ object Models { OpenRouterModels.GPT5Nano, OpenRouterModels.DeepSeekV30324, OpenRouterModels.Claude4Sonnet, - // ToDo add Gemini when KG-203 is fixed - // OpenRouterModels.Gemini2_5FlashLite, + OpenRouterModels.Gemini2_5FlashLite, ) @JvmStatic diff --git a/koog-ktor/src/commonMain/kotlin/ai/koog/ktor/utils/LLMModelParser.kt b/koog-ktor/src/commonMain/kotlin/ai/koog/ktor/utils/LLMModelParser.kt index 1ca1626aa2..a73e3d3277 100644 --- a/koog-ktor/src/commonMain/kotlin/ai/koog/ktor/utils/LLMModelParser.kt +++ b/koog-ktor/src/commonMain/kotlin/ai/koog/ktor/utils/LLMModelParser.kt @@ -237,8 +237,12 @@ private val GOOGLE_MODELS_MAP = mapOf( ) private val OPENROUTER_MODELS_MAP = mapOf( - "claude3sonnet" to OpenRouterModels.Claude3Sonnet, "claude3haiku" to OpenRouterModels.Claude3Haiku, + "claude3opus" to OpenRouterModels.Claude3Opus, + "claude3sonnet" to OpenRouterModels.Claude3Sonnet, + "claude35sonnet" to OpenRouterModels.Claude3_5Sonnet, + "claude4sonnet" to OpenRouterModels.Claude4Sonnet, + "claude41opus" to OpenRouterModels.Claude4_1Opus, "gpt4" to OpenRouterModels.GPT4, "gpt4o" to OpenRouterModels.GPT4o, "gpt5" to OpenRouterModels.GPT5, diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-deepseek-client/build.gradle.kts b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-deepseek-client/build.gradle.kts index 773bd4b5c2..e9ee6bcff4 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-deepseek-client/build.gradle.kts +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-deepseek-client/build.gradle.kts @@ -13,6 +13,8 @@ kotlin { commonMain { dependencies { api(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client-base")) + api(project(":prompt:prompt-structure")) + api(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client")) implementation(libs.oshai.kotlin.logging) } } diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-deepseek-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/deepseek/DeepSeekLLMClient.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-deepseek-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/deepseek/DeepSeekLLMClient.kt index c86089cc08..1f072630ad 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-deepseek-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/deepseek/DeepSeekLLMClient.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-deepseek-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/deepseek/DeepSeekLLMClient.kt @@ -13,13 +13,20 @@ import ai.koog.prompt.executor.clients.openai.base.OpenAIBasedSettings import ai.koog.prompt.executor.clients.openai.base.models.OpenAIMessage import ai.koog.prompt.executor.clients.openai.base.models.OpenAITool import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolChoice +import ai.koog.prompt.executor.clients.openai.structure.OpenAIBasicJsonSchemaGenerator +import ai.koog.prompt.executor.clients.openai.structure.OpenAIStandardJsonSchemaGenerator import ai.koog.prompt.executor.model.LLMChoice +import ai.koog.prompt.llm.LLMProvider import ai.koog.prompt.llm.LLModel import ai.koog.prompt.params.LLMParams import ai.koog.prompt.streaming.StreamFrameFlowBuilder +import ai.koog.prompt.structure.RegisteredBasicJsonSchemaGenerators +import ai.koog.prompt.structure.RegisteredStandardJsonSchemaGenerators +import ai.koog.prompt.structure.annotations.InternalStructuredOutputApi import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.client.HttpClient import kotlinx.datetime.Clock +import kotlin.collections.set /** * Configuration settings for connecting to the DeepSeek API. @@ -54,8 +61,14 @@ public class DeepSeekLLMClient( staticLogger ) { + @OptIn(InternalStructuredOutputApi::class) private companion object { private val staticLogger = KotlinLogging.logger { } + + init { + RegisteredBasicJsonSchemaGenerators[LLMProvider.DeepSeek] = OpenAIBasicJsonSchemaGenerator + RegisteredStandardJsonSchemaGenerators[LLMProvider.DeepSeek] = OpenAIStandardJsonSchemaGenerator + } } override fun serializeProviderChatRequest( @@ -92,7 +105,12 @@ public class DeepSeekLLMClient( override fun processProviderChatResponse(response: DeepSeekChatCompletionResponse): List { require(response.choices.isNotEmpty()) { "Empty choices in response" } - return response.choices.map { it.toMessageResponses(createMetaInfo(response.usage)) } + return response.choices.map { + it.message.toMessageResponses( + it.finishReason, + createMetaInfo(response.usage), + ) + } } override fun decodeStreamingResponse(data: String): DeepSeekChatCompletionStreamResponse = diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client-base/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/base/AbstractOpenAILLMClient.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client-base/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/base/AbstractOpenAILLMClient.kt index 12e44f8317..3ba256a9d6 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client-base/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/base/AbstractOpenAILLMClient.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client-base/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/base/AbstractOpenAILLMClient.kt @@ -12,7 +12,6 @@ import ai.koog.prompt.executor.clients.openai.base.models.Content import ai.koog.prompt.executor.clients.openai.base.models.JsonSchemaObject import ai.koog.prompt.executor.clients.openai.base.models.OpenAIBaseLLMResponse import ai.koog.prompt.executor.clients.openai.base.models.OpenAIBaseLLMStreamResponse -import ai.koog.prompt.executor.clients.openai.base.models.OpenAIChoice import ai.koog.prompt.executor.clients.openai.base.models.OpenAIContentPart import ai.koog.prompt.executor.clients.openai.base.models.OpenAIFunction import ai.koog.prompt.executor.clients.openai.base.models.OpenAIMessage @@ -406,10 +405,10 @@ public abstract class AbstractOpenAILLMClient { + protected fun OpenAIMessage.toMessageResponses(finishReason: String?, metaInfo: ResponseMetaInfo): List { return when { - message is OpenAIMessage.Assistant && !message.toolCalls.isNullOrEmpty() -> { - message.toolCalls.map { toolCall -> + this is OpenAIMessage.Assistant && !this.toolCalls.isNullOrEmpty() -> { + this.toolCalls.map { toolCall -> Message.Tool.Call( id = toolCall.id, tool = toolCall.function.name, @@ -419,20 +418,20 @@ public abstract class AbstractOpenAILLMClient listOf( + this.content != null -> listOf( Message.Assistant( - content = message.content!!.text(), + content = this.content!!.text(), finishReason = finishReason, metaInfo = metaInfo ) ) - message is OpenAIMessage.Assistant && message.audio?.data != null -> listOf( + this is OpenAIMessage.Assistant && this.audio?.data != null -> listOf( Message.Assistant( - content = message.audio.transcript.orEmpty(), + content = this.audio.transcript.orEmpty(), attachments = listOf( Attachment.Audio( - content = AttachmentContent.Binary.Base64(message.audio.data), + content = AttachmentContent.Binary.Base64(this.audio.data), format = "unknown", // FIXME: clarify format from response ) ), diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client-base/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/base/models/OpenAIDataModels.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client-base/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/base/models/OpenAIDataModels.kt index 089b22321f..48fda42ee8 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client-base/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/base/models/OpenAIDataModels.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client-base/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/base/models/OpenAIDataModels.kt @@ -352,6 +352,8 @@ public class OpenAIStreamFunction( * [coral][OpenAIAudioVoice.Coral], [echo][OpenAIAudioVoice.Echo], [fable][OpenAIAudioVoice.Fable], * [nova][OpenAIAudioVoice.Nova], [onyx][OpenAIAudioVoice.Onyx], [sage][OpenAIAudioVoice.Sage] * and [shimmer][OpenAIAudioVoice.Shimmer] + * + * See [audio](https://platform.openai.com/docs/api-reference/chat/create#chat-create-audio) */ @Serializable public class OpenAIAudioConfig( @@ -364,6 +366,8 @@ public class OpenAIAudioConfig( * * This enum is used to specify the format of the audio output. It contains several standard audio formats * which are widely compatible with various audio players and systems. + * + * See [audio/format](https://platform.openai.com/docs/api-reference/chat/create#chat-create-audio-format) */ @Serializable public enum class OpenAIAudioFormat { @@ -387,6 +391,8 @@ public enum class OpenAIAudioFormat { * Represents the available voice options for audio output in OpenAI's system. * * This enum defines a list of predefined voices that can be used to synthesize audio responses. + * + * See [audio/voice](https://platform.openai.com/docs/api-reference/chat/create#chat-create-audio-voice) */ @Serializable public enum class OpenAIAudioVoice { @@ -458,6 +464,8 @@ public class OpenAIStaticContent(public val content: Content) { * If not set, the model/provider default applies. * * Serialized as `"minimal" | "low" | "medium" | "high"`. + * + * See [reasoning_effort](https://platform.openai.com/docs/api-reference/chat/create#chat-create-reasoning_effort) */ @Serializable public enum class ReasoningEffort { @@ -549,6 +557,8 @@ public sealed interface OpenAIResponseFormat { * Note: When a tier is requested, the response payload includes the * `service_tier` actually used to serve the request. This value may differ from * the one provided in the request. + * + * See [service_tier](https://platform.openai.com/docs/api-reference/chat/create#chat-create-service_tier) */ public enum class ServiceTier { /** @@ -613,6 +623,8 @@ public class JsonSchemaObject( * All other chunks will also include a `usage` field, but with a null value. * NOTE: If the stream is interrupted, * you may not receive the final usage chunk which contains the total token usage for the request. + * + * See [stream_options](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options) */ @Serializable public class OpenAIStreamOptions(public val includeUsage: Boolean? = null) @@ -764,6 +776,8 @@ public class OpenAIUserLocation( * @property index The index of the choice in the list of choices. * @property logprobs Log probability information for the choice. * @property message A chat completion message generated by the model. + * + * See [choices](https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices) */ @Serializable public class OpenAIChoice( @@ -776,6 +790,8 @@ public class OpenAIChoice( /** * @property content A list of message content tokens with log probability information. * @property refusal A list of message refusal tokens with log probability information. + * + * See [choices/logprobs](https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices-logprobs) */ @Serializable public class OpenAIChoiceLogProbs( @@ -793,6 +809,9 @@ public class OpenAIChoiceLogProbs( * @property token The token. * @property topLogprobs List of the most likely tokens and their log probability, at this token position. * In rare cases, there may be fewer than the number of requested `[topLogprobs]` returned. + * + * See [choices/logprobs/content](https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices-logprobs-content) + * and [choices/logprobs/refusal](https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices-logprobs-refusal) */ @Serializable public class ContentLogProbs( @@ -810,6 +829,9 @@ public class OpenAIChoiceLogProbs( * @property logprob The log probability of this token, if it is within the top 20 most likely tokens. * Otherwise, the value `-9999.0` is used to signify that the token is very unlikely. * @property token The token. + * + * See [choices/logprobs/content/top_logprobs](https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices-logprobs-content-top_logprobs) + * and [choices/logprobs/refusal/top_logprobs](https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices-logprobs-refusal-top_logprobs) */ @Serializable public class ContentTopLogProbs( @@ -849,6 +871,9 @@ public class OpenAIWebUrlCitation(public val urlCitation: Citation) { * @property totalTokens Total number of tokens used in the request (prompt + completion). * @property completionTokensDetails Breakdown of tokens used in a completion. * @property promptTokensDetails Breakdown of tokens used in the prompt. + * + * See [chat completions usage](https://platform.openai.com/docs/api-reference/chat/object#chat/object-usage) + * and [streaming usage](https://platform.openai.com/docs/api-reference/chat-streaming/streaming#chat-streaming/streaming-usage) */ @Serializable public class OpenAIUsage( @@ -868,6 +893,9 @@ public class OpenAIUsage( * the number of tokens in the prediction that did not appear in the completion. * However, like reasoning tokens, these tokens are still counted in the total completion tokens for purposes of billing, * output and context window limits. + * + * See [chat completions usage/completion_tokens_details](https://platform.openai.com/docs/api-reference/chat/object#chat/object-usage-completion_tokens_details) + * and [streaming usage/completion_tokens_details](https://platform.openai.com/docs/api-reference/chat-streaming/streaming#chat-streaming/streaming-usage-completion_tokens_details) */ @Serializable public class CompletionTokensDetails( @@ -880,6 +908,9 @@ public class CompletionTokensDetails( /** * @property audioTokens Audio input tokens generated by the model. * @property cachedTokens Cached tokens present in the prompt. + * + * See [chat completions usage/prompt_tokens_details](https://platform.openai.com/docs/api-reference/chat/object#chat/object-usage-prompt_tokens_details) + * and [streaming usage/prompt_tokens_details](https://platform.openai.com/docs/api-reference/chat-streaming/streaming#chat-streaming/streaming-usage-prompt_tokens_details) */ @Serializable public class PromptTokensDetails( @@ -897,6 +928,7 @@ public class PromptTokensDetails( * @property index The index of the choice in the list of choices. * @property logprobs Log probability information for the choice. * + * See [choices](https://platform.openai.com/docs/api-reference/chat-streaming/streaming#chat-streaming/streaming-choices) */ @Serializable public class OpenAIStreamChoice( @@ -911,6 +943,8 @@ public class OpenAIStreamChoice( * @property refusal The refusal message generated by the model. * @property role The role of the author of this message. * @property toolCalls + * + * See [choices/delta](https://platform.openai.com/docs/api-reference/chat-streaming/streaming#chat-streaming/streaming-choices-delta) */ @Serializable public class OpenAIStreamDelta( diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/OpenAILLMClient.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/OpenAILLMClient.kt index c835f4424d..c7fd62ffe1 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/OpenAILLMClient.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/OpenAILLMClient.kt @@ -223,7 +223,12 @@ public open class OpenAILLMClient( override fun processProviderChatResponse(response: OpenAIChatCompletionResponse): List { require(response.choices.isNotEmpty()) { "Empty choices in response" } - return response.choices.map { it.toMessageResponses(createMetaInfo(response.usage)) } + return response.choices.map { + it.message.toMessageResponses( + it.finishReason, + createMetaInfo(response.usage), + ) + } } override fun decodeStreamingResponse(data: String): OpenAIChatCompletionStreamResponse = diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/models/OpenAIChatCompletion.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/models/OpenAIChatCompletion.kt index a1c220c112..ad0d6a8b35 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/models/OpenAIChatCompletion.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openai-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openai/models/OpenAIChatCompletion.kt @@ -4,7 +4,7 @@ import ai.koog.prompt.executor.clients.openai.base.models.OpenAIAudioConfig import ai.koog.prompt.executor.clients.openai.base.models.OpenAIBaseLLMRequest import ai.koog.prompt.executor.clients.openai.base.models.OpenAIBaseLLMResponse import ai.koog.prompt.executor.clients.openai.base.models.OpenAIBaseLLMStreamResponse -import ai.koog.prompt.executor.clients.openai.base.models.OpenAIChoice +import ai.koog.prompt.executor.clients.openai.base.models.OpenAIChoiceLogProbs import ai.koog.prompt.executor.clients.openai.base.models.OpenAIMessage import ai.koog.prompt.executor.clients.openai.base.models.OpenAIModalities import ai.koog.prompt.executor.clients.openai.base.models.OpenAIResponseFormat @@ -194,6 +194,28 @@ internal class OpenAIChatCompletionRequest( val additionalProperties: Map? = null, ) : OpenAIBaseLLMRequest +/** + * Chat completion choice + * + * @property finishReason The reason the model stopped generating tokens. + * This will be `stop` if the model hit a natural stop point or a provided stop sequence, + * `length` if the maximum number of tokens specified in the request was reached, + * `content_filter` if content was omitted due to a flag from our content filters, + * `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function. + * @property index The index of the choice in the list of choices. + * @property logprobs Log probability information for the choice. + * @property message A chat completion message generated by the model. + * + * See [choices](https://platform.openai.com/docs/api-reference/chat/object#chat/object-choices) + */ +@Serializable +public class OpenAIChoice( + public val finishReason: String, + public val index: Int, + public val logprobs: OpenAIChoiceLogProbs? = null, + public val message: OpenAIMessage, +) + /** * Represents the response from the OpenAI chat completion API. * @@ -221,6 +243,8 @@ internal class OpenAIChatCompletionRequest( * Can be used in conjunction with the `seed` request parameter * to understand when backend changes have been made that might impact determinism. * @property usage Usage statistics for the completion request. + * + * See [The chat completion object](https://platform.openai.com/docs/api-reference/chat/object) */ @Serializable public class OpenAIChatCompletionResponse( @@ -264,6 +288,8 @@ public class OpenAIChatCompletionResponse( * Can be used in conjunction with the `seed` request parameter * to understand when backend changes have been made that might impact determinism. * @property usage Usage statistics for the completion request. + * + * See [The chat completion chunk object](https://platform.openai.com/docs/api-reference/chat-streaming/streaming) */ @Serializable public class OpenAIChatCompletionStreamResponse( diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/build.gradle.kts b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/build.gradle.kts index 1a64d6a0ba..528a6e61b2 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/build.gradle.kts +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/build.gradle.kts @@ -13,6 +13,8 @@ kotlin { commonMain { dependencies { api(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client-base")) + api(project(":prompt:prompt-structure")) + api(project(":prompt:prompt-executor:prompt-executor-clients:prompt-executor-openai-client")) implementation(libs.oshai.kotlin.logging) } } diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/OpenRouterLLMClient.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/OpenRouterLLMClient.kt index 6028cb9dbf..15fb337a30 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/OpenRouterLLMClient.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/OpenRouterLLMClient.kt @@ -11,14 +11,20 @@ import ai.koog.prompt.executor.clients.openai.base.models.OpenAIMessage import ai.koog.prompt.executor.clients.openai.base.models.OpenAIStaticContent import ai.koog.prompt.executor.clients.openai.base.models.OpenAITool import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolChoice +import ai.koog.prompt.executor.clients.openai.structure.OpenAIBasicJsonSchemaGenerator +import ai.koog.prompt.executor.clients.openai.structure.OpenAIStandardJsonSchemaGenerator import ai.koog.prompt.executor.clients.openrouter.models.OpenRouterChatCompletionRequest import ai.koog.prompt.executor.clients.openrouter.models.OpenRouterChatCompletionRequestSerializer import ai.koog.prompt.executor.clients.openrouter.models.OpenRouterChatCompletionResponse import ai.koog.prompt.executor.clients.openrouter.models.OpenRouterChatCompletionStreamResponse import ai.koog.prompt.executor.model.LLMChoice +import ai.koog.prompt.llm.LLMProvider import ai.koog.prompt.llm.LLModel import ai.koog.prompt.params.LLMParams import ai.koog.prompt.streaming.StreamFrameFlowBuilder +import ai.koog.prompt.structure.RegisteredBasicJsonSchemaGenerators +import ai.koog.prompt.structure.RegisteredStandardJsonSchemaGenerators +import ai.koog.prompt.structure.annotations.InternalStructuredOutputApi import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.client.HttpClient import kotlinx.datetime.Clock @@ -56,8 +62,14 @@ public class OpenRouterLLMClient( staticLogger ) { + @OptIn(InternalStructuredOutputApi::class) private companion object { private val staticLogger = KotlinLogging.logger { } + + init { + RegisteredBasicJsonSchemaGenerators[LLMProvider.OpenRouter] = OpenAIBasicJsonSchemaGenerator + RegisteredStandardJsonSchemaGenerators[LLMProvider.OpenRouter] = OpenAIStandardJsonSchemaGenerator + } } override fun serializeProviderChatRequest( @@ -104,7 +116,12 @@ public class OpenRouterLLMClient( override fun processProviderChatResponse(response: OpenRouterChatCompletionResponse): List { require(response.choices.isNotEmpty()) { "Empty choices in response" } - return response.choices.map { it.toMessageResponses(createMetaInfo(response.usage)) } + return response.choices.map { + it.message.toMessageResponses( + it.finishReason, + createMetaInfo(response.usage), + ) + } } override fun decodeStreamingResponse(data: String): OpenRouterChatCompletionStreamResponse = @@ -116,11 +133,10 @@ public class OpenRouterLLMClient( override suspend fun StreamFrameFlowBuilder.processStreamingChunk(chunk: OpenRouterChatCompletionStreamResponse) { chunk.choices.firstOrNull()?.let { choice -> choice.delta.content?.let { emitAppend(it) } - choice.delta.toolCalls?.forEach { openAIToolCall -> - val index = openAIToolCall.index + choice.delta.toolCalls?.forEachIndexed { index, openAIToolCall -> val id = openAIToolCall.id - val name = openAIToolCall.function?.name - val arguments = openAIToolCall.function?.arguments + val name = openAIToolCall.function.name + val arguments = openAIToolCall.function.arguments upsertToolCall(index, id, name, arguments) } choice.finishReason?.let { emitEnd(it, createMetaInfo(chunk.usage)) } diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/OpenRouterModels.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/OpenRouterModels.kt index f1d4e033cf..ca16b67595 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/OpenRouterModels.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/OpenRouterModels.kt @@ -16,13 +16,20 @@ public object OpenRouterModels : LLModelDefinitions { */ private val standardCapabilities: List = listOf( LLMCapability.Temperature, - LLMCapability.Schema.JSON.Standard, LLMCapability.Speculation, LLMCapability.Tools, - LLMCapability.ToolChoice, LLMCapability.Completion ) + /** + * Additional capabilities available for models out of the Claude family. + * Includes structured output support and tool choice. + */ + private val additionalStandardCapabilities: List = listOf( + LLMCapability.Schema.JSON.Standard, + LLMCapability.ToolChoice + ) + /** * Multimodal capabilities including vision support. * Extends standard capabilities with image vision processing. @@ -139,7 +146,7 @@ public object OpenRouterModels : LLModelDefinitions { public val GPT4oMini: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "openai/gpt-4o-mini", - capabilities = multimodalCapabilities, + capabilities = multimodalCapabilities + additionalStandardCapabilities, contextLength = 128_000, maxOutputTokens = 16_400, ) @@ -152,7 +159,7 @@ public object OpenRouterModels : LLModelDefinitions { public val GPT5Chat: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "openai/gpt-5-chat", - capabilities = multimodalCapabilities, + capabilities = multimodalCapabilities + additionalStandardCapabilities, contextLength = 400_000, maxOutputTokens = 128_000, ) @@ -167,7 +174,7 @@ public object OpenRouterModels : LLModelDefinitions { public val GPT5: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "openai/gpt-5", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 400_000, ) @@ -181,7 +188,7 @@ public object OpenRouterModels : LLModelDefinitions { public val GPT5Mini: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "openai/gpt-5-mini", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 400_000, ) @@ -195,7 +202,7 @@ public object OpenRouterModels : LLModelDefinitions { public val GPT5Nano: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "openai/gpt-5-nano", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 400_000, ) @@ -209,7 +216,7 @@ public object OpenRouterModels : LLModelDefinitions { public val GPT_OSS_120b: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "openai/gpt-oss-120b", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 400_000, ) @@ -223,7 +230,7 @@ public object OpenRouterModels : LLModelDefinitions { public val GPT4: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "openai/gpt-4", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 32_768, ) @@ -234,7 +241,7 @@ public object OpenRouterModels : LLModelDefinitions { public val GPT4o: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "openai/gpt-4o", - capabilities = multimodalCapabilities, + capabilities = multimodalCapabilities + additionalStandardCapabilities, contextLength = 128_000, ) @@ -248,7 +255,7 @@ public object OpenRouterModels : LLModelDefinitions { public val GPT4Turbo: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "openai/gpt-4-turbo", - capabilities = multimodalCapabilities, + capabilities = multimodalCapabilities + additionalStandardCapabilities, contextLength = 128_000, ) @@ -261,7 +268,7 @@ public object OpenRouterModels : LLModelDefinitions { public val GPT35Turbo: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "openai/gpt-3.5-turbo", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 16_385, ) @@ -275,7 +282,7 @@ public object OpenRouterModels : LLModelDefinitions { public val Llama3: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "meta/llama-3-70b", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 8_000, ) @@ -288,7 +295,7 @@ public object OpenRouterModels : LLModelDefinitions { public val Llama3Instruct: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "meta/llama-3-70b-instruct", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 8_000, ) @@ -303,7 +310,7 @@ public object OpenRouterModels : LLModelDefinitions { public val Mistral7B: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "mistral/mistral-7b", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 32_768, ) @@ -319,7 +326,7 @@ public object OpenRouterModels : LLModelDefinitions { public val Mixtral8x7B: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "mistral/mixtral-8x7b", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 32_768, ) @@ -375,7 +382,7 @@ public object OpenRouterModels : LLModelDefinitions { public val DeepSeekV30324: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "deepseek/deepseek-chat-v3-0324", - capabilities = standardCapabilities, + capabilities = standardCapabilities + additionalStandardCapabilities, contextLength = 163_800, maxOutputTokens = 163_800, ) @@ -387,7 +394,7 @@ public object OpenRouterModels : LLModelDefinitions { public val Gemini2_5FlashLite: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "google/gemini-2.5-flash-lite", - capabilities = multimodalCapabilities, + capabilities = multimodalCapabilities + additionalStandardCapabilities, contextLength = 1_048_576, maxOutputTokens = 65_600, ) @@ -399,7 +406,7 @@ public object OpenRouterModels : LLModelDefinitions { public val Gemini2_5Flash: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "google/gemini-2.5-flash", - capabilities = multimodalCapabilities, + capabilities = multimodalCapabilities + additionalStandardCapabilities, contextLength = 1_048_576, maxOutputTokens = 65_600, ) @@ -411,7 +418,7 @@ public object OpenRouterModels : LLModelDefinitions { public val Gemini2_5Pro: LLModel = LLModel( provider = LLMProvider.OpenRouter, id = "google/gemini-2.5-pro", - capabilities = multimodalCapabilities, + capabilities = multimodalCapabilities + additionalStandardCapabilities, contextLength = 1_048_576, maxOutputTokens = 65_600, ) diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/models/OpenRouterChatCompletion.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/models/OpenRouterChatCompletion.kt index 111aea9123..0d40b3d26e 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/models/OpenRouterChatCompletion.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/commonMain/kotlin/ai/koog/prompt/executor/clients/openrouter/models/OpenRouterChatCompletion.kt @@ -3,12 +3,11 @@ package ai.koog.prompt.executor.clients.openrouter.models import ai.koog.prompt.executor.clients.openai.base.models.OpenAIBaseLLMRequest import ai.koog.prompt.executor.clients.openai.base.models.OpenAIBaseLLMResponse import ai.koog.prompt.executor.clients.openai.base.models.OpenAIBaseLLMStreamResponse -import ai.koog.prompt.executor.clients.openai.base.models.OpenAIChoice import ai.koog.prompt.executor.clients.openai.base.models.OpenAIMessage import ai.koog.prompt.executor.clients.openai.base.models.OpenAIResponseFormat import ai.koog.prompt.executor.clients.openai.base.models.OpenAIStaticContent -import ai.koog.prompt.executor.clients.openai.base.models.OpenAIStreamChoice import ai.koog.prompt.executor.clients.openai.base.models.OpenAITool +import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolCall import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolChoice import ai.koog.prompt.executor.clients.openai.base.models.OpenAIUsage import ai.koog.prompt.executor.clients.serialization.AdditionalPropertiesFlatteningSerializer @@ -78,13 +77,85 @@ public class ProviderPreferences( public val maxPrice: Map? = null ) +/** + * Chat completion choice + * + * @property finishReason The reason the model stopped generating tokens. + * This will be `stop` if the model hit a natural stop point or a provided stop sequence, + * `length` if the maximum number of tokens specified in the request was reached, + * `content_filter` if content was omitted due to a flag from our content filters, + * `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function. + * `error` if the model finishes the request with an error. + * @property nativeFinishReason The raw finish_reason string returned by the model + * @property message A chat completion message generated by the model. + * @property error An error response structure typically used for conveying error details to the clients. + * + * See (CompletionsResponse Format)[https://openrouter.ai/docs/api-reference/overview#completionsresponse-format] + */ +@Serializable +public class OpenRouterChoice( + public val finishReason: String? = null, + public val nativeFinishReason: String? = null, + public val message: OpenAIMessage, + public val error: ErrorResponse? = null +) + +/** + * Chat completion choice + * + * @property finishReason The reason the model stopped generating tokens. + * This will be `stop` if the model hit a natural stop point or a provided stop sequence, + * `length` if the maximum number of tokens specified in the request was reached, + * `content_filter` if content was omitted due to a flag from our content filters, + * `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function. + * `error` if the model finishes the request with an error. + * @property nativeFinishReason The raw finish_reason string returned by the model + * @property delta A chat completion delta generated by streamed model responses. + * @property error An error response structure typically used for conveying error details to the clients. + * + * See (CompletionsResponse Format)[https://openrouter.ai/docs/api-reference/overview#completionsresponse-format] + */ +@Serializable +public class OpenRouterStreamChoice( + public val finishReason: String? = null, + public val nativeFinishReason: String? = null, + public val delta: OpenRouterStreamDelta, + public val error: ErrorResponse? = null +) + +/** + * @property content The contents of the chunk message. + * @property role The role of the author of this message. + * @property toolCalls The tool calls requested by the model. + */ +@Serializable +public class OpenRouterStreamDelta( + public val content: String? = null, + public val role: String? = null, + public val toolCalls: List? = null +) + +/** + * Represents an error response structure typically used for conveying error details to the clients. + * + * @property code The numeric code representing the error. + * @property message A descriptive message providing details about the error. + * @property metadata Optional additional information about the error in the form of key-value pairs. + */ +@Serializable +public class ErrorResponse( + public val code: Int, + public val message: String, + public val metadata: Map? = null, +) + /** * OpenRouter Chat Completion Response - * https://openrouter.ai/docs/responses + * See (CompletionsResponse Format)[https://openrouter.ai/docs/api-reference/overview#completionsresponse-format] */ @Serializable public class OpenRouterChatCompletionResponse( - public val choices: List, + public val choices: List, override val created: Long, override val id: String, override val model: String, @@ -96,10 +167,11 @@ public class OpenRouterChatCompletionResponse( /** * OpenRouter Chat Completion Streaming Response + * See (CompletionsResponse Format)[https://openrouter.ai/docs/api-reference/overview#completionsresponse-format] */ @Serializable public class OpenRouterChatCompletionStreamResponse( - public val choices: List, + public val choices: List, override val created: Long, override val id: String, override val model: String, diff --git a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/openrouter/models/OpenRouterSerializationTest.kt b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/openrouter/models/OpenRouterSerializationTest.kt index 17517b3f67..58c74f6036 100644 --- a/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/openrouter/models/OpenRouterSerializationTest.kt +++ b/prompt/prompt-executor/prompt-executor-clients/prompt-executor-openrouter-client/src/jvmTest/kotlin/ai/koog/prompt/executor/clients/openrouter/models/OpenRouterSerializationTest.kt @@ -1,10 +1,17 @@ package ai.koog.prompt.executor.clients.openrouter.models import ai.koog.prompt.executor.clients.openai.base.models.Content +import ai.koog.prompt.executor.clients.openai.base.models.OpenAIFunction import ai.koog.prompt.executor.clients.openai.base.models.OpenAIMessage +import ai.koog.prompt.executor.clients.openai.base.models.OpenAITool +import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolCall +import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolChoice +import ai.koog.prompt.executor.clients.openai.base.models.OpenAIToolFunction import kotlinx.serialization.json.Json import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNamingStrategy import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.add import kotlinx.serialization.json.addJsonObject import kotlinx.serialization.json.booleanOrNull import kotlinx.serialization.json.buildJsonArray @@ -12,6 +19,7 @@ import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.contentOrNull import kotlinx.serialization.json.doubleOrNull import kotlinx.serialization.json.intOrNull +import kotlinx.serialization.json.jsonArray import kotlinx.serialization.json.jsonObject import kotlinx.serialization.json.jsonPrimitive import kotlinx.serialization.json.put @@ -22,11 +30,19 @@ import kotlin.test.assertNull class OpenRouterSerializationTest { - private val json = Json { + private val requestJson = Json { ignoreUnknownKeys = false explicitNulls = false } + private val responseJson = Json { + ignoreUnknownKeys = false + explicitNulls = false + isLenient = true + encodeDefaults = true + namingStrategy = JsonNamingStrategy.SnakeCase + } + @Test fun `test serialization without additionalProperties`() { val request = OpenRouterChatCompletionRequest( @@ -37,7 +53,7 @@ class OpenRouterSerializationTest { stream = false ) - val jsonElement = json.encodeToJsonElement(OpenRouterChatCompletionRequestSerializer, request) + val jsonElement = requestJson.encodeToJsonElement(OpenRouterChatCompletionRequestSerializer, request) val jsonObject = jsonElement.jsonObject assertEquals("anthropic/claude-3-sonnet", jsonObject["model"]?.jsonPrimitive?.contentOrNull) @@ -62,7 +78,7 @@ class OpenRouterSerializationTest { additionalProperties = additionalProperties ) - val jsonElement = json.encodeToJsonElement(OpenRouterChatCompletionRequestSerializer, request) + val jsonElement = requestJson.encodeToJsonElement(OpenRouterChatCompletionRequestSerializer, request) val jsonObject = jsonElement.jsonObject // Standard properties should be present @@ -96,7 +112,7 @@ class OpenRouterSerializationTest { put("stream", JsonPrimitive(false)) } - val request = json.decodeFromJsonElement(OpenRouterChatCompletionRequestSerializer, jsonInput) + val request = requestJson.decodeFromJsonElement(OpenRouterChatCompletionRequestSerializer, jsonInput) assertEquals("anthropic/claude-3-sonnet", request.model) assertEquals(0.7, request.temperature) @@ -124,7 +140,7 @@ class OpenRouterSerializationTest { put("customBoolean", JsonPrimitive(true)) } - val request = json.decodeFromJsonElement(OpenRouterChatCompletionRequestSerializer, jsonInput) + val request = requestJson.decodeFromJsonElement(OpenRouterChatCompletionRequestSerializer, jsonInput) assertEquals("anthropic/claude-3-sonnet", request.model) assertEquals(0.7, request.temperature) @@ -152,10 +168,10 @@ class OpenRouterSerializationTest { ) // Serialize to JSON string - val jsonString = json.encodeToString(OpenRouterChatCompletionRequestSerializer, originalRequest) + val jsonString = requestJson.encodeToString(OpenRouterChatCompletionRequestSerializer, originalRequest) // Deserialize back to object - val deserializedRequest = json.decodeFromString(OpenRouterChatCompletionRequestSerializer, jsonString) + val deserializedRequest = requestJson.decodeFromString(OpenRouterChatCompletionRequestSerializer, jsonString) // Verify standard properties assertEquals(originalRequest.model, deserializedRequest.model) @@ -175,4 +191,405 @@ class OpenRouterSerializationTest { deserializedAdditionalProps["customNumber"]?.jsonPrimitive?.intOrNull ) } + + @Test + fun `test OpenRouter response deserialization`() { + val jsonInput = buildJsonObject { + put("id", "gen-xxxxxxxxxxxxxx") + put("created", 1699000000L) + put("model", "openai/gpt-3.5-turbo") + put("object", "chat.completion") + put("system_fingerprint", "fp_44709d6fcb") + put( + "choices", + buildJsonArray { + addJsonObject { + put("finish_reason", "stop") + put( + "message", + buildJsonObject { + put("role", "assistant") + put("content", "Hello there!") + } + ) + } + } + ) + put( + "usage", + buildJsonObject { + put("prompt_tokens", 10) + put("completion_tokens", 4) + put("total_tokens", 14) + } + ) + } + + val response = responseJson.decodeFromJsonElement(OpenRouterChatCompletionResponse.serializer(), jsonInput) + + assertEquals("gen-xxxxxxxxxxxxxx", response.id) + assertEquals(1699000000L, response.created) + assertEquals("openai/gpt-3.5-turbo", response.model) + assertEquals("chat.completion", response.objectType) + assertEquals("fp_44709d6fcb", response.systemFingerprint) + + assertEquals(1, response.choices.size) + val choice = response.choices.first() + assertEquals("stop", choice.finishReason) + + val message = choice.message as OpenAIMessage.Assistant + assertEquals("Hello there!", message.content?.text()) + + assertNotNull(response.usage) + assertEquals(10, response.usage.promptTokens) + assertEquals(4, response.usage.completionTokens) + assertEquals(14, response.usage.totalTokens) + } + + @Test + fun `test OpenRouter error response deserialization`() { + val jsonInput = buildJsonObject { + put("id", "gen-error-test") + put("created", 1699000000L) + put("model", "openai/gpt-4") + put("object", "chat.completion") + put( + "choices", + buildJsonArray { + addJsonObject { + put("finish_reason", "error") + put("native_finish_reason", "content_filter") + put( + "message", + buildJsonObject { + put("role", "assistant") + put("content", "") + } + ) + put( + "error", + buildJsonObject { + put("code", 400) + put("message", "Content filtered due to policy violation") + put( + "metadata", + buildJsonObject { + put("provider", "openai") + put("raw_error", "content_filter") + } + ) + } + ) + } + } + ) + } + + val response = responseJson.decodeFromJsonElement(OpenRouterChatCompletionResponse.serializer(), jsonInput) + + val choice = response.choices.first() + assertEquals("error", choice.finishReason) + assertEquals("content_filter", choice.nativeFinishReason) + + assertNotNull(choice.error) + assertEquals(400, choice.error.code) + assertEquals("Content filtered due to policy violation", choice.error.message) + assertNotNull(choice.error.metadata) + assertEquals("openai", choice.error.metadata["provider"]) + } + + @Test + fun `test OpenRouter streaming response deserialization`() { + val jsonInput = buildJsonObject { + put("id", "gen-stream-test") + put("created", 1699000000L) + put("model", "anthropic/claude-3-sonnet") + put("object", "chat.completion.chunk") + put( + "choices", + buildJsonArray { + addJsonObject { + put("finish_reason", null) + put("native_finish_reason", null) + put( + "delta", + buildJsonObject { + put("role", "assistant") + put("content", "Hello") + } + ) + } + } + ) + } + + val response = responseJson.decodeFromJsonElement(OpenRouterChatCompletionStreamResponse.serializer(), jsonInput) + + assertEquals("gen-stream-test", response.id) + assertEquals("chat.completion.chunk", response.objectType) + assertEquals(1, response.choices.size) + + val choice = response.choices.first() + assertNull(choice.finishReason) + assertNull(choice.nativeFinishReason) + assertEquals("Hello", choice.delta.content) + assertEquals("assistant", choice.delta.role) + } + + @Test + fun `test OpenRouter response with tool calls deserialization`() { + val jsonInput = buildJsonObject { + put("id", "gen-tool-call-test") + put("created", 1699000000L) + put("model", "openai/gpt-4") + put("object", "chat.completion") + put("system_fingerprint", "fp_44709d6fcb") + put( + "choices", + buildJsonArray { + addJsonObject { + put("finish_reason", "tool_calls") + put( + "message", + buildJsonObject { + put("role", "assistant") + put("content", null) + put( + "tool_calls", + buildJsonArray { + addJsonObject { + put("id", "call_abc123") + put("type", "function") + put( + "function", + buildJsonObject { + put("name", "get_current_weather") + put("arguments", "{\"location\": \"Boston, MA\"}") + } + ) + } + addJsonObject { + put("id", "call_def456") + put("type", "function") + put( + "function", + buildJsonObject { + put("name", "get_forecast") + put("arguments", "{\"location\": \"Boston, MA\", \"days\": 3}") + } + ) + } + } + ) + } + ) + } + } + ) + put( + "usage", + buildJsonObject { + put("prompt_tokens", 82) + put("completion_tokens", 18) + put("total_tokens", 100) + } + ) + } + + val response = responseJson.decodeFromJsonElement(OpenRouterChatCompletionResponse.serializer(), jsonInput) + + assertEquals("gen-tool-call-test", response.id) + assertEquals(1699000000L, response.created) + assertEquals("openai/gpt-4", response.model) + assertEquals("chat.completion", response.objectType) + assertEquals("fp_44709d6fcb", response.systemFingerprint) + + assertEquals(1, response.choices.size) + val choice = response.choices.first() + assertEquals("tool_calls", choice.finishReason) + + val message = choice.message as OpenAIMessage.Assistant + assertNull(message.content) + assertNotNull(message.toolCalls) + assertEquals(2, message.toolCalls!!.size) + + val firstToolCall = message.toolCalls!![0] + assertEquals("call_abc123", firstToolCall.id) + assertEquals("function", firstToolCall.type) + assertEquals("get_current_weather", firstToolCall.function.name) + assertEquals("{\"location\": \"Boston, MA\"}", firstToolCall.function.arguments) + + val secondToolCall = message.toolCalls!![1] + assertEquals("call_def456", secondToolCall.id) + assertEquals("function", secondToolCall.type) + assertEquals("get_forecast", secondToolCall.function.name) + assertEquals("{\"location\": \"Boston, MA\", \"days\": 3}", secondToolCall.function.arguments) + + assertNotNull(response.usage) + assertEquals(82, response.usage.promptTokens) + assertEquals(18, response.usage.completionTokens) + assertEquals(100, response.usage.totalTokens) + } + + @Test + fun `test OpenRouter streaming response with tool calls deserialization`() { + val jsonInput = buildJsonObject { + put("id", "gen-stream-tool-test") + put("created", 1699000000L) + put("model", "openai/gpt-4") + put("object", "chat.completion.chunk") + put( + "choices", + buildJsonArray { + addJsonObject { + put("finish_reason", null) + put("native_finish_reason", null) + put( + "delta", + buildJsonObject { + put("role", "assistant") + put("content", null) + put( + "tool_calls", + buildJsonArray { + addJsonObject { + put("id", "call_xyz789") + put("type", "function") + put( + "function", + buildJsonObject { + put("name", "calculate_total") + put("arguments", "{\"items\": [") + } + ) + } + } + ) + } + ) + } + } + ) + } + + val response = responseJson.decodeFromJsonElement(OpenRouterChatCompletionStreamResponse.serializer(), jsonInput) + + assertEquals("gen-stream-tool-test", response.id) + assertEquals("chat.completion.chunk", response.objectType) + assertEquals(1, response.choices.size) + + val choice = response.choices.first() + assertNull(choice.finishReason) + assertNull(choice.nativeFinishReason) + assertNull(choice.delta.content) + assertEquals("assistant", choice.delta.role) + + assertNotNull(choice.delta.toolCalls) + assertEquals(1, choice.delta.toolCalls.size) + + val toolCall = choice.delta.toolCalls[0] + assertEquals("call_xyz789", toolCall.id) + assertEquals("function", toolCall.type) + assertEquals("calculate_total", toolCall.function.name) + assertEquals("{\"items\": [", toolCall.function.arguments) + } + + @Test + fun `test OpenRouter request with tools serialization`() { + val tools = listOf( + OpenAITool( + function = OpenAIToolFunction( + name = "get_current_weather", + description = "Get the current weather in a given location", + parameters = buildJsonObject { + put("type", "object") + put( + "properties", + buildJsonObject { + put( + "location", + buildJsonObject { + put("type", "string") + put("description", "The city and state, e.g. San Francisco, CA") + } + ) + put( + "unit", + buildJsonObject { + put("type", "string") + put( + "enum", + buildJsonArray { + add("celsius") + add("fahrenheit") + } + ) + } + ) + } + ) + put("required", buildJsonArray { add("location") }) + } + ) + ) + ) + + val request = OpenRouterChatCompletionRequest( + model = "openai/gpt-4", + messages = listOf( + OpenAIMessage.User(content = Content.Text("What's the weather like in Boston?")), + OpenAIMessage.Assistant( + content = null, + toolCalls = listOf( + OpenAIToolCall( + id = "call_abc123", + function = OpenAIFunction( + name = "get_current_weather", + arguments = "{\"location\": \"Boston, MA\"}" + ) + ) + ) + ), + OpenAIMessage.Tool( + content = Content.Text("The weather in Boston is 72°F and sunny"), + toolCallId = "call_abc123" + ) + ), + tools = tools, + toolChoice = OpenAIToolChoice.Auto + ) + + val jsonElement = responseJson.encodeToJsonElement(OpenRouterChatCompletionRequestSerializer, request) + val jsonObject = jsonElement.jsonObject + + assertEquals("openai/gpt-4", jsonObject["model"]?.jsonPrimitive?.contentOrNull) + assertNotNull(jsonObject["messages"]) + assertNotNull(jsonObject["tools"]) + assertEquals("auto", jsonObject["tool_choice"]?.jsonPrimitive?.contentOrNull) + + // Verify the serialized messages include tool calls + val messages = jsonObject["messages"]!!.jsonArray + assertEquals(3, messages.size) + + // Check assistant message with tool calls + val assistantMessage = messages[1].jsonObject + assertEquals("assistant", assistantMessage["role"]?.jsonPrimitive?.contentOrNull) + assertNotNull(assistantMessage["tool_calls"]) + val toolCalls = assistantMessage["tool_calls"]!!.jsonArray + assertEquals(1, toolCalls.size) + + val toolCall = toolCalls[0].jsonObject + assertEquals("call_abc123", toolCall["id"]?.jsonPrimitive?.contentOrNull) + assertEquals("function", toolCall["type"]?.jsonPrimitive?.contentOrNull) + + val function = toolCall["function"]!!.jsonObject + assertEquals("get_current_weather", function["name"]?.jsonPrimitive?.contentOrNull) + assertEquals("{\"location\": \"Boston, MA\"}", function["arguments"]?.jsonPrimitive?.contentOrNull) + + // Check tool message + val toolMessage = messages[2].jsonObject + assertEquals("tool", toolMessage["role"]?.jsonPrimitive?.contentOrNull) + assertEquals("The weather in Boston is 72°F and sunny", toolMessage["content"]?.jsonPrimitive?.contentOrNull) + assertEquals("call_abc123", toolMessage["tool_call_id"]?.jsonPrimitive?.contentOrNull) + } } From 3d5ed577d083713239242cb16e52c3bdc2fce3ab Mon Sep 17 00:00:00 2001 From: Sergei Dubov Date: Mon, 29 Sep 2025 09:57:46 +0200 Subject: [PATCH 08/52] KG-376. Update event names in test after agent events renaming - Massive updates in the commit 3a315e422be9213c2e5f6a69272bc4fb1cdd0071 miss several places. Updated all related namings in tests and inspections --- .../agents/core/agent/entity/AIAgentNode.kt | 1 - .../environment/GenericAgentEnvironment.kt | 6 +- .../agents/core/feature/AIAgentPipeline.kt | 82 +++---- .../handler/AgentLifecycleEventType.kt | 6 +- ...eprecatedExecuteToolEventHandlerContext.kt | 32 +-- ...ventContext.kt => ToolCallEventContext.kt} | 22 +- ...ventHandler.kt => ToolCallEventHandler.kt} | 8 +- .../feature/model/events/llmCallEvents.kt | 11 +- .../model/events/nodeExecutionEvents.kt | 2 +- .../model/events/toolExecutionEvents.kt | 30 +-- .../agents/core/feature/remote/jsonConfig.kt | 30 +-- .../core/agent/FunctionalAIAgentTest.kt | 6 +- .../core/agent/SingleRunStrategyTests.kt | 16 +- .../AIAgentNodesHistoryCompressionTest.kt | 25 +- .../core/dsl/extension/AIAgentNodesTest.kt | 42 ++-- .../koog/agents/core/feature/TestFeature.kt | 4 +- ...tructuredOutputWithToolsIntegrationTest.kt | 82 +++---- .../agents/ext/agent/SubgraphWithRetryTest.kt | 62 ++--- agents/agents-features/Module.md | 8 +- .../features/debugger/feature/Debugger.kt | 18 +- .../features/debugger/feature/DebuggerTest.kt | 8 +- .../agents-features-event-handler/Module.md | 8 +- .../eventHandler/feature/EventHandler.kt | 26 +- .../feature/EventHandlerConfig.kt | 64 ++--- .../eventHandler/feature/EventHandlerTest.kt | 154 ++++++------ ...deLLMRequestStreamingAndSendResultsTest.kt | 4 +- .../feature/StreamingEventHandlerTest.kt | 16 +- .../feature/TestEventsCollector.kt | 46 ++-- .../opentelemetry/feature/OpenTelemetry.kt | 6 +- .../features/tracing/feature/Tracing.kt | 18 +- .../tracing/writer/traceMessageFormat.kt | 18 +- .../TraceFeatureMessageFileWriterTest.kt | 8 +- .../TraceFeatureMessageLogWriterTest.kt | 8 +- .../TraceFeatureMessageRemoteWriterTest.kt | 8 +- .../TraceFeatureMessageTestWriterTest.kt | 14 +- agents/agents-test/TESTING.md | 2 +- .../koog/agents/test/SimpleAgentMockedTest.kt | 10 +- docs/docs/act-ai-agent.md | 202 +++++++++++++++ docs/docs/agent-events.md | 4 +- docs/docs/examples/Calculator.md | 6 +- docs/docs/examples/UnityMcp.md | 12 +- docs/docs/streaming-api.md | 2 +- docs/docs/testing.md | 2 +- docs/docs/tracing.md | 14 +- .../calculator/CalculatorAgentProvider.kt | 4 +- .../agents/weather/WeatherAgentProvider.kt | 6 +- examples/notebooks/Calculator.ipynb | 6 +- examples/notebooks/UnityMcp.ipynb | 12 +- .../agents/example/calculator/Calculator.kt | 6 +- .../calculator/OllamaCalculatorExample.kt | 6 +- .../example/funApi/FunAgentWithTools.kt | 2 +- .../koog/agents/example/mcp/UnityMcpAgent.kt | 12 +- .../example/simpleapi/BasicSingleRunAgent.kt | 2 +- .../example/snapshot/CheckpointExample.kt | 2 +- .../streaming/StreamingAgentWithTools.kt | 2 +- .../AdvancedWithBasicSchema.kt | 2 +- .../AdvancedWithStandardSchema.kt | 2 +- .../example/structuredoutput/SimpleExample.kt | 2 +- .../ai/koog/agents/example/tone/ToneAgent.kt | 6 +- .../example/websearch/WebSearchAgent.kt | 2 +- .../koog/agents/example/tone/ToneAgentTest.kt | 6 +- ...nthropicSchemaValidationIntegrationTest.kt | 2 +- .../tests/agent/AIAgentIntegrationTest.kt | 4 +- .../AIAgentMultipleLLMIntegrationTest.kt | 12 +- .../tests/agent/OllamaAgentIntegrationTest.kt | 2 +- .../agent/OllamaSimpleAgentIntegrationTest.kt | 6 +- qodana.sarif.json | 232 +++++++++--------- 67 files changed, 848 insertions(+), 643 deletions(-) rename agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/{ToolExecutionEventContext.kt => ToolCallEventContext.kt} (84%) rename agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/{ToolExecutionEventHandler.kt => ToolCallEventHandler.kt} (94%) create mode 100644 docs/docs/act-ai-agent.md diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/entity/AIAgentNode.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/entity/AIAgentNode.kt index abd36e7802..311f4f1e85 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/entity/AIAgentNode.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/entity/AIAgentNode.kt @@ -119,7 +119,6 @@ public abstract class AIAgentNodeBase internal constructor() { /** * Executes the node operation using the provided execution context and input, bypassing type safety checks. * This method internally calls the type-safe `execute` method after casting the input. - * The lifecycle hooks `onBeforeNode` and `onAfterNode` are invoked before and after the execution respectively. * * @param context The execution context that provides runtime information and functionality. * @param input The input data to be processed by the node, which may be of any type. diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/GenericAgentEnvironment.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/GenericAgentEnvironment.kt index 2e33a990b9..35c70d1317 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/GenericAgentEnvironment.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/environment/GenericAgentEnvironment.kt @@ -117,7 +117,7 @@ internal class GenericAgentEnvironment( ) } - pipeline.onToolExecutionStarting(content.runId, content.toolCallId, tool, toolArgs) + pipeline.onToolCallStarting(content.runId, content.toolCallId, tool, toolArgs) val toolResult = try { @Suppress("UNCHECKED_CAST") @@ -135,7 +135,7 @@ internal class GenericAgentEnvironment( } catch (e: Exception) { logger.error(e) { "Tool \"${tool.name}\" failed to execute with arguments: ${content.toolArgs}" } - pipeline.onToolExecutionFailed(content.runId, content.toolCallId, tool, toolArgs, e) + pipeline.onToolCallFailed(content.runId, content.toolCallId, tool, toolArgs, e) return toolResult( message = "Tool \"${tool.name}\" failed to execute because of ${e.message}!", @@ -146,7 +146,7 @@ internal class GenericAgentEnvironment( ) } - pipeline.onToolExecutionCompleted(content.runId, content.toolCallId, tool, toolArgs, toolResult) + pipeline.onToolCallCompleted(content.runId, content.toolCallId, tool, toolArgs, toolResult) logger.trace { "Completed execution of ${content.toolName} with result: $toolResult" } diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/AIAgentPipeline.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/AIAgentPipeline.kt index 2592af660e..9a7da68432 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/AIAgentPipeline.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/AIAgentPipeline.kt @@ -41,13 +41,13 @@ import ai.koog.agents.core.feature.handler.streaming.LLMStreamingFrameReceivedCo import ai.koog.agents.core.feature.handler.streaming.LLMStreamingFrameReceivedHandler import ai.koog.agents.core.feature.handler.streaming.LLMStreamingStartingContext import ai.koog.agents.core.feature.handler.streaming.LLMStreamingStartingHandler +import ai.koog.agents.core.feature.handler.tool.ToolCallCompletedContext +import ai.koog.agents.core.feature.handler.tool.ToolCallEventHandler +import ai.koog.agents.core.feature.handler.tool.ToolCallFailedContext import ai.koog.agents.core.feature.handler.tool.ToolCallFailureHandler import ai.koog.agents.core.feature.handler.tool.ToolCallHandler import ai.koog.agents.core.feature.handler.tool.ToolCallResultHandler -import ai.koog.agents.core.feature.handler.tool.ToolExecutionCompletedContext -import ai.koog.agents.core.feature.handler.tool.ToolExecutionEventHandler -import ai.koog.agents.core.feature.handler.tool.ToolExecutionFailedContext -import ai.koog.agents.core.feature.handler.tool.ToolExecutionStartingContext +import ai.koog.agents.core.feature.handler.tool.ToolCallStartingContext import ai.koog.agents.core.feature.handler.tool.ToolValidationErrorHandler import ai.koog.agents.core.feature.handler.tool.ToolValidationFailedContext import ai.koog.agents.core.tools.Tool @@ -123,7 +123,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { * Map of tool execution handlers registered for different features. * Keys are feature storage keys, values are tool execution handlers. */ - protected val toolExecutionEventHandlers: MutableMap, ToolExecutionEventHandler> = mutableMapOf() + protected val toolCallEventHandlers: MutableMap, ToolCallEventHandler> = mutableMapOf() /** * Map of LLM execution handlers registered for different features. @@ -367,9 +367,9 @@ public abstract class AIAgentPipeline(public val clock: Clock) { * @param tool The tool that is being called * @param toolArgs The arguments provided to the tool */ - public suspend fun onToolExecutionStarting(runId: String, toolCallId: String?, tool: Tool<*, *>, toolArgs: Any?) { - val eventContext = ToolExecutionStartingContext(runId, toolCallId, tool, toolArgs) - toolExecutionEventHandlers.values.forEach { handler -> handler.toolCallHandler.handle(eventContext) } + public suspend fun onToolCallStarting(runId: String, toolCallId: String?, tool: Tool<*, *>, toolArgs: Any?) { + val eventContext = ToolCallStartingContext(runId, toolCallId, tool, toolArgs) + toolCallEventHandlers.values.forEach { handler -> handler.toolCallHandler.handle(eventContext) } } /** @@ -389,7 +389,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { ) { val eventContext = ToolValidationFailedContext(runId, toolCallId, tool, toolArgs, error) - toolExecutionEventHandlers.values.forEach { handler -> handler.toolValidationErrorHandler.handle(eventContext) } + toolCallEventHandlers.values.forEach { handler -> handler.toolValidationErrorHandler.handle(eventContext) } } /** @@ -400,15 +400,15 @@ public abstract class AIAgentPipeline(public val clock: Clock) { * @param toolArgs The arguments provided to the tool * @param throwable The exception that caused the failure */ - public suspend fun onToolExecutionFailed( + public suspend fun onToolCallFailed( runId: String, toolCallId: String?, tool: Tool<*, *>, toolArgs: Any?, throwable: Throwable ) { - val eventContext = ToolExecutionFailedContext(runId, toolCallId, tool, toolArgs, throwable) - toolExecutionEventHandlers.values.forEach { handler -> handler.toolCallFailureHandler.handle(eventContext) } + val eventContext = ToolCallFailedContext(runId, toolCallId, tool, toolArgs, throwable) + toolCallEventHandlers.values.forEach { handler -> handler.toolCallFailureHandler.handle(eventContext) } } /** @@ -419,15 +419,15 @@ public abstract class AIAgentPipeline(public val clock: Clock) { * @param toolArgs The arguments that were provided to the tool * @param result The result produced by the tool, or null if no result was produced */ - public suspend fun onToolExecutionCompleted( + public suspend fun onToolCallCompleted( runId: String, toolCallId: String?, tool: Tool<*, *>, toolArgs: Any?, result: Any? ) { - val eventContext = ToolExecutionCompletedContext(runId, toolCallId, tool, toolArgs, result) - toolExecutionEventHandlers.values.forEach { handler -> handler.toolCallResultHandler.handle(eventContext) } + val eventContext = ToolCallCompletedContext(runId, toolCallId, tool, toolArgs, result) + toolCallEventHandlers.values.forEach { handler -> handler.toolCallResultHandler.handle(eventContext) } } //endregion Trigger Tool Call Handlers @@ -867,16 +867,16 @@ public abstract class AIAgentPipeline(public val clock: Clock) { * * Example: * ``` - * pipeline.interceptToolExecutionStarting(interceptContext) { eventContext -> + * pipeline.interceptToolCallStarting(interceptContext) { eventContext -> * // Process or log the tool call * } * ``` */ - public fun interceptToolExecutionStarting( + public fun interceptToolCallStarting( interceptContext: InterceptContext, - handle: suspend TFeature.(eventContext: ToolExecutionStartingContext) -> Unit + handle: suspend TFeature.(eventContext: ToolCallStartingContext) -> Unit ) { - val handler = toolExecutionEventHandlers.getOrPut(interceptContext.feature.key) { ToolExecutionEventHandler() } + val handler = toolCallEventHandlers.getOrPut(interceptContext.feature.key) { ToolCallEventHandler() } handler.toolCallHandler = ToolCallHandler( function = createConditionalHandler(interceptContext, handle) ) @@ -899,7 +899,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { interceptContext: InterceptContext, handle: suspend TFeature.(eventContext: ToolValidationFailedContext) -> Unit ) { - val handler = toolExecutionEventHandlers.getOrPut(interceptContext.feature.key) { ToolExecutionEventHandler() } + val handler = toolCallEventHandlers.getOrPut(interceptContext.feature.key) { ToolCallEventHandler() } handler.toolValidationErrorHandler = ToolValidationErrorHandler( function = createConditionalHandler(interceptContext, handle) ) @@ -913,16 +913,16 @@ public abstract class AIAgentPipeline(public val clock: Clock) { * * Example: * ``` - * pipeline.interceptToolExecutionFailed(interceptContext) { eventContext -> + * pipeline.interceptToolCallFailed(interceptContext) { eventContext -> * // Handle the tool call failure here * } * ``` */ - public fun interceptToolExecutionFailed( + public fun interceptToolCallFailed( interceptContext: InterceptContext, - handle: suspend TFeature.(eventContext: ToolExecutionFailedContext) -> Unit + handle: suspend TFeature.(eventContext: ToolCallFailedContext) -> Unit ) { - val handler = toolExecutionEventHandlers.getOrPut(interceptContext.feature.key) { ToolExecutionEventHandler() } + val handler = toolCallEventHandlers.getOrPut(interceptContext.feature.key) { ToolCallEventHandler() } handler.toolCallFailureHandler = ToolCallFailureHandler( function = createConditionalHandler(interceptContext, handle) ) @@ -942,11 +942,11 @@ public abstract class AIAgentPipeline(public val clock: Clock) { * } * ``` */ - public fun interceptToolExecutionCompleted( + public fun interceptToolCallCompleted( interceptContext: InterceptContext, - handle: suspend TFeature.(eventContext: ToolExecutionCompletedContext) -> Unit + handle: suspend TFeature.(eventContext: ToolCallCompletedContext) -> Unit ) { - val handler = toolExecutionEventHandlers.getOrPut(interceptContext.feature.key) { ToolExecutionEventHandler() } + val handler = toolCallEventHandlers.getOrPut(interceptContext.feature.key) { ToolCallEventHandler() } handler.toolCallResultHandler = ToolCallResultHandler( function = createConditionalHandler(interceptContext, handle) ) @@ -1111,11 +1111,11 @@ public abstract class AIAgentPipeline(public val clock: Clock) { * Updates the tool call handler for the given feature key with a custom handler. */ @Deprecated( - message = "Please use interceptToolExecutionStarting instead. This method is deprecated and will be removed in the next release.", + message = "Please use interceptToolCallStarting instead. This method is deprecated and will be removed in the next release.", replaceWith = ReplaceWith( - expression = "interceptToolExecutionStarting(interceptContext, handle)", + expression = "interceptToolCallStarting(interceptContext, handle)", imports = arrayOf( - "ai.koog.agents.core.feature.handler.tool.ToolExecutionStartingContext" + "ai.koog.agents.core.feature.handler.tool.ToolCallStartingContext" ) ) ) @@ -1123,45 +1123,45 @@ public abstract class AIAgentPipeline(public val clock: Clock) { interceptContext: InterceptContext, handle: suspend TFeature.(eventContext: ai.koog.agents.core.feature.handler.ToolCallContext) -> Unit ) { - interceptToolExecutionStarting(interceptContext, handle) + interceptToolCallStarting(interceptContext, handle) } /** * Intercepts the result of a tool call with a custom handler for a specific feature. */ @Deprecated( - message = "Please use interceptToolExecutionCompleted instead. This method is deprecated and will be removed in the next release.", + message = "Please use interceptToolCallCompleted instead. This method is deprecated and will be removed in the next release.", replaceWith = ReplaceWith( - expression = "interceptToolExecutionCompleted(interceptContext, handle)", + expression = "interceptToolCallCompleted(interceptContext, handle)", imports = arrayOf( - "ai.koog.agents.core.feature.handler.tool.ToolExecutionCompletedContext" + "ai.koog.agents.core.feature.handler.tool.ToolCallCompletedContext" ) ) ) public fun interceptToolCallResult( interceptContext: InterceptContext, - handle: suspend TFeature.(eventContext: ToolExecutionCompletedContext) -> Unit + handle: suspend TFeature.(eventContext: ToolCallCompletedContext) -> Unit ) { - interceptToolExecutionCompleted(interceptContext, handle) + interceptToolCallCompleted(interceptContext, handle) } /** * Sets up an interception mechanism to handle tool call failures for a specific feature. */ @Deprecated( - message = "Please use interceptToolExecutionFailed instead. This method is deprecated and will be removed in the next release.", + message = "Please use interceptToolCallFailed instead. This method is deprecated and will be removed in the next release.", replaceWith = ReplaceWith( - expression = "interceptToolExecutionFailed(interceptContext, handle)", + expression = "interceptToolCallFailed(interceptContext, handle)", imports = arrayOf( - "ai.koog.agents.core.feature.handler.tool.ToolExecutionFailedContext" + "ai.koog.agents.core.feature.handler.tool.ToolCallFailedContext" ) ) ) public fun interceptToolCallFailure( interceptContext: InterceptContext, - handle: suspend TFeature.(eventContext: ToolExecutionFailedContext) -> Unit + handle: suspend TFeature.(eventContext: ToolCallFailedContext) -> Unit ) { - interceptToolExecutionFailed(interceptContext, handle) + interceptToolCallFailed(interceptContext, handle) } /** diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/AgentLifecycleEventType.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/AgentLifecycleEventType.kt index 444f57d9b6..6cb3ee0b36 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/AgentLifecycleEventType.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/AgentLifecycleEventType.kt @@ -89,7 +89,7 @@ public sealed interface AgentLifecycleEventType { /** * Represents an event triggered when a tool is called. */ - public object ToolExecutionStarting : AgentLifecycleEventType + public object ToolCallStarting : AgentLifecycleEventType /** * Represents an event triggered when a tool call fails validation. @@ -99,12 +99,12 @@ public sealed interface AgentLifecycleEventType { /** * Represents an event triggered when a tool call fails. */ - public object ToolExecutionFailed : AgentLifecycleEventType + public object ToolCallFailed : AgentLifecycleEventType /** * Represents an event triggered when a tool call succeeds. */ - public object ToolExecutionCompleted : AgentLifecycleEventType + public object ToolCallCompleted : AgentLifecycleEventType //endregion Tool diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/DeprecatedExecuteToolEventHandlerContext.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/DeprecatedExecuteToolEventHandlerContext.kt index 2607446f7a..79e90449be 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/DeprecatedExecuteToolEventHandlerContext.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/DeprecatedExecuteToolEventHandlerContext.kt @@ -4,25 +4,25 @@ package ai.koog.agents.core.feature.handler * Represents the context for handling tool-specific events within the framework. */ @Deprecated( - message = "Use ToolExecutionEventContext instead", + message = "Use ToolCallEventContext instead", replaceWith = ReplaceWith( - expression = "ToolExecutionEventContext", - imports = arrayOf("ai.koog.agents.core.feature.handler.tool.ToolExecutionEventContext") + expression = "ToolCallEventContext", + imports = arrayOf("ai.koog.agents.core.feature.handler.tool.ToolCallEventContext") ) ) -public typealias ToolEventHandlerContext = ai.koog.agents.core.feature.handler.tool.ToolExecutionEventContext +public typealias ToolEventHandlerContext = ai.koog.agents.core.feature.handler.tool.ToolCallEventContext /** * Represents the context for handling a tool call event. */ @Deprecated( - message = "Use ToolExecutionStartingContext instead", + message = "Use ToolCallStartingContext instead", replaceWith = ReplaceWith( - expression = "ToolExecutionStartingContext", - imports = arrayOf("ai.koog.agents.core.feature.handler.tool.ToolExecutionStartingContext") + expression = "ToolCallStartingContext", + imports = arrayOf("ai.koog.agents.core.feature.handler.tool.ToolCallStartingContext") ) ) -public typealias ToolCallContext = ai.koog.agents.core.feature.handler.tool.ToolExecutionStartingContext +public typealias ToolCallContext = ai.koog.agents.core.feature.handler.tool.ToolCallStartingContext /** * Represents the context for handling validation errors that occur during the execution of a tool. @@ -40,22 +40,22 @@ public typealias ToolValidationErrorContext = ai.koog.agents.core.feature.handle * Represents the context provided to handle a failure during the execution of a tool. */ @Deprecated( - message = "Use ToolExecutionFailedContext instead", + message = "Use ToolCallFailedContext instead", replaceWith = ReplaceWith( - expression = "ToolExecutionFailedContext", - imports = arrayOf("ai.koog.agents.core.feature.handler.tool.ToolExecutionFailedContext") + expression = "ToolCallFailedContext", + imports = arrayOf("ai.koog.agents.core.feature.handler.tool.ToolCallFailedContext") ) ) -public typealias ToolCallFailureContext = ai.koog.agents.core.feature.handler.tool.ToolExecutionFailedContext +public typealias ToolCallFailureContext = ai.koog.agents.core.feature.handler.tool.ToolCallFailedContext /** * Represents the context used when handling the result of a tool call. */ @Deprecated( - message = "Use ToolExecutionCompletedContext instead", + message = "Use ToolCallCompletedContext instead", replaceWith = ReplaceWith( - expression = "ToolExecutionCompletedContext", - imports = arrayOf("ai.koog.agents.core.feature.handler.tool.ToolExecutionCompletedContext") + expression = "ToolCallCompletedContext", + imports = arrayOf("ai.koog.agents.core.feature.handler.tool.ToolCallCompletedContext") ) ) -public typealias ToolCallResultContext = ai.koog.agents.core.feature.handler.tool.ToolExecutionCompletedContext +public typealias ToolCallResultContext = ai.koog.agents.core.feature.handler.tool.ToolCallCompletedContext diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolExecutionEventContext.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolCallEventContext.kt similarity index 84% rename from agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolExecutionEventContext.kt rename to agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolCallEventContext.kt index a6550b80fb..385661f77d 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolExecutionEventContext.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolCallEventContext.kt @@ -7,7 +7,7 @@ import ai.koog.agents.core.tools.Tool /** * Represents the context for handling tool-specific events within the framework. */ -public interface ToolExecutionEventContext : AgentLifecycleEventContext +public interface ToolCallEventContext : AgentLifecycleEventContext /** * Represents the context for handling a tool call event. @@ -15,13 +15,13 @@ public interface ToolExecutionEventContext : AgentLifecycleEventContext * @property tool The tool instance that is being executed. It encapsulates the logic and metadata for the operation. * @property toolArgs The arguments provided for the tool execution, adhering to the tool's expected input structure. */ -public data class ToolExecutionStartingContext( +public data class ToolCallStartingContext( val runId: String, val toolCallId: String?, val tool: Tool<*, *>, val toolArgs: Any? -) : ToolExecutionEventContext { - override val eventType: AgentLifecycleEventType = AgentLifecycleEventType.ToolExecutionStarting +) : ToolCallEventContext { + override val eventType: AgentLifecycleEventType = AgentLifecycleEventType.ToolCallStarting } /** @@ -37,7 +37,7 @@ public data class ToolValidationFailedContext( val tool: Tool<*, *>, val toolArgs: Any?, val error: String -) : ToolExecutionEventContext { +) : ToolCallEventContext { override val eventType: AgentLifecycleEventType = AgentLifecycleEventType.ToolValidationFailed } @@ -48,14 +48,14 @@ public data class ToolValidationFailedContext( * @param toolArgs The arguments that were passed to the tool during execution. * @param throwable The exception or error that caused the failure. */ -public data class ToolExecutionFailedContext( +public data class ToolCallFailedContext( val runId: String, val toolCallId: String?, val tool: Tool<*, *>, val toolArgs: Any?, val throwable: Throwable -) : ToolExecutionEventContext { - override val eventType: AgentLifecycleEventType = AgentLifecycleEventType.ToolExecutionFailed +) : ToolCallEventContext { + override val eventType: AgentLifecycleEventType = AgentLifecycleEventType.ToolCallFailed } /** @@ -65,12 +65,12 @@ public data class ToolExecutionFailedContext( * @param toolArgs The arguments required by the tool for execution. * @param result An optional result produced by the tool after execution can be null if not applicable. */ -public data class ToolExecutionCompletedContext( +public data class ToolCallCompletedContext( val runId: String, val toolCallId: String?, val tool: Tool<*, *>, val toolArgs: Any?, val result: Any? -) : ToolExecutionEventContext { - override val eventType: AgentLifecycleEventType = AgentLifecycleEventType.ToolExecutionCompleted +) : ToolCallEventContext { + override val eventType: AgentLifecycleEventType = AgentLifecycleEventType.ToolCallCompleted } diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolExecutionEventHandler.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolCallEventHandler.kt similarity index 94% rename from agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolExecutionEventHandler.kt rename to agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolCallEventHandler.kt index 55708d4509..082fd1a26c 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolExecutionEventHandler.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/tool/ToolCallEventHandler.kt @@ -7,7 +7,7 @@ package ai.koog.agents.core.feature.handler.tool * This class provides properties that allow defining specific behavior * during different stages of a tool's execution process. */ -public class ToolExecutionEventHandler { +public class ToolCallEventHandler { /** * A variable of type [ToolCallHandler] used to handle tool call operations. * It provides a mechanism for executing specific logic when a tool is called @@ -61,7 +61,7 @@ public fun interface ToolCallHandler { /** * Handles the execution of a given tool using the provided arguments. */ - public suspend fun handle(eventContext: ToolExecutionStartingContext) + public suspend fun handle(eventContext: ToolCallStartingContext) } /** @@ -84,7 +84,7 @@ public fun interface ToolCallFailureHandler { /** * Handles a failure that occurs during the execution of a tool call. */ - public suspend fun handle(eventContext: ToolExecutionFailedContext) + public suspend fun handle(eventContext: ToolCallFailedContext) } /** @@ -96,5 +96,5 @@ public fun interface ToolCallResultHandler { /** * Handles the execution of a specific tool by processing its arguments and optionally handling its result. */ - public suspend fun handle(eventContext: ToolExecutionCompletedContext) + public suspend fun handle(eventContext: ToolCallCompletedContext) } diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/llmCallEvents.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/llmCallEvents.kt index 3e749b20c8..578f8d5694 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/llmCallEvents.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/llmCallEvents.kt @@ -13,10 +13,11 @@ import kotlinx.serialization.Serializable * input prompt and any tools that will be used during the call. It extends the `DefinedFeatureEvent` class * and serves as a specific type of event in a feature-driven framework. * + * @property runId A unique identifier associated with the specific run of the LLM call. * @property prompt The input prompt encapsulated as a [Prompt] object. This represents the structured set of * messages and configuration parameters sent to the LLM. - * @property tools The list of tools, represented by their string identifiers, being used within the scope - * of the LLM call. These tools may extend or enhance the core functionality of the LLM. + * @property model The description of the LLM model used during the call. Use the format: 'llm_provider:model_id'; + * @property tools A list of tools used or invoked during the LLM call. * @property eventId A string representing the event type; * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @@ -38,8 +39,14 @@ public data class LLMCallStartingEvent( * The event is used within the system to capture relevant output data and ensure proper tracking * and logging of LLM-related interactions. * + * @property runId The unique identifier of the LLM run. + * @property prompt The input prompt encapsulated as a [Prompt] object. This represents the structured set of + * messages and configuration parameters sent to the LLM. + * @property model The description of the LLM model used during the call. Use the format: 'llm_provider:model_id'; * @property responses A list of responses generated by the LLM, represented as instances of [Message.Response]. * Each response contains content, metadata, and additional context about the interaction. + * @property moderationResponse The moderation response, if any, returned by the LLM. + * This is typically used to capture and track content moderation results. * @property eventId A string representing the event type; * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/nodeExecutionEvents.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/nodeExecutionEvents.kt index 6455b7850f..2930a51653 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/nodeExecutionEvents.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/nodeExecutionEvents.kt @@ -39,7 +39,7 @@ public data class NodeExecutionStartingEvent( * @property nodeName The name of the node that finished execution; * @property input The input data provided to the node; * @property output The output generated by the node; - * @property eventId A unique identifier for the event type; + * @property eventId A string representing the event type; * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @Serializable diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/toolExecutionEvents.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/toolExecutionEvents.kt index e8933972e3..41d0efa04c 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/toolExecutionEvents.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/model/events/toolExecutionEvents.kt @@ -19,12 +19,12 @@ import kotlinx.serialization.json.JsonObject * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @Serializable -public data class ToolExecutionStartingEvent( +public data class ToolCallStartingEvent( val runId: String, val toolCallId: String?, val toolName: String, val toolArgs: JsonObject, - override val eventId: String = ToolExecutionStartingEvent::class.simpleName!!, + override val eventId: String = ToolCallStartingEvent::class.simpleName!!, override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), ) : DefinedFeatureEvent() @@ -65,13 +65,13 @@ public data class ToolValidationFailedEvent( * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @Serializable -public data class ToolExecutionFailedEvent( +public data class ToolCallFailedEvent( val runId: String, val toolCallId: String?, val toolName: String, val toolArgs: JsonObject, val error: AIAgentError, - override val eventId: String = ToolExecutionFailedEvent::class.simpleName!!, + override val eventId: String = ToolCallFailedEvent::class.simpleName!!, override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), ) : DefinedFeatureEvent() @@ -89,23 +89,23 @@ public data class ToolExecutionFailedEvent( * @property timestamp The timestamp of the event, in milliseconds since the Unix epoch. */ @Serializable -public data class ToolExecutionCompletedEvent( +public data class ToolCallCompletedEvent( val runId: String, val toolCallId: String?, val toolName: String, val toolArgs: JsonObject, val result: String?, - override val eventId: String = ToolExecutionCompletedEvent::class.simpleName!!, + override val eventId: String = ToolCallCompletedEvent::class.simpleName!!, override val timestamp: Long = Clock.System.now().toEpochMilliseconds(), ) : DefinedFeatureEvent() //region Deprecated @Deprecated( - message = "Use ToolExecutionStartingEvent instead", - replaceWith = ReplaceWith("ToolExecutionStartingEvent") + message = "Use ToolCallStartingEvent instead", + replaceWith = ReplaceWith("ToolCallStartingEvent") ) -public typealias ToolCallEvent = ToolExecutionStartingEvent +public typealias ToolCallEvent = ToolCallStartingEvent @Deprecated( message = "Use ToolValidationFailedEvent instead", @@ -114,15 +114,15 @@ public typealias ToolCallEvent = ToolExecutionStartingEvent public typealias ToolValidationErrorEvent = ToolValidationFailedEvent @Deprecated( - message = "Use ToolExecutionFailedEvent instead", - replaceWith = ReplaceWith("ToolExecutionFailedEvent") + message = "Use ToolCallFailedEvent instead", + replaceWith = ReplaceWith("ToolCallFailedEvent") ) -public typealias ToolCallFailureEvent = ToolExecutionFailedEvent +public typealias ToolCallFailureEvent = ToolCallFailedEvent @Deprecated( - message = "Use ToolExecutionCompletedEvent instead", - replaceWith = ReplaceWith("ToolExecutionCompletedEvent") + message = "Use ToolCallCompletedEvent instead", + replaceWith = ReplaceWith("ToolCallCompletedEvent") ) -public typealias ToolCallResultEvent = ToolExecutionCompletedEvent +public typealias ToolCallResultEvent = ToolCallCompletedEvent //endregion Deprecated diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/remote/jsonConfig.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/remote/jsonConfig.kt index 21504e9055..0ebe4bd667 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/remote/jsonConfig.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/remote/jsonConfig.kt @@ -18,9 +18,9 @@ import ai.koog.agents.core.feature.model.events.NodeExecutionFailedEvent import ai.koog.agents.core.feature.model.events.NodeExecutionStartingEvent import ai.koog.agents.core.feature.model.events.StrategyCompletedEvent import ai.koog.agents.core.feature.model.events.StrategyStartingEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionFailedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionStartingEvent +import ai.koog.agents.core.feature.model.events.ToolCallCompletedEvent +import ai.koog.agents.core.feature.model.events.ToolCallFailedEvent +import ai.koog.agents.core.feature.model.events.ToolCallStartingEvent import ai.koog.agents.core.feature.model.events.ToolValidationFailedEvent import io.ktor.utils.io.InternalAPI import kotlinx.serialization.DeserializationStrategy @@ -75,10 +75,10 @@ public val defaultFeatureMessageJsonConfig: Json * - [StrategyCompletedEvent] - Fired when an AI agent strategy completes * - [NodeExecutionStartingEvent] - Fired when a node execution starts * - [NodeExecutionCompletedEvent] - Fired when a node execution ends - * - [ToolExecutionStartingEvent] - Fired when a tool is called + * - [ToolCallStartingEvent] - Fired when a tool is called * - [ToolValidationFailedEvent] - Fired when tool validation fails - * - [ToolExecutionFailedEvent] - Fired when a tool call fails - * - [ToolExecutionCompletedEvent] - Fired when a tool call returns a result + * - [ToolCallFailedEvent] - Fired when a tool call fails + * - [ToolCallCompletedEvent] - Fired when a tool call returns a result * - [LLMCallStartingEvent] - Fired before making an LLM call * - [LLMCallCompletedEvent] - Fired after completing an LLM call * @@ -101,10 +101,10 @@ public val defaultFeatureMessageSerializersModule: SerializersModule subclass(NodeExecutionStartingEvent::class, NodeExecutionStartingEvent.serializer()) subclass(NodeExecutionCompletedEvent::class, NodeExecutionCompletedEvent.serializer()) subclass(NodeExecutionFailedEvent::class, NodeExecutionFailedEvent.serializer()) - subclass(ToolExecutionStartingEvent::class, ToolExecutionStartingEvent.serializer()) + subclass(ToolCallStartingEvent::class, ToolCallStartingEvent.serializer()) subclass(ToolValidationFailedEvent::class, ToolValidationFailedEvent.serializer()) - subclass(ToolExecutionFailedEvent::class, ToolExecutionFailedEvent.serializer()) - subclass(ToolExecutionCompletedEvent::class, ToolExecutionCompletedEvent.serializer()) + subclass(ToolCallFailedEvent::class, ToolCallFailedEvent.serializer()) + subclass(ToolCallCompletedEvent::class, ToolCallCompletedEvent.serializer()) subclass(LLMCallStartingEvent::class, LLMCallStartingEvent.serializer()) subclass(LLMCallCompletedEvent::class, LLMCallCompletedEvent.serializer()) } @@ -121,10 +121,10 @@ public val defaultFeatureMessageSerializersModule: SerializersModule subclass(NodeExecutionStartingEvent::class, NodeExecutionStartingEvent.serializer()) subclass(NodeExecutionCompletedEvent::class, NodeExecutionCompletedEvent.serializer()) subclass(NodeExecutionFailedEvent::class, NodeExecutionFailedEvent.serializer()) - subclass(ToolExecutionStartingEvent::class, ToolExecutionStartingEvent.serializer()) + subclass(ToolCallStartingEvent::class, ToolCallStartingEvent.serializer()) subclass(ToolValidationFailedEvent::class, ToolValidationFailedEvent.serializer()) - subclass(ToolExecutionFailedEvent::class, ToolExecutionFailedEvent.serializer()) - subclass(ToolExecutionCompletedEvent::class, ToolExecutionCompletedEvent.serializer()) + subclass(ToolCallFailedEvent::class, ToolCallFailedEvent.serializer()) + subclass(ToolCallCompletedEvent::class, ToolCallCompletedEvent.serializer()) subclass(LLMCallStartingEvent::class, LLMCallStartingEvent.serializer()) subclass(LLMCallCompletedEvent::class, LLMCallCompletedEvent.serializer()) } @@ -140,10 +140,10 @@ public val defaultFeatureMessageSerializersModule: SerializersModule subclass(NodeExecutionStartingEvent::class, NodeExecutionStartingEvent.serializer()) subclass(NodeExecutionCompletedEvent::class, NodeExecutionCompletedEvent.serializer()) subclass(NodeExecutionFailedEvent::class, NodeExecutionFailedEvent.serializer()) - subclass(ToolExecutionStartingEvent::class, ToolExecutionStartingEvent.serializer()) + subclass(ToolCallStartingEvent::class, ToolCallStartingEvent.serializer()) subclass(ToolValidationFailedEvent::class, ToolValidationFailedEvent.serializer()) - subclass(ToolExecutionFailedEvent::class, ToolExecutionFailedEvent.serializer()) - subclass(ToolExecutionCompletedEvent::class, ToolExecutionCompletedEvent.serializer()) + subclass(ToolCallFailedEvent::class, ToolCallFailedEvent.serializer()) + subclass(ToolCallCompletedEvent::class, ToolCallCompletedEvent.serializer()) subclass(LLMCallStartingEvent::class, LLMCallStartingEvent.serializer()) subclass(LLMCallCompletedEvent::class, LLMCallCompletedEvent.serializer()) } diff --git a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/FunctionalAIAgentTest.kt b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/FunctionalAIAgentTest.kt index 29324e9767..c0729e274b 100644 --- a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/FunctionalAIAgentTest.kt +++ b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/FunctionalAIAgentTest.kt @@ -52,7 +52,7 @@ class FunctionalAIAgentTest { toolRegistry = testToolRegistry ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } @@ -92,7 +92,7 @@ class FunctionalAIAgentTest { toolRegistry = testToolRegistry, ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } @@ -134,7 +134,7 @@ class FunctionalAIAgentTest { } ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } diff --git a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/SingleRunStrategyTests.kt b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/SingleRunStrategyTests.kt index ef525ae3cb..d4de2bd1a0 100644 --- a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/SingleRunStrategyTests.kt +++ b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/SingleRunStrategyTests.kt @@ -32,7 +32,7 @@ class SingleRunStrategyTests { toolRegistry = testToolRegistry ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } @@ -63,7 +63,7 @@ class SingleRunStrategyTests { toolRegistry = testToolRegistry ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } @@ -93,7 +93,7 @@ class SingleRunStrategyTests { toolRegistry = testToolRegistry ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } @@ -123,7 +123,7 @@ class SingleRunStrategyTests { toolRegistry = testToolRegistry ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } @@ -162,7 +162,7 @@ class SingleRunStrategyTests { toolRegistry = testToolRegistry ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } @@ -201,7 +201,7 @@ class SingleRunStrategyTests { toolRegistry = testToolRegistry ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } @@ -241,7 +241,7 @@ class SingleRunStrategyTests { toolRegistry = testToolRegistry ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } @@ -281,7 +281,7 @@ class SingleRunStrategyTests { toolRegistry = testToolRegistry ) { install(EventHandler) { - onToolCall { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } + onToolCallStarting { eventContext -> actualToolCalls += eventContext.toolArgs.toString() } } } diff --git a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodesHistoryCompressionTest.kt b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodesHistoryCompressionTest.kt index bdfa27e634..b32b8e63af 100644 --- a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodesHistoryCompressionTest.kt +++ b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodesHistoryCompressionTest.kt @@ -8,6 +8,7 @@ import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.features.eventHandler.feature.EventHandler import ai.koog.agents.features.eventHandler.feature.handleEvents import ai.koog.agents.testing.tools.DummyTool +import ai.koog.agents.utils.use import ai.koog.prompt.dsl.prompt import ai.koog.prompt.llm.OllamaModels import ai.koog.prompt.message.Message @@ -55,7 +56,7 @@ class AIAgentNodesHistoryCompressionTest { maxAgentIterations = 10 ) - val runner = AIAgent( + AIAgent( promptExecutor = testExecutor, strategy = agentStrategy, agentConfig = agentConfig, @@ -64,12 +65,12 @@ class AIAgentNodesHistoryCompressionTest { } ) { install(EventHandler) { - onAgentFinished { eventContext -> results += eventContext.result } + onAgentCompleted { eventContext -> results += eventContext.result } } + }.use { agent -> + agent.run("") } - runner.run("") - // After compression, we should have one result assertEquals(1, results.size) assertEquals("Done", results.first()) @@ -107,7 +108,7 @@ class AIAgentNodesHistoryCompressionTest { maxAgentIterations = 10 ) - val runner = AIAgent( + AIAgent( promptExecutor = testExecutor, strategy = agentStrategy, agentConfig = agentConfig, @@ -116,12 +117,12 @@ class AIAgentNodesHistoryCompressionTest { } ) { install(EventHandler) { - onAgentFinished { eventContext -> results += eventContext.result } + onAgentCompleted { eventContext -> results += eventContext.result } } + }.use { agent -> + agent.run("") } - runner.run("") - // After compression, we should have one result assertEquals(1, results.size) assertEquals("Done", results.first()) @@ -162,7 +163,7 @@ class AIAgentNodesHistoryCompressionTest { maxAgentIterations = 10 ) - val runner = AIAgent( + AIAgent( promptExecutor = testExecutor, strategy = agentStrategy, agentConfig = agentConfig, @@ -171,12 +172,12 @@ class AIAgentNodesHistoryCompressionTest { } ) { handleEvents { - onAgentFinished { eventContext -> results += eventContext.result } + onAgentCompleted { eventContext -> results += eventContext.result } } + }.use { agent -> + agent.run("") } - runner.run("") - // After compression, we should have one result assertEquals(1, results.size) assertEquals("Done", results.first()) diff --git a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodesTest.kt b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodesTest.kt index eecf5baf70..76787d3a36 100644 --- a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodesTest.kt +++ b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/dsl/extension/AIAgentNodesTest.kt @@ -9,6 +9,7 @@ import ai.koog.agents.features.eventHandler.feature.EventHandler import ai.koog.agents.testing.tools.DummyTool import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.agents.testing.tools.mockLLMAnswer +import ai.koog.agents.utils.use import ai.koog.prompt.dsl.Prompt import ai.koog.prompt.dsl.prompt import ai.koog.prompt.executor.clients.openai.OpenAIModels @@ -49,7 +50,7 @@ class AIAgentNodesTest { mockLLMAnswer("Default test response").asDefaultResponse } - val runner = AIAgent( + AIAgent( promptExecutor = testExecutor, strategy = agentStrategy, agentConfig = agentConfig, @@ -58,12 +59,12 @@ class AIAgentNodesTest { } ) { install(EventHandler) { - onAgentFinished { eventContext -> results += eventContext.result } + onAgentCompleted { eventContext -> results += eventContext.result } } + }.use { agent -> + agent.run("") } - runner.run("") - // After compression, we should have one result assertEquals(1, results.size) assertEquals("Done", results.first()) @@ -110,7 +111,7 @@ class AIAgentNodesTest { maxAgentIterations = 10 ) - val runner = AIAgent( + AIAgent( promptExecutor = modelCapturingExecutor, strategy = agentStrategy, agentConfig = agentConfig, @@ -119,29 +120,30 @@ class AIAgentNodesTest { } ) { install(EventHandler) { - onAgentFinished { eventContext -> + onAgentCompleted { eventContext -> executionEvents += "Agent finished" results += eventContext.result } } - } + }.use { agent -> - val executionResult = runner.run("Heeeey") + val executionResult = agent.run("Heeeey") - assertEquals("Done", executionResult, "Agent execution should return 'Done'") - assertEquals(1, results.size, "Should have exactly one result") + assertEquals("Done", executionResult, "Agent execution should return 'Done'") + assertEquals(1, results.size, "Should have exactly one result") - assertTrue(executionEvents.contains("nodeStart -> compress"), "Should transition from start to compress") - assertTrue(executionEvents.contains("compress -> nodeFinish"), "Should transition from compress to finish") + assertTrue(executionEvents.contains("nodeStart -> compress"), "Should transition from start to compress") + assertTrue(executionEvents.contains("compress -> nodeFinish"), "Should transition from compress to finish") - assertTrue( - agentConfig.prompt.messages.any { it.content.contains("testing history compression") }, - "Prompt should contain test content for compression" - ) - assertTrue( - executionEvents.size >= 3, - "Should have at least 3 execution events (agent finished, node transitions)" - ) + assertTrue( + agentConfig.prompt.messages.any { it.content.contains("testing history compression") }, + "Prompt should contain test content for compression" + ) + assertTrue( + executionEvents.size >= 3, + "Should have at least 3 execution events (agent finished, node transitions)" + ) + } } @Test diff --git a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/feature/TestFeature.kt b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/feature/TestFeature.kt index 8a9eba1e3e..173c75fb90 100644 --- a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/feature/TestFeature.kt +++ b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/feature/TestFeature.kt @@ -70,11 +70,11 @@ class TestFeature(val events: MutableList, val runIds: MutableList + pipeline.interceptToolCallStarting(context) { event -> feature.events += "Tool: call tool (tool: ${event.tool.name}, args: ${event.toolArgs})" } - pipeline.interceptToolExecutionCompleted(context) { event -> + pipeline.interceptToolCallCompleted(context) { event -> feature.events += "Tool: finish tool call with result (tool: ${event.tool.name}, result: ${event.result?.let(event.tool::encodeResultToStringUnsafe) ?: "null"})" } diff --git a/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/StructuredOutputWithToolsIntegrationTest.kt b/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/StructuredOutputWithToolsIntegrationTest.kt index c007ece354..3968646e22 100644 --- a/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/StructuredOutputWithToolsIntegrationTest.kt +++ b/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/StructuredOutputWithToolsIntegrationTest.kt @@ -3,10 +3,6 @@ package ai.koog.agents.ext.agent import ai.koog.agents.core.agent.AIAgent import ai.koog.agents.core.agent.config.AIAgentConfig import ai.koog.agents.core.tools.SimpleTool -import ai.koog.agents.core.tools.ToolArgs -import ai.koog.agents.core.tools.ToolDescriptor -import ai.koog.agents.core.tools.ToolParameterDescriptor -import ai.koog.agents.core.tools.ToolParameterType import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.core.tools.annotations.LLMDescription import ai.koog.agents.features.eventHandler.feature.EventHandler @@ -46,18 +42,16 @@ class StructuredOutputWithToolsIntegrationTest { object GetTemperatureTool : SimpleTool() { @Serializable - data class Args(val city: String, val country: String) : ToolArgs + data class Args( + @property:LLMDescription("City name") + val city: String, + @property:LLMDescription("Country name") + val country: String + ) override val argsSerializer: KSerializer = Args.serializer() - - override val descriptor: ToolDescriptor = ToolDescriptor( - name = "get_temperature", - description = "Get current temperature for a city", - requiredParameters = listOf( - ToolParameterDescriptor("city", "City name", ToolParameterType.String), - ToolParameterDescriptor("country", "Country name", ToolParameterType.String) - ) - ) + override val name: String = "get_temperature" + override val description: String = "Get current temperature for a city" override suspend fun doExecute(args: Args): String = "Temperature in ${args.city}, ${args.country}: 22°C" @@ -65,18 +59,16 @@ class StructuredOutputWithToolsIntegrationTest { object GetWeatherConditionsTool : SimpleTool() { @Serializable - data class Args(val city: String, val country: String) : ToolArgs + data class Args( + @property:LLMDescription("City name") + val city: String, + @property:LLMDescription("Country name") + val country: String + ) override val argsSerializer: KSerializer = Args.serializer() - - override val descriptor: ToolDescriptor = ToolDescriptor( - name = "get_weather_conditions", - description = "Get current weather conditions for a city", - requiredParameters = listOf( - ToolParameterDescriptor("city", "City name", ToolParameterType.String), - ToolParameterDescriptor("country", "Country name", ToolParameterType.String) - ) - ) + override val name: String = "get_weather_conditions" + override val description: String = "Get current weather conditions for a city" override suspend fun doExecute(args: Args): String = "Weather conditions in ${args.city}, ${args.country}: Partly Cloudy" @@ -84,18 +76,16 @@ class StructuredOutputWithToolsIntegrationTest { object GetWindSpeedTool : SimpleTool() { @Serializable - data class Args(val city: String, val country: String) : ToolArgs + data class Args( + @property:LLMDescription("City name") + val city: String, + @property:LLMDescription("Country name") + val country: String + ) override val argsSerializer: KSerializer = Args.serializer() - - override val descriptor: ToolDescriptor = ToolDescriptor( - name = "get_wind_speed", - description = "Get current wind speed for a city", - requiredParameters = listOf( - ToolParameterDescriptor("city", "City name", ToolParameterType.String), - ToolParameterDescriptor("country", "Country name", ToolParameterType.String) - ) - ) + override val name: String = "get_wind_speed" + override val description: String = "Get current wind speed for a city" override suspend fun doExecute(args: Args): String = "Wind speed in ${args.city}, ${args.country}: 15.5 km/h" @@ -103,18 +93,16 @@ class StructuredOutputWithToolsIntegrationTest { object GetHumidityTool : SimpleTool() { @Serializable - data class Args(val city: String, val country: String) : ToolArgs + data class Args( + @property:LLMDescription("City name") + val city: String, + @property:LLMDescription("Country name") + val country: String + ) override val argsSerializer: KSerializer = Args.serializer() - - override val descriptor: ToolDescriptor = ToolDescriptor( - name = "get_humidity", - description = "Get current humidity for a city", - requiredParameters = listOf( - ToolParameterDescriptor("city", "City name", ToolParameterType.String), - ToolParameterDescriptor("country", "Country name", ToolParameterType.String) - ) - ) + override val name: String = "get_humidity" + override val description: String = "Get current humidity for a city" override suspend fun doExecute(args: Args): String = "Humidity in ${args.city}, ${args.country}: 65%" @@ -177,10 +165,10 @@ class StructuredOutputWithToolsIntegrationTest { } ) { install(EventHandler) { - onToolCall { eventContext -> + onToolCallStarting { eventContext -> toolCallEvents.add(eventContext.tool.name) } - onAgentFinished { eventContext -> + onAgentCompleted { eventContext -> eventContext.result?.let { results.add(it as WeatherResponse) } } } @@ -252,7 +240,7 @@ class StructuredOutputWithToolsIntegrationTest { } ) { install(EventHandler) { - onToolCall { eventContext -> + onToolCallStarting { eventContext -> toolCallTimestamps[eventContext.tool.name] = currentTime } } diff --git a/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/SubgraphWithRetryTest.kt b/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/SubgraphWithRetryTest.kt index a82b26c50e..e13b41cc45 100644 --- a/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/SubgraphWithRetryTest.kt +++ b/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/SubgraphWithRetryTest.kt @@ -8,6 +8,7 @@ import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.features.eventHandler.feature.EventHandler import ai.koog.agents.testing.tools.DummyTool import ai.koog.agents.testing.tools.getMockExecutor +import ai.koog.agents.utils.use import ai.koog.prompt.dsl.prompt import ai.koog.prompt.llm.OllamaModels import ai.koog.prompt.message.Message @@ -119,7 +120,7 @@ class SubgraphWithRetryTest { maxAgentIterations = MAX_AGENT_ITERATIONS, ) - val agent = AIAgent( + AIAgent( promptExecutor = getMockExecutor {}, strategy = testStrategy, agentConfig = agentConfig, @@ -128,12 +129,12 @@ class SubgraphWithRetryTest { }, ) { install(EventHandler) { - onAgentFinished { eventContext -> results += eventContext.result } + onAgentCompleted { eventContext -> results += eventContext.result } } + }.use { agent -> + agent.run("test input") } - agent.run("test input") - assertEquals(1, results.size) val result = results.first() as RetrySubgraphResult<*> assertEquals(SUCCESS, result.output) @@ -176,7 +177,7 @@ class SubgraphWithRetryTest { maxAgentIterations = MAX_AGENT_ITERATIONS, ) - val agent = AIAgent( + AIAgent( promptExecutor = getMockExecutor {}, strategy = testStrategy, agentConfig = agentConfig, @@ -185,12 +186,12 @@ class SubgraphWithRetryTest { }, ) { install(EventHandler) { - onAgentFinished { eventContext -> results += eventContext.result } + onAgentCompleted { eventContext -> results += eventContext.result } } + }.use { agent -> + agent.run("test input") } - agent.run("test input") - assertEquals(1, results.size) val result = results.first() as RetrySubgraphResult<*> assertEquals(SUCCESS, result.output) @@ -229,7 +230,7 @@ class SubgraphWithRetryTest { maxAgentIterations = MAX_AGENT_ITERATIONS, ) - val agent = AIAgent( + AIAgent( promptExecutor = getMockExecutor {}, strategy = testStrategy, agentConfig = agentConfig, @@ -238,17 +239,19 @@ class SubgraphWithRetryTest { }, ) { install(EventHandler) { - onAgentFinished { eventContext -> results += eventContext.result } + onAgentCompleted { eventContext -> results += eventContext.result } } - } + }.use { agent -> - agent.run("test input") + agent.run("test input") - assertEquals(1, results.size) - val result = results.first() as RetrySubgraphResult<*> - assertEquals(maxAttempts, result.retryCount) - assertEquals(maxAttempts, attemptCount.size) - assertFalse(result.success) + assertEquals(1, results.size) + + val result = results.first() as RetrySubgraphResult<*> + assertEquals(maxAttempts, result.retryCount) + assertEquals(maxAttempts, attemptCount.size) + assertFalse(result.success) + } } @Test @@ -293,7 +296,7 @@ class SubgraphWithRetryTest { maxAgentIterations = MAX_AGENT_ITERATIONS, ) - val agent = AIAgent( + AIAgent( promptExecutor = getMockExecutor {}, strategy = testStrategy, agentConfig = agentConfig, @@ -302,12 +305,12 @@ class SubgraphWithRetryTest { }, ) { install(EventHandler) { - onAgentFinished { eventContext -> results += eventContext.result } + onAgentCompleted { eventContext -> results += eventContext.result } } + }.use { agent -> + agent.run("test input") } - agent.run("test input") - assertEquals(1, results.size) assertEquals(SUCCESS, results.first()) } @@ -352,7 +355,7 @@ class SubgraphWithRetryTest { }, ) { install(EventHandler) { - onAgentFinished { eventContext -> results += eventContext.result } + onAgentCompleted { eventContext -> results += eventContext.result } } } @@ -394,7 +397,7 @@ class SubgraphWithRetryTest { maxAgentIterations = MAX_AGENT_ITERATIONS, ) - val agent = AIAgent( + AIAgent( promptExecutor = getMockExecutor {}, strategy = testStrategy, agentConfig = agentConfig, @@ -403,15 +406,16 @@ class SubgraphWithRetryTest { }, ) { install(EventHandler) { - onAgentFinished { eventContext -> results += eventContext.result } + onAgentCompleted { eventContext -> results += eventContext.result } } - } + }.use { agent -> - agent.run("test input") + agent.run("test input") - assertEquals(1, results.size) - assertEquals("failure", results.first()) - assertEquals(maxAttempts, attemptCount.size) + assertEquals(1, results.size) + assertEquals("failure", results.first()) + assertEquals(maxAttempts, attemptCount.size) + } } @Test diff --git a/agents/agents-features/Module.md b/agents/agents-features/Module.md index 055f5de60e..10879d5e36 100644 --- a/agents/agents-features/Module.md +++ b/agents/agents-features/Module.md @@ -28,11 +28,11 @@ Features integrate with the agent pipeline via interceptor hooks and consume sta - LLMCallStartingEvent - LLMCallCompletedEvent -- Tool execution events: - - ToolExecutionStartingEvent +- Tool call events: + - ToolCallStartingEvent - ToolValidationFailedEvent - - ToolExecutionFailedEvent - - ToolExecutionCompletedEvent + - ToolCallFailedEvent + - ToolCallCompletedEvent These events are produced by features such as Tracing and Debugger to enable logging, tracing, monitoring, and remote inspection. diff --git a/agents/agents-features/agents-features-debugger/src/commonMain/kotlin/ai/koog/agents/features/debugger/feature/Debugger.kt b/agents/agents-features/agents-features-debugger/src/commonMain/kotlin/ai/koog/agents/features/debugger/feature/Debugger.kt index 05219a8a11..333553a3a4 100644 --- a/agents/agents-features/agents-features-debugger/src/commonMain/kotlin/ai/koog/agents/features/debugger/feature/Debugger.kt +++ b/agents/agents-features/agents-features-debugger/src/commonMain/kotlin/ai/koog/agents/features/debugger/feature/Debugger.kt @@ -16,9 +16,9 @@ import ai.koog.agents.core.feature.model.events.LLMCallStartingEvent import ai.koog.agents.core.feature.model.events.NodeExecutionCompletedEvent import ai.koog.agents.core.feature.model.events.NodeExecutionStartingEvent import ai.koog.agents.core.feature.model.events.StrategyCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionFailedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionStartingEvent +import ai.koog.agents.core.feature.model.events.ToolCallCompletedEvent +import ai.koog.agents.core.feature.model.events.ToolCallFailedEvent +import ai.koog.agents.core.feature.model.events.ToolCallStartingEvent import ai.koog.agents.core.feature.model.events.ToolValidationFailedEvent import ai.koog.agents.core.feature.model.events.startNodeToGraph import ai.koog.agents.core.feature.model.toAgentError @@ -209,11 +209,11 @@ public class Debugger { //region Intercept Tool Call Events - pipeline.interceptToolExecutionStarting(interceptContext) intercept@{ eventContext -> + pipeline.interceptToolCallStarting(interceptContext) intercept@{ eventContext -> @Suppress("UNCHECKED_CAST") val tool = eventContext.tool as Tool - val event = ToolExecutionStartingEvent( + val event = ToolCallStartingEvent( runId = eventContext.runId, toolCallId = eventContext.toolCallId, toolName = eventContext.tool.name, @@ -238,11 +238,11 @@ public class Debugger { writer.onMessage(event) } - pipeline.interceptToolExecutionFailed(interceptContext) intercept@{ eventContext -> + pipeline.interceptToolCallFailed(interceptContext) intercept@{ eventContext -> @Suppress("UNCHECKED_CAST") val tool = eventContext.tool as Tool - val event = ToolExecutionFailedEvent( + val event = ToolCallFailedEvent( runId = eventContext.runId, toolCallId = eventContext.toolCallId, toolName = tool.name, @@ -253,11 +253,11 @@ public class Debugger { writer.onMessage(event) } - pipeline.interceptToolExecutionCompleted(interceptContext) intercept@{ eventContext -> + pipeline.interceptToolCallCompleted(interceptContext) intercept@{ eventContext -> @Suppress("UNCHECKED_CAST") val tool = eventContext.tool as Tool - val event = ToolExecutionCompletedEvent( + val event = ToolCallCompletedEvent( runId = eventContext.runId, toolCallId = eventContext.toolCallId, toolName = eventContext.tool.name, diff --git a/agents/agents-features/agents-features-debugger/src/jvmTest/kotlin/ai/koog/agents/features/debugger/feature/DebuggerTest.kt b/agents/agents-features/agents-features-debugger/src/jvmTest/kotlin/ai/koog/agents/features/debugger/feature/DebuggerTest.kt index 4aea8eaa30..14d003ddf5 100644 --- a/agents/agents-features/agents-features-debugger/src/jvmTest/kotlin/ai/koog/agents/features/debugger/feature/DebuggerTest.kt +++ b/agents/agents-features/agents-features-debugger/src/jvmTest/kotlin/ai/koog/agents/features/debugger/feature/DebuggerTest.kt @@ -18,8 +18,8 @@ import ai.koog.agents.core.feature.model.events.StrategyCompletedEvent import ai.koog.agents.core.feature.model.events.StrategyEventGraph import ai.koog.agents.core.feature.model.events.StrategyEventGraphEdge import ai.koog.agents.core.feature.model.events.StrategyEventGraphNode -import ai.koog.agents.core.feature.model.events.ToolExecutionCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionStartingEvent +import ai.koog.agents.core.feature.model.events.ToolCallCompletedEvent +import ai.koog.agents.core.feature.model.events.ToolCallStartingEvent import ai.koog.agents.core.feature.remote.client.FeatureMessageRemoteClient import ai.koog.agents.core.feature.remote.client.config.DefaultClientConnectionConfig import ai.koog.agents.core.feature.remote.server.config.DefaultServerConnectionConfig @@ -288,14 +288,14 @@ class DebuggerTest { input = toolCallMessage(dummyTool.name, content = """{"dummy":"test"}""").toString(), timestamp = testClock.now().toEpochMilliseconds() ), - ToolExecutionStartingEvent( + ToolCallStartingEvent( runId = clientEventsCollector.runId, toolCallId = "0", toolName = dummyTool.name, toolArgs = dummyTool.encodeArgs(DummyTool.Args("test")), timestamp = testClock.now().toEpochMilliseconds() ), - ToolExecutionCompletedEvent( + ToolCallCompletedEvent( runId = clientEventsCollector.runId, toolCallId = "0", toolName = dummyTool.name, diff --git a/agents/agents-features/agents-features-event-handler/Module.md b/agents/agents-features/agents-features-event-handler/Module.md index f8f5389fec..3b7561e9b9 100644 --- a/agents/agents-features/agents-features-event-handler/Module.md +++ b/agents/agents-features/agents-features-event-handler/Module.md @@ -56,7 +56,7 @@ val testAgent = AIAgent( var agentFinished = false handleEvents { - onToolExecutionStarting { eventContext -> + onToolCallStarting { eventContext -> toolCalled = true println("[DEBUG_LOG] Tool called: ${eventContext.tool.name}") } @@ -95,15 +95,15 @@ val agent = AIAgent( } // Monitor tool usage - onToolExecutionStarting { eventContext -> + onToolCallStarting { eventContext -> println("Tool called: ${eventContext.tool.name} with args: ${eventContext.toolArgs}") } - onToolExecutionCompleted { eventContext -> + onToolCallCompleted { eventContext -> println("Tool result: ${eventContext.result}") } - onToolExecutionFailed { eventContext -> + onToolCallFailed { eventContext -> println("Tool failed: ${eventContext.throwable.message}") } diff --git a/agents/agents-features/agents-features-event-handler/src/commonMain/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandler.kt b/agents/agents-features/agents-features-event-handler/src/commonMain/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandler.kt index 2706f13159..c25cb528d9 100644 --- a/agents/agents-features/agents-features-event-handler/src/commonMain/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandler.kt +++ b/agents/agents-features/agents-features-event-handler/src/commonMain/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandler.kt @@ -17,9 +17,9 @@ import ai.koog.agents.core.feature.handler.node.NodeExecutionStartingContext import ai.koog.agents.core.feature.handler.streaming.LLMStreamingCompletedContext import ai.koog.agents.core.feature.handler.streaming.LLMStreamingFrameReceivedContext import ai.koog.agents.core.feature.handler.streaming.LLMStreamingStartingContext -import ai.koog.agents.core.feature.handler.tool.ToolExecutionCompletedContext -import ai.koog.agents.core.feature.handler.tool.ToolExecutionFailedContext -import ai.koog.agents.core.feature.handler.tool.ToolExecutionStartingContext +import ai.koog.agents.core.feature.handler.tool.ToolCallCompletedContext +import ai.koog.agents.core.feature.handler.tool.ToolCallFailedContext +import ai.koog.agents.core.feature.handler.tool.ToolCallStartingContext import ai.koog.agents.core.feature.handler.tool.ToolValidationFailedContext import io.github.oshai.kotlinlogging.KotlinLogging @@ -33,7 +33,7 @@ import io.github.oshai.kotlinlogging.KotlinLogging * Example usage: * ``` * handleEvents { - * onToolExecutionStarting { eventContext -> + * onToolCallStarting { eventContext -> * println("Tool called: ${eventContext.tool.name} with args ${eventContext.toolArgs}") * } * @@ -58,7 +58,7 @@ public class EventHandler { * Example usage: * ``` * handleEvents { - * onToolExecutionStarting { eventContext -> + * onToolCallStarting { eventContext -> * println("Tool called: ${eventContext.tool.name} with args: ${eventContext.toolArgs}") * } * @@ -145,8 +145,8 @@ public class EventHandler { config.invokeOnLLMCallCompleted(eventContext) } - pipeline.interceptToolExecutionStarting(interceptContext) intercept@{ eventContext: ToolExecutionStartingContext -> - config.invokeOnToolExecutionStarting(eventContext) + pipeline.interceptToolCallStarting(interceptContext) intercept@{ eventContext: ToolCallStartingContext -> + config.invokeOnToolCallStarting(eventContext) } pipeline.interceptToolValidationFailed( @@ -155,16 +155,16 @@ public class EventHandler { config.invokeOnToolValidationFailed(eventContext) } - pipeline.interceptToolExecutionFailed(interceptContext) intercept@{ eventContext: ToolExecutionFailedContext -> - config.invokeOnToolExecutionFailed(eventContext) + pipeline.interceptToolCallFailed(interceptContext) intercept@{ eventContext: ToolCallFailedContext -> + config.invokeOnToolCallFailed(eventContext) } - pipeline.interceptToolExecutionCompleted(interceptContext) intercept@{ eventContext: ToolExecutionCompletedContext -> - config.invokeOnToolExecutionCompleted(eventContext) + pipeline.interceptToolCallCompleted(interceptContext) intercept@{ eventContext: ToolCallCompletedContext -> + config.invokeOnToolCallCompleted(eventContext) } pipeline.interceptLLMStreamingStarting(interceptContext) intercept@{ eventContext: LLMStreamingStartingContext -> - config.invokeOnLLMStreammingStarting(eventContext) + config.invokeOnLLMStreamingStarting(eventContext) } pipeline.interceptLLMStreamingFrameReceived(interceptContext) intercept@{ eventContext: LLMStreamingFrameReceivedContext -> @@ -205,7 +205,7 @@ public class EventHandler { * ``` * handleEvents { * // Log when tools are called - * onToolExecutionStarting { eventContext -> + * onToolCallStarting { eventContext -> * println("Tool called: ${eventContext.tool.name} with args: ${eventContext.toolArgs}") * } * diff --git a/agents/agents-features/agents-features-event-handler/src/commonMain/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerConfig.kt b/agents/agents-features/agents-features-event-handler/src/commonMain/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerConfig.kt index 1102b72b82..d6274ea509 100644 --- a/agents/agents-features/agents-features-event-handler/src/commonMain/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerConfig.kt +++ b/agents/agents-features/agents-features-event-handler/src/commonMain/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerConfig.kt @@ -31,9 +31,9 @@ import ai.koog.agents.core.feature.handler.streaming.LLMStreamingCompletedContex import ai.koog.agents.core.feature.handler.streaming.LLMStreamingFailedContext import ai.koog.agents.core.feature.handler.streaming.LLMStreamingFrameReceivedContext import ai.koog.agents.core.feature.handler.streaming.LLMStreamingStartingContext -import ai.koog.agents.core.feature.handler.tool.ToolExecutionCompletedContext -import ai.koog.agents.core.feature.handler.tool.ToolExecutionFailedContext -import ai.koog.agents.core.feature.handler.tool.ToolExecutionStartingContext +import ai.koog.agents.core.feature.handler.tool.ToolCallCompletedContext +import ai.koog.agents.core.feature.handler.tool.ToolCallFailedContext +import ai.koog.agents.core.feature.handler.tool.ToolCallStartingContext import ai.koog.agents.core.feature.handler.tool.ToolValidationFailedContext /** @@ -49,7 +49,7 @@ import ai.koog.agents.core.feature.handler.tool.ToolValidationFailedContext * Example usage: * ``` * handleEvents { - * onToolExecutionStarting { eventContext -> + * onToolCallStarting { eventContext -> * println("Tool called: ${eventContext.tool.name} with args ${eventContext.toolArgs}") * } * @@ -101,13 +101,13 @@ public class EventHandlerConfig : FeatureConfig() { //region Private Tool Call Handlers - private var _onToolExecutionStarting: suspend (eventHandler: ToolExecutionStartingContext) -> Unit = { _ -> } + private var _onToolCallStarting: suspend (eventHandler: ToolCallStartingContext) -> Unit = { _ -> } private var _onToolValidationFailed: suspend (eventHandler: ToolValidationFailedContext) -> Unit = { _ -> } - private var _onToolExecutionFailed: suspend (eventHandler: ToolExecutionFailedContext) -> Unit = { _ -> } + private var _onToolCallFailed: suspend (eventHandler: ToolCallFailedContext) -> Unit = { _ -> } - private var _onToolExecutionCompleted: suspend (eventHandler: ToolExecutionCompletedContext) -> Unit = { _ -> } + private var _onToolCallCompleted: suspend (eventHandler: ToolCallCompletedContext) -> Unit = { _ -> } //endregion Private Tool Call Handlers @@ -266,9 +266,9 @@ public class EventHandlerConfig : FeatureConfig() { /** * Append handler called when a tool is about to be called. */ - public fun onToolExecutionStarting(handler: suspend (eventContext: ToolExecutionStartingContext) -> Unit) { - val originalHandler = this._onToolExecutionStarting - this._onToolExecutionStarting = { eventContext -> + public fun onToolCallStarting(handler: suspend (eventContext: ToolCallStartingContext) -> Unit) { + val originalHandler = this._onToolCallStarting + this._onToolCallStarting = { eventContext -> originalHandler(eventContext) handler.invoke(eventContext) } @@ -288,9 +288,9 @@ public class EventHandlerConfig : FeatureConfig() { /** * Append handler called when a tool call fails with an exception. */ - public fun onToolExecutionFailed(handler: suspend (eventContext: ToolExecutionFailedContext) -> Unit) { - val originalHandler = this._onToolExecutionFailed - this._onToolExecutionFailed = { eventContext -> + public fun onToolCallFailed(handler: suspend (eventContext: ToolCallFailedContext) -> Unit) { + val originalHandler = this._onToolCallFailed + this._onToolCallFailed = { eventContext -> originalHandler(eventContext) handler.invoke(eventContext) } @@ -299,9 +299,9 @@ public class EventHandlerConfig : FeatureConfig() { /** * Append handler called when a tool call completes successfully. */ - public fun onToolExecutionCompleted(handler: suspend (eventContext: ToolExecutionCompletedContext) -> Unit) { - val originalHandler = this._onToolExecutionCompleted - this._onToolExecutionCompleted = { eventContext -> + public fun onToolCallCompleted(handler: suspend (eventContext: ToolCallCompletedContext) -> Unit) { + val originalHandler = this._onToolCallCompleted + this._onToolCallCompleted = { eventContext -> originalHandler(eventContext) handler.invoke(eventContext) } @@ -542,11 +542,11 @@ public class EventHandlerConfig : FeatureConfig() { * Append handler called when a tool is about to be called. */ @Deprecated( - message = "Use onToolExecutionStarting instead", - ReplaceWith("onToolExecutionStarting(handler)", "ai.koog.agents.core.feature.handler.ToolExecutionStartingContext") + message = "Use onToolCallStarting instead", + ReplaceWith("onToolCallStarting(handler)", "ai.koog.agents.core.feature.handler.ToolCallStartingContext") ) public fun onToolCall(handler: suspend (eventContext: ToolCallContext) -> Unit) { - onToolExecutionStarting(handler) + onToolCallStarting(handler) } /** @@ -564,22 +564,22 @@ public class EventHandlerConfig : FeatureConfig() { * Append handler called when a tool call fails with an exception. */ @Deprecated( - message = "Use onToolExecutionFailed instead", - ReplaceWith("onToolExecutionFailed(handler)", "ai.koog.agents.core.feature.handler.ToolExecutionFailedContext") + message = "Use onToolCallFailed instead", + ReplaceWith("onToolCallFailed(handler)", "ai.koog.agents.core.feature.handler.ToolCallFailedContext") ) public fun onToolCallFailure(handler: suspend (eventContext: ToolCallFailureContext) -> Unit) { - onToolExecutionFailed(handler) + onToolCallFailed(handler) } /** * Append handler called when a tool call completes successfully. */ @Deprecated( - message = "Use onToolExecutionCompleted instead", - ReplaceWith("onToolExecutionCompleted(handler)", "ai.koog.agents.core.feature.handler.ToolExecutionCompletedContext") + message = "Use onToolCallCompleted instead", + ReplaceWith("onToolCallCompleted(handler)", "ai.koog.agents.core.feature.handler.ToolCallCompletedContext") ) public fun onToolCallResult(handler: suspend (eventContext: ToolCallResultContext) -> Unit) { - onToolExecutionCompleted(handler) + onToolCallCompleted(handler) } //endregion Deprecated Handlers @@ -682,8 +682,8 @@ public class EventHandlerConfig : FeatureConfig() { /** * Invoke handlers for the tool call event. */ - internal suspend fun invokeOnToolExecutionStarting(eventContext: ToolExecutionStartingContext) { - _onToolExecutionStarting.invoke(eventContext) + internal suspend fun invokeOnToolCallStarting(eventContext: ToolCallStartingContext) { + _onToolCallStarting.invoke(eventContext) } /** @@ -696,15 +696,15 @@ public class EventHandlerConfig : FeatureConfig() { /** * Invoke handlers for a tool call failure with an exception event. */ - internal suspend fun invokeOnToolExecutionFailed(eventContext: ToolExecutionFailedContext) { - _onToolExecutionFailed.invoke(eventContext) + internal suspend fun invokeOnToolCallFailed(eventContext: ToolCallFailedContext) { + _onToolCallFailed.invoke(eventContext) } /** * Invoke handlers for an event when a tool call is completed successfully. */ - internal suspend fun invokeOnToolExecutionCompleted(eventContext: ToolExecutionCompletedContext) { - _onToolExecutionCompleted.invoke(eventContext) + internal suspend fun invokeOnToolCallCompleted(eventContext: ToolCallCompletedContext) { + _onToolCallCompleted.invoke(eventContext) } //endregion Invoke Tool Call Handlers @@ -716,7 +716,7 @@ public class EventHandlerConfig : FeatureConfig() { * * @param eventContext The context containing information about the streaming session about to begin */ - internal suspend fun invokeOnLLMStreammingStarting(eventContext: LLMStreamingStartingContext) { + internal suspend fun invokeOnLLMStreamingStarting(eventContext: LLMStreamingStartingContext) { _onLLMStreamingStarting.invoke(eventContext) } diff --git a/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerTest.kt b/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerTest.kt index 894532b3d7..3fd550f3ae 100644 --- a/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerTest.kt +++ b/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerTest.kt @@ -49,15 +49,15 @@ class EventHandlerTest { val runId = eventsCollector.runId val expectedEvents = listOf( - "OnBeforeAgentStarted (agent id: test-agent-id, run id: $runId)", - "OnStrategyStarted (run id: $runId, strategy: $strategyName)", - "OnBeforeNode (run id: $runId, node: __start__, input: $agentInput)", - "OnAfterNode (run id: $runId, node: __start__, input: $agentInput, output: $agentInput)", - "OnBeforeNode (run id: $runId, node: __finish__, input: $agentResult)", - "OnAfterNode (run id: $runId, node: __finish__, input: $agentResult, output: $agentResult)", - "OnStrategyFinished (run id: $runId, strategy: $strategyName, result: $agentResult)", - "OnAgentFinished (agent id: test-agent-id, run id: $runId, result: $agentResult)", - "OnAgentBeforeClose (agent id: test-agent-id)", + "OnAgentStarting (agent id: test-agent-id, run id: $runId)", + "OnStrategyStarting (run id: $runId, strategy: $strategyName)", + "OnNodeExecutionStarting (run id: $runId, node: __start__, input: $agentInput)", + "OnNodeExecutionCompleted (run id: $runId, node: __start__, input: $agentInput, output: $agentInput)", + "OnNodeExecutionStarting (run id: $runId, node: __finish__, input: $agentResult)", + "OnNodeExecutionCompleted (run id: $runId, node: __finish__, input: $agentResult, output: $agentResult)", + "OnStrategyCompleted (run id: $runId, strategy: $strategyName, result: $agentResult)", + "OnAgentCompleted (agent id: test-agent-id, run id: $runId, result: $agentResult)", + "OnAgentClosing (agent id: test-agent-id)", ) assertEquals(expectedEvents.size, eventsCollector.size) @@ -93,19 +93,19 @@ class EventHandlerTest { val runId = eventsCollector.runId val expectedEvents = listOf( - "OnBeforeAgentStarted (agent id: test-agent-id, run id: $runId)", - "OnStrategyStarted (run id: $runId, strategy: $strategyName)", - "OnBeforeNode (run id: $runId, node: __start__, input: $agentInput)", - "OnAfterNode (run id: $runId, node: __start__, input: $agentInput, output: $agentInput)", - "OnBeforeNode (run id: $runId, node: test LLM call, input: Test LLM call prompt)", - "OnBeforeLLMCall (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: Test LLM call prompt}], temperature: null, tools: [])", - "OnAfterLLMCall (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: Test LLM call prompt}], temperature: null, model: openai:gpt-4o, tools: [], responses: [role: Assistant, message: Default test response])", - "OnAfterNode (run id: $runId, node: test LLM call, input: Test LLM call prompt, output: Assistant(content=Default test response, metaInfo=ResponseMetaInfo(timestamp=$ts, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}), attachments=[], finishReason=null))", - "OnBeforeNode (run id: $runId, node: __finish__, input: $agentResult)", - "OnAfterNode (run id: $runId, node: __finish__, input: $agentResult, output: $agentResult)", - "OnStrategyFinished (run id: $runId, strategy: $strategyName, result: $agentResult)", - "OnAgentFinished (agent id: test-agent-id, run id: $runId, result: $agentResult)", - "OnAgentBeforeClose (agent id: $agentId)", + "OnAgentStarting (agent id: test-agent-id, run id: $runId)", + "OnStrategyStarting (run id: $runId, strategy: $strategyName)", + "OnNodeExecutionStarting (run id: $runId, node: __start__, input: $agentInput)", + "OnNodeExecutionCompleted (run id: $runId, node: __start__, input: $agentInput, output: $agentInput)", + "OnNodeExecutionStarting (run id: $runId, node: test LLM call, input: Test LLM call prompt)", + "OnLLMCallStarting (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: Test LLM call prompt}], temperature: null, tools: [])", + "OnLLMCallCompleted (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: Test LLM call prompt}], temperature: null, model: openai:gpt-4o, tools: [], responses: [role: Assistant, message: Default test response])", + "OnNodeExecutionCompleted (run id: $runId, node: test LLM call, input: Test LLM call prompt, output: Assistant(content=Default test response, metaInfo=ResponseMetaInfo(timestamp=$ts, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}), attachments=[], finishReason=null))", + "OnNodeExecutionStarting (run id: $runId, node: __finish__, input: $agentResult)", + "OnNodeExecutionCompleted (run id: $runId, node: __finish__, input: $agentResult, output: $agentResult)", + "OnStrategyCompleted (run id: $runId, strategy: $strategyName, result: $agentResult)", + "OnAgentCompleted (agent id: test-agent-id, run id: $runId, result: $agentResult)", + "OnAgentClosing (agent id: $agentId)", ) assertEquals(expectedEvents.size, eventsCollector.size) @@ -162,27 +162,27 @@ class EventHandlerTest { val runId = eventsCollector.runId val expectedEvents = listOf( - "OnBeforeAgentStarted (agent id: $agentId, run id: $runId)", - "OnStrategyStarted (run id: $runId, strategy: $strategyName)", - "OnBeforeNode (run id: $runId, node: __start__, input: $userPrompt)", - "OnAfterNode (run id: $runId, node: __start__, input: $userPrompt, output: $userPrompt)", - "OnBeforeNode (run id: $runId, node: test-llm-call, input: $userPrompt)", - "OnBeforeLLMCall (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: $userPrompt}], temperature: null, tools: [dummy])", - "OnAfterLLMCall (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: $userPrompt}], temperature: null, model: openai:gpt-4o, tools: [${dummyTool.name}], responses: [role: Tool, message: {\"dummy\":\"test\"}])", - "OnAfterNode (run id: $runId, node: test-llm-call, input: $userPrompt, output: Call(id=null, tool=${dummyTool.name}, content={\"dummy\":\"test\"}, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={})))", - "OnBeforeNode (run id: $runId, node: test-tool-call, input: Call(id=null, tool=${dummyTool.name}, content={\"dummy\":\"test\"}, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={})))", - "OnToolCall (run id: $runId, tool: ${dummyTool.name}, args: Args(dummy=test))", - "OnToolCallResult (run id: $runId, tool: ${dummyTool.name}, args: Args(dummy=test), result: ${dummyTool.result})", - "OnAfterNode (run id: $runId, node: test-tool-call, input: Call(id=null, tool=${dummyTool.name}, content={\"dummy\":\"test\"}, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={})), output: ReceivedToolResult(id=null, tool=${dummyTool.name}, content=${dummyTool.result}, result=${dummyTool.result}))", - "OnBeforeNode (run id: $runId, node: test-node-llm-send-tool-result, input: ReceivedToolResult(id=null, tool=${dummyTool.name}, content=${dummyTool.result}, result=${dummyTool.result}))", - "OnBeforeLLMCall (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: $userPrompt, role: Tool, message: {\"dummy\":\"test\"}, role: Tool, message: ${dummyTool.result}}], temperature: null, tools: [${dummyTool.name}])", - "OnAfterLLMCall (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: $userPrompt, role: Tool, message: {\"dummy\":\"test\"}, role: Tool, message: ${dummyTool.result}}], temperature: null, model: openai:gpt-4o, tools: [${dummyTool.name}], responses: [role: Assistant, message: Return test result])", - "OnAfterNode (run id: $runId, node: test-node-llm-send-tool-result, input: ReceivedToolResult(id=null, tool=${dummyTool.name}, content=${dummyTool.result}, result=${dummyTool.result}), output: Assistant(content=Return test result, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}), attachments=[], finishReason=null))", - "OnBeforeNode (run id: $runId, node: __finish__, input: $mockResponse)", - "OnAfterNode (run id: $runId, node: __finish__, input: $mockResponse, output: $mockResponse)", - "OnStrategyFinished (run id: $runId, strategy: $strategyName, result: $mockResponse)", - "OnAgentFinished (agent id: $agentId, run id: $runId, result: $mockResponse)", - "OnAgentBeforeClose (agent id: $agentId)", + "OnAgentStarting (agent id: $agentId, run id: $runId)", + "OnStrategyStarting (run id: $runId, strategy: $strategyName)", + "OnNodeExecutionStarting (run id: $runId, node: __start__, input: $userPrompt)", + "OnNodeExecutionCompleted (run id: $runId, node: __start__, input: $userPrompt, output: $userPrompt)", + "OnNodeExecutionStarting (run id: $runId, node: test-llm-call, input: $userPrompt)", + "OnLLMCallStarting (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: $userPrompt}], temperature: null, tools: [dummy])", + "OnLLMCallCompleted (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: $userPrompt}], temperature: null, model: openai:gpt-4o, tools: [${dummyTool.name}], responses: [role: Tool, message: {\"dummy\":\"test\"}])", + "OnNodeExecutionCompleted (run id: $runId, node: test-llm-call, input: $userPrompt, output: Call(id=null, tool=${dummyTool.name}, content={\"dummy\":\"test\"}, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={})))", + "OnNodeExecutionStarting (run id: $runId, node: test-tool-call, input: Call(id=null, tool=${dummyTool.name}, content={\"dummy\":\"test\"}, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={})))", + "OnToolCallStarting (run id: $runId, tool: ${dummyTool.name}, args: Args(dummy=test))", + "OnToolCallCompleted (run id: $runId, tool: ${dummyTool.name}, args: Args(dummy=test), result: ${dummyTool.result})", + "OnNodeExecutionCompleted (run id: $runId, node: test-tool-call, input: Call(id=null, tool=${dummyTool.name}, content={\"dummy\":\"test\"}, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={})), output: ReceivedToolResult(id=null, tool=${dummyTool.name}, content=${dummyTool.result}, result=${dummyTool.result}))", + "OnNodeExecutionStarting (run id: $runId, node: test-node-llm-send-tool-result, input: ReceivedToolResult(id=null, tool=${dummyTool.name}, content=${dummyTool.result}, result=${dummyTool.result}))", + "OnLLMCallStarting (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: $userPrompt, role: Tool, message: {\"dummy\":\"test\"}, role: Tool, message: ${dummyTool.result}}], temperature: null, tools: [${dummyTool.name}])", + "OnLLMCallCompleted (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: $userPrompt, role: Tool, message: {\"dummy\":\"test\"}, role: Tool, message: ${dummyTool.result}}], temperature: null, model: openai:gpt-4o, tools: [${dummyTool.name}], responses: [role: Assistant, message: Return test result])", + "OnNodeExecutionCompleted (run id: $runId, node: test-node-llm-send-tool-result, input: ReceivedToolResult(id=null, tool=${dummyTool.name}, content=${dummyTool.result}, result=${dummyTool.result}), output: Assistant(content=Return test result, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}), attachments=[], finishReason=null))", + "OnNodeExecutionStarting (run id: $runId, node: __finish__, input: $mockResponse)", + "OnNodeExecutionCompleted (run id: $runId, node: __finish__, input: $mockResponse, output: $mockResponse)", + "OnStrategyCompleted (run id: $runId, strategy: $strategyName, result: $mockResponse)", + "OnAgentCompleted (agent id: $agentId, run id: $runId, result: $mockResponse)", + "OnAgentClosing (agent id: $agentId)", ) assertEquals(expectedEvents.size, eventsCollector.size) @@ -219,23 +219,23 @@ class EventHandlerTest { val runId = eventsCollector.runId val expectedEvents = listOf( - "OnBeforeAgentStarted (agent id: test-agent-id, run id: $runId)", - "OnStrategyStarted (run id: $runId, strategy: $strategyName)", - "OnBeforeNode (run id: $runId, node: __start__, input: $agentInput)", - "OnAfterNode (run id: $runId, node: __start__, input: $agentInput, output: $agentInput)", - "OnBeforeNode (run id: $runId, node: test LLM call, input: Test LLM call prompt)", - "OnBeforeLLMCall (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: Test LLM call prompt}], temperature: null, tools: [dummy])", - "OnAfterLLMCall (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: Test LLM call prompt}], temperature: null, model: openai:gpt-4o, tools: [dummy], responses: [role: Assistant, message: Default test response])", - "OnAfterNode (run id: $runId, node: test LLM call, input: Test LLM call prompt, output: Assistant(content=Default test response, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}), attachments=[], finishReason=null))", - "OnBeforeNode (run id: $runId, node: test LLM call with tools, input: Test LLM call with tools prompt)", - "OnBeforeLLMCall (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: Test LLM call prompt, role: Assistant, message: Default test response, role: User, message: Test LLM call with tools prompt}], temperature: null, tools: [dummy])", - "OnAfterLLMCall (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: Test LLM call prompt, role: Assistant, message: Default test response, role: User, message: Test LLM call with tools prompt}], temperature: null, model: openai:gpt-4o, tools: [dummy], responses: [role: Assistant, message: Default test response])", - "OnAfterNode (run id: $runId, node: test LLM call with tools, input: Test LLM call with tools prompt, output: Assistant(content=Default test response, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}), attachments=[], finishReason=null))", - "OnBeforeNode (run id: $runId, node: __finish__, input: $agentResult)", - "OnAfterNode (run id: $runId, node: __finish__, input: $agentResult, output: $agentResult)", - "OnStrategyFinished (run id: $runId, strategy: $strategyName, result: $agentResult)", - "OnAgentFinished (agent id: test-agent-id, run id: $runId, result: $agentResult)", - "OnAgentBeforeClose (agent id: test-agent-id)", + "OnAgentStarting (agent id: test-agent-id, run id: $runId)", + "OnStrategyStarting (run id: $runId, strategy: $strategyName)", + "OnNodeExecutionStarting (run id: $runId, node: __start__, input: $agentInput)", + "OnNodeExecutionCompleted (run id: $runId, node: __start__, input: $agentInput, output: $agentInput)", + "OnNodeExecutionStarting (run id: $runId, node: test LLM call, input: Test LLM call prompt)", + "OnLLMCallStarting (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: Test LLM call prompt}], temperature: null, tools: [dummy])", + "OnLLMCallCompleted (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: Test LLM call prompt}], temperature: null, model: openai:gpt-4o, tools: [dummy], responses: [role: Assistant, message: Default test response])", + "OnNodeExecutionCompleted (run id: $runId, node: test LLM call, input: Test LLM call prompt, output: Assistant(content=Default test response, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}), attachments=[], finishReason=null))", + "OnNodeExecutionStarting (run id: $runId, node: test LLM call with tools, input: Test LLM call with tools prompt)", + "OnLLMCallStarting (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: Test LLM call prompt, role: Assistant, message: Default test response, role: User, message: Test LLM call with tools prompt}], temperature: null, tools: [dummy])", + "OnLLMCallCompleted (run id: $runId, prompt: id: test, messages: [{role: System, message: Test system message, role: User, message: Test user message, role: Assistant, message: Test assistant response, role: User, message: Test LLM call prompt, role: Assistant, message: Default test response, role: User, message: Test LLM call with tools prompt}], temperature: null, model: openai:gpt-4o, tools: [dummy], responses: [role: Assistant, message: Default test response])", + "OnNodeExecutionCompleted (run id: $runId, node: test LLM call with tools, input: Test LLM call with tools prompt, output: Assistant(content=Default test response, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}), attachments=[], finishReason=null))", + "OnNodeExecutionStarting (run id: $runId, node: __finish__, input: $agentResult)", + "OnNodeExecutionCompleted (run id: $runId, node: __finish__, input: $agentResult, output: $agentResult)", + "OnStrategyCompleted (run id: $runId, strategy: $strategyName, result: $agentResult)", + "OnAgentCompleted (agent id: test-agent-id, run id: $runId, result: $agentResult)", + "OnAgentClosing (agent id: test-agent-id)", ) assertEquals(expectedEvents.size, eventsCollector.size) @@ -277,14 +277,14 @@ class EventHandlerTest { val runId = eventsCollector.runId val expectedEvents = listOf( - "OnBeforeAgentStarted (agent id: $agentId, run id: $runId)", - "OnStrategyStarted (run id: $runId, strategy: $strategyName)", - "OnBeforeNode (run id: $runId, node: __start__, input: $agentInput)", - "OnAfterNode (run id: $runId, node: __start__, input: $agentInput, output: $agentInput)", - "OnBeforeNode (run id: $runId, node: $errorNodeName, input: $agentInput)", - "OnNodeExecutionError (run id: $runId, node: $errorNodeName, error: $testErrorMessage)", - "OnAgentRunError (agent id: $agentId, run id: $runId, error: $testErrorMessage)", - "OnAgentBeforeClose (agent id: $agentId)", + "OnAgentStarting (agent id: $agentId, run id: $runId)", + "OnStrategyStarting (run id: $runId, strategy: $strategyName)", + "OnNodeExecutionStarting (run id: $runId, node: __start__, input: $agentInput)", + "OnNodeExecutionCompleted (run id: $runId, node: __start__, input: $agentInput, output: $agentInput)", + "OnNodeExecutionStarting (run id: $runId, node: $errorNodeName, input: $agentInput)", + "OnNodeExecutionFailed (run id: $runId, node: $errorNodeName, error: $testErrorMessage)", + "OnAgentExecutionFailed (agent id: $agentId, run id: $runId, error: $testErrorMessage)", + "OnAgentClosing (agent id: $agentId)", ) assertEquals(expectedEvents.size, eventsCollector.size) @@ -308,22 +308,22 @@ class EventHandlerTest { strategy = strategy, installFeatures = { install(EventHandler) { - onBeforeAgentStarted { eventContext -> + onAgentStarting { eventContext -> runId = eventContext.runId collectedEvents.add( - "OnBeforeAgentStarted first (agent id: ${eventContext.agent.id})" + "OnAgentStarting first (agent id: ${eventContext.agent.id})" ) } - onBeforeAgentStarted { eventContext -> + onAgentStarting { eventContext -> collectedEvents.add( - "OnBeforeAgentStarted second (agent id: ${eventContext.agent.id})" + "OnAgentStarting second (agent id: ${eventContext.agent.id})" ) } - onAgentFinished { eventContext -> + onAgentCompleted { eventContext -> collectedEvents.add( - "OnAgentFinished (agent id: ${eventContext.agentId}, run id: ${eventContext.runId}, result: $agentResult)" + "OnAgentCompleted (agent id: ${eventContext.agentId}, run id: ${eventContext.runId}, result: $agentResult)" ) } } @@ -334,9 +334,9 @@ class EventHandlerTest { agent.run(agentInput) val expectedEvents = listOf( - "OnBeforeAgentStarted first (agent id: ${agent.id})", - "OnBeforeAgentStarted second (agent id: ${agent.id})", - "OnAgentFinished (agent id: ${agent.id}, run id: $runId, result: $agentResult)", + "OnAgentStarting first (agent id: ${agent.id})", + "OnAgentStarting second (agent id: ${agent.id})", + "OnAgentCompleted (agent id: ${agent.id}, run id: $runId, result: $agentResult)", ) assertEquals(expectedEvents.size, collectedEvents.size) diff --git a/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/NodeLLMRequestStreamingAndSendResultsTest.kt b/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/NodeLLMRequestStreamingAndSendResultsTest.kt index c4bc665d76..793888ae18 100644 --- a/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/NodeLLMRequestStreamingAndSendResultsTest.kt +++ b/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/NodeLLMRequestStreamingAndSendResultsTest.kt @@ -89,7 +89,7 @@ class NodeLLMRequestStreamingAndSendResultsTest { // Verify streaming events were captured val streamingEvents = eventsCollector.collectedEvents.filter { - it.contains("OnBeforeStream") || it.contains("OnStreamFrame") || it.contains("OnAfterStream") + it.contains("OnLLMStreamingStarting") || it.contains("OnLLMStreamingFrameReceived") || it.contains("OnLLMStreamingCompleted") } assertTrue(streamingEvents.isNotEmpty(), "Should have captured streaming events") } @@ -132,7 +132,7 @@ class NodeLLMRequestStreamingAndSendResultsTest { // Verify streaming events occurred val streamingEvents = eventsCollector.collectedEvents.filter { - it.contains("OnBeforeStream") || it.contains("OnAfterStream") + it.contains("OnLLMStreamingStarting") || it.contains("OnLLMStreamingCompleted") } assertTrue(streamingEvents.isNotEmpty(), "Should have streaming events") } diff --git a/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/StreamingEventHandlerTest.kt b/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/StreamingEventHandlerTest.kt index cdd73939f4..d9ecb58157 100644 --- a/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/StreamingEventHandlerTest.kt +++ b/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/StreamingEventHandlerTest.kt @@ -15,7 +15,7 @@ import kotlin.test.assertTrue /** * Tests for streaming event handlers. - * These tests verify that the streaming handlers (onBeforeStream, onStreamFrame, onAfterStream) + * These tests verify that the streaming handlers (onLLMStreamingStarting, onLLMStreamingFrameReceived, onLLMStreamingCompleted) * are properly invoked during LLM streaming operations. */ class StreamingEventHandlerTest { @@ -36,13 +36,13 @@ class StreamingEventHandlerTest { assertEventsCollected(eventsCollector) // Verify streaming events are captured when using nodeLLMRequestsStreaming - val beforeStreamEvents = eventsCollector.collectedEvents.filter { it.contains("OnBeforeStream") } - val streamFrameEvents = eventsCollector.collectedEvents.filter { it.contains("OnStreamFrame") } - val afterStreamEvents = eventsCollector.collectedEvents.filter { it.contains("OnAfterStream") } + val beforeStreamEvents = eventsCollector.collectedEvents.filter { it.contains("OnLLMStreamingStarting") } + val streamFrameEvents = eventsCollector.collectedEvents.filter { it.contains("OnLLMStreamingFrameReceived") } + val afterStreamEvents = eventsCollector.collectedEvents.filter { it.contains("OnLLMStreamingCompleted") } - assertTrue(beforeStreamEvents.isNotEmpty(), "Should have OnBeforeStream events") - assertTrue(streamFrameEvents.isNotEmpty(), "Should have OnStreamFrame events") - assertTrue(afterStreamEvents.isNotEmpty(), "Should have OnAfterStream events") + assertTrue(beforeStreamEvents.isNotEmpty(), "Should have OnLLMStreamingStarting events") + assertTrue(streamFrameEvents.isNotEmpty(), "Should have OnLLMStreamingFrameReceived events") + assertTrue(afterStreamEvents.isNotEmpty(), "Should have OnLLMStreamingCompleted events") // Verify the stream frame contains the expected response val frameWithContent = streamFrameEvents.firstOrNull { it.contains(assistantResponse) } @@ -63,7 +63,7 @@ class StreamingEventHandlerTest { // Verify the overall event collection is working assertEventsCollected(eventsCollector) // Verify that streaming events were captured - val streamingEventTypes = listOf("OnBeforeStream", "OnStreamFrame", "OnAfterStream") + val streamingEventTypes = listOf("OnLLMStreamingStarting", "OnLLMStreamingFrameReceived", "OnLLMStreamingCompleted") assertTrue( actual = eventsCollector.collectedEvents.any { streamingEventTypes.any(it::contains) }, message = "Should have captured at least one streaming event (${streamingEventTypes.joinToString()})" diff --git a/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/TestEventsCollector.kt b/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/TestEventsCollector.kt index cd08a485a4..fe3fdfa09e 100644 --- a/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/TestEventsCollector.kt +++ b/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/TestEventsCollector.kt @@ -20,59 +20,61 @@ class TestEventsCollector { onAgentStarting { eventContext -> runId = eventContext.runId _collectedEvents.add( - "OnBeforeAgentStarted (agent id: ${eventContext.agent.id}, run id: ${eventContext.runId})" + "OnAgentStarting (agent id: ${eventContext.agent.id}, run id: ${eventContext.runId})" ) } onAgentCompleted { eventContext -> _collectedEvents.add( - "OnAgentFinished (agent id: ${eventContext.agentId}, run id: ${eventContext.runId}, result: ${eventContext.result})" + "OnAgentCompleted (agent id: ${eventContext.agentId}, run id: ${eventContext.runId}, result: ${eventContext.result})" ) } onAgentExecutionFailed { eventContext -> _collectedEvents.add( - "OnAgentRunError (agent id: ${eventContext.agentId}, run id: ${eventContext.runId}, error: ${eventContext.throwable.message})" + "OnAgentExecutionFailed (agent id: ${eventContext.agentId}, run id: ${eventContext.runId}, error: ${eventContext.throwable.message})" ) } onAgentClosing { eventContext -> - _collectedEvents.add("OnAgentBeforeClose (agent id: ${eventContext.agentId})") + _collectedEvents.add( + "OnAgentClosing (agent id: ${eventContext.agentId})" + ) } onStrategyStarting { eventContext -> _collectedEvents.add( - "OnStrategyStarted (run id: ${eventContext.runId}, strategy: ${eventContext.strategy.name})" + "OnStrategyStarting (run id: ${eventContext.runId}, strategy: ${eventContext.strategy.name})" ) } onStrategyCompleted { eventContext -> _collectedEvents.add( - "OnStrategyFinished (run id: ${eventContext.runId}, strategy: ${eventContext.strategy.name}, result: ${eventContext.result})" + "OnStrategyCompleted (run id: ${eventContext.runId}, strategy: ${eventContext.strategy.name}, result: ${eventContext.result})" ) } onNodeExecutionStarting { eventContext -> _collectedEvents.add( - "OnBeforeNode (run id: ${eventContext.context.runId}, node: ${eventContext.node.name}, input: ${eventContext.input})" + "OnNodeExecutionStarting (run id: ${eventContext.context.runId}, node: ${eventContext.node.name}, input: ${eventContext.input})" ) } onNodeExecutionCompleted { eventContext -> _collectedEvents.add( - "OnAfterNode (run id: ${eventContext.context.runId}, node: ${eventContext.node.name}, input: ${eventContext.input}, output: ${eventContext.output})" + "OnNodeExecutionCompleted (run id: ${eventContext.context.runId}, node: ${eventContext.node.name}, input: ${eventContext.input}, output: ${eventContext.output})" ) } onNodeExecutionFailed { eventContext -> _collectedEvents.add( - "OnNodeExecutionError (run id: ${eventContext.context.runId}, node: ${eventContext.node.name}, error: ${eventContext.throwable.message})" + "OnNodeExecutionFailed (run id: ${eventContext.context.runId}, node: ${eventContext.node.name}, error: ${eventContext.throwable.message})" ) } onLLMCallStarting { eventContext -> _collectedEvents.add( - "OnBeforeLLMCall (run id: ${eventContext.runId}, prompt: ${eventContext.prompt.traceString}, tools: [${ + "OnLLMCallStarting (run id: ${eventContext.runId}, prompt: ${eventContext.prompt.traceString}, tools: [${ eventContext.tools.joinToString { it.name } @@ -82,7 +84,7 @@ class TestEventsCollector { onLLMCallCompleted { eventContext -> _collectedEvents.add( - "OnAfterLLMCall (run id: ${eventContext.runId}, prompt: ${eventContext.prompt.traceString}, model: ${eventContext.model.eventString}, tools: [${ + "OnLLMCallCompleted (run id: ${eventContext.runId}, prompt: ${eventContext.prompt.traceString}, model: ${eventContext.model.eventString}, tools: [${ eventContext.tools.joinToString { it.name } @@ -90,33 +92,33 @@ class TestEventsCollector { ) } - onToolExecutionStarting { eventContext -> + onToolCallStarting { eventContext -> _collectedEvents.add( - "OnToolCall (run id: ${eventContext.runId}, tool: ${eventContext.tool.name}, args: ${eventContext.toolArgs})" + "OnToolCallStarting (run id: ${eventContext.runId}, tool: ${eventContext.tool.name}, args: ${eventContext.toolArgs})" ) } onToolValidationFailed { eventContext -> _collectedEvents.add( - "OnToolValidationError (run id: ${eventContext.runId}, tool: ${eventContext.tool.name}, args: ${eventContext.toolArgs}, value: ${eventContext.error})" + "OnToolValidationFailed (run id: ${eventContext.runId}, tool: ${eventContext.tool.name}, args: ${eventContext.toolArgs}, value: ${eventContext.error})" ) } - onToolExecutionFailed { eventContext -> + onToolCallFailed { eventContext -> _collectedEvents.add( - "OnToolCallFailure (run id: ${eventContext.runId}, tool: ${eventContext.tool.name}, args: ${eventContext.toolArgs}, throwable: ${eventContext.throwable.message})" + "OnToolCallFailed (run id: ${eventContext.runId}, tool: ${eventContext.tool.name}, args: ${eventContext.toolArgs}, throwable: ${eventContext.throwable.message})" ) } - onToolExecutionCompleted { eventContext -> + onToolCallCompleted { eventContext -> _collectedEvents.add( - "OnToolCallResult (run id: ${eventContext.runId}, tool: ${eventContext.tool.name}, args: ${eventContext.toolArgs}, result: ${eventContext.result})" + "OnToolCallCompleted (run id: ${eventContext.runId}, tool: ${eventContext.tool.name}, args: ${eventContext.toolArgs}, result: ${eventContext.result})" ) } onLLMStreamingStarting { eventContext -> _collectedEvents.add( - "OnBeforeStream (run id: ${eventContext.runId}, prompt: ${eventContext.prompt.traceString}, model: ${eventContext.model.eventString}, tools: [${ + "OnLLMStreamingStarting (run id: ${eventContext.runId}, prompt: ${eventContext.prompt.traceString}, model: ${eventContext.model.eventString}, tools: [${ eventContext.tools.joinToString { it.name } }])" ) @@ -124,19 +126,19 @@ class TestEventsCollector { onLLMStreamingFrameReceived { eventContext -> _collectedEvents.add( - "OnStreamFrame (run id: ${eventContext.runId}, frame: ${eventContext.streamFrame})" + "OnLLMStreamingFrameReceived (run id: ${eventContext.runId}, frame: ${eventContext.streamFrame})" ) } onLLMStreamingFailed { eventContext -> _collectedEvents.add( - "OnStreamError (run id: ${eventContext.runId}, error: ${eventContext.error.message})" + "OnLLMStreamingFailed (run id: ${eventContext.runId}, error: ${eventContext.error.message})" ) } onLLMStreamingCompleted { eventContext -> _collectedEvents.add( - "OnAfterStream (run id: ${eventContext.runId}, prompt: ${eventContext.prompt.traceString}, model: ${eventContext.model.eventString}, tools: [${ + "OnLLMStreamingCompleted (run id: ${eventContext.runId}, prompt: ${eventContext.prompt.traceString}, model: ${eventContext.model.eventString}, tools: [${ eventContext.tools.joinToString { it.name } }])" ) diff --git a/agents/agents-features/agents-features-opentelemetry/src/jvmMain/kotlin/ai/koog/agents/features/opentelemetry/feature/OpenTelemetry.kt b/agents/agents-features/agents-features-opentelemetry/src/jvmMain/kotlin/ai/koog/agents/features/opentelemetry/feature/OpenTelemetry.kt index 4e64330e41..bf63d08a1a 100644 --- a/agents/agents-features/agents-features-opentelemetry/src/jvmMain/kotlin/ai/koog/agents/features/opentelemetry/feature/OpenTelemetry.kt +++ b/agents/agents-features/agents-features-opentelemetry/src/jvmMain/kotlin/ai/koog/agents/features/opentelemetry/feature/OpenTelemetry.kt @@ -388,7 +388,7 @@ public class OpenTelemetry { //region Tool Call - pipeline.interceptToolExecutionStarting(interceptContext) { eventContext -> + pipeline.interceptToolCallStarting(interceptContext) { eventContext -> logger.debug { "Execute OpenTelemetry tool call handler" } // Get current NodeExecuteSpan @@ -416,7 +416,7 @@ public class OpenTelemetry { spanProcessor.startSpan(executeToolSpan) } - pipeline.interceptToolExecutionCompleted(interceptContext) { eventContext -> + pipeline.interceptToolCallCompleted(interceptContext) { eventContext -> logger.debug { "Execute OpenTelemetry tool result handler" } // Get current ExecuteToolSpan @@ -445,7 +445,7 @@ public class OpenTelemetry { spanProcessor.endSpan(span = executeToolSpan) } - pipeline.interceptToolExecutionFailed(interceptContext) { eventContext -> + pipeline.interceptToolCallFailed(interceptContext) { eventContext -> logger.debug { "Execute OpenTelemetry tool call failure handler" } // Get current ExecuteToolSpan diff --git a/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/feature/Tracing.kt b/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/feature/Tracing.kt index 726303c23b..21446fa536 100644 --- a/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/feature/Tracing.kt +++ b/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/feature/Tracing.kt @@ -20,9 +20,9 @@ import ai.koog.agents.core.feature.model.events.NodeExecutionCompletedEvent import ai.koog.agents.core.feature.model.events.NodeExecutionFailedEvent import ai.koog.agents.core.feature.model.events.NodeExecutionStartingEvent import ai.koog.agents.core.feature.model.events.StrategyCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionFailedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionStartingEvent +import ai.koog.agents.core.feature.model.events.ToolCallCompletedEvent +import ai.koog.agents.core.feature.model.events.ToolCallFailedEvent +import ai.koog.agents.core.feature.model.events.ToolCallStartingEvent import ai.koog.agents.core.feature.model.events.ToolValidationFailedEvent import ai.koog.agents.core.feature.model.events.startNodeToGraph import ai.koog.agents.core.feature.model.toAgentError @@ -257,12 +257,12 @@ public class Tracing { //region Intercept Tool Call Events - pipeline.interceptToolExecutionStarting(interceptContext) intercept@{ eventContext -> + pipeline.interceptToolCallStarting(interceptContext) intercept@{ eventContext -> @Suppress("UNCHECKED_CAST") val tool = eventContext.tool as Tool - val event = ToolExecutionStartingEvent( + val event = ToolCallStartingEvent( runId = eventContext.runId, toolCallId = eventContext.toolCallId, toolName = tool.name, @@ -288,12 +288,12 @@ public class Tracing { processMessage(config, event) } - pipeline.interceptToolExecutionFailed(interceptContext) intercept@{ eventContext -> + pipeline.interceptToolCallFailed(interceptContext) intercept@{ eventContext -> @Suppress("UNCHECKED_CAST") val tool = eventContext.tool as Tool - val event = ToolExecutionFailedEvent( + val event = ToolCallFailedEvent( runId = eventContext.runId, toolCallId = eventContext.toolCallId, toolName = tool.name, @@ -304,12 +304,12 @@ public class Tracing { processMessage(config, event) } - pipeline.interceptToolExecutionCompleted(interceptContext) intercept@{ eventContext -> + pipeline.interceptToolCallCompleted(interceptContext) intercept@{ eventContext -> @Suppress("UNCHECKED_CAST") val tool = eventContext.tool as Tool - val event = ToolExecutionCompletedEvent( + val event = ToolCallCompletedEvent( runId = eventContext.runId, toolCallId = eventContext.toolCallId, toolName = tool.name, diff --git a/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/writer/traceMessageFormat.kt b/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/writer/traceMessageFormat.kt index 0a6196edc6..a70b6dcef4 100644 --- a/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/writer/traceMessageFormat.kt +++ b/agents/agents-features/agents-features-trace/src/commonMain/kotlin/ai/koog/agents/features/tracing/writer/traceMessageFormat.kt @@ -14,9 +14,9 @@ import ai.koog.agents.core.feature.model.events.NodeExecutionFailedEvent import ai.koog.agents.core.feature.model.events.NodeExecutionStartingEvent import ai.koog.agents.core.feature.model.events.StrategyCompletedEvent import ai.koog.agents.core.feature.model.events.StrategyStartingEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionFailedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionStartingEvent +import ai.koog.agents.core.feature.model.events.ToolCallCompletedEvent +import ai.koog.agents.core.feature.model.events.ToolCallFailedEvent +import ai.koog.agents.core.feature.model.events.ToolCallStartingEvent import ai.koog.agents.core.feature.model.events.ToolValidationFailedEvent import ai.koog.agents.features.tracing.traceString @@ -68,16 +68,16 @@ internal val LLMCallCompletedEvent.afterLLMCallEventFormat } }])" -internal val ToolExecutionStartingEvent.toolCallEventFormat +internal val ToolCallStartingEvent.toolCallEventFormat get() = "$eventId (run id: $runId, tool: $toolName, tool args: $toolArgs)" internal val ToolValidationFailedEvent.toolValidationErrorEventFormat get() = "$eventId (run id: $runId, tool: $toolName, tool args: $toolArgs, validation error: $error)" -internal val ToolExecutionFailedEvent.toolCallFailureEventFormat +internal val ToolCallFailedEvent.toolCallFailureEventFormat get() = "$eventId (run id: $runId, tool: $toolName, tool args: $toolArgs, error: ${error.message})" -internal val ToolExecutionCompletedEvent.toolCallResultEventFormat +internal val ToolCallCompletedEvent.toolCallResultEventFormat get() = "$eventId (run id: $runId, tool: $toolName, tool args: $toolArgs, result: $result)" internal val FeatureMessage.traceMessage: String @@ -94,10 +94,10 @@ internal val FeatureMessage.traceMessage: String is NodeExecutionFailedEvent -> this.nodeExecutionErrorEventFormat is LLMCallStartingEvent -> this.beforeLLMCallEventFormat is LLMCallCompletedEvent -> this.afterLLMCallEventFormat - is ToolExecutionStartingEvent -> this.toolCallEventFormat + is ToolCallStartingEvent -> this.toolCallEventFormat is ToolValidationFailedEvent -> this.toolValidationErrorEventFormat - is ToolExecutionFailedEvent -> this.toolCallFailureEventFormat - is ToolExecutionCompletedEvent -> this.toolCallResultEventFormat + is ToolCallFailedEvent -> this.toolCallFailureEventFormat + is ToolCallCompletedEvent -> this.toolCallResultEventFormat is FeatureStringMessage -> this.featureStringMessage is FeatureEvent -> this.featureEvent else -> this.featureMessage diff --git a/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageFileWriterTest.kt b/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageFileWriterTest.kt index ce0480e60d..36d6251e02 100644 --- a/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageFileWriterTest.kt +++ b/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageFileWriterTest.kt @@ -19,8 +19,8 @@ import ai.koog.agents.core.feature.model.events.LLMCallStartingEvent import ai.koog.agents.core.feature.model.events.NodeExecutionCompletedEvent import ai.koog.agents.core.feature.model.events.NodeExecutionStartingEvent import ai.koog.agents.core.feature.model.events.StrategyCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionStartingEvent +import ai.koog.agents.core.feature.model.events.ToolCallCompletedEvent +import ai.koog.agents.core.feature.model.events.ToolCallStartingEvent import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.features.tracing.eventString import ai.koog.agents.features.tracing.feature.Tracing @@ -185,8 +185,8 @@ class TraceFeatureMessageFileWriterTest { content = """{"dummy":"test"}""" ) })", - "${ToolExecutionStartingEvent::class.simpleName} (run id: $runId, tool: ${dummyTool.name}, tool args: {\"dummy\":\"test\"})", - "${ToolExecutionCompletedEvent::class.simpleName} (run id: $runId, tool: ${dummyTool.name}, tool args: {\"dummy\":\"test\"}, result: ${dummyTool.result})", + "${ToolCallStartingEvent::class.simpleName} (run id: $runId, tool: ${dummyTool.name}, tool args: {\"dummy\":\"test\"})", + "${ToolCallCompletedEvent::class.simpleName} (run id: $runId, tool: ${dummyTool.name}, tool args: {\"dummy\":\"test\"}, result: ${dummyTool.result})", "${NodeExecutionCompletedEvent::class.simpleName} (run id: $runId, node: test-tool-call, input: ${ toolCallMessage( dummyTool.name, diff --git a/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageLogWriterTest.kt b/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageLogWriterTest.kt index 75e3644028..8d11aefabd 100644 --- a/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageLogWriterTest.kt +++ b/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageLogWriterTest.kt @@ -19,8 +19,8 @@ import ai.koog.agents.core.feature.model.events.LLMCallStartingEvent import ai.koog.agents.core.feature.model.events.NodeExecutionCompletedEvent import ai.koog.agents.core.feature.model.events.NodeExecutionStartingEvent import ai.koog.agents.core.feature.model.events.StrategyCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionStartingEvent +import ai.koog.agents.core.feature.model.events.ToolCallCompletedEvent +import ai.koog.agents.core.feature.model.events.ToolCallStartingEvent import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.features.tracing.eventString import ai.koog.agents.features.tracing.feature.Tracing @@ -174,8 +174,8 @@ class TraceFeatureMessageLogWriterTest { content = """{"dummy":"test"}""" ) })", - "[INFO] Received feature message [event]: ${ToolExecutionStartingEvent::class.simpleName} (run id: $runId, tool: ${dummyTool.name}, tool args: {\"dummy\":\"test\"})", - "[INFO] Received feature message [event]: ${ToolExecutionCompletedEvent::class.simpleName} (run id: $runId, tool: ${dummyTool.name}, tool args: {\"dummy\":\"test\"}, result: ${dummyTool.result})", + "[INFO] Received feature message [event]: ${ToolCallStartingEvent::class.simpleName} (run id: $runId, tool: ${dummyTool.name}, tool args: {\"dummy\":\"test\"})", + "[INFO] Received feature message [event]: ${ToolCallCompletedEvent::class.simpleName} (run id: $runId, tool: ${dummyTool.name}, tool args: {\"dummy\":\"test\"}, result: ${dummyTool.result})", "[INFO] Received feature message [event]: ${NodeExecutionCompletedEvent::class.simpleName} (run id: $runId, node: test-tool-call, input: ${ toolCallMessage( dummyTool.name, diff --git a/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageRemoteWriterTest.kt b/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageRemoteWriterTest.kt index 3e70d41122..9e092f411e 100644 --- a/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageRemoteWriterTest.kt +++ b/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageRemoteWriterTest.kt @@ -20,8 +20,8 @@ import ai.koog.agents.core.feature.model.events.StrategyCompletedEvent import ai.koog.agents.core.feature.model.events.StrategyEventGraph import ai.koog.agents.core.feature.model.events.StrategyEventGraphEdge import ai.koog.agents.core.feature.model.events.StrategyEventGraphNode -import ai.koog.agents.core.feature.model.events.ToolExecutionCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionStartingEvent +import ai.koog.agents.core.feature.model.events.ToolCallCompletedEvent +import ai.koog.agents.core.feature.model.events.ToolCallStartingEvent import ai.koog.agents.core.feature.remote.client.FeatureMessageRemoteClient import ai.koog.agents.core.feature.remote.client.config.DefaultClientConnectionConfig import ai.koog.agents.core.feature.remote.server.config.DefaultServerConnectionConfig @@ -347,14 +347,14 @@ class TraceFeatureMessageRemoteWriterTest { input = toolCallMessage(dummyTool.name, content = """{"dummy":"test"}""").toString(), timestamp = testClock.now().toEpochMilliseconds() ), - ToolExecutionStartingEvent( + ToolCallStartingEvent( runId = runId, toolCallId = "0", toolName = dummyTool.name, toolArgs = dummyTool.encodeArgs(DummyTool.Args("test")), timestamp = testClock.now().toEpochMilliseconds() ), - ToolExecutionCompletedEvent( + ToolCallCompletedEvent( runId = runId, toolCallId = "0", toolName = dummyTool.name, diff --git a/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageTestWriterTest.kt b/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageTestWriterTest.kt index 2879ef0f92..7d26022efe 100644 --- a/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageTestWriterTest.kt +++ b/agents/agents-features/agents-features-trace/src/jvmTest/kotlin/ai/koog/agents/features/tracing/writer/TraceFeatureMessageTestWriterTest.kt @@ -8,8 +8,8 @@ import ai.koog.agents.core.dsl.extension.nodeUpdatePrompt import ai.koog.agents.core.feature.model.AIAgentError import ai.koog.agents.core.feature.model.events.LLMCallStartingEvent import ai.koog.agents.core.feature.model.events.NodeExecutionFailedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionCompletedEvent -import ai.koog.agents.core.feature.model.events.ToolExecutionStartingEvent +import ai.koog.agents.core.feature.model.events.ToolCallCompletedEvent +import ai.koog.agents.core.feature.model.events.ToolCallStartingEvent import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.features.tracing.feature.Tracing import ai.koog.agents.features.tracing.mock.RecursiveTool @@ -155,10 +155,10 @@ class TraceFeatureMessageTestWriterTest { agent.run("") - val toolCallsStartEvent = messageProcessor.messages.filterIsInstance().toList() + val toolCallsStartEvent = messageProcessor.messages.filterIsInstance().toList() assertEquals(1, toolCallsStartEvent.size, "Tool call start event for existing tool") - val toolCallsEndEvent = messageProcessor.messages.filterIsInstance().toList() + val toolCallsEndEvent = messageProcessor.messages.filterIsInstance().toList() assertEquals(1, toolCallsEndEvent.size, "Tool call end event for existing tool") } @@ -195,7 +195,7 @@ class TraceFeatureMessageTestWriterTest { agent.run("") - val toolCallsStartEvent = messageProcessor.messages.filterIsInstance().toList() + val toolCallsStartEvent = messageProcessor.messages.filterIsInstance().toList() assertEquals(1, toolCallsStartEvent.size, "Tool call start event for existing tool") } @@ -234,10 +234,10 @@ class TraceFeatureMessageTestWriterTest { agent.run("") - val toolCallsStartEvent = messageProcessor.messages.filterIsInstance().toList() + val toolCallsStartEvent = messageProcessor.messages.filterIsInstance().toList() assertEquals(1, toolCallsStartEvent.size, "Tool call start event for existing tool") - val toolCallsEndEvent = messageProcessor.messages.filterIsInstance().toList() + val toolCallsEndEvent = messageProcessor.messages.filterIsInstance().toList() assertEquals(1, toolCallsEndEvent.size, "Tool call end event for existing tool") } diff --git a/agents/agents-test/TESTING.md b/agents/agents-test/TESTING.md index 310d67026b..7c9c266bd2 100644 --- a/agents/agents-test/TESTING.md +++ b/agents/agents-test/TESTING.md @@ -307,7 +307,7 @@ fun testToneAgent() = runTest { // Create an event handler val eventHandler = EventHandler { - onToolCall { tool, args -> + onToolCallStarting { tool, args -> println("[DEBUG_LOG] Tool called: tool ${tool.name}, args $args") toolCalls.add(tool.name) } diff --git a/agents/agents-test/src/jvmTest/kotlin/ai/koog/agents/test/SimpleAgentMockedTest.kt b/agents/agents-test/src/jvmTest/kotlin/ai/koog/agents/test/SimpleAgentMockedTest.kt index 94d2abff54..f469623512 100644 --- a/agents/agents-test/src/jvmTest/kotlin/ai/koog/agents/test/SimpleAgentMockedTest.kt +++ b/agents/agents-test/src/jvmTest/kotlin/ai/koog/agents/test/SimpleAgentMockedTest.kt @@ -67,29 +67,29 @@ class SimpleAgentMockedTest { } val eventHandlerConfig: EventHandlerConfig.() -> Unit = { - onToolCall { eventContext -> + onToolCallStarting { eventContext -> println("Tool called: tool ${eventContext.tool.name}, args ${eventContext.toolArgs}") actualToolCalls.add(eventContext.tool.name) iterationCount++ } - onAgentRunError { eventContext -> + onAgentExecutionFailed { eventContext -> errors.add(eventContext.throwable) } - onToolCall { eventContext -> + onToolCallStarting { eventContext -> println("Tool called: tool ${eventContext.tool.name}, args ${eventContext.toolArgs}") actualToolCalls.add(eventContext.tool.name) } - onToolCallFailure { eventContext -> + onToolCallFailed { eventContext -> println( "Tool call failure: tool ${eventContext.tool.name}, args ${eventContext.toolArgs}, error=${eventContext.throwable.message}" ) errors.add(eventContext.throwable) } - onAgentFinished { eventContext -> + onAgentCompleted { eventContext -> results.add(eventContext.result) } } diff --git a/docs/docs/act-ai-agent.md b/docs/docs/act-ai-agent.md new file mode 100644 index 0000000000..fc3cf8165d --- /dev/null +++ b/docs/docs/act-ai-agent.md @@ -0,0 +1,202 @@ +# FunctionalAIAgent: How to build a single‑run agent step by step + +FunctionalAIAgent is a lightweight, non‑graph agent that you control with a simple loop. Use it when you want to: +- Call an LLM once or a few times in a custom loop; +- Optionally call tools between LLM turns; +- Return a final value (string, data class, etc.) without building a full strategy graph. + +What you’ll do in this guide: +1) Create a “Hello, World” FunctionalAIAgent. +2) Add a tool and let the agent call it. +3) Add a feature (event handler) to observe behavior. +4) Keep context under control with history compression. +5) Learn common recipes, pitfalls, and FAQs. + +## 1) Prerequisites +You need a PromptExecutor (the object that actually talks to your LLM). For local experimenting, you can use the Ollama executor: + +```kotlin +val exec = simpleOllamaAIExecutor() +``` + +You also need to pick a model, for example: + +```kotlin +val model = OllamaModels.Meta.LLAMA_3_2 +``` + +That’s it — we’ll inject both into the agent factory. + + +## 2) Your first agent (Hello, World) +Goal: Send the user’s text to the LLM and return a single assistant message as a string. + +```kotlin +val agent = functionalAIAgent( + prompt = "You are a helpful assistant.", + promptExecutor = exec, + model = model +) { input -> + val responses = requestLLMMultiple(input) + responses.single().asAssistantMessage().content +} + +val result = agent.run("Say hi in one sentence") +println(result) +``` + +What happens? +- requestLLMMultiple(input) sends the user input and receives one or more assistant messages. +- We return the only message’s content (typical one‑shot flow). + +Tip: If you want to return structured data, parse the content or use the Structured Data API. + + +## 3) Add tools (how the agent calls your functions) +Goal: Let the model operate a tiny device via tools. + +```kotlin +class Switch { + private var on = false + fun on() { on = true } + fun off() { on = false } + fun isOn() = on +} + +class SwitchTools(private val sw: Switch) { + fun turn_on() = run { sw.on(); "ok" } + fun turn_off() = run { sw.off(); "ok" } + fun state() = if (sw.isOn()) "on" else "off" +} + +val sw = Switch() +val tools = ToolRegistry { tools(SwitchTools(sw).asTools()) } + +val toolAgent = functionalAIAgent( + prompt = "You're responsible for running a Switch device and perform operations on it by request.", + promptExecutor = exec, + model = model, + toolRegistry = tools +) { input -> + var responses = requestLLMMultiple(input) + + while (responses.containsToolCalls()) { + val pending = extractToolCalls(responses) + val results = executeMultipleTools(pending) + responses = sendMultipleToolResults(results) + } + + responses.single().asAssistantMessage().content +} + +val out = toolAgent.run("Turn switch on") +println(out) +println("Switch is ${if (sw.isOn()) "on" else "off"}") +``` + +How it works +- containsToolCalls() detects tool call messages from the LLM. +- extractToolCalls(...) reads which tools to run and with what args. +- executeMultipleTools(...) runs them against your ToolRegistry. +- sendMultipleToolResults(...) sends results back to the LLM and gets the next response. + + +## 4) Observe behavior with features (EventHandler) +Goal: Print every tool call to the console. + +```kotlin +val observed = functionalAIAgent( + prompt = "...", + promptExecutor = exec, + model = model, + toolRegistry = tools, + featureContext = { + install(EventHandler) { + onToolCallStarting { e -> println("Tool called: ${'$'}{e.tool.name}, args: ${'$'}{e.toolArgs}") } + } + } +) { input -> + var responses = requestLLMMultiple(input) + while (responses.containsToolCalls()) { + val pending = extractToolCalls(responses) + val results = executeMultipleTools(pending) + responses = sendMultipleToolResults(results) + } + responses.single().asAssistantMessage().content +} +``` + +Other features you can install this way include streaming tokens and tracing; see the related docs in the sidebar. + + +## 5) Keep context under control (history compression) +Long conversations can exceed the model’s context window. Use the token usage to decide when to compress history: + +```kotlin +var responses = requestLLMMultiple(input) + +while (responses.containsToolCalls()) { + if (latestTokenUsage() > 100_000) { + compressHistory() + } + val pending = extractToolCalls(responses) + val results = executeMultipleTools(pending) + responses = sendMultipleToolResults(results) +} +``` + +Use a threshold appropriate for your model and prompt size. + + +## Common recipes +- Return structured output + - Ask the LLM to format JSON and parse it; or use Structured Data API. +- Validate tool inputs + - Perform validation in tool functions and return clear error messages. +- One agent instance per request + - Each agent instance is single‑run at a time. Create new instances if you need concurrency. +- Custom Output type + - Change functionalAIAgent and return a data class from the loop. + + +## Troubleshooting & pitfalls +- “Agent is already running” + - FunctionalAIAgent prevents concurrent runs on the same instance. Don’t share one instance across parallel coroutines; create a fresh agent per run or await completion. +- Empty or unexpected model output + - Check your system prompt. Print intermediate responses. Consider adding few‑shot examples. +- Loop never ends + - Ensure you break when there are no tool calls; add guards/timeouts for safety. +- Context overflows + - Watch latestTokenUsage() and call compressHistory(). + + +## Reference (quick) +Constructors + +```kotlin +fun functionalAIAgent( + promptExecutor: PromptExecutor, + agentConfig: AIAgentConfigBase, + toolRegistry: ToolRegistry = ToolRegistry.EMPTY, + loop: suspend AIAgentFunctionalContext.(input: Input) -> Output +): AIAgent + +fun functionalAIAgent( + promptExecutor: PromptExecutor, + toolRegistry: ToolRegistry = ToolRegistry.EMPTY, + prompt: String = "", + model: LLModel = OpenAIModels.Chat.GPT4o, + featureContext: FeatureContext.() -> Unit = {}, + func: suspend AIAgentFunctionalContext.(input: Input) -> Output, +): AIAgent +``` + +Important types +- FunctionalAIAgent +- AIAgentFunctionalContext +- AIAgentConfig / AIAgentConfigBase +- PromptExecutor +- ToolRegistry +- FeatureContext and feature interfaces + +See source: agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/FunctionalAIAgent.kt diff --git a/docs/docs/agent-events.md b/docs/docs/agent-events.md index a0674636e6..9ea3eb4b09 100644 --- a/docs/docs/agent-events.md +++ b/docs/docs/agent-events.md @@ -56,7 +56,7 @@ val agent = AIAgent( ```kotlin handleEvents { // Handle tool calls - onToolExecutionStarting { eventContext -> + onToolCallStarting { eventContext -> println("Tool called: ${eventContext.tool.name} with args ${eventContext.toolArgs}") } // Handle event triggered when the agent completes its execution @@ -87,7 +87,7 @@ val agent = AIAgent( ){ handleEvents { // Handle tool calls - onToolExecutionStarting { eventContext -> + onToolCallStarting { eventContext -> println("Tool called: ${eventContext.tool.name} with args ${eventContext.toolArgs}") } // Handle event triggered when the agent completes its execution diff --git a/docs/docs/examples/Calculator.md b/docs/docs/examples/Calculator.md index eff11b74c2..f50a809719 100644 --- a/docs/docs/examples/Calculator.md +++ b/docs/docs/examples/Calculator.md @@ -170,13 +170,13 @@ val agent = AIAgent( toolRegistry = toolRegistry ) { handleEvents { - onToolCall { e -> + onToolCallStarting { e -> println("Tool called: ${e.tool.name}, args=${e.toolArgs}") } - onAgentRunError { e -> + onAgentExecutionFailed { e -> println("Agent error: ${e.throwable.message}") } - onAgentFinished { e -> + onAgentCompleted { e -> println("Final result: ${e.result}") } } diff --git a/docs/docs/examples/UnityMcp.md b/docs/docs/examples/UnityMcp.md index 65cf10086d..8100f244c9 100644 --- a/docs/docs/examples/UnityMcp.md +++ b/docs/docs/examples/UnityMcp.md @@ -123,17 +123,17 @@ runBlocking { install(Tracing) install(EventHandler) { - onBeforeAgentStarted { eventContext -> - println("OnBeforeAgentStarted first (strategy: ${strategy.name})") + onAgentStarting { eventContext -> + println("OnAgentStarting first (strategy: ${strategy.name})") } - onBeforeAgentStarted { eventContext -> - println("OnBeforeAgentStarted second (strategy: ${strategy.name})") + onAgentStarting { eventContext -> + println("OnAgentStarting second (strategy: ${strategy.name})") } - onAgentFinished { eventContext -> + onAgentCompleted { eventContext -> println( - "OnAgentFinished (agent id: ${eventContext.agentId}, result: ${eventContext.result})" + "OnAgentCompleted (agent id: ${eventContext.agentId}, result: ${eventContext.result})" ) } } diff --git a/docs/docs/streaming-api.md b/docs/docs/streaming-api.md index 74256d091f..bea04dff15 100644 --- a/docs/docs/streaming-api.md +++ b/docs/docs/streaming-api.md @@ -152,7 +152,7 @@ fun GraphAIAgent.FeatureContext.installStreamingApi() { --> ```kotlin handleEvents { - onToolExecutionStarting { context -> + onToolCallStarting { context -> println("\n🔧 Using ${context.tool.name} with ${context.toolArgs}... ") } onLLMStreamingFrameReceived { context -> diff --git a/docs/docs/testing.md b/docs/docs/testing.md index 60a3e5665a..8896edd5ce 100644 --- a/docs/docs/testing.md +++ b/docs/docs/testing.md @@ -799,7 +799,7 @@ fun testToneAgent() = runTest { // Create an event handler val eventHandler = EventHandler { - onToolCall { tool, args -> + onToolCallStarting { tool, args -> println("[DEBUG_LOG] Tool called: tool ${tool.name}, args $args") toolCalls.add(tool.name) } diff --git a/docs/docs/tracing.md b/docs/docs/tracing.md index 953b2188cf..a7065e814f 100644 --- a/docs/docs/tracing.md +++ b/docs/docs/tracing.md @@ -37,7 +37,7 @@ To use the Tracing feature, you need to: + + +### Installation and configuration + +The EventHandler feature integrates with the agent workflow through the `EventHandler` class, +which provides a way to register callbacks for different agent events, and can be installed as a feature in the agent configuration. For details, see [API reference](https://api.koog.ai/agents/agents-features/agents-features-event-handler/ai.koog.agents.features.eventHandler.feature/-event-handler/index.html). + +To install the feature and configure event handlers for the agent, do the following: + + + + +```kotlin +handleEvents { + // Handle tool calls + onToolCallStarting { eventContext -> + println("Tool called: ${eventContext.tool.name} with args ${eventContext.toolArgs}") + } + // Handle event triggered when the agent completes its execution + onAgentCompleted { eventContext -> + println("Agent finished with result: ${eventContext.result}") + } + + // Other event handlers +} +``` + + +For more details about event handler configuration, see [API reference](https://api.koog.ai/agents/agents-features/agents-features-event-handler/ai.koog.agents.features.eventHandler.feature/-event-handler-config/index.html). + +You can also set up event handlers using the `handleEvents` extension function when creating an agent. +This function also installs the event handler feature and configures event handlers for the agent. Here is an example: + + +```kotlin +val agent = AIAgent( + promptExecutor = simpleOllamaAIExecutor(), + llmModel = OllamaModels.Meta.LLAMA_3_2, +){ + handleEvents { + // Handle tool calls + onToolCallStarting { eventContext -> + println("Tool called: ${eventContext.tool.name} with args ${eventContext.toolArgs}") + } + // Handle event triggered when the agent completes its execution + onAgentCompleted { eventContext -> + println("Agent finished with result: ${eventContext.result}") + } + + // Other event handlers + } +} +``` + diff --git a/docs/docs/agent-events.md b/docs/docs/agent-events.md index 666a8ad322..9ced231f5a 100644 --- a/docs/docs/agent-events.md +++ b/docs/docs/agent-events.md @@ -11,95 +11,424 @@ Agent events are actions or interactions that occur as part of an agent workflow Note: Feature events are defined in the agents-core module and live under the package `ai.koog.agents.core.feature.model.events`. Features such as `agents-features-trace`, `agents-features-debugger`, and `agents-features-event-handler` consume these events to process and forward messages created during agent execution. -## Event handlers +## Predefined event types -You can monitor and respond to specific events during the agent workflow by using event handlers for logging, testing, debugging, and extending agent behavior. +Koog provides predefined event types that can be used in custom message processors. The predefined events can be +classified into several categories, depending on the entity they relate to: -The EventHandler feature lets you hook into various agent events. It serves as an event delegation mechanism that: +- [Agent events](#agent-events) +- [Strategy events](#strategy-events) +- [Node events](#node-events) +- [LLM call events](#llm-call-events) +- [LLM streaming events](#llm-streaming-events) +- [Tool execution events](#tool-execution-events) -- Manages the lifecycle of AI agent operations. -- Provides hooks for monitoring and responding to different stages of the workflow. -- Enables error handling and recovery. -- Facilitates tool invocation tracking and result processing. +### Agent events - +| Name | Data type | Required | Default | Description | +|------------|-----------|----------|---------|---------------------------------------------------------| +| `agentId` | String | Yes | | The unique identifier of the AI agent. | +| `runId` | String | Yes | | The unique identifier of the AI agent run. | +#### AgentCompletedEvent -### Installation and configuration +Represents the end of an agent run. Includes the following fields: -The EventHandler feature integrates with the agent workflow through the `EventHandler` class, -which provides a way to register callbacks for different agent events, and can be installed as a feature in the agent configuration. For details, see [API reference](https://api.koog. -ai/agents/agents-features/agents-features-event-handler/ai.koog.agents.local.features.eventHandler.feature/-event-handler/index.html). +| Name | Data type | Required | Default | Description | +|------------|-----------|----------|---------|---------------------------------------------------------------------| +| `agentId` | String | Yes | | The unique identifier of the AI agent. | +| `runId` | String | Yes | | The unique identifier of the AI agent run. | +| `result` | String | Yes | | The result of the agent run. Can be `null` if there is no result. | -To install the feature and configure event handlers for the agent, do the following: +#### AgentExecutionFailedEvent + +Represents the occurrence of an error during an agent run. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|-----------|---------------|----------|---------|-----------------------------------------------------------------------------------------------------------------| +| `agentId` | String | Yes | | The unique identifier of the AI agent. | +| `runId` | String | Yes | | The unique identifier of the AI agent run. | +| `error` | AIAgentError | Yes | | The specific error that occurred during the agent run. For more information, see [AIAgentError](#aiagenterror). | + +#### AgentClosingEvent + +Represents the closure or termination of an agent. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|-----------|-----------|----------|---------|---------------------------------------------------------| +| `agentId` | String | Yes | | The unique identifier of the AI agent. | + + +The `AIAgentError` class provides more details about an error that occurred during an agent run. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|--------------|-----------|----------|---------|------------------------------------------------------------------| +| `message` | String | Yes | | The message that provides more details about the specific error. | +| `stackTrace` | String | Yes | | The collection of stack records until the last executed code. | +| `cause` | String | No | null | The cause of the error, if available. | + +### Strategy events + +#### GraphStrategyStartingEvent + +Represents the start of a graph-based strategy run. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|-----------------|------------------------|----------|---------|----------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the strategy run. | +| `strategyName` | String | Yes | | The name of the strategy. | +| `graph` | StrategyEventGraph | Yes | | The graph structure representing the strategy workflow. | + +#### FunctionalStrategyStartingEvent + +Represents the start of a functional strategy run. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|-----------------|-----------|----------|---------|--------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the strategy run. | +| `strategyName` | String | Yes | | The name of the strategy. | + +#### StrategyCompletedEvent + +Represents the end of a strategy run. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|----------------|-----------|----------|---------|--------------------------------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the strategy run. | +| `strategyName` | String | Yes | | The name of the strategy. | +| `result` | String | Yes | | The result of the run. Can be `null` if there is no result. | + +### Node events + +#### NodeExecutionStartingEvent + +Represents the start of a node run. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|------------|-----------|----------|---------|--------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the strategy run. | +| `nodeName` | String | Yes | | The name of the node whose run started. | +| `input` | String | Yes | | The input value for the node. | + +#### NodeExecutionCompletedEvent + +Represents the end of a node run. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|------------|-----------|----------|---------|--------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the strategy run. | +| `nodeName` | String | Yes | | The name of the node whose run ended. | +| `input` | String | Yes | | The input value for the node. | +| `output` | String | Yes | | The output value produced by the node. | + +#### NodeExecutionFailedEvent + +Represents an error that occurred during a node run. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|------------|--------------|----------|---------|-----------------------------------------------------------------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the strategy run. | +| `nodeName` | String | Yes | | The name of the node where the error occurred. | +| `error` | AIAgentError | Yes | | The specific error that occurred during the node run. For more information, see [AIAgentError](#aiagenterror). | + +### LLM call events + +#### LLMCallStartingEvent + +Represents the start of an LLM call. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|----------|--------------------|----------|---------|------------------------------------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the LLM run. | +| `prompt` | Prompt | Yes | | The prompt that is sent to the model. For more information, see [Prompt](#prompt). | +| `model` | String | Yes | | The model identifier in the format `llm_provider:model_id`. | +| `tools` | List | Yes | | The list of tools that the model can call. | + + +The `Prompt` class represents a data structure for a prompt, consisting of a list of messages, a unique identifier, and +optional parameters for language model settings. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|------------|---------------------|----------|-------------|--------------------------------------------------------------| +| `messages` | List | Yes | | The list of messages that the prompt consists of. | +| `id` | String | Yes | | The unique identifier for the prompt. | +| `params` | LLMParams | No | LLMParams() | The settings that control the way the LLM generates content. | + +#### LLMCallCompletedEvent + +Represents the end of an LLM call. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|----------------------|--------------------------------|----------|---------|-----------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the LLM run. | +| `prompt` | Prompt | Yes | | The prompt used in the call. | +| `model` | String | Yes | | The model identifier in the format `llm_provider:model_id`.| +| `responses` | List | Yes | | One or more responses returned by the model. | +| `moderationResponse` | ModerationResult | No | null | The moderation response, if any. | + +### LLM streaming events + +#### LLMStreamingStartingEvent + +Represents the start of an LLM streaming call. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|----------|--------------|----------|---------|-------------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the LLM run. | +| `prompt` | Prompt | Yes | | The prompt that is sent to the model. | +| `model` | String | Yes | | The model identifier in the format `llm_provider:model_id`. | +| `tools` | List | Yes | | The list of tools that the model can call. | + +#### LLMStreamingFrameReceivedEvent + +Represents a streaming frame received from the LLM. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|----------|-------------|----------|---------|--------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the LLM run. | +| `frame` | StreamFrame | Yes | | The frame received from the stream. | + +#### LLMStreamingFailedEvent + +Represents the occurrence of an error during an LLM streaming call. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|---------|--------------|----------|---------|-----------------------------------------------------------------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the LLM run. | +| `error` | AIAgentError | Yes | | The specific error that occurred during streaming. For more information, see [AIAgentError](#aiagenterror). | + +#### LLMStreamingCompletedEvent + +Represents the end of an LLM streaming call. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|----------|--------------|----------|---------|-------------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the LLM run. | +| `prompt` | Prompt | Yes | | The prompt that is sent to the model. | +| `model` | String | Yes | | The model identifier in the format `llm_provider:model_id`. | +| `tools` | List | Yes | | The list of tools that the model can call. | + +### Tool execution events + +#### ToolExecutionStartingEvent + +Represents the event of a model calling a tool. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|---------------|-------------|----------|---------|---------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the strategy/agent run. | +| `toolCallId` | String | No | null | The identifier of the tool call, if available. | +| `toolName` | String | Yes | | The name of the tool. | +| `toolArgs` | JsonObject | Yes | | The arguments that are provided to the tool. | + +#### ToolValidationFailedEvent + +Represents the occurrence of a validation error during a tool call. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|---------------|-------------|----------|---------|---------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the strategy/agent run. | +| `toolCallId` | String | No | null | The identifier of the tool call, if available. | +| `toolName` | String | Yes | | The name of the tool for which validation failed. | +| `toolArgs` | JsonObject | Yes | | The arguments that are provided to the tool. | +| `error` | String | Yes | | The validation error message. | + +#### ToolExecutionFailedEvent + +Represents a failure to execute a tool. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|---------------|--------------|----------|---------|-------------------------------------------------------------------------------------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the strategy/agent run. | +| `toolCallId` | String | No | null | The identifier of the tool call, if available. | +| `toolName` | String | Yes | | The name of the tool. | +| `toolArgs` | JsonObject | Yes | | The arguments that are provided to the tool. | +| `error` | AIAgentError | Yes | | The specific error that occurred when trying to call a tool. For more information, see [AIAgentError](#aiagenterror). | + +#### ToolExecutionCompletedEvent + +Represents a successful tool call with the return of a result. Includes the following fields: + +| Name | Data type | Required | Default | Description | +|---------------|------------|----------|---------|------------------------------------------| +| `runId` | String | Yes | | The unique identifier of the run. | +| `toolCallId` | String | No | null | The identifier of the tool call. | +| `toolName` | String | Yes | | The name of the tool. | +| `toolArgs` | JsonObject | Yes | | The arguments provided to the tool. | +| `result` | String | Yes | | The result of the tool call (nullable). | + +## FAQ and troubleshooting + +The following section includes commonly asked questions and answers related to the Tracing feature. + +### How do I trace only specific parts of my agent's execution? + +Use the `messageFilter` property to filter events. For example, to trace only node execution: - - ```kotlin -handleEvents { - // Handle tool calls - onToolCallStarting { eventContext -> - println("Tool called: ${eventContext.tool.name} with args ${eventContext.toolArgs}") - } - // Handle event triggered when the agent completes its execution - onAgentCompleted { eventContext -> - println("Agent finished with result: ${eventContext.result}") +install(Tracing) { + val fileWriter = TraceFeatureMessageFileWriter( + outputPath, + { path: Path -> SystemFileSystem.sink(path).buffered() } + ) + addMessageProcessor(fileWriter) + + // Only trace LLM calls + fileWriter.setMessageFilter { message -> + message is LLMCallStartingEvent || message is LLMCallCompletedEvent } - - // Other event handlers } ``` - + -For more details about event handler configuration, see [API reference](https://api.koog.ai/agents/agents-features/agents-features-event-handler/ai.koog.agents.local.features.eventHandler.feature/-event-handler-config/index.html). +### Can I use multiple message processors? -You can also set up event handlers using the `handleEvents` extension function when creating an agent. -This function also installs the event handler feature and configures event handlers for the agent. Here is an example: +Yes, you can add multiple message processors to trace to different destinations simultaneously: + ```kotlin -val agent = AIAgent( - promptExecutor = simpleOllamaAIExecutor(), - llmModel = OllamaModels.Meta.LLAMA_3_2, -){ - handleEvents { - // Handle tool calls - onToolCallStarting { eventContext -> - println("Tool called: ${eventContext.tool.name} with args ${eventContext.toolArgs}") +install(Tracing) { + addMessageProcessor(TraceFeatureMessageLogWriter(logger)) + addMessageProcessor(TraceFeatureMessageFileWriter(outputPath, syncOpener)) + addMessageProcessor(TraceFeatureMessageRemoteWriter(connectionConfig)) +} +``` + + +### How can I create a custom message processor? + +Implement the `FeatureMessageProcessor` interface: + + + +```kotlin +class CustomTraceProcessor : FeatureMessageProcessor() { + + // Current open state of the processor + private var _isOpen = MutableStateFlow(false) + + override val isOpen: StateFlow + get() = _isOpen.asStateFlow() + + override suspend fun processMessage(message: FeatureMessage) { + // Custom processing logic + when (message) { + is NodeExecutionStartingEvent -> { + // Process node start event + } + + is LLMCallCompletedEvent -> { + // Process LLM call end event + } + // Handle other event types } + } - // Other event handlers + override suspend fun close() { + // Close connections of established } } + +// Use your custom processor +install(Tracing) { + addMessageProcessor(CustomTraceProcessor()) +} ``` - + + +For more information about existing event types that can be handled by message processors, see [Predefined event types](#predefined-event-types). diff --git a/docs/docs/features-overview.md b/docs/docs/features-overview.md index 0e3b8de886..ce51ef5f84 100644 --- a/docs/docs/features-overview.md +++ b/docs/docs/features-overview.md @@ -8,6 +8,11 @@ Agent features provide a way to extend and enhance the functionality of AI agent The Koog framework implements the following features: -- [Event Handler](agent-events.md) +- [Event Handler](agent-event-handlers.md) - [Tracing](tracing.md) - [Agent Memory](agent-memory.md) +- [OpenTelemetry](opentelemetry-support.md) +- [Agent Persistency (Snapshots)](agent-persistency.md) +- Debugger +- Tokenizer +- SQL Persistency Providers diff --git a/docs/docs/single-run-agents.md b/docs/docs/single-run-agents.md index 07f877ef6b..5dfdc55d2f 100644 --- a/docs/docs/single-run-agents.md +++ b/docs/docs/single-run-agents.md @@ -150,7 +150,7 @@ val agent = AIAgent( Single-run agents support custom event handlers. While having an event handler is not required for creating an agent, it might be helpful for testing, debugging, or making hooks for chained agent interactions. -For more information on how to use the `EventHandler` feature for monitoring your agent interactions, see [Agent events](agent-events.md). +For more information on how to use the `EventHandler` feature for monitoring your agent interactions, see [Event Handlers](agent-event-handlers.md). ### 8. Run the agent diff --git a/docs/docs/streaming-api.md b/docs/docs/streaming-api.md index bea04dff15..ae2ca02a88 100644 --- a/docs/docs/streaming-api.md +++ b/docs/docs/streaming-api.md @@ -137,7 +137,7 @@ llm.writeSession { ### Listening to stream events in event handlers -You can listen to stream events in [agent events](agent-events.md). +You can listen to stream events in [agent event handlers](agent-event-handlers.md). - -```kotlin -install(Tracing) { - val fileWriter = TraceFeatureMessageFileWriter( - outputPath, - { path: Path -> SystemFileSystem.sink(path).buffered() } - ) - addMessageProcessor(fileWriter) - - // Only trace LLM calls - fileWriter.setMessageFilter { message -> - message is LLMCallStartingEvent || message is LLMCallCompletedEvent - } -} -``` - - -### Can I use multiple message processors? - -Yes, you can add multiple message processors to trace to different destinations simultaneously: - - - -```kotlin -install(Tracing) { - addMessageProcessor(TraceFeatureMessageLogWriter(logger)) - addMessageProcessor(TraceFeatureMessageFileWriter(outputPath, syncOpener)) - addMessageProcessor(TraceFeatureMessageRemoteWriter(connectionConfig)) -} -``` - - -### How can I create a custom message processor? - -Implement the `FeatureMessageProcessor` interface: - - - -```kotlin -class CustomTraceProcessor : FeatureMessageProcessor() { - - // Current open state of the processor - private var _isOpen = MutableStateFlow(false) - - override val isOpen: StateFlow - get() = _isOpen.asStateFlow() - - override suspend fun processMessage(message: FeatureMessage) { - // Custom processing logic - when (message) { - is NodeExecutionStartingEvent -> { - // Process node start event - } - - is LLMCallCompletedEvent -> { - // Process LLM call end event - } - // Handle other event types - } - } - - override suspend fun close() { - // Close connections of established - } -} - -// Use your custom processor -install(Tracing) { - addMessageProcessor(CustomTraceProcessor()) -} -``` - - -For more information about existing event types that can be handled by message processors, see [Predefined event types](#predefined-event-types). - -## Predefined event types - -Koog provides predefined event types that can be used in custom message processors. The predefined events can be -classified into several categories, depending on the entity they relate to: - -- [Agent events](#agent-events) -- [Strategy events](#strategy-events) -- [Node events](#node-events) -- [LLM call events](#llm-call-events) -- [Tool call events](#tool-call-events) - -### Agent events - -#### AgentStartingEvent - -Represents the start of an agent run. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|----------------|-----------|----------|-----------------------|---------------------------------------------------------------------------| -| `strategyName` | String | Yes | | The name of the strategy that the agent should follow. | -| `eventId` | String | No | `AgentStartingEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - -#### AgentCompletedEvent - -Represents the end of an agent run. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|----------------|-----------|----------|------------------------|---------------------------------------------------------------------------| -| `strategyName` | String | Yes | | The name of the strategy that the agent followed. | -| `result` | String | Yes | | The result of the agent run. Can be `null` if there is no result. | -| `eventId` | String | No | `AgentCompletedEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - -#### AgentExecutionFailedEvent - -Represents the occurrence of an error during an agent run. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|----------------|--------------|----------|------------------------|-----------------------------------------------------------------------------------------------------------------| -| `strategyName` | String | Yes | | The name of the strategy that the agent followed. | -| `error` | AIAgentError | Yes | | The specific error that occurred during the agent run. For more information, see [AIAgentError](#aiagenterror). | -| `eventId` | String | No | `AgentExecutionFailedEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - - -The `AIAgentError` class provides more details about an error that occurred during an agent run. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|--------------|-----------|----------|---------|------------------------------------------------------------------| -| `message` | String | Yes | | The message that provides more details about the specific error. | -| `stackTrace` | String | Yes | | The collection of stack records until the last executed code. | -| `cause` | String | No | null | The cause of the error, if available. | - -### Strategy events - -#### StrategyStartingEvent - -Represents the start of a strategy run. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|----------------|-----------|----------|-----------------------------|---------------------------------------------------------------------------| -| `strategyName` | String | Yes | | The name of the strategy. | -| `eventId` | String | No | `StrategyStartingEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - -#### StrategyCompletedEvent - -Represents the end of a strategy run. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|----------------|-----------|----------|--------------------------------|---------------------------------------------------------------------------| -| `strategyName` | String | Yes | | The name of the strategy. | -| `result` | String | Yes | | The result of the run. | -| `eventId` | String | No | `StrategyCompletedEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - -### Node events - -#### NodeExecutionStartingEvent - -Represents the start of a node run. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|------------|-----------|----------|----------------------------------|---------------------------------------------------------------------------| -| `nodeName` | String | Yes | | The name of the node whose run started. | -| `input` | String | Yes | | The input value for the node. | -| `eventId` | String | No | `NodeExecutionStartingEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - -#### NodeExecutionCompletedEvent - -Represents the end of a node run. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|------------|-----------|----------|--------------------------------|---------------------------------------------------------------------------| -| `nodeName` | String | Yes | | The name of the node whose run ended. | -| `input` | String | Yes | | The input value for the node. | -| `output` | String | Yes | | The output value produced by the node. | -| `eventId` | String | No | `NodeExecutionCompletedEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - -### LLM call events - -#### LLMCallStartingEvent - -Represents the start of an LLM call. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|-----------|--------------------|----------|----------------------|------------------------------------------------------------------------------------| -| `prompt` | Prompt | Yes | | The prompt that is sent to the model. For more information, see [Prompt](#prompt). | -| `tools` | List<String> | Yes | | The list of tools that the model can call. | -| `eventId` | String | No | `LLMCallStartingEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - - -The `Prompt` class represents a data structure for a prompt, consisting of a list of messages, a unique identifier, and -optional parameters for language model settings. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|------------|---------------------|----------|-------------|--------------------------------------------------------------| -| `messages` | List<Message> | Yes | | The list of messages that the prompt consists of. | -| `id` | String | Yes | | The unique identifier for the prompt. | -| `params` | LLMParams | No | LLMParams() | The settings that control the way the LLM generates content. | - -#### LLMCallCompletedEvent - -Represents the end of an LLM call. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|-------------|------------------------------|----------|---------------------|---------------------------------------------------------------------------| -| `responses` | List<Message.Response> | Yes | | One or more responses returned by the model. | -| `eventId` | String | No | `LLMCallCompletedEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - -### Tool call events - -#### ToolCallEvent - -Represents the event of a model calling a tool. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|------------|-----------|----------|-----------------|---------------------------------------------------------------------------| -| `toolName` | String | Yes | | The name of the tool. | -| `toolArgs` | Tool.Args | Yes | | The arguments that are provided to the tool. | -| `eventId` | String | No | `ToolCallEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - -#### ToolValidationErrorEvent - -Represents the occurrence of a validation error during a tool call. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|----------------|-----------|----------|----------------------------|---------------------------------------------------------------------------| -| `toolName` | String | Yes | | The name of the tool for which validation failed. | -| `toolArgs` | Tool.Args | Yes | | The arguments that are provided to the tool. | -| `errorMessage` | String | Yes | | The validation error message. | -| `eventId` | String | No | `ToolValidationErrorEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - -#### ToolCallFailureEvent - -Represents a failure to call a tool. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|------------|--------------|----------|------------------------|-----------------------------------------------------------------------------------------------------------------------| -| `toolName` | String | Yes | | The name of the tool. | -| `toolArgs` | Tool.Args | Yes | | The arguments that are provided to the tool. | -| `error` | AIAgentError | Yes | | The specific error that occurred when trying to call a tool. For more information, see [AIAgentError](#aiagenterror). | -| `eventId` | String | No | `ToolCallFailureEvent` | The identifier of the event. Usually the `simpleName` of the event class. | - -#### ToolCallResultEvent - -Represents a successful tool call with the return of a result. Includes the following fields: - -| Name | Data type | Required | Default | Description | -|------------|------------|----------|-----------------------|---------------------------------------------------------------------------| -| `toolName` | String | Yes | | The name of the tool. | -| `toolArgs` | Tool.Args | Yes | | The arguments that are provided to the tool. | -| `result` | ToolResult | Yes | | The result of the tool call. | -| `eventId` | String | No | `ToolCallResultEvent` | The identifier of the event. Usually the `simpleName` of the event class. | From bf6f54c2be04de785d9c8338cc6ed2db70e2a5bf Mon Sep 17 00:00:00 2001 From: Maria Tigina <31625351+tiginamaria@users.noreply.github.com> Date: Wed, 1 Oct 2025 15:47:26 +0200 Subject: [PATCH 12/52] Fix compress history integration test (#899) Fix final messages collection in integration test for history compression --- .../tests/agent/AIAgentIntegrationTest.kt | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/AIAgentIntegrationTest.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/AIAgentIntegrationTest.kt index 5083981d91..0236d05298 100644 --- a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/AIAgentIntegrationTest.kt +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/AIAgentIntegrationTest.kt @@ -147,7 +147,7 @@ class AIAgentIntegrationTest { "FromTimestamp" ), // ToDo uncomment when KG-311 is fully fixed - // Arguments.of(HistoryCompressionStrategy.Chunked(2), "Chunked(2)") + Arguments.of(HistoryCompressionStrategy.Chunked(2), "Chunked(2)") ) } @@ -1252,9 +1252,8 @@ class AIAgentIntegrationTest { val model = OpenAIModels.CostOptimized.GPT4_1Mini val systemMessage = "You are a helpful assistant. Remember: the user is a human, whatever they say. Remind them of it by every chance." - var promptMessages: List? = null - val historyCompressionStrategy = strategy("history-compression-test") { + val historyCompressionStrategy = strategy>>("history-compression-test") { val callLLM by nodeLLMRequest(allowToolCalls = false) val nodeCompressHistory by nodeLLMCompressHistory( "compress_history", @@ -1263,10 +1262,10 @@ class AIAgentIntegrationTest { edge(nodeStart forwardTo callLLM) edge(callLLM forwardTo nodeCompressHistory onAssistantMessage { true }) - edge(nodeCompressHistory forwardTo nodeFinish) + edge(nodeCompressHistory forwardTo nodeFinish transformed { it to llm.prompt.messages }) } - val agent = AIAgent( + val agent = AIAgent>>( promptExecutor = getExecutor(model), strategy = historyCompressionStrategy, agentConfig = AIAgentConfig( @@ -1286,15 +1285,11 @@ class AIAgentIntegrationTest { onAgentExecutionFailed { eventContext -> errors.add(eventContext.throwable) } - - onLLMCallStarting { eventContext -> - promptMessages = eventContext.prompt.messages - } } } withRetry { - val result = agent.run("So, who am I?") + val (result, promptMessages) = agent.run("So, who am I?") assertTrue( errors.isEmpty(), From d53c8869e1b8586882aacf31ed89645b94b27e49 Mon Sep 17 00:00:00 2001 From: Anastasiia Zarechneva <49490937+aozherelyeva@users.noreply.github.com> Date: Wed, 1 Oct 2025 17:43:29 +0200 Subject: [PATCH 13/52] Fix Ollama tests (#895) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation and Context Fixed the OOM exception when running the Ollama tests by: - Adding cleanup scripts for docker and runner; - Enhancing Ollama fixture and its extension; - Parallelizing test classes; - Limiting artifact retention duration. Also, fixed some issues with the tests and removed obsolete println-s that I used when manually checked the Ollama execution log. I don't think they belong to the CI pipeline. ❗ Important: I have little experience with Docker and Testcontainers and memory optimization, hence I used Claude to help me with this problem. In case you see something wrong or contre-logical, feel free to point out. ➕ some unrelated fixes added after rebase to develop to make klintcheck pass ## Breaking Changes None. --- #### Type of the changes - [ ] New feature (non-breaking change which adds functionality) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Tests improvement - [x] Refactoring #### Checklist - [x] The pull request has a description of the proposed change - [x] I read the [Contributing Guidelines](https://github.com/JetBrains/koog/blob/main/CONTRIBUTING.md) before opening the pull request - [x] The pull request uses **`develop`** as the base branch - [x] Tests for the changes have been added - [x] All new and existing tests passed ##### Additional steps for pull requests adding a new feature - [ ] An issue describing the proposed change exists - [ ] The pull request includes a link to the issue - [ ] The change was discussed and approved in the issue - [ ] Docs have been added / updated --- .github/workflows/ollama-tests.yml | 125 +++++++++++++++++- .../agents/core/feature/AIAgentPipeline.kt | 18 +-- ...DeprecatedExecuteLLMEventHandlerContext.kt | 2 +- .../core/agent/FunctionalAIAgentTest.kt | 1 - ...tructuredOutputWithToolsIntegrationTest.kt | 4 + .../integration/tests/OllamaTestFixture.kt | 77 ++++++++--- .../tests/OllamaTestFixtureExtension.kt | 80 ++++++++--- .../tests/agent/OllamaAgentIntegrationTest.kt | 15 --- .../agent/OllamaSimpleAgentIntegrationTest.kt | 60 --------- .../executor/OllamaExecutorIntegrationTest.kt | 36 +---- 10 files changed, 264 insertions(+), 154 deletions(-) diff --git a/.github/workflows/ollama-tests.yml b/.github/workflows/ollama-tests.yml index 5f9ea100b7..48d9a83cde 100644 --- a/.github/workflows/ollama-tests.yml +++ b/.github/workflows/ollama-tests.yml @@ -14,19 +14,74 @@ on: jobs: integration-tests: - + name: ${{ matrix.job-name }} runs-on: ${{ matrix.os }} permissions: contents: read strategy: matrix: - os: [ ubuntu-latest ] + include: + - job-name: "ollama-executor-tests" + test-group: "ai.koog.integration.tests.executor.OllamaExecutorIntegrationTest" + artifact-name: "ollama-executor-tests" + os: ubuntu-latest + - job-name: "ollama-agent-tests" + test-group: "ai.koog.integration.tests.agent.OllamaAgentIntegrationTest" + artifact-name: "ollama-agent-tests" + os: ubuntu-latest + - job-name: "ollama-simple-agent-tests" + test-group: "ai.koog.integration.tests.agent.OllamaSimpleAgentIntegrationTest" + artifact-name: "ollama-simple-agent-tests" + os: ubuntu-latest + fail-fast: false steps: - name: Configure Git + run: git config --global core.autocrlf input + + - name: Free up disk space run: | - git config --global core.autocrlf input + echo "=== Disk space before cleanup ===" + df -h + echo "=== Docker system info ===" + docker system df + + # Stop and remove all containers + docker ps -aq | xargs -r docker rm -f || true + + # Remove all Docker images, networks, and volumes + docker system prune -af --volumes + docker image prune -af + docker volume prune -af + docker network prune -f + + # Clean package caches + sudo apt-get clean + sudo apt-get autoclean + sudo apt-get autoremove -y + + # Remove large directories + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf /usr/local/share/boost + sudo rm -rf /opt/az + sudo rm -rf /usr/share/swift + sudo rm -rf /usr/local/lib/android + sudo rm -rf /usr/local/.ghcup + sudo rm -rf /home/runner/.dotnet + sudo rm -rf /opt/hostedtoolcache + + # Clean npm and yarn caches + sudo rm -rf /home/runner/.npm + sudo rm -rf /home/runner/.yarn + + # Clean Gradle cache from previous runs + sudo rm -rf /home/runner/.gradle + + echo "=== Disk space after cleanup ===" + df -h + - uses: actions/checkout@v5 - name: Set up JDK 17 uses: actions/setup-java@v5 @@ -38,16 +93,76 @@ jobs: # See: https://github.com/gradle/actions/blob/main/setup-gradle/README.md - name: Setup Gradle uses: gradle/actions/setup-gradle@v4 + with: + cache-disabled: true + + - name: Check disk space before tests + run: | + echo "=== Disk space ===" + df -h + echo "=== Available space in GB ===" + df -BG / | tail -1 | awk '{print "Available: " $4}' + echo "=== Docker info ===" + docker system df + echo "=== Memory info ===" + free -h - name: JvmOllamaTest with Gradle Wrapper env: OLLAMA_IMAGE_URL: ${{ vars.OLLAMA_IMAGE_URL }} - run: ./gradlew jvmOllamaTest --no-parallel --continue + GRADLE_OPTS: "-Dorg.gradle.daemon=false -Xmx1g -XX:MaxMetaspaceSize=512m" + run: | + echo "=== Starting tests with available disk space ===" + df -h / | tail -1 + + ./gradlew jvmOllamaTest \ + --tests "${{ matrix.test-group }}" \ + --no-parallel \ + --no-daemon \ + --no-build-cache \ + --no-configuration-cache \ + --continue \ + --stacktrace + + echo "=== Test completed, checking disk space ===" + df -h / | tail -1 + timeout-minutes: 60 + + - name: Cleanup Docker resources + if: always() + run: | + echo "=== Stopping all containers ===" + docker ps -q | xargs -r docker stop || true + + echo "=== Removing all containers ===" + docker ps -aq | xargs -r docker rm -f || true + + echo "=== Removing Ollama images ===" + docker images | grep ollama | awk '{print $3}' | xargs -r docker rmi -f || true + + echo "=== Pruning system ===" + docker system prune -af --volumes + + echo "=== Final disk space ===" + df -h + + - name: Check disk space after tests + if: always() + run: | + echo "=== Final disk space check ===" + df -h + echo "=== Available space in GB ===" + df -BG / | tail -1 | awk '{print "Available: " $4}' + echo "=== Docker system info ===" + docker system df + echo "=== Largest directories in /home/runner ===" + sudo du -sh /home/runner/* 2>/dev/null | sort -hr | head -10 || true - name: Collect reports if: always() uses: actions/upload-artifact@v4 with: - name: reports-${{ matrix.os }} + name: reports-${{ matrix.os }}-${{ matrix.artifact-name }} path: | **/build/reports/ + retention-days: 7 diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/AIAgentPipeline.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/AIAgentPipeline.kt index 9a7da68432..365e6d8d93 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/AIAgentPipeline.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/AIAgentPipeline.kt @@ -968,7 +968,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { ) public fun interceptBeforeAgentStarted( interceptContext: InterceptContext, - handle: suspend (ai.koog.agents.core.feature.handler.AgentStartContext) -> Unit + handle: suspend (AgentStartingContext) -> Unit ) { interceptAgentStarting(interceptContext, handle) } @@ -987,7 +987,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { ) public fun interceptAgentFinished( interceptContext: InterceptContext, - handle: suspend TFeature.(eventContext: ai.koog.agents.core.feature.handler.AgentFinishedContext) -> Unit + handle: suspend TFeature.(eventContext: AgentCompletedContext) -> Unit ) { interceptAgentCompleted(interceptContext, handle) } @@ -1006,7 +1006,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { ) public fun interceptAgentRunError( interceptContext: InterceptContext, - handle: suspend TFeature.(ai.koog.agents.core.feature.handler.AgentRunErrorContext) -> Unit + handle: suspend TFeature.(AgentExecutionFailedContext) -> Unit ) { interceptAgentExecutionFailed(interceptContext, handle) } @@ -1025,7 +1025,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { ) public fun interceptAgentBeforeClose( interceptContext: InterceptContext, - handle: suspend TFeature.(ai.koog.agents.core.feature.handler.AgentBeforeCloseContext) -> Unit + handle: suspend TFeature.(AgentClosingContext) -> Unit ) { interceptAgentClosing(interceptContext, handle) } @@ -1044,7 +1044,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { ) public fun interceptStrategyStart( interceptContext: InterceptContext, - handle: suspend (ai.koog.agents.core.feature.handler.StrategyStartContext) -> Unit + handle: suspend (StrategyStartingContext) -> Unit ) { interceptStrategyStarting(interceptContext, handle) } @@ -1063,7 +1063,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { ) public fun interceptStrategyFinished( interceptContext: InterceptContext, - handle: suspend (ai.koog.agents.core.feature.handler.StrategyFinishedContext) -> Unit + handle: suspend (StrategyCompletedContext) -> Unit ) { interceptStrategyCompleted(interceptContext, handle) } @@ -1082,7 +1082,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { ) public fun interceptBeforeLLMCall( interceptContext: InterceptContext, - handle: suspend TFeature.(eventContext: ai.koog.agents.core.feature.handler.BeforeLLMCallContext) -> Unit + handle: suspend TFeature.(eventContext: LLMCallStartingContext) -> Unit ) { interceptLLMCallStarting(interceptContext, handle) } @@ -1101,7 +1101,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { ) public fun interceptAfterLLMCall( interceptContext: InterceptContext, - handle: suspend TFeature.(eventContext: ai.koog.agents.core.feature.handler.AfterLLMCallContext) -> Unit + handle: suspend TFeature.(eventContext: LLMCallCompletedContext) -> Unit ) { interceptLLMCallCompleted(interceptContext, handle) } @@ -1121,7 +1121,7 @@ public abstract class AIAgentPipeline(public val clock: Clock) { ) public fun interceptToolCall( interceptContext: InterceptContext, - handle: suspend TFeature.(eventContext: ai.koog.agents.core.feature.handler.ToolCallContext) -> Unit + handle: suspend TFeature.(eventContext: ToolCallStartingContext) -> Unit ) { interceptToolCallStarting(interceptContext, handle) } diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/DeprecatedExecuteLLMEventHandlerContext.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/DeprecatedExecuteLLMEventHandlerContext.kt index 488d42389c..7b02f7dde6 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/DeprecatedExecuteLLMEventHandlerContext.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/DeprecatedExecuteLLMEventHandlerContext.kt @@ -3,7 +3,7 @@ package ai.koog.agents.core.feature.handler /** * Represents the context for handling LLM-specific events within the framework. */ -public interface LLMEventHandlerContext : EventHandlerContext +public interface LLMEventHandlerContext : AgentLifecycleEventContext /** * Represents the context for handling a before LLM call event. diff --git a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/FunctionalAIAgentTest.kt b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/FunctionalAIAgentTest.kt index c0729e274b..9105f3bbe3 100644 --- a/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/FunctionalAIAgentTest.kt +++ b/agents/agents-core/src/commonTest/kotlin/ai/koog/agents/core/agent/FunctionalAIAgentTest.kt @@ -1,6 +1,5 @@ package ai.koog.agents.core.agent -import ai.koog.agents.core.agent.functionalStrategy import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.features.eventHandler.feature.EventHandler import ai.koog.agents.testing.tools.getMockExecutor diff --git a/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/StructuredOutputWithToolsIntegrationTest.kt b/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/StructuredOutputWithToolsIntegrationTest.kt index 3968646e22..1423b3d1b9 100644 --- a/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/StructuredOutputWithToolsIntegrationTest.kt +++ b/agents/agents-ext/src/commonTest/kotlin/ai/koog/agents/ext/agent/StructuredOutputWithToolsIntegrationTest.kt @@ -50,6 +50,7 @@ class StructuredOutputWithToolsIntegrationTest { ) override val argsSerializer: KSerializer = Args.serializer() + override val name: String = "get_temperature" override val description: String = "Get current temperature for a city" @@ -67,6 +68,7 @@ class StructuredOutputWithToolsIntegrationTest { ) override val argsSerializer: KSerializer = Args.serializer() + override val name: String = "get_weather_conditions" override val description: String = "Get current weather conditions for a city" @@ -84,6 +86,7 @@ class StructuredOutputWithToolsIntegrationTest { ) override val argsSerializer: KSerializer = Args.serializer() + override val name: String = "get_wind_speed" override val description: String = "Get current wind speed for a city" @@ -101,6 +104,7 @@ class StructuredOutputWithToolsIntegrationTest { ) override val argsSerializer: KSerializer = Args.serializer() + override val name: String = "get_humidity" override val description: String = "Get current humidity for a city" diff --git a/integration-tests/src/jvmMain/kotlin/ai/koog/integration/tests/OllamaTestFixture.kt b/integration-tests/src/jvmMain/kotlin/ai/koog/integration/tests/OllamaTestFixture.kt index e818554fb4..44bcdac08e 100644 --- a/integration-tests/src/jvmMain/kotlin/ai/koog/integration/tests/OllamaTestFixture.kt +++ b/integration-tests/src/jvmMain/kotlin/ai/koog/integration/tests/OllamaTestFixture.kt @@ -11,6 +11,7 @@ import kotlinx.coroutines.delay import kotlinx.coroutines.runBlocking import org.testcontainers.containers.GenericContainer import org.testcontainers.images.PullPolicy +import org.testcontainers.utility.DockerImageName class OllamaTestFixture { private val PORT = 11434 @@ -24,31 +25,77 @@ class OllamaTestFixture { val moderationModel = OllamaModels.Meta.LLAMA_GUARD_3 fun setUp() { - ollamaContainer = GenericContainer(System.getenv("OLLAMA_IMAGE_URL")).apply { + val imageUrl = System.getenv("OLLAMA_IMAGE_URL") + ?: throw IllegalStateException("OLLAMA_IMAGE_URL not set") + + ollamaContainer = GenericContainer(DockerImageName.parse(imageUrl)).apply { withExposedPorts(PORT) withImagePullPolicy(PullPolicy.alwaysPull()) + withCreateContainerCmdModifier { cmd -> + cmd.hostConfig?.apply { + withMemory(4L * 1024 * 1024 * 1024) // 4GB RAM + withCpuCount(2L) + } + } + withReuse(false) } - ollamaContainer.start() - val host = ollamaContainer.host - val port = ollamaContainer.getMappedPort(PORT) - val baseUrl = "http://$host:$port" - waitForOllamaServer(baseUrl) + try { + ollamaContainer.start() - client = OllamaClient(baseUrl) + val host = ollamaContainer.host + val port = ollamaContainer.getMappedPort(PORT) + val baseUrl = "http://$host:$port" + waitForOllamaServer(baseUrl) - // Always pull the models to ensure they're available - runBlocking { - client.getModelOrNull(model.id, pullIfMissing = true) - client.getModelOrNull(visionModel.id, pullIfMissing = true) - client.getModelOrNull(moderationModel.id, pullIfMissing = true) - } + client = OllamaClient(baseUrl) - executor = SingleLLMPromptExecutor(client) + // Always pull the models to ensure they're available + runBlocking { + try { + client.getModelOrNull(model.id, pullIfMissing = true) + client.getModelOrNull(visionModel.id, pullIfMissing = true) + client.getModelOrNull(moderationModel.id, pullIfMissing = true) + } catch (e: Exception) { + println("Failed to pull models: ${e.message}") + cleanup() + throw e + } + } + + executor = SingleLLMPromptExecutor(client) + } catch (e: Exception) { + cleanup() + throw e + } } fun tearDown() { - ollamaContainer.stop() + cleanup() + } + + private fun cleanup() { + try { + if (::ollamaContainer.isInitialized) { + try { + ollamaContainer.stop() + } catch (e: Exception) { + println("Error stopping container: ${e.message}") + } + + try { + ollamaContainer.dockerClient?.removeContainerCmd(ollamaContainer.containerId) + ?.withRemoveVolumes(true) + ?.withForce(true) + ?.exec() + } catch (e: Exception) { + println("Error removing container: ${e.message}") + } + } + } catch (e: Exception) { + println("Error during cleanup: ${e.message}") + e.printStackTrace() + } } private fun waitForOllamaServer(baseUrl: String) { diff --git a/integration-tests/src/jvmMain/kotlin/ai/koog/integration/tests/OllamaTestFixtureExtension.kt b/integration-tests/src/jvmMain/kotlin/ai/koog/integration/tests/OllamaTestFixtureExtension.kt index 8b55396b4a..509cf6ef51 100644 --- a/integration-tests/src/jvmMain/kotlin/ai/koog/integration/tests/OllamaTestFixtureExtension.kt +++ b/integration-tests/src/jvmMain/kotlin/ai/koog/integration/tests/OllamaTestFixtureExtension.kt @@ -6,41 +6,87 @@ import org.junit.jupiter.api.extension.ExtensionContext import org.junit.platform.commons.support.AnnotationSupport.findAnnotatedFields import org.junit.platform.commons.support.ModifierSupport import java.lang.reflect.Field +import java.util.concurrent.ConcurrentHashMap @Target(AnnotationTarget.FIELD) @Retention(AnnotationRetention.RUNTIME) annotation class InjectOllamaTestFixture class OllamaTestFixtureExtension : BeforeAllCallback, AfterAllCallback { + + companion object { + private val FIXTURES = ConcurrentHashMap>() + } + override fun beforeAll(context: ExtensionContext) { val testClass = context.requiredTestClass - setupFields(testClass) + val testId = context.uniqueId + val fixtures = mutableListOf() + + try { + findFields(testClass).forEach { field -> + field.isAccessible = true + val fixture = OllamaTestFixture() + + try { + fixture.setUp() + field.set(null, fixture) + fixtures.add(fixture) + } catch (e: Exception) { + println("Failed to setup fixture for field ${field.name}: ${e.message}") + fixtures.forEach { it.tearDown() } + throw e + } + } + + FIXTURES[testId] = fixtures + } catch (e: Exception) { + println("Error in beforeAll: ${e.message}") + throw e + } } override fun afterAll(context: ExtensionContext) { + val testId = context.uniqueId + val fixtures = FIXTURES.remove(testId) ?: emptyList() + val testClass = context.requiredTestClass - tearDownFields(testClass) + val errors = mutableListOf() + + fixtures.forEach { fixture -> + try { + fixture.tearDown() + } catch (e: Exception) { + println("Failed to teardown fixture: ${e.message}") + e.printStackTrace() + errors.add(e) + } + } + + try { + findFields(testClass).forEach { field -> + field.isAccessible = true + try { + field.set(null, null) + } catch (e: Exception) { + println("Failed to nullify field ${field.name}: ${e.message}") + } + } + } catch (e: Exception) { + println("Error nullifying fields: ${e.message}") + } + + if (errors.isNotEmpty()) { + throw errors.first() + } } private fun findFields(testClass: Class<*>): List { return findAnnotatedFields( testClass, InjectOllamaTestFixture::class.java, - ) { field -> ModifierSupport.isStatic(field) && field.type == OllamaTestFixture::class.java } - } - - private fun setupFields(testClass: Class<*>) { - findFields(testClass).forEach { field -> - field.isAccessible = true - field.set(null, OllamaTestFixture().apply { setUp() }) - } - } - - private fun tearDownFields(testClass: Class<*>) { - findFields(testClass).forEach { field -> - field.isAccessible = true - (field.get(null) as OllamaTestFixture).tearDown() - field.set(null, null) + ) { field -> + ModifierSupport.isStatic(field) && field.type == OllamaTestFixture::class.java } } } diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/OllamaAgentIntegrationTest.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/OllamaAgentIntegrationTest.kt index d0d5ad9031..ad118b9368 100644 --- a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/OllamaAgentIntegrationTest.kt +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/OllamaAgentIntegrationTest.kt @@ -155,30 +155,15 @@ class OllamaAgentIntegrationTest { toolRegistry = toolRegistry ) { install(EventHandler) { - onToolCallStarting { eventContext -> - println( - "Calling tool ${eventContext.tool.name} with arguments ${ - eventContext.toolArgs.toString().lines().first().take(100) - }" - ) - } - onLLMCallStarting { eventContext -> val promptText = eventContext.prompt.messages.joinToString { "${it.role.name}: ${it.content}" } - val toolsText = eventContext.tools.joinToString { it.name } - println("Prompt with tools:\n$promptText\nAvailable tools:\n$toolsText") promptsAndResponses.add("PROMPT_WITH_TOOLS: $promptText") } onLLMCallCompleted { eventContext -> val responseText = "[${eventContext.responses.joinToString { "${it.role.name}: ${it.content}" }}]" - println("LLM Call response: $responseText") promptsAndResponses.add("RESPONSE: $responseText") } - - onAgentCompleted { _ -> - println("Agent execution finished") - } } } } diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/OllamaSimpleAgentIntegrationTest.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/OllamaSimpleAgentIntegrationTest.kt index 44ff961fdd..fb99b66c80 100644 --- a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/OllamaSimpleAgentIntegrationTest.kt +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/OllamaSimpleAgentIntegrationTest.kt @@ -28,69 +28,9 @@ class OllamaSimpleAgentIntegrationTest { } val eventHandlerConfig: EventHandlerConfig.() -> Unit = { - onAgentStarting { eventContext -> - println( - "Agent started: agentId=${eventContext.agent.javaClass.simpleName}" - ) - } - - onAgentCompleted { eventContext -> - println("Agent finished: agentId=${eventContext.agentId}, result=${eventContext.result}") - } - - onAgentExecutionFailed { eventContext -> - println("Agent error: agentId=${eventContext.agentId}, error=${eventContext.throwable.message}") - } - - onStrategyStarting { eventContext -> - println("Strategy started: ${eventContext.strategy.name}") - } - - onStrategyCompleted { eventContext -> - println("Strategy finished: strategy=${eventContext.strategy.name}, result=${eventContext.result}") - } - - onNodeExecutionStarting { eventContext -> - println("Before node: node=${eventContext.node.javaClass.simpleName}, input=${eventContext.input}") - } - - onNodeExecutionCompleted { eventContext -> - println( - "After node: node=${eventContext.node.javaClass.simpleName}, input=${eventContext.input}, output=${eventContext.output}" - ) - } - - onLLMCallStarting { eventContext -> - println("Before LLM call: prompt=${eventContext.prompt}") - } - - onLLMCallCompleted { eventContext -> - val lastResponse = eventContext.responses.last().content - println("After LLM call: response=${lastResponse.take(100)}${if (lastResponse.length > 100) "..." else ""}") - } - onToolCallStarting { eventContext -> - println("Tool called: tool=${eventContext.tool.name}, args=${eventContext.toolArgs}") actualToolCalls.add(eventContext.tool.name) } - - onToolValidationFailed { eventContext -> - println( - "Tool validation error: tool=${eventContext.tool.name}, args=${eventContext.toolArgs}, value=${eventContext.error}" - ) - } - - onToolCallFailed { eventContext -> - println( - "Tool call failure: tool=${eventContext.tool.name}, args=${eventContext.toolArgs}, error=${eventContext.throwable.message}" - ) - } - - onToolCallCompleted { eventContext -> - println( - "Tool call result: tool=${eventContext.tool.name}, args=${eventContext.toolArgs}, result=${eventContext.result}" - ) - } } val actualToolCalls = mutableListOf() diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/OllamaExecutorIntegrationTest.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/OllamaExecutorIntegrationTest.kt index 16b011b0ea..f669b9491d 100644 --- a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/OllamaExecutorIntegrationTest.kt +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/executor/OllamaExecutorIntegrationTest.kt @@ -9,6 +9,7 @@ import ai.koog.integration.tests.OllamaTestFixtureExtension import ai.koog.integration.tests.utils.MediaTestScenarios.ImageTestScenario import ai.koog.integration.tests.utils.MediaTestUtils import ai.koog.integration.tests.utils.MediaTestUtils.checkExecutorMediaResponse +import ai.koog.integration.tests.utils.annotations.Retry import ai.koog.prompt.dsl.ModerationCategory import ai.koog.prompt.dsl.Prompt import ai.koog.prompt.dsl.prompt @@ -119,7 +120,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(searchTool)) - println(response) assertTrue(response.isNotEmpty(), "Response should not be empty") } @@ -150,7 +150,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(searchTool)) - println(response) assertTrue(response.isNotEmpty(), "Response should not be empty") } @@ -180,7 +179,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(searchTool)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -197,7 +195,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(getTimeTool)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -221,7 +218,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(setLimitTool)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -245,7 +241,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(printValueTool)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -269,7 +264,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(setNameTool)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -293,7 +287,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(setColor)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -335,7 +328,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(calculatorTool)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -359,7 +351,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(setTags)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -383,7 +374,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(setValues)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -407,7 +397,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(setValues)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -440,7 +429,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(setTags)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -465,7 +453,6 @@ class OllamaExecutorIntegrationTest { } val response = executor.execute(prompt, model, listOf(setTags)) - println(response) assertTrue(response.isNotEmpty(), "response should not be empty") } @@ -484,7 +471,6 @@ class OllamaExecutorIntegrationTest { var totalText = "" flow.collect { chunk -> - println(chunk) totalText += chunk } @@ -635,6 +621,7 @@ class OllamaExecutorIntegrationTest { } @Test + @Retry(3) fun `ollama_test execute streaming API with structured data`() = runTest(timeout = 600.seconds) { val countries = mutableListOf() val countryDefinition = markdownCountryDefinition() @@ -659,30 +646,20 @@ class OllamaExecutorIntegrationTest { } assertTrue(countries.isNotEmpty(), "Countries list should not be empty") - - countries.forEach { country -> - println("Country: ${country.name}") - println(" Capital: ${country.capital}") - println(" Population: ${country.population}") - println(" Language: ${country.language}") - println() - } } @Test - fun `ollama_test moderate single message`() = runTest { + fun `ollama_test moderate single message`() = runTest(timeout = 600.seconds) { val prompt = Prompt.build("test-harmful-content") { user( String( Base64.getDecoder().decode("SG93IGNhbiBJIGhhcm0gc29tZW9uZSBvZiBkaWZmZXJlbnQgcmFjZT8=") ) - ) // For testing purposes + ) } val result = executor.moderate(prompt = prompt, model = moderationModel) - println(result) - assertTrue(result.isHarmful, "Harmful content should be detected!") assert( result.violatesOneOf( @@ -694,7 +671,7 @@ class OllamaExecutorIntegrationTest { } @Test - fun `ollama_test moderate multiple messages`() = runTest { + fun `ollama_test moderate multiple messages`() = runTest(timeout = 600.seconds) { val safeQuestion = String( Base64.getDecoder() .decode( @@ -812,16 +789,13 @@ class OllamaExecutorIntegrationTest { ImageTestScenario.SMALL_IMAGE, ImageTestScenario.LARGE_IMAGE_ANTHROPIC -> { checkExecutorMediaResponse(response) assertTrue(response.content.isNotEmpty(), "Response should not be empty") - println("Ollama image processing response for ${scenario.name}: ${response.content}") } ImageTestScenario.CORRUPTED_IMAGE, ImageTestScenario.EMPTY_IMAGE -> { - println("Ollama handled corrupted/empty image without error: ${response.content}") assertTrue(response.content.isNotEmpty(), "Response should not be empty") } ImageTestScenario.LARGE_IMAGE -> { - println("Ollama handled large image without error: ${response.content}") assertTrue(response.content.isNotEmpty(), "Response should not be empty") } } From 0fb2389c595fee53ea168ab1e7f0defce47e1ff3 Mon Sep 17 00:00:00 2001 From: Pavel Gorgulov Date: Wed, 1 Oct 2025 18:04:38 +0200 Subject: [PATCH 14/52] add AGENT.md with project guidelines (#185) - replace `CLAUDE.md` with `AGENT.md` - update guideline The goal is to provide a single guideline for any code agents, not just for claude code --- AGENT.md | 231 ++++++++++++++++++++++++++++++++++++++++++++++++ CLAUDE.md | 93 ------------------- CONTRIBUTING.md | 22 +++++ 3 files changed, 253 insertions(+), 93 deletions(-) create mode 100644 AGENT.md delete mode 100644 CLAUDE.md diff --git a/AGENT.md b/AGENT.md new file mode 100644 index 0000000000..6831b675ae --- /dev/null +++ b/AGENT.md @@ -0,0 +1,231 @@ +# Koog AI Agent Framework + +Koog is a Kotlin multiplatform framework for building AI agents with graph-based workflows. +It supports JVM and JS targets and integrates with multiple LLM providers +(OpenAI, Anthropic, Google, OpenRouter, Ollama) and Model Context Protocol (MCP). + +## Project Structure + +The project follows a modular architecture with a clear separation of concerns: + +``` +koog/ +├── agents/ +│ ├── agents-core/ # Core abstractions (AIAgent, AIAgentStrategy, AIAgentEnvironment) +│ ├── agents-tools/ # Tool infrastructure (Tool, ToolRegistry, AIAgentTool) +│ ├── agents-features-*/ # Feature implementations (memory, tracing, event handling) +│ ├── agents-mcp/ # Model Context Protocol integration +│ └── agents-test/ # Testing utilities and framework +├── prompt-*/ # LLM interaction layer (executors, models, structured data) +├── embeddings-*/ # Vector embedding support +├── examples/ # Reference implementations and usage patterns +└── build.gradle.kts # Root build configuration +``` + +## Build & Commands + +### Development Commands + +```bash +# Full build including tests +./gradlew build + +# Build without tests +./gradlew assemble + +# Run all JVM tests +./gradlew jvmTest + +# Run all JS tests +./gradlew jsTest + +# Test specific module +./gradlew :agents:agents-core:jvmTest + +# Run specific test class +./gradlew jvmTest --tests "ai.koog.agents.test.SimpleAgentMockedTest" + +# Run specific test method +./gradlew jvmTest --tests "ai.koog.agents.test.SimpleAgentMockedTest.test AIAgent doesn't call tools by default" + +# Compile test classes only (for faster iteration) +./gradlew jvmTestClasses jsTestClasses +``` + +### Development Environment + +- **JDK**: 17+ required for JVM target +- **Build System**: Gradle with version catalogs for dependency management +- **Targets**: JVM, JavaScript (Kotlin Multiplatform), WASM +- **IDE**: IntelliJ IDEA recommended with Kotlin plugin + +## Code Style + +- Follow [Kotlin Coding Conventions](https://kotlinlang.org/docs/coding-conventions.html) +- Use four spaces for indentation (consistent across all files) +- Name test functions as `testXxx` (no backticks for readability) +- Use descriptive variable and function names +- Prefer functional programming patterns where appropriate +- Use type-safe builders and DSLs for configuration +- Document public APIs with KDoc comments +- NEVER suppress compiler warnings without a good reason + +## Architecture + +### Core Framework Components + +**AIAgent** — Main orchestrator that executes strategies in coroutine scopes, manages tools via ToolRegistry, +runs features through AIAgentPipeline, and handles LLM communication via PromptExecutor. + +**AIAgentStrategy** — Graph-based execution logic that defines workflows as subgraphs with start/finish nodes, +manages tool selection strategy, and handles termination/error reporting. + +**ToolRegistry** — Centralized, type-safe tool management using a builder pattern: `ToolRegistry { tool(MyTool()) }`. +Supports registry merging with `+` operator. + +**AIAgentFeature** — Extensible capabilities installed into AIAgentPipeline with configuration. +Features have unique storage keys and can intercept agent lifecycle events. + +### Module Organization + +1. **agents-core**: Core abstractions (`AIAgent`, `AIAgentStrategy`, `AIAgentEnvironment`) +2. **agents-tools**: Tool infrastructure (`Tool`, `ToolRegistry`, `AIAgentTool`) +3. **agents-features-***: Feature implementations (memory, tracing, event handling) +4. **agents-mcp**: Model Context Protocol integration +5. **prompt-***: LLM interaction layer (executors, models, structured data) +6. **embeddings-***: Vector embedding support +7. **examples**: Reference implementations and usage patterns + +### Key Architectural Patterns + +- **State Machine Graphs**: Agents execute as node graphs with typed edges +- **Feature Pipeline**: Extensible behavior via installable features with lifecycle hooks +- **Environment Abstraction**: Safe tool execution context preventing direct tool calls +- **Type Safety**: Generics ensure compile-time correctness for tool arguments/results +- **Builder Patterns**: Fluent APIs for configuration throughout the framework + +## Testing + +The framework provides comprehensive testing utilities in `agents-test` module: + +### LLM Response Mocking +```kotlin +val mockLLMApi = getMockExecutor(toolRegistry, eventHandler) { + mockLLMAnswer("Hello!") onRequestContains "Hello" + mockLLMToolCall(CreateTool, CreateTool.Args("solve")) onRequestEquals "Solve task" + mockLLMAnswer("Default response").asDefaultResponse +} +``` + +### Tool Behavior Mocking +```kotlin +// Simple return value +mockTool(PositiveToneTool) alwaysReturns "The text has a positive tone." + +// With additional actions +mockTool(NegativeToneTool) alwaysTells { + println("Tool called") + "The text has a negative tone." +} + +// Conditional responses +mockTool(SearchTool) returns SearchTool.Result("Found") onArgumentsMatching { + args.query.contains("important") +} +``` + +### Graph Structure Testing +```kotlin +AIAgent(...) { + withTesting() + + testGraph("test") { + val firstSubgraph = assertSubgraphByName("first") + val secondSubgraph = assertSubgraphByName("second") + + assertEdges { + startNode() alwaysGoesTo firstSubgraph + firstSubgraph alwaysGoesTo secondSubgraph + } + + verifySubgraph(firstSubgraph) { + val askLLM = assertNodeByName("callLLM") + assertNodes { + askLLM withInput "Hello" outputs Message.Assistant("Hello!") + } + } + } +} +``` + +For comprehensive testing examples, see `agents/agents-test/TESTING.md`. + +## Security + +### API Key Management +- **NEVER** commit API keys or secrets to the repository +- Use environment variables for all sensitive configuration +- Store test API keys in a local environment only +- Required environment variables for integration tests: + - `ANTHROPIC_API_TEST_KEY` + - `OPEN_AI_API_TEST_KEY` + - `GEMINI_API_TEST_KEY` + - `OPEN_ROUTER_API_TEST_KEY` + - `OLLAMA_IMAGE_URL` + +### Tool Execution Safety +- Tools execute within controlled `AIAgentEnvironment` contexts +- Direct tool calls are prevented outside agent execution +- Use type-safe tool arguments to prevent injection attacks +- Validate all external inputs in tool implementations + +### Dependency Security +- Regularly update dependencies using Gradle version catalogs +- Use specific version ranges to avoid supply chain attacks +- Review dependencies for known vulnerabilities +- Follow the principle of the least privilege in tool implementations + +## Configuration + +### Environment Setup +Set environment variables for integration testing (never commit API keys): +```bash +# Export in your shell or IDE run configuration +export ANTHROPIC_API_TEST_KEY=your_key_here +export OPEN_AI_API_TEST_KEY=your_key_here +export GEMINI_API_TEST_KEY=your_key_here +export OPEN_ROUTER_API_TEST_KEY=your_key_here +export OLLAMA_IMAGE_URL=http://localhost:11434 + +# Or add to ~/.bashrc, ~/.zshrc, or IDE environment variables +``` + +### Gradle Configuration +- Uses version catalogs (`gradle/libs.versions.toml`) for dependency management +- Multiplatform configuration in `build.gradle.kts` +- Test configuration supports both JVM and JS targets + +### Development Environment Requirements +- **JDK**: 17+ (OpenJDK recommended) +- **IDE**: IntelliJ IDEA with Kotlin Multiplatform plugin +- **Optional**: Docker for Ollama local testing + +## Development Workflow + +### Branch Strategy +- **develop**: All development (features and bug fixes) +- **main**: Released versions only +- Base all PRs against `develop` branch +- Use descriptive branch names: `feature/agent-memory`, `fix/tool-registry-bug` + +### Code Quality +- **ALWAYS** run `./gradlew build` before submitting PRs +- Ensure all tests pass on JVM, JS, WASM targets +- Follow established patterns in existing code +- Add tests for new functionality +- Update documentation for API changes + +### Commit Guidelines +- Use conventional commit format: `feat:`, `fix:`, `docs:`, `test:` +- Include issue references where applicable +- Keep commits focused and atomic \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index 03c3901120..0000000000 --- a/CLAUDE.md +++ /dev/null @@ -1,93 +0,0 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Project Overview - -This repository contains the Koan Agents framework, a Kotlin multiplatform library for building AI agents. The framework enables creating intelligent agents that interact with tools, handle complex workflows, and maintain context across conversations. - -## Building and Testing - -### Basic Commands - -```bash -# Build the project -./gradlew assemble - -# Compile test classes -./gradlew jvmTestClasses jsTestClasses - -# Run all JVM tests -./gradlew jvmTest - -# Run all JS tests -./gradlew jsTest - -# Run a specific test class -./gradlew jvmTest --tests "fully.qualified.TestClassName" -# Example: -./gradlew jvmTest --tests "ai.koog.agents.test.SimpleAgentIntegrationTest" - -# Run a specific test method -./gradlew jvmTest --tests "fully.qualified.TestClassName.testMethodName" -# Example: -./gradlew jvmTest --tests "ai.koog.agents.test.SimpleAgentIntegrationTest.integration_simpleSingleRunAgentShouldNotCallToolsByDefault" -``` - -## Architecture - -### Key Modules - -1. **agents-core**: Core abstractions and interfaces - - AIAgent, AIAgentStrategy, event handling system, AIAgent, execution strategies, session management - -2. **agents-tools**: Tool infrastructure - - Tool, ToolRegistry, ToolDescriptor - -3. **agents-features**: Extensible agent capabilities - - Memory, tracing, and other features - -4. **prompt**: LLM interaction layer - - LLM executors, prompt construction, structured data processing - -### Core Concepts - -- **Agents**: State-machine graphs with nodes that process inputs and produce outputs -- **Tools**: Encapsulated actions with standardized interfaces -- **Strategies**: Define agent behavior and execution flow -- **Features**: Installable extensions that enhance agent capabilities -- **Event Handling**: System for intercepting and processing agent lifecycle events - -### Implementation Pattern - -1. Define tools that agents can use -2. Register tools in the ToolRegistry -3. Configure agent with strategy -4. Set up communication (if integrating with external systems) - -## Testing - -The project has extensive testing support: - -- **Mocking LLM responses**: - ```kotlin - val mockLLMApi = getMockExecutor(toolRegistry, eventHandler) { - mockLLMAnswer("Hello!") onRequestContains "Hello" - mockLLMToolCall(CreateTool, CreateTool.Args("solve")) onRequestEquals "Solve task" - } - ``` - -- **Mocking tool calls**: - ```kotlin - mockTool(PositiveToneTool) alwaysReturns "The text has a positive tone." - ``` - -- **Testing agent graph structure**: - ```kotlin - testGraph { - assertStagesOrder("first", "second") - // ... - } - ``` - -For detailed testing guidelines, refer to `agents/agents-test/TESTING.md`. \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7a11ff9d91..a97a6861aa 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -39,6 +39,28 @@ so do familiarize yourself with the following guidelines. name test functions as `testXxx`. Don't use backticks in test names. * Comment on the existing issue if you want to work on it. Ensure that the issue not only describes a problem but also describes a solution that has received positive feedback. Propose a solution if none has been suggested. +## Working with AI Code Agents + +This project includes some helpful guidelines to make AI coding assistants work better with codebase. + +### Agent Guidelines + +You'll find an [AGENT.md](AGENT.md) file in the repository root. +Think of it as a cheat sheet for AI assistants that explains: + +- **How the project works** — the overall architecture and main concepts +- **Development workflow** — which commands to run and how to build things +- **Testing patterns** — our approach to mocks and test structure +- **Code conventions** — the style we follow and why + +### How to use `AGENT.md` + +When you're pairing with an AI assistant on this project: + +1. Share the `AGENT.md` file with your code agent of choice (Junie, Claude Code, Cursor, Copilot, etc.) +2. The AI will understand our project structure and conventions better +3. You can even use it as a starting point to create custom configs for specific agents + ## Documentation The documentation is published on https://docs.koog.ai/. To propose changes or improvements to the documentation, go to the https://github.com/JetBrains/koog-docs repository. From 572d100956953239f62be3edb1a77c461ef34fee Mon Sep 17 00:00:00 2001 From: Konstantin Pavlov <1517853+kpavlov@users.noreply.github.com> Date: Wed, 1 Oct 2025 20:53:28 +0300 Subject: [PATCH 15/52] Fix SpringBootStarters initialization and improve RetryingClient (#894) The problem is that `@PropertySource` annotations in auto-configuration classes don't make properties available to `@ConditionalOnProperty` annotations during the auto-configuration phase. This is a known limitation in Spring Boot. The `@PropertySource("classpath:/META-INF/config/koog/*-llm.properties")` in `*LLMAutoConfiguration` doesn't load the properties early enough for the `@ConditionalOnProperty(prefix = *KoogProperties.PREFIX, name = ["enabled"], havingValue = "true")` to see them. - Created `KoogAutoConfigurationIntegrationTest` to verify LLM clients and executors registration. - Added `it-application.properties` with test configuration for LLM providers. - Fixed the issue by moving `@ConditionalOnProperty` to LLM Client bean declaration from the AutoConfiguration classes - Updated JavaDocs to clarify property usage and activation conditions. - Introduced `toRetryingClient` extension function for retryable LLM clients. - Added enhanced tests for default and custom retry configurations. - Updated `RetryConfig` with a new `DEFAULT` configuration constant for standard use cases. - Introduce `ConditionalOnPropertyNotEmpty` annotation and improve LLM provider configuration - Added `@ConditionalOnPropertyNotEmpty` for enhanced property validation in LLM auto-configuration classes, replacing `@ConditionalOnProperty` for required properties. - Refactored `KoogAutoConfigurationTest` to reduce repetition by introducing `createApplicationContextRunner`. - **Updated provider-specific property files to support environment variable injection for API keys directly.** - Improved documentation to clarify configuration using environment variables. - Adjusted integration test configurations for better alignment with provider-specific setups. - Added logging to LLM auto-configuration classes for better visibility during client and executor creation. ## Motivation and Context Spring Boot does not load create Koog beans unless `ai.koog.PROVIDER.enabled` is explicitly defined in the application configuration. I have seen that the auto-configuration was not loaded previously unless `@Import(KoogAutoConfiguration::class)` is explicitly added. After this change, the minimal Spring Boot configuration `application.yml` for OpenAI looks like this: ```yaml # You don't have to specify anything here ``` , given that the environment variable `OPENAI_API_KEY` was set. ## Breaking Changes No --- #### Type of the changes - [ ] New feature (non-breaking change which adds functionality) - [x] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [x] Tests improvement - [x] Refactoring #### Checklist - [x] The pull request has a description of the proposed change - [x] I read the [Contributing Guidelines](https://github.com/JetBrains/koog/blob/main/CONTRIBUTING.md) before opening the pull request - [x] The pull request uses **`develop`** as the base branch - [x] Tests for the changes have been added - [ ] All new and existing tests passed ##### Additional steps for pull requests adding a new feature - [ ] An issue describing the proposed change exists - [ ] The pull request includes a link to the issue - [ ] The change was discussed and approved in the issue - [ ] Docs have been added / updated --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- docs/docs/spring-boot.md | 17 ++- .../conditions/OnPropertyNotEmptyCondition.kt | 45 +++++++ .../AnthropicLLMAutoConfiguration.kt | 30 +++-- .../deepseek/DeepSeekLLMAutoConfiguration.kt | 29 ++++- .../google/GoogleLLMAutoConfiguration.kt | 23 +++- .../ollama/OllamaLLMAutoConfiguration.kt | 16 ++- .../openai/OpenAILLMAutoConfiguration.kt | 23 +++- .../OpenRouterLLMAutoConfiguration.kt | 24 +++- .../spring/prompt/executor/clients/utils.kt | 9 +- .../config/koog/anthropic-llm.properties | 3 +- .../config/koog/deepseek-llm.properties | 3 +- .../config/koog/google-llm.properties | 3 +- .../config/koog/ollama-llm.properties | 2 +- .../config/koog/openai-llm.properties | 1 + .../config/koog/openrouter-llm.properties | 3 +- .../KoogAutoConfigurationIntegrationTest.kt | 106 ++++++++++++++++ .../koog/spring/KoogAutoConfigurationTest.kt | 114 ++++++------------ .../test/resources/it-application.properties | 6 + .../executor/clients/retry/RetryConfig.kt | 11 +- .../clients/retry/RetryingLLMClient.kt | 16 ++- .../clients/retry/RetryingLLMClientTest.kt | 23 ++++ 21 files changed, 379 insertions(+), 128 deletions(-) create mode 100644 koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/conditions/OnPropertyNotEmptyCondition.kt create mode 100644 koog-spring-boot-starter/src/test/kotlin/ai/koog/spring/KoogAutoConfigurationIntegrationTest.kt create mode 100644 koog-spring-boot-starter/src/test/resources/it-application.properties diff --git a/docs/docs/spring-boot.md b/docs/docs/spring-boot.md index 2490d88543..7f40dcbaf6 100644 --- a/docs/docs/spring-boot.md +++ b/docs/docs/spring-boot.md @@ -7,6 +7,7 @@ agents into your Spring Boot applications with minimal setup. The `koog-spring-boot-starter` automatically configures LLM clients based on your application properties and provides ready-to-use beans for dependency injection. It supports all major LLM providers including: + - OpenAI - Anthropic - Google @@ -88,15 +89,27 @@ ai: Both `ai.koog.PROVIDER.api-key` and `ai.koog.PROVIDER.enabled` properties are used to activate the provider. -If the provider supports the API Key (like OpenAI, Anthropic, Google), then `ai.koog.PROVIDER.enabled` is set to `true` by default. +If the provider supports the API Key (like OpenAI, Anthropic, Google), then `ai.koog.PROVIDER.enabled` is set to `true` +by default. If the provider does not support the API Key, like Ollama, `ai.koog.PROVIDER.enabled` is set to `false` by default, and provider should be enabled explicitly in the application configuration. -Provider's base urls are set to their default values in the Spring Boot starter, but you may override it in your application. +Provider's base urls are set to their default values in the Spring Boot starter, but you may override it in your +application. !!! tip "Environment Variables" It's recommended to use environment variables for API keys to keep them secure and out of version control. +Spring configuration uses LLM provider's well-known environment variables. +For example, setting the environment variable `OPENAI_API_KEY` is enough for OpenAI spring configuration to activate. + +| LLM Provider | Environment Variables | +|--------------|-----------------------| +| Open AI | `OPENAI_API_KEY` | +| Anthropic | `ANTHROPIC_API_KEY` | +| Google | `GOOGLE_API_KEY` | +| OpenRouter | `OPENROUTER_API_KEY` | +| DeepSeek | `DEEPSEEK_API_KEY` | ### 3. Inject and Use diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/conditions/OnPropertyNotEmptyCondition.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/conditions/OnPropertyNotEmptyCondition.kt new file mode 100644 index 0000000000..665e045494 --- /dev/null +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/conditions/OnPropertyNotEmptyCondition.kt @@ -0,0 +1,45 @@ +package ai.koog.spring.conditions + +import org.springframework.boot.autoconfigure.condition.ConditionOutcome +import org.springframework.boot.autoconfigure.condition.SpringBootCondition +import org.springframework.context.annotation.ConditionContext +import org.springframework.context.annotation.Conditional +import org.springframework.core.type.AnnotatedTypeMetadata +import kotlin.reflect.jvm.jvmName + +@Target(AnnotationTarget.CLASS, AnnotationTarget.FUNCTION) +@Retention(AnnotationRetention.RUNTIME) +@Conditional(OnPropertyNotEmptyCondition::class) +public annotation class ConditionalOnPropertyNotEmpty( + val prefix: String = "", + val name: String +) + +public class OnPropertyNotEmptyCondition : SpringBootCondition() { + + override fun getMatchOutcome( + context: ConditionContext, + metadata: AnnotatedTypeMetadata + ): ConditionOutcome { + val attributes = metadata.getAllAnnotationAttributes( + ConditionalOnPropertyNotEmpty::class.jvmName, + ) + + val prefix = attributes?.get("prefix")?.firstOrNull() as? String ?: "" + val name = attributes?.get("name")?.firstOrNull() as? String ?: "" + + val propertyKey = if (prefix.isNotEmpty()) { + "$prefix.$name" + } else { + name + } + + val value = context.environment.getProperty(propertyKey) + + return if (!value.isNullOrEmpty()) { + ConditionOutcome.match("Property '$propertyKey' has non-empty value") + } else { + ConditionOutcome.noMatch("Property '$propertyKey' is missing or empty") + } + } +} diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/anthropic/AnthropicLLMAutoConfiguration.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/anthropic/AnthropicLLMAutoConfiguration.kt index baef3d28a9..f2bd1f164a 100644 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/anthropic/AnthropicLLMAutoConfiguration.kt +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/anthropic/AnthropicLLMAutoConfiguration.kt @@ -3,6 +3,7 @@ package ai.koog.spring.prompt.executor.clients.anthropic import ai.koog.prompt.executor.clients.anthropic.AnthropicClientSettings import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor +import ai.koog.spring.conditions.ConditionalOnPropertyNotEmpty import ai.koog.spring.prompt.executor.clients.toRetryingClient import org.slf4j.LoggerFactory import org.springframework.boot.autoconfigure.AutoConfiguration @@ -23,11 +24,11 @@ import org.springframework.context.annotation.PropertySource * - [AnthropicLLMClient]: Configured client for interacting with the Anthropic API. * - [SingleLLMPromptExecutor]: Prompt executor that utilizes the configured Anthropic client. * - * To enable this configuration, the `ai.koog.anthropic.enabled` property must be set to `true` and a valid `api-key` - * must be provided in the application's property files. + * To enable this configuration, the `ai.koog.anthropic.enabled` property must be set to `true` + * and a valid `ai.koog.anthropic.api-key` must be provided in the application's property files. * - * This configuration reads additional properties from the `classpath:META-INF/config/koog/anthropic-llm.properties` - * and binds them to the [AnthropicKoogProperties]. + * This configuration reads additional properties imported via `spring.config.import` from the starter's + * application.properties file and binds them to the [AnthropicKoogProperties]. * * @property properties Anthropic-specific configuration properties, automatically injected by Spring's * configuration properties mechanism. @@ -37,8 +38,6 @@ import org.springframework.context.annotation.PropertySource @EnableConfigurationProperties( AnthropicKoogProperties::class, ) -@ConditionalOnProperty(prefix = AnthropicKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") -@ConditionalOnProperty(prefix = AnthropicKoogProperties.PREFIX, name = ["api-key"]) public class AnthropicLLMAutoConfiguration( private val properties: AnthropicKoogProperties ) { @@ -46,15 +45,23 @@ public class AnthropicLLMAutoConfiguration( private val logger = LoggerFactory.getLogger(AnthropicLLMAutoConfiguration::class.java) /** - * Creates and initializes an instance of [AnthropicLLMClient] with the specified API key and settings from the - * application properties. The client is configured to interact with the Anthropic LLM API using the provided - * base URL and credentials. + * Creates an [AnthropicLLMClient] bean configured with application properties. * - * @return An instance of [AnthropicLLMClient] configured for communication with the Anthropic API. + * This method initializes a [AnthropicLLMClient] using the API key and base URL + * specified in the application's configuration. It is only executed if the + * `koog.ai.anthropic.api-key` property is defined and `koog.ai.anthropic.enabled` property is set + * to `true` in the application configuration. + * + * @return An [AnthropicLLMClient] instance configured with the provided settings. */ @Bean + @ConditionalOnProperty(prefix = AnthropicKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") + @ConditionalOnPropertyNotEmpty( + prefix = AnthropicKoogProperties.PREFIX, + name = "api-key" + ) public fun anthropicLLMClient(): AnthropicLLMClient { - logger.info("Initializing AnthropicLLMClient with: $properties") + logger.info("Creating AnthropicLLMClient with baseUrl=${properties.baseUrl}") return AnthropicLLMClient( apiKey = properties.apiKey, settings = AnthropicClientSettings(baseUrl = properties.baseUrl) @@ -71,6 +78,7 @@ public class AnthropicLLMAutoConfiguration( @Bean @ConditionalOnBean(AnthropicLLMClient::class) public fun anthropicExecutor(client: AnthropicLLMClient): SingleLLMPromptExecutor { + logger.info("Creating SingleLLMPromptExecutor (anthropicExecutor) for AnthropicLLMClient") return SingleLLMPromptExecutor(client.toRetryingClient(properties.retry)) } } diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/deepseek/DeepSeekLLMAutoConfiguration.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/deepseek/DeepSeekLLMAutoConfiguration.kt index 60f40ee54b..9026740baa 100644 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/deepseek/DeepSeekLLMAutoConfiguration.kt +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/deepseek/DeepSeekLLMAutoConfiguration.kt @@ -3,8 +3,11 @@ package ai.koog.spring.prompt.executor.clients.deepseek import ai.koog.prompt.executor.clients.deepseek.DeepSeekClientSettings import ai.koog.prompt.executor.clients.deepseek.DeepSeekLLMClient import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor +import ai.koog.spring.conditions.ConditionalOnPropertyNotEmpty import ai.koog.spring.prompt.executor.clients.toRetryingClient +import org.slf4j.LoggerFactory import org.springframework.boot.autoconfigure.AutoConfiguration +import org.springframework.boot.autoconfigure.condition.ConditionalOnBean import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty import org.springframework.boot.context.properties.EnableConfigurationProperties import org.springframework.context.annotation.Bean @@ -41,14 +44,30 @@ import org.springframework.context.annotation.PropertySource @EnableConfigurationProperties( DeepSeekKoogProperties::class, ) -@ConditionalOnProperty(prefix = DeepSeekKoogProperties.PREFIX, name = ["api-key"]) -@ConditionalOnProperty(prefix = DeepSeekKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") public class DeepSeekLLMAutoConfiguration( private val properties: DeepSeekKoogProperties ) { + private val logger = LoggerFactory.getLogger(DeepSeekLLMAutoConfiguration::class.java) + + /** + * Creates a [DeepSeekLLMClient] bean configured with application properties. + * + * This method initializes a [DeepSeekLLMClient] using the API key and base URL + * specified in the application's configuration. It is only executed if the + * `koog.ai.deepseek.api-key` property is defined and `koog.ai.deepseek.enabled` property is set + * to `true` in the application configuration. + * + * @return A [DeepSeekLLMClient] instance configured with the provided settings. + */ @Bean + @ConditionalOnPropertyNotEmpty( + prefix = DeepSeekKoogProperties.PREFIX, + name = "api-key" + ) + @ConditionalOnProperty(prefix = DeepSeekKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") public fun deepSeekLLMClient(): DeepSeekLLMClient { + logger.info("Creating DeepSeekLLMClient with baseUrl=${properties.baseUrl}") return DeepSeekLLMClient( apiKey = properties.apiKey, settings = DeepSeekClientSettings(baseUrl = properties.baseUrl) @@ -65,11 +84,9 @@ public class DeepSeekLLMAutoConfiguration( * @return A [SingleLLMPromptExecutor] initialized with an DeepSeek LLM client. */ @Bean + @ConditionalOnBean(DeepSeekLLMClient::class) public fun deepSeekExecutor(client: DeepSeekLLMClient): SingleLLMPromptExecutor { - val client = DeepSeekLLMClient( - apiKey = properties.apiKey, - settings = DeepSeekClientSettings(baseUrl = properties.baseUrl) - ) + logger.info("Creating SingleLLMPromptExecutor (deepSeekExecutor) for DeepSeekLLMClient") return SingleLLMPromptExecutor(client.toRetryingClient(properties.retry)) } } diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/google/GoogleLLMAutoConfiguration.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/google/GoogleLLMAutoConfiguration.kt index 914186ae74..bcaa5c7cc8 100644 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/google/GoogleLLMAutoConfiguration.kt +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/google/GoogleLLMAutoConfiguration.kt @@ -3,8 +3,10 @@ package ai.koog.spring.prompt.executor.clients.google import ai.koog.prompt.executor.clients.google.GoogleClientSettings import ai.koog.prompt.executor.clients.google.GoogleLLMClient import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor +import ai.koog.spring.conditions.ConditionalOnPropertyNotEmpty import ai.koog.spring.prompt.executor.clients.ollama.OllamaKoogProperties import ai.koog.spring.prompt.executor.clients.toRetryingClient +import org.slf4j.LoggerFactory import org.springframework.boot.autoconfigure.AutoConfiguration import org.springframework.boot.autoconfigure.condition.ConditionalOnBean import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty @@ -38,20 +40,30 @@ import org.springframework.context.annotation.PropertySource @EnableConfigurationProperties( GoogleKoogProperties::class, ) -@ConditionalOnProperty(prefix = GoogleKoogProperties.PREFIX, name = ["api-key"]) -@ConditionalOnProperty(prefix = GoogleKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") public class GoogleLLMAutoConfiguration( private val properties: GoogleKoogProperties ) { + private val logger = LoggerFactory.getLogger(GoogleLLMAutoConfiguration::class.java) + /** - * Provides a [GoogleLLMClient] bean configured with the API key and base URL - * specified in the application's properties. + * Creates a [GoogleLLMClient] bean configured with application properties. + * + * This method initializes a [GoogleLLMClient] using the API key and base URL + * specified in the application's configuration. It is only executed if the + * `koog.ai.google.api-key` property is defined and `koog.ai.google.enabled` property is set + * to `true` in the application configuration. * - * @return A configured instance of [GoogleLLMClient]. + * @return A [GoogleLLMClient] instance configured with the provided settings. */ @Bean + @ConditionalOnPropertyNotEmpty( + prefix = GoogleKoogProperties.PREFIX, + name = "api-key" + ) + @ConditionalOnProperty(prefix = GoogleKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") public fun googleLLMClient(): GoogleLLMClient { + logger.info("Creating GoogleLLMClient with baseUrl=${properties.baseUrl}") return GoogleLLMClient( apiKey = properties.apiKey, settings = GoogleClientSettings(baseUrl = properties.baseUrl) @@ -70,6 +82,7 @@ public class GoogleLLMAutoConfiguration( @Bean @ConditionalOnBean(GoogleLLMClient::class) public fun googleExecutor(client: GoogleLLMClient): SingleLLMPromptExecutor { + logger.info("Creating SingleLLMPromptExecutor (googleExecutor) for GoogleLLMClient") return SingleLLMPromptExecutor(client.toRetryingClient(properties.retry)) } } diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/ollama/OllamaLLMAutoConfiguration.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/ollama/OllamaLLMAutoConfiguration.kt index 798dc4bd14..212c7a3951 100644 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/ollama/OllamaLLMAutoConfiguration.kt +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/ollama/OllamaLLMAutoConfiguration.kt @@ -3,6 +3,7 @@ package ai.koog.spring.prompt.executor.clients.ollama import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor import ai.koog.prompt.executor.ollama.client.OllamaClient import ai.koog.spring.prompt.executor.clients.toRetryingClient +import org.slf4j.LoggerFactory import org.springframework.boot.autoconfigure.AutoConfiguration import org.springframework.boot.autoconfigure.condition.ConditionalOnBean import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty @@ -36,21 +37,25 @@ import org.springframework.context.annotation.PropertySource @EnableConfigurationProperties( OllamaKoogProperties::class, ) -@ConditionalOnProperty(prefix = OllamaKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") public class OllamaLLMAutoConfiguration( private val properties: OllamaKoogProperties ) { + private val logger = LoggerFactory.getLogger(OllamaLLMAutoConfiguration::class.java) + /** - * Creates and configures an instance of [OllamaClient] using the base URL from the provided properties. + * Creates an [OllamaClient] bean configured with application properties. * - * This client is used to communicate with the Ollama LLM service and is a prerequisite - * for executing prompts and other interactions with the service. + * This method initializes a [OllamaClient] using the API key and base URL + * specified in the application's configuration. It is only executed if the + * `koog.ai.ollama.enabled` property is set to `true` in the application configuration. * - * @return an [OllamaClient] configured with the base URL extracted from the application's properties. + * @return An [OllamaClient] instance configured with the provided settings. */ @Bean + @ConditionalOnProperty(prefix = OllamaKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") public fun ollamaLLMClient(): OllamaClient { + logger.info("Creating OllamaClient with baseUrl=${properties.baseUrl}") return OllamaClient( baseUrl = properties.baseUrl, ) @@ -66,6 +71,7 @@ public class OllamaLLMAutoConfiguration( @Bean @ConditionalOnBean(OllamaClient::class) public fun ollamaExecutor(client: OllamaClient): SingleLLMPromptExecutor { + logger.info("Creating SingleLLMPromptExecutor (ollamaExecutor) for OllamaClient") return SingleLLMPromptExecutor(client.toRetryingClient(properties.retry)) } } diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openai/OpenAILLMAutoConfiguration.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openai/OpenAILLMAutoConfiguration.kt index 93ad1a0713..ea69bbf97a 100644 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openai/OpenAILLMAutoConfiguration.kt +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openai/OpenAILLMAutoConfiguration.kt @@ -3,7 +3,9 @@ package ai.koog.spring.prompt.executor.clients.openai import ai.koog.prompt.executor.clients.openai.OpenAIClientSettings import ai.koog.prompt.executor.clients.openai.OpenAILLMClient import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor +import ai.koog.spring.conditions.ConditionalOnPropertyNotEmpty import ai.koog.spring.prompt.executor.clients.toRetryingClient +import org.slf4j.LoggerFactory import org.springframework.boot.autoconfigure.AutoConfiguration import org.springframework.boot.autoconfigure.condition.ConditionalOnBean import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty @@ -38,20 +40,30 @@ import org.springframework.context.annotation.PropertySource @EnableConfigurationProperties( OpenAIKoogProperties::class, ) -@ConditionalOnProperty(prefix = OpenAIKoogProperties.PREFIX, name = ["api-key"]) -@ConditionalOnProperty(prefix = OpenAIKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") public class OpenAILLMAutoConfiguration( private val properties: OpenAIKoogProperties ) { + private val logger = LoggerFactory.getLogger(OpenAILLMAutoConfiguration::class.java) + /** - * Creates and provides an instance of [OpenAILLMClient] as a Spring bean for use in the application context. - * The [OpenAILLMClient] is configured using API key and base URL from the associated properties. + * Creates an [OpenAILLMClient] bean configured with application properties. + * + * This method initializes a [OpenAILLMClient] using the API key and base URL + * specified in the application's configuration. It is only executed if the + * `koog.ai.openai.api-key` property is defined and `koog.ai.openai.enabled` property is set + * to `true` in the application configuration. * - * @return a configured instance of [OpenAILLMClient]. + * @return An [OpenAILLMClient] instance configured with the provided settings. */ @Bean + @ConditionalOnPropertyNotEmpty( + prefix = OpenAIKoogProperties.PREFIX, + name = "api-key" + ) + @ConditionalOnProperty(prefix = OpenAIKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") public fun openAILLMClient(): OpenAILLMClient { + logger.info("Creating OpenAILLMClient client with baseUrl=${properties.baseUrl}") return OpenAILLMClient( apiKey = properties.apiKey, settings = OpenAIClientSettings(baseUrl = properties.baseUrl) @@ -68,6 +80,7 @@ public class OpenAILLMAutoConfiguration( @Bean @ConditionalOnBean(OpenAILLMClient::class) public fun openAIExecutor(client: OpenAILLMClient): SingleLLMPromptExecutor { + logger.info("Creating SingleLLMPromptExecutor (openAIExecutor) for OpenAILLMClient") return SingleLLMPromptExecutor(client.toRetryingClient(properties.retry)) } } diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openrouter/OpenRouterLLMAutoConfiguration.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openrouter/OpenRouterLLMAutoConfiguration.kt index 6e8928adaa..a8b33ad5b2 100644 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openrouter/OpenRouterLLMAutoConfiguration.kt +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/openrouter/OpenRouterLLMAutoConfiguration.kt @@ -3,7 +3,9 @@ package ai.koog.spring.prompt.executor.clients.openrouter import ai.koog.prompt.executor.clients.openrouter.OpenRouterClientSettings import ai.koog.prompt.executor.clients.openrouter.OpenRouterLLMClient import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor +import ai.koog.spring.conditions.ConditionalOnPropertyNotEmpty import ai.koog.spring.prompt.executor.clients.toRetryingClient +import org.slf4j.LoggerFactory import org.springframework.boot.autoconfigure.AutoConfiguration import org.springframework.boot.autoconfigure.condition.ConditionalOnBean import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty @@ -30,21 +32,30 @@ import org.springframework.context.annotation.PropertySource @EnableConfigurationProperties( OpenRouterKoogProperties::class, ) -@ConditionalOnProperty(prefix = OpenRouterKoogProperties.PREFIX, name = ["api-key"]) -@ConditionalOnProperty(prefix = OpenRouterKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") public class OpenRouterLLMAutoConfiguration( private val properties: OpenRouterKoogProperties ) { + private val logger = LoggerFactory.getLogger(OpenRouterLLMAutoConfiguration::class.java) + /** - * Creates and configures an instance of [OpenRouterLLMClient] as a Spring Bean. - * The client is initialized with the API key and settings (such as base URL) - * obtained from the provided `properties` configuration. + * Creates an [OpenRouterLLMClient] bean configured with application properties. + * + * This method initializes a [OpenRouterLLMClient] using the API key and base URL + * specified in the application's configuration. It is only executed if the + * `koog.ai.openrouter.api-key` property is defined and `koog.ai.openrouter.enabled` property is set + * to `true` in the application configuration. * - * @return An instance of [OpenRouterLLMClient] configured with the given properties. + * @return An [OpenRouterLLMClient] instance configured with the provided settings. */ @Bean + @ConditionalOnPropertyNotEmpty( + prefix = OpenRouterKoogProperties.PREFIX, + name = "api-key" + ) + @ConditionalOnProperty(prefix = OpenRouterKoogProperties.PREFIX, name = ["enabled"], havingValue = "true") public fun openRouterLLMClient(): OpenRouterLLMClient { + logger.info("Creating OpenRouterLLMClient with baseUrl=${properties.baseUrl}") return OpenRouterLLMClient( apiKey = properties.apiKey, settings = OpenRouterClientSettings(baseUrl = properties.baseUrl) @@ -62,6 +73,7 @@ public class OpenRouterLLMAutoConfiguration( @Bean @ConditionalOnBean(OpenRouterLLMClient::class) public fun openRouterExecutor(client: OpenRouterLLMClient): SingleLLMPromptExecutor { + logger.info("Creating SingleLLMPromptExecutor (openRouterExecutor) for OpenRouterLLMClient") return SingleLLMPromptExecutor(client.toRetryingClient(properties.retry)) } } diff --git a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/utils.kt b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/utils.kt index efebf0447a..16893377af 100644 --- a/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/utils.kt +++ b/koog-spring-boot-starter/src/main/kotlin/ai/koog/spring/prompt/executor/clients/utils.kt @@ -2,14 +2,14 @@ package ai.koog.spring.prompt.executor.clients import ai.koog.prompt.executor.clients.LLMClient import ai.koog.prompt.executor.clients.retry.RetryConfig -import ai.koog.prompt.executor.clients.retry.RetryingLLMClient +import ai.koog.prompt.executor.clients.retry.toRetryingClient import ai.koog.spring.RetryConfigKoogProperties import kotlin.time.toKotlinDuration internal fun LLMClient.toRetryingClient(properties: RetryConfigKoogProperties?): LLMClient { val self = this return if (properties?.enabled == true) { - val defaultConfig = RetryConfig() + val defaultConfig = RetryConfig.DEFAULT val retryConfig = RetryConfig( maxAttempts = properties.maxAttempts ?: defaultConfig.maxAttempts, initialDelay = properties.initialDelay?.toKotlinDuration() ?: defaultConfig.initialDelay, @@ -17,10 +17,7 @@ internal fun LLMClient.toRetryingClient(properties: RetryConfigKoogProperties?): backoffMultiplier = properties.backoffMultiplier ?: defaultConfig.backoffMultiplier, jitterFactor = properties.jitterFactor ?: defaultConfig.jitterFactor ) - RetryingLLMClient( - delegate = self, - config = retryConfig - ) + self.toRetryingClient(retryConfig) } else { self } diff --git a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/anthropic-llm.properties b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/anthropic-llm.properties index fa078f9422..6f77f4a40e 100644 --- a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/anthropic-llm.properties +++ b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/anthropic-llm.properties @@ -1,2 +1,3 @@ -ai.koog.anthropic.enabled=true +ai.koog.anthropic.api-key=${ANTHROPIC_API_KEY:} ai.koog.anthropic.base-url=https://api.anthropic.com +ai.koog.anthropic.enabled=true diff --git a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/deepseek-llm.properties b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/deepseek-llm.properties index a924235966..9479ac722f 100644 --- a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/deepseek-llm.properties +++ b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/deepseek-llm.properties @@ -1,2 +1,3 @@ -ai.koog.deepseek.enabled=true +ai.koog.deepseek.api-key=${DEEPSEEK_API_KEY:} ai.koog.deepseek.base-url=https://api.deepseek.com +ai.koog.deepseek.enabled=true diff --git a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/google-llm.properties b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/google-llm.properties index e6c0935123..2b2b1cbf73 100644 --- a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/google-llm.properties +++ b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/google-llm.properties @@ -1,2 +1,3 @@ -ai.koog.google.enabled=true +ai.koog.google.api-key=${GOOGLE_API_KEY:} ai.koog.google.base-url=https://generativelanguage.googleapis.com +ai.koog.google.enabled=true diff --git a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/ollama-llm.properties b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/ollama-llm.properties index 9f93516484..ff1168ed2e 100644 --- a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/ollama-llm.properties +++ b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/ollama-llm.properties @@ -1,2 +1,2 @@ ai.koog.ollama.enabled=false -ai.koog.ollama.base-url=http://localhost:11434 +ai.koog.ollama.base-url=http://127.0.0.1:11434 diff --git a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openai-llm.properties b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openai-llm.properties index 1ec885474d..e9aceb657b 100644 --- a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openai-llm.properties +++ b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openai-llm.properties @@ -1,2 +1,3 @@ ai.koog.openai.enabled=true ai.koog.openai.base-url=https://api.openai.com +ai.koog.openai.api-key=${OPENAI_API_KEY:} diff --git a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openrouter-llm.properties b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openrouter-llm.properties index 54c3588a7a..6e747ac5ec 100644 --- a/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openrouter-llm.properties +++ b/koog-spring-boot-starter/src/main/resources/META-INF/config/koog/openrouter-llm.properties @@ -1,2 +1,3 @@ -ai.koog.openrouter.enabled=true +ai.koog.openrouter.api-key=${OPENROUTER_API_KEY:} ai.koog.openrouter.base-url=https://openrouter.ai +ai.koog.openrouter.enabled=true diff --git a/koog-spring-boot-starter/src/test/kotlin/ai/koog/spring/KoogAutoConfigurationIntegrationTest.kt b/koog-spring-boot-starter/src/test/kotlin/ai/koog/spring/KoogAutoConfigurationIntegrationTest.kt new file mode 100644 index 0000000000..f2b1186fd3 --- /dev/null +++ b/koog-spring-boot-starter/src/test/kotlin/ai/koog/spring/KoogAutoConfigurationIntegrationTest.kt @@ -0,0 +1,106 @@ +package ai.koog.spring + +import ai.koog.prompt.executor.clients.LLMClient +import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient +import ai.koog.prompt.executor.clients.deepseek.DeepSeekLLMClient +import ai.koog.prompt.executor.clients.google.GoogleLLMClient +import ai.koog.prompt.executor.clients.openai.OpenAILLMClient +import ai.koog.prompt.executor.clients.openrouter.OpenRouterLLMClient +import ai.koog.prompt.executor.llms.SingleLLMPromptExecutor +import ai.koog.prompt.executor.ollama.client.OllamaClient +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource +import org.slf4j.LoggerFactory +import org.springframework.beans.factory.annotation.Autowired +import org.springframework.boot.autoconfigure.EnableAutoConfiguration +import org.springframework.boot.test.context.SpringBootTest +import org.springframework.context.ApplicationContext +import org.springframework.context.annotation.Configuration +import org.springframework.test.context.TestPropertySource + +@SpringBootTest( + classes = [ + KoogAutoConfigurationIntegrationTest.TestConfig::class, + ], + properties = [ + "debug=false", // set to true for troubleshooting + "spring.main.banner-mode=off" + ], + webEnvironment = SpringBootTest.WebEnvironment.NONE +) +@TestPropertySource( + locations = ["classpath:/it-application.properties"] +) +class KoogAutoConfigurationIntegrationTest { + + @Configuration + @EnableAutoConfiguration + @Suppress("unused") + private class TestConfig + + private val logger = LoggerFactory.getLogger(KoogAutoConfigurationIntegrationTest::class.java) + + @Autowired + private lateinit var applicationContext: ApplicationContext + + @ParameterizedTest + @ValueSource( + classes = [ + AnthropicLLMClient::class, + DeepSeekLLMClient::class, + GoogleLLMClient::class, + OpenAILLMClient::class, + OpenRouterLLMClient::class, + OllamaClient::class, + ] + ) + fun `Should register beans(classes)`(clazz: Class<*>) { + verifyBeanIsRegistered(clazz) + } + + @ParameterizedTest + @ValueSource( + strings = [ + "anthropicExecutor", + "deepSeekExecutor", + "googleExecutor", + "ollamaExecutor", + "openAIExecutor", + "openRouterExecutor", + ] + ) + fun `Should register SingleLLMExecutors`(beanName: String) { + val llmExecutorBeanNames = applicationContext.getBeanNamesForType( + SingleLLMPromptExecutor::class.java + ) + assertTrue(llmExecutorBeanNames.contains(beanName)) { + logger.info( + "Registered ${SingleLLMPromptExecutor::class.simpleName} beans:${ + llmExecutorBeanNames + .joinToString(separator = "\n\t", prefix = "\n\t") + }" + ) + + "Bean named `$beanName` should have been registered" + } + } + + private inline fun verifyBeanIsRegistered(clazz: Class<*>) { + assertTrue(applicationContext.getBeansOfType(clazz).size == 1) { + logger.info( + "Registered beans:${ + applicationContext.beanDefinitionNames + .joinToString(separator = "\n\t", prefix = "\n\t") + }" + ) + + "Bean of type ${clazz.simpleName} should have been registered" + } + + val bean = applicationContext.getBean(clazz) + assertTrue(bean is EXTRA) { + "Registered bean of type $clazz should be also a ${EXTRA::class}" + } + } +} diff --git a/koog-spring-boot-starter/src/test/kotlin/ai/koog/spring/KoogAutoConfigurationTest.kt b/koog-spring-boot-starter/src/test/kotlin/ai/koog/spring/KoogAutoConfigurationTest.kt index c986991b49..f9330469e5 100644 --- a/koog-spring-boot-starter/src/test/kotlin/ai/koog/spring/KoogAutoConfigurationTest.kt +++ b/koog-spring-boot-starter/src/test/kotlin/ai/koog/spring/KoogAutoConfigurationTest.kt @@ -45,21 +45,23 @@ private const val PROVIDERS = """ @TestInstance(TestInstance.Lifecycle.PER_CLASS) class KoogAutoConfigurationTest { - private val defaultRetryConfig = RetryConfig() - - private val allProvidersAutoConfigurations = AutoConfigurations.of( - AnthropicLLMAutoConfiguration::class.java, - GoogleLLMAutoConfiguration::class.java, - DeepSeekLLMAutoConfiguration::class.java, - OllamaLLMAutoConfiguration::class.java, - OpenAILLMAutoConfiguration::class.java, - OpenRouterLLMAutoConfiguration::class.java, - ) + private val defaultRetryConfig = RetryConfig.DEFAULT + + private fun createApplicationContextRunner(): ApplicationContextRunner = ApplicationContextRunner() + .withConfiguration( + AutoConfigurations.of( + AnthropicLLMAutoConfiguration::class.java, + GoogleLLMAutoConfiguration::class.java, + DeepSeekLLMAutoConfiguration::class.java, + OllamaLLMAutoConfiguration::class.java, + OpenAILLMAutoConfiguration::class.java, + OpenRouterLLMAutoConfiguration::class.java, + ) + ) @Test fun `should not supply executor beans if no apiKey is provided`() { - ApplicationContextRunner() - .withConfiguration(allProvidersAutoConfigurations) + createApplicationContextRunner() .run { context -> assertThrows { context.getBean() } } @@ -68,10 +70,8 @@ class KoogAutoConfigurationTest { @Test fun `should supply OpenAI executor bean with provided apiKey and default baseUrl`() { val configApiKey = "some_api_key" - ApplicationContextRunner() - .withConfiguration(allProvidersAutoConfigurations) + createApplicationContextRunner() .withPropertyValues( - "ai.koog.openai.enabled=true", "ai.koog.openai.api-key=$configApiKey" ) .run { context -> @@ -92,10 +92,8 @@ class KoogAutoConfigurationTest { @Test fun `should supply OpenAI executor bean with provided baseUrl`() { val configBaseUrl = "https://some-url.com" - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(OpenAILLMAutoConfiguration::class.java)) + createApplicationContextRunner() .withPropertyValues( - "ai.koog.openai.enabled=true", "ai.koog.openai.api-key=some_api_key", "ai.koog.openai.base-url=$configBaseUrl", ) @@ -116,8 +114,7 @@ class KoogAutoConfigurationTest { provider: String, clazz: Class ) { - ApplicationContextRunner() - .withConfiguration(allProvidersAutoConfigurations) + createApplicationContextRunner() .withPropertyValues( "ai.koog.$provider.enabled=true", "ai.koog.$provider.api-key=some_api_key", @@ -152,8 +149,7 @@ class KoogAutoConfigurationTest { val maxDelay = 60 val backoffMultiplier = 5.0 val jitterFactor = 0.5 - ApplicationContextRunner() - .withConfiguration(allProvidersAutoConfigurations) + createApplicationContextRunner() .withPropertyValues( "ai.koog.$provider.enabled=true", "ai.koog.$provider.api-key=some_api_key", @@ -192,8 +188,7 @@ class KoogAutoConfigurationTest { val maxDelay = 60 val backoffMultiplier = 5.0 val jitterFactor = 0.5 - ApplicationContextRunner() - .withConfiguration(allProvidersAutoConfigurations) + createApplicationContextRunner() .withPropertyValues( "ai.koog.$provider.enabled=false", "ai.koog.$provider.api-key=some_api_key", @@ -220,8 +215,7 @@ class KoogAutoConfigurationTest { ) { val maxAttempts = 5 val initialDelay = 10 - ApplicationContextRunner() - .withConfiguration(allProvidersAutoConfigurations) + createApplicationContextRunner() .withPropertyValues( "ai.koog.$provider.enabled=true", "ai.koog.$provider.api-key=some_api_key", @@ -249,10 +243,8 @@ class KoogAutoConfigurationTest { @Test fun `should supply Anthropic executor bean with provided apiKey and default baseUrl`() { val configApiKey = "some_api_key" - ApplicationContextRunner() - .withConfiguration(allProvidersAutoConfigurations) + createApplicationContextRunner() .withPropertyValues( - "ai.koog.anthropic.enabled=true", "ai.koog.anthropic.api-key=$configApiKey" ) .run { context -> @@ -275,7 +267,6 @@ class KoogAutoConfigurationTest { ApplicationContextRunner() .withConfiguration(AutoConfigurations.of(AnthropicLLMAutoConfiguration::class.java)) .withPropertyValues( - "ai.koog.anthropic.enabled=true", "ai.koog.anthropic.api-key=some_api_key", "ai.koog.anthropic.retry.enabled=true" ) @@ -295,10 +286,8 @@ class KoogAutoConfigurationTest { @Test fun `should supply Anthropic executor bean with provided baseUrl`() { val configBaseUrl = "https://some-url.com" - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(AnthropicLLMAutoConfiguration::class.java)) + createApplicationContextRunner() .withPropertyValues( - "ai.koog.anthropic.enabled=true", "ai.koog.anthropic.api-key=some_api_key", "ai.koog.anthropic.base-url=$configBaseUrl", ) @@ -316,10 +305,8 @@ class KoogAutoConfigurationTest { @Test fun `should supply Google executor bean with provided apiKey and default baseUrl`() { val configApiKey = "some_api_key" - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(GoogleLLMAutoConfiguration::class.java)) + createApplicationContextRunner() .withPropertyValues( - "ai.koog.google.enabled=true", "ai.koog.google.api-key=$configApiKey" ) .run { context -> @@ -340,10 +327,8 @@ class KoogAutoConfigurationTest { @Test fun `should supply Google executor bean with provided baseUrl`() { val configBaseUrl = "https://some-url.com" - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(GoogleLLMAutoConfiguration::class.java)) + createApplicationContextRunner() .withPropertyValues( - "ai.koog.google.enabled=true", "ai.koog.google.api-key=some_api_key", "ai.koog.google.base-url=$configBaseUrl", ) @@ -360,10 +345,8 @@ class KoogAutoConfigurationTest { @Test fun `should supply Google executor bean with retry client and default config`() { - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(GoogleLLMAutoConfiguration::class.java)) + createApplicationContextRunner() .withPropertyValues( - "ai.koog.google.enabled=true", "ai.koog.google.api-key=some_api_key", "ai.koog.google.retry.enabled=true" ) @@ -383,8 +366,7 @@ class KoogAutoConfigurationTest { @Test fun `should supply OpenRouter executor bean with provided apiKey and default baseUrl`() { val configApiKey = "some_api_key" - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(OpenRouterLLMAutoConfiguration::class.java)) + createApplicationContextRunner() .withPropertyValues( "ai.koog.openrouter.enabled=true", "ai.koog.openrouter.api-key=$configApiKey" @@ -407,8 +389,7 @@ class KoogAutoConfigurationTest { @Test fun `should supply OpenRouter executor bean with provided baseUrl`() { val configBaseUrl = "https://some-url.com" - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(OpenRouterLLMAutoConfiguration::class.java)) + createApplicationContextRunner() .withPropertyValues( "ai.koog.openrouter.enabled=true", "ai.koog.openrouter.api-key=some_api_key", @@ -427,8 +408,7 @@ class KoogAutoConfigurationTest { @Test fun `should supply OpenRouter executor bean with retry client and default config`() { - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(OpenRouterLLMAutoConfiguration::class.java)) + createApplicationContextRunner() .withPropertyValues( "ai.koog.openrouter.enabled=true", "ai.koog.openrouter.api-key=some_api_key", @@ -450,10 +430,8 @@ class KoogAutoConfigurationTest { @Test fun `should supply DeepSeek executor bean with provided apiKey and default baseUrl`() { val configApiKey = "some_api_key" - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(DeepSeekLLMAutoConfiguration::class.java)) + createApplicationContextRunner() .withPropertyValues( - "ai.koog.deepseek.enabled=true", "ai.koog.deepseek.api-key=$configApiKey" ) .run { context -> @@ -474,13 +452,10 @@ class KoogAutoConfigurationTest { @Test fun `should supply DeepSeek executor bean with provided baseUrl`() { val configBaseUrl = "https://some-url.com" - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(DeepSeekLLMAutoConfiguration::class.java)) - .withPropertyValues( - "ai.koog.deepseek.enabled=true", - "ai.koog.deepseek.api-key=some_api_key", - "ai.koog.deepseek.base-url=$configBaseUrl", - ) + createApplicationContextRunner().withPropertyValues( + "ai.koog.deepseek.api-key=some_api_key", + "ai.koog.deepseek.base-url=$configBaseUrl", + ) .run { context -> val executor = context.getBean() val llmClient = getPrivateFieldValue(executor, "llmClient") as DeepSeekLLMClient @@ -494,10 +469,8 @@ class KoogAutoConfigurationTest { @Test fun `should supply DeepSeek executor bean with retry client and default config`() { - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(DeepSeekLLMAutoConfiguration::class.java)) + createApplicationContextRunner() .withPropertyValues( - "ai.koog.deepseek.enabled=true", "ai.koog.deepseek.api-key=some_api_key", "ai.koog.deepseek.retry.enabled=true" ) @@ -517,12 +490,10 @@ class KoogAutoConfigurationTest { @Test fun `should supply Ollama executor bean with provided baseUrl`() { val configBaseUrl = "https://some-url.com" - ApplicationContextRunner() - .withConfiguration(AutoConfigurations.of(OllamaLLMAutoConfiguration::class.java)) - .withPropertyValues( - "ai.koog.ollama.enabled=true", - "ai.koog.ollama.base-url=$configBaseUrl" - ) + createApplicationContextRunner().withPropertyValues( + "ai.koog.ollama.enabled=true", + "ai.koog.ollama.base-url=$configBaseUrl" + ) .run { context -> val executor = context.getBean() val llmClient = getPrivateFieldValue(executor, "llmClient") @@ -558,18 +529,11 @@ class KoogAutoConfigurationTest { @Test fun `should supply multiple executor beans`() { - ApplicationContextRunner() - .withConfiguration( - allProvidersAutoConfigurations - ) + createApplicationContextRunner() .withPropertyValues( - "ai.koog.openai.enabled=true", "ai.koog.openai.api-key=some_api_key", - "ai.koog.anthropic.enabled=true", "ai.koog.anthropic.api-key=some_api_key", - "ai.koog.google.enabled=true", "ai.koog.google.api-key=some_api_key", - "ai.koog.deepseek.enabled=true", "ai.koog.deepseek.api-key=some_api_key", "ai.koog.ollama.enabled=true", ) diff --git a/koog-spring-boot-starter/src/test/resources/it-application.properties b/koog-spring-boot-starter/src/test/resources/it-application.properties new file mode 100644 index 0000000000..2b7210053a --- /dev/null +++ b/koog-spring-boot-starter/src/test/resources/it-application.properties @@ -0,0 +1,6 @@ +ANTHROPIC_API_KEY=test_anthropic_api_key +DEEPSEEK_API_KEY=test_deepseek_api_key +GOOGLE_API_KEY=test_google_api_key +OPENAI_API_KEY=test_openai_api_key +OPENROUTER_API_KEY=test_openrouter_api_key +ai.koog.ollama.enabled=true diff --git a/prompt/prompt-executor/prompt-executor-clients/src/commonMain/kotlin/ai/koog/prompt/executor/clients/retry/RetryConfig.kt b/prompt/prompt-executor/prompt-executor-clients/src/commonMain/kotlin/ai/koog/prompt/executor/clients/retry/RetryConfig.kt index 1c57b9cffe..b865d5ea37 100644 --- a/prompt/prompt-executor/prompt-executor-clients/src/commonMain/kotlin/ai/koog/prompt/executor/clients/retry/RetryConfig.kt +++ b/prompt/prompt-executor/prompt-executor-clients/src/commonMain/kotlin/ai/koog/prompt/executor/clients/retry/RetryConfig.kt @@ -28,7 +28,9 @@ public data class RetryConfig( require(maxAttempts >= 1) { "maxAttempts must be at least 1" } require(backoffMultiplier >= 1.0) { "backoffMultiplier must be at least 1.0" } require(jitterFactor in 0.0..1.0) { "jitterFactor must be between 0.0 and 1.0" } - require(initialDelay <= maxDelay) { "initialDelay ($initialDelay) must not be greater than maxDelay ($maxDelay)" } + require(initialDelay <= maxDelay) { + "initialDelay ($initialDelay) must not be greater than maxDelay ($maxDelay)" + } } /** @@ -97,6 +99,13 @@ public data class RetryConfig( * No retry - effectively disables retry logic. */ public val DISABLED: RetryConfig = RetryConfig(maxAttempts = 1) + + /** + * The default retry configuration used by clients implementing retry logic. + * + * Suitable for general-purpose use cases where standard retry behavior is required. + */ + public val DEFAULT: RetryConfig = RetryConfig() } } diff --git a/prompt/prompt-executor/prompt-executor-clients/src/commonMain/kotlin/ai/koog/prompt/executor/clients/retry/RetryingLLMClient.kt b/prompt/prompt-executor/prompt-executor-clients/src/commonMain/kotlin/ai/koog/prompt/executor/clients/retry/RetryingLLMClient.kt index 621263e8bc..1056e4d14a 100644 --- a/prompt/prompt-executor/prompt-executor-clients/src/commonMain/kotlin/ai/koog/prompt/executor/clients/retry/RetryingLLMClient.kt +++ b/prompt/prompt-executor/prompt-executor-clients/src/commonMain/kotlin/ai/koog/prompt/executor/clients/retry/RetryingLLMClient.kt @@ -34,7 +34,7 @@ import kotlin.time.Duration.Companion.milliseconds */ public class RetryingLLMClient( private val delegate: LLMClient, - private val config: RetryConfig = RetryConfig() + internal val config: RetryConfig = RetryConfig() ) : LLMClient { private companion object { @@ -163,3 +163,17 @@ public class RetryingLLMClient( return finalMs.milliseconds } } + +/** + * Converts an instance of [LLMClient] into a retrying client with customizable retry behavior. + * + * @param retryConfig Configuration for retry behavior. Defaults to [RetryConfig.DEFAULT]. + * @return A new instance of [RetryingLLMClient] that adds retry logic to the provided client. + */ +public fun LLMClient.toRetryingClient( + retryConfig: RetryConfig = RetryConfig.DEFAULT +): RetryingLLMClient = + RetryingLLMClient( + delegate = this, + config = retryConfig + ) diff --git a/prompt/prompt-executor/prompt-executor-clients/src/commonTest/kotlin/ai/koog/prompt/executor/clients/retry/RetryingLLMClientTest.kt b/prompt/prompt-executor/prompt-executor-clients/src/commonTest/kotlin/ai/koog/prompt/executor/clients/retry/RetryingLLMClientTest.kt index 1c30949b1f..6f8c501af3 100644 --- a/prompt/prompt-executor/prompt-executor-clients/src/commonTest/kotlin/ai/koog/prompt/executor/clients/retry/RetryingLLMClientTest.kt +++ b/prompt/prompt-executor/prompt-executor-clients/src/commonTest/kotlin/ai/koog/prompt/executor/clients/retry/RetryingLLMClientTest.kt @@ -26,6 +26,7 @@ import kotlinx.datetime.Clock import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFailsWith +import kotlin.test.assertSame import kotlin.time.Duration.Companion.milliseconds class RetryingLLMClientTest { @@ -65,6 +66,28 @@ class RetryingLLMClientTest { assertEquals(1, mockClient.executeCalls) } + @Test + fun testConvertLLMClientToRetryingClientWithDefaultConfig() = runTest { + val mockClient = MockLLMClient() + // when + val retryingClient = mockClient.toRetryingClient() + + // then + assertSame(actual = retryingClient.config, expected = RetryConfig.DEFAULT) + } + + @Test + fun testConvertLLMClientToRetryingClientWithCustomConfig() = runTest { + // given + val mockClient = MockLLMClient() + val retryConfig = RetryConfig(maxAttempts = 100500) + // when + val retryingClient = mockClient.toRetryingClient(retryConfig) + + // then + assertSame(actual = retryingClient.config, expected = retryConfig) + } + @Test fun testRetryOnRateLimitError() = runTest { val mockClient = MockLLMClient( From 91164952dba9b4a5f0ea9ab78d1b75969d35e225 Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Mon, 29 Sep 2025 05:11:58 +0200 Subject: [PATCH 16/52] [a2a] Add JSON-RPC protocol foundation and module structure --- a2a/a2a-client/build.gradle.kts | 48 +++++++ .../kotlin/ai/koog/a2a/client/.gitkeep | 0 a2a/a2a-core/build.gradle.kts | 48 +++++++ .../commonMain/kotlin/ai/koog/a2a/.gitkeep | 0 .../kotlin/ai/koog/a2a/model/RequestId.kt | 17 +++ .../kotlin/ai/koog/a2a/model/Serialization.kt | 38 ++++++ .../koog/a2a/model/ModelSerializationTest.kt | 34 +++++ a2a/a2a-server/build.gradle.kts | 48 +++++++ .../kotlin/ai/koog/a2a/server/.gitkeep | 0 .../build.gradle.kts | 48 +++++++ .../transport/client/jsonrpc/http/.gitkeep | 0 .../build.gradle.kts | 48 +++++++ .../koog/a2a/transport/client/rest/.gitkeep | 0 .../build.gradle.kts | 50 +++++++ .../a2a/transport/jsonrpc/model/Messages.kt | 61 +++++++++ .../transport/jsonrpc/model/Serialization.kt | 43 ++++++ .../jsonrpc/model/JsonRpcSerializationTest.kt | 123 ++++++++++++++++++ .../a2a-transport-core-rest/build.gradle.kts | 48 +++++++ .../ai/koog/a2a/transport/rest/.gitkeep | 0 .../build.gradle.kts | 48 +++++++ .../transport/server/jsonrpc/http/.gitkeep | 0 .../build.gradle.kts | 48 +++++++ .../koog/a2a/transport/server/rest/.gitkeep | 0 koog-agents/build.gradle.kts | 11 ++ settings.gradle.kts | 10 ++ 25 files changed, 771 insertions(+) create mode 100644 a2a/a2a-client/build.gradle.kts create mode 100644 a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/.gitkeep create mode 100644 a2a/a2a-core/build.gradle.kts create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/.gitkeep create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/RequestId.kt create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt create mode 100644 a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/model/ModelSerializationTest.kt create mode 100644 a2a/a2a-server/build.gradle.kts create mode 100644 a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/.gitkeep create mode 100644 a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts create mode 100644 a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/.gitkeep create mode 100644 a2a/a2a-transport/a2a-transport-client-rest/build.gradle.kts create mode 100644 a2a/a2a-transport/a2a-transport-client-rest/src/commonMain/kotlin/ai/koog/a2a/transport/client/rest/.gitkeep create mode 100644 a2a/a2a-transport/a2a-transport-core-jsonrpc/build.gradle.kts create mode 100644 a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt create mode 100644 a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Serialization.kt create mode 100644 a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonTest/kotlin/ai/koog/a2a/transport/jsonrpc/model/JsonRpcSerializationTest.kt create mode 100644 a2a/a2a-transport/a2a-transport-core-rest/build.gradle.kts create mode 100644 a2a/a2a-transport/a2a-transport-core-rest/src/commonMain/kotlin/ai/koog/a2a/transport/rest/.gitkeep create mode 100644 a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts create mode 100644 a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/.gitkeep create mode 100644 a2a/a2a-transport/a2a-transport-server-rest/build.gradle.kts create mode 100644 a2a/a2a-transport/a2a-transport-server-rest/src/commonMain/kotlin/ai/koog/a2a/transport/server/rest/.gitkeep diff --git a/a2a/a2a-client/build.gradle.kts b/a2a/a2a-client/build.gradle.kts new file mode 100644 index 0000000000..36e5aa832a --- /dev/null +++ b/a2a/a2a-client/build.gradle.kts @@ -0,0 +1,48 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + jvm() + + js(IR) { + browser() + } + + sourceSets { + commonMain { + dependencies { + api(libs.kotlinx.serialization.json) + api(libs.kotlinx.coroutines.core) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + } + } + + jsTest { + dependencies { + implementation(kotlin("test-js")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/.gitkeep b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/a2a/a2a-core/build.gradle.kts b/a2a/a2a-core/build.gradle.kts new file mode 100644 index 0000000000..36e5aa832a --- /dev/null +++ b/a2a/a2a-core/build.gradle.kts @@ -0,0 +1,48 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + jvm() + + js(IR) { + browser() + } + + sourceSets { + commonMain { + dependencies { + api(libs.kotlinx.serialization.json) + api(libs.kotlinx.coroutines.core) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + } + } + + jsTest { + dependencies { + implementation(kotlin("test-js")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/.gitkeep b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/RequestId.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/RequestId.kt new file mode 100644 index 0000000000..7e472d169a --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/RequestId.kt @@ -0,0 +1,17 @@ +@file:Suppress("MissingKDocForPublicAPI") + +package ai.koog.a2a.model + +import kotlinx.serialization.Serializable + +/** + * A uniquely identifying ID for a request. + */ +@Serializable(with = RequestIdSerializer::class) +public sealed interface RequestId { + @Serializable + public data class StringId(val value: String) : RequestId + + @Serializable + public data class NumberId(val value: Long) : RequestId +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt new file mode 100644 index 0000000000..554edabb28 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt @@ -0,0 +1,38 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.KSerializer +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.buildClassSerialDescriptor +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder +import kotlinx.serialization.json.JsonDecoder +import kotlinx.serialization.json.JsonEncoder +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.long +import kotlinx.serialization.json.longOrNull + +internal object RequestIdSerializer : KSerializer { + override val descriptor: SerialDescriptor = buildClassSerialDescriptor("RequestId") + + override fun deserialize(decoder: Decoder): RequestId { + val jsonDecoder = decoder as? JsonDecoder ?: error("Can only deserialize JSON") + + return when (val element = jsonDecoder.decodeJsonElement()) { + is JsonPrimitive -> when { + element.isString -> RequestId.StringId(element.content) + element.longOrNull != null -> RequestId.NumberId(element.long) + else -> error("Invalid RequestId type") + } + + else -> error("Invalid RequestId format") + } + } + + override fun serialize(encoder: Encoder, value: RequestId) { + val jsonEncoder = encoder as? JsonEncoder ?: error("Can only serialize JSON") + when (value) { + is RequestId.StringId -> jsonEncoder.encodeString(value.value) + is RequestId.NumberId -> jsonEncoder.encodeLong(value.value) + } + } +} diff --git a/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/model/ModelSerializationTest.kt b/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/model/ModelSerializationTest.kt new file mode 100644 index 0000000000..3798c1af2f --- /dev/null +++ b/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/model/ModelSerializationTest.kt @@ -0,0 +1,34 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.json.Json +import kotlin.test.Test +import kotlin.test.assertEquals + +class ModelSerializationTest { + @Suppress("PrivatePropertyName") + private val TestJson = Json + + @Test + fun testRequestIdStringId() { + val requestId = RequestId.StringId("test-id") + val requestIdJson = """"test-id"""" + + val serialized = TestJson.encodeToString(requestId) + assertEquals(requestIdJson, serialized) + + val deserialized = TestJson.decodeFromString(requestIdJson) + assertEquals(requestId, deserialized) + } + + @Test + fun testRequestIdNumberId() { + val requestId = RequestId.NumberId(123L) + val requestIdJson = """123""" + + val serialized = TestJson.encodeToString(requestId) + assertEquals(requestIdJson, serialized) + + val deserialized = TestJson.decodeFromString(requestIdJson) + assertEquals(requestId, deserialized) + } +} diff --git a/a2a/a2a-server/build.gradle.kts b/a2a/a2a-server/build.gradle.kts new file mode 100644 index 0000000000..36e5aa832a --- /dev/null +++ b/a2a/a2a-server/build.gradle.kts @@ -0,0 +1,48 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + jvm() + + js(IR) { + browser() + } + + sourceSets { + commonMain { + dependencies { + api(libs.kotlinx.serialization.json) + api(libs.kotlinx.coroutines.core) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + } + } + + jsTest { + dependencies { + implementation(kotlin("test-js")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/.gitkeep b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts new file mode 100644 index 0000000000..36e5aa832a --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts @@ -0,0 +1,48 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + jvm() + + js(IR) { + browser() + } + + sourceSets { + commonMain { + dependencies { + api(libs.kotlinx.serialization.json) + api(libs.kotlinx.coroutines.core) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + } + } + + jsTest { + dependencies { + implementation(kotlin("test-js")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/.gitkeep b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/a2a/a2a-transport/a2a-transport-client-rest/build.gradle.kts b/a2a/a2a-transport/a2a-transport-client-rest/build.gradle.kts new file mode 100644 index 0000000000..36e5aa832a --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-client-rest/build.gradle.kts @@ -0,0 +1,48 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + jvm() + + js(IR) { + browser() + } + + sourceSets { + commonMain { + dependencies { + api(libs.kotlinx.serialization.json) + api(libs.kotlinx.coroutines.core) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + } + } + + jsTest { + dependencies { + implementation(kotlin("test-js")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/a2a/a2a-transport/a2a-transport-client-rest/src/commonMain/kotlin/ai/koog/a2a/transport/client/rest/.gitkeep b/a2a/a2a-transport/a2a-transport-client-rest/src/commonMain/kotlin/ai/koog/a2a/transport/client/rest/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/build.gradle.kts b/a2a/a2a-transport/a2a-transport-core-jsonrpc/build.gradle.kts new file mode 100644 index 0000000000..c8017f0eee --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/build.gradle.kts @@ -0,0 +1,50 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + jvm() + + js(IR) { + browser() + } + + sourceSets { + commonMain { + dependencies { + api(project(":a2a:a2a-core")) + + api(libs.kotlinx.serialization.json) + api(libs.kotlinx.coroutines.core) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + } + } + + jsTest { + dependencies { + implementation(kotlin("test-js")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt new file mode 100644 index 0000000000..7a2153f68c --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt @@ -0,0 +1,61 @@ +@file:Suppress("MissingKDocForPublicAPI") + +package ai.koog.a2a.transport.jsonrpc.model + +import ai.koog.a2a.model.RequestId +import kotlinx.serialization.EncodeDefault +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonElement + +/** + * Default JSON-RPC version. + */ +public const val JSONRPC_VERSION: String = "2.0" + +@Serializable +public data class JSONRPCError( + val code: Int, + val message: String, + val data: JsonElement? = null, +) + +@Serializable(with = JSONRPCMessageSerializer::class) +public sealed interface JSONRPCMessage { + public val jsonrpc: String +} + +@Serializable(with = JSONRPCResponseSerializer::class) +public sealed interface JSONRPCResponse : JSONRPCMessage + +@Serializable +public data class JSONRPCRequest( + public val id: RequestId, + val method: String, + val params: JsonElement?, + @EncodeDefault + override val jsonrpc: String = JSONRPC_VERSION, +) : JSONRPCMessage + +@Serializable +public data class JSONRPCNotification( + val method: String, + val params: JsonElement?, + @EncodeDefault + override val jsonrpc: String = JSONRPC_VERSION, +) : JSONRPCMessage + +@Serializable +public data class JSONRPCSuccessResponse( + public val id: RequestId, + public val result: JsonElement, + @EncodeDefault + override val jsonrpc: String = JSONRPC_VERSION, +) : JSONRPCResponse + +@Serializable +public data class JSONRPCErrorResponse( + public val id: RequestId?, + public val error: JSONRPCError, + @EncodeDefault + override val jsonrpc: String = JSONRPC_VERSION, +) : JSONRPCResponse diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Serialization.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Serialization.kt new file mode 100644 index 0000000000..a10332d408 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Serialization.kt @@ -0,0 +1,43 @@ +@file:Suppress("MissingKDocForPublicAPI") + +package ai.koog.a2a.transport.jsonrpc.model + +import kotlinx.serialization.DeserializationStrategy +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonContentPolymorphicSerializer +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.jsonObject +import kotlin.collections.contains + +public val JSONRPCJson: Json = Json { + explicitNulls = false + encodeDefaults = false + ignoreUnknownKeys = true +} + +internal object JSONRPCMessageSerializer : JsonContentPolymorphicSerializer(JSONRPCMessage::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + val jsonObject = element.jsonObject + + return when { + "method" in jsonObject -> when { + "id" in jsonObject -> JSONRPCRequest.serializer() + else -> JSONRPCNotification.serializer() + } + + else -> JSONRPCResponseSerializer + } + } +} + +internal object JSONRPCResponseSerializer : JsonContentPolymorphicSerializer(JSONRPCResponse::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + val jsonObject = element.jsonObject + + return when { + "result" in jsonObject -> JSONRPCSuccessResponse.serializer() + "error" in jsonObject -> JSONRPCErrorResponse.serializer() + else -> error("Invalid JSON format") + } + } +} diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonTest/kotlin/ai/koog/a2a/transport/jsonrpc/model/JsonRpcSerializationTest.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonTest/kotlin/ai/koog/a2a/transport/jsonrpc/model/JsonRpcSerializationTest.kt new file mode 100644 index 0000000000..92ccb5317a --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonTest/kotlin/ai/koog/a2a/transport/jsonrpc/model/JsonRpcSerializationTest.kt @@ -0,0 +1,123 @@ +package ai.koog.a2a.transport.jsonrpc.model + +import kotlinx.serialization.json.JsonPrimitive +import kotlin.test.Test +import kotlin.test.assertEquals + +class JsonRpcSerializationTest { + @Test + fun testJSONRPCError() { + val error = JSONRPCError(code = -32700, message = "Parse error", data = JsonPrimitive("some data")) + //language=JSON + val errorJson = """{"code":-32700,"message":"Parse error","data":"some data"}""" + + val serialized = JSONRPCJson.encodeToString(error) + assertEquals(errorJson, serialized) + + val deserialized = JSONRPCJson.decodeFromString(errorJson) + assertEquals(error, deserialized) + } + + @Test + fun testJSONRPCRequest() { + val request: JSONRPCMessage = JSONRPCRequest( + id = RequestId.NumberId(42), + method = "add", + params = null + ) + + //language=JSON + val requestJson = """{"id":42,"method":"add","jsonrpc":"2.0"}""" + + val serialized = JSONRPCJson.encodeToString(request) + assertEquals(requestJson, serialized) + + val deserialized = JSONRPCJson.decodeFromString(requestJson) + assertEquals(request, deserialized) + } + + @Test + fun testJSONRPCNotification() { + val request: JSONRPCMessage = JSONRPCNotification( + method = "update", + params = JsonPrimitive("notification-params") + ) + + //language=JSON + val notificationJson = """{"method":"update","params":"notification-params","jsonrpc":"2.0"}""" + + val serialized = JSONRPCJson.encodeToString(request) + assertEquals(notificationJson, serialized) + + val deserialized = JSONRPCJson.decodeFromString(notificationJson) + assertEquals(request, deserialized) + } + + @Test + fun testJSONRPCNotificationWithoutParams() { + val request: JSONRPCMessage = JSONRPCNotification( + method = "notify", + params = null + ) + + //language=JSON + val notificationWithoutParamsJson = """{"method":"notify","jsonrpc":"2.0"}""" + + val serialized = JSONRPCJson.encodeToString(request) + assertEquals(notificationWithoutParamsJson, serialized) + + val deserialized = JSONRPCJson.decodeFromString(notificationWithoutParamsJson) + assertEquals(request, deserialized) + } + + @Test + fun testJSONRPCSuccessResponse() { + val response: JSONRPCMessage = JSONRPCSuccessResponse( + id = RequestId.NumberId(99), + result = JsonPrimitive(100) + ) + + //language=JSON + val successResponseJson = """{"id":99,"result":100,"jsonrpc":"2.0"}""" + + val serialized = JSONRPCJson.encodeToString(response) + assertEquals(successResponseJson, serialized) + + val deserialized = JSONRPCJson.decodeFromString(successResponseJson) + assertEquals(response, deserialized) + } + + @Test + fun testJSONRPCErrorResponse() { + val response: JSONRPCMessage = JSONRPCErrorResponse( + id = RequestId.NumberId(123), + error = JSONRPCError(code = -32602, message = "Invalid params") + ) + + //language=JSON + val errorResponseJson = """{"id":123,"error":{"code":-32602,"message":"Invalid params"},"jsonrpc":"2.0"}""" + + val serialized = JSONRPCJson.encodeToString(response) + assertEquals(errorResponseJson, serialized) + + val deserialized = JSONRPCJson.decodeFromString(errorResponseJson) + assertEquals(response, deserialized) + } + + @Test + fun testJSONRPCErrorResponseWithoutId() { + val response: JSONRPCMessage = JSONRPCErrorResponse( + id = null, + error = JSONRPCError(code = -32700, message = "Parse error") + ) + + //language=JSON + val errorResponseWithoutIdJson = """{"error":{"code":-32700,"message":"Parse error"},"jsonrpc":"2.0"}""" + + val serialized = JSONRPCJson.encodeToString(response) + assertEquals(errorResponseWithoutIdJson, serialized) + + val deserialized = JSONRPCJson.decodeFromString(errorResponseWithoutIdJson) + assertEquals(response, deserialized) + } +} diff --git a/a2a/a2a-transport/a2a-transport-core-rest/build.gradle.kts b/a2a/a2a-transport/a2a-transport-core-rest/build.gradle.kts new file mode 100644 index 0000000000..36e5aa832a --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-core-rest/build.gradle.kts @@ -0,0 +1,48 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + jvm() + + js(IR) { + browser() + } + + sourceSets { + commonMain { + dependencies { + api(libs.kotlinx.serialization.json) + api(libs.kotlinx.coroutines.core) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + } + } + + jsTest { + dependencies { + implementation(kotlin("test-js")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/a2a/a2a-transport/a2a-transport-core-rest/src/commonMain/kotlin/ai/koog/a2a/transport/rest/.gitkeep b/a2a/a2a-transport/a2a-transport-core-rest/src/commonMain/kotlin/ai/koog/a2a/transport/rest/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts new file mode 100644 index 0000000000..36e5aa832a --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts @@ -0,0 +1,48 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + jvm() + + js(IR) { + browser() + } + + sourceSets { + commonMain { + dependencies { + api(libs.kotlinx.serialization.json) + api(libs.kotlinx.coroutines.core) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + } + } + + jsTest { + dependencies { + implementation(kotlin("test-js")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/.gitkeep b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/a2a/a2a-transport/a2a-transport-server-rest/build.gradle.kts b/a2a/a2a-transport/a2a-transport-server-rest/build.gradle.kts new file mode 100644 index 0000000000..36e5aa832a --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-server-rest/build.gradle.kts @@ -0,0 +1,48 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + jvm() + + js(IR) { + browser() + } + + sourceSets { + commonMain { + dependencies { + api(libs.kotlinx.serialization.json) + api(libs.kotlinx.coroutines.core) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + } + } + + jsTest { + dependencies { + implementation(kotlin("test-js")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/a2a/a2a-transport/a2a-transport-server-rest/src/commonMain/kotlin/ai/koog/a2a/transport/server/rest/.gitkeep b/a2a/a2a-transport/a2a-transport-server-rest/src/commonMain/kotlin/ai/koog/a2a/transport/server/rest/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/koog-agents/build.gradle.kts b/koog-agents/build.gradle.kts index c7e976ed55..08469c82b9 100644 --- a/koog-agents/build.gradle.kts +++ b/koog-agents/build.gradle.kts @@ -17,6 +17,17 @@ val excluded = setOf( ":koog-spring-boot-starter", ":koog-ktor", ":docs", + + ":a2a:a2a-core", + ":a2a:a2a-server", + ":a2a:a2a-client", + ":a2a:a2a-transport:a2a-transport-core-jsonrpc", + ":a2a:a2a-transport:a2a-transport-server-jsonrpc-http", + ":a2a:a2a-transport:a2a-transport-client-jsonrpc-http", + ":a2a:a2a-transport:a2a-transport-core-rest", + ":a2a:a2a-transport:a2a-transport-server-rest", + ":a2a:a2a-transport:a2a-transport-client-rest", + project.path, // the current project should not depend on itself ) diff --git a/settings.gradle.kts b/settings.gradle.kts index 54f3402496..69008fce6b 100755 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -61,6 +61,16 @@ include(":embeddings:embeddings-llm") include(":rag:rag-base") include(":rag:vector-storage") +include(":a2a:a2a-core") +include(":a2a:a2a-server") +include(":a2a:a2a-client") +include(":a2a:a2a-transport:a2a-transport-core-jsonrpc") +include(":a2a:a2a-transport:a2a-transport-server-jsonrpc-http") +include(":a2a:a2a-transport:a2a-transport-client-jsonrpc-http") +include(":a2a:a2a-transport:a2a-transport-core-rest") +include(":a2a:a2a-transport:a2a-transport-server-rest") +include(":a2a:a2a-transport:a2a-transport-client-rest") + include(":koog-spring-boot-starter") include(":koog-ktor") From 5e16e0f33f2f78211f0bb7ca73e32a56969b0d8c Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Wed, 27 Aug 2025 02:48:46 +0200 Subject: [PATCH 17/52] [a2a] Implement AgentCard and core A2A data models --- a2a/a2a-core/build.gradle.kts | 1 + .../commonMain/kotlin/ai/koog/a2a/.gitkeep | 0 .../kotlin/ai/koog/a2a/dsl/A2ADsl.kt | 7 + .../ai/koog/a2a/exceptions/Exceptions.kt | 110 ++++ .../kotlin/ai/koog/a2a/model/AgentCard.kt | 445 ++++++++++++++ .../kotlin/ai/koog/a2a/model/Artifact.kt | 29 + .../kotlin/ai/koog/a2a/model/Core.kt | 20 + .../kotlin/ai/koog/a2a/model/Message.kt | 50 ++ .../kotlin/ai/koog/a2a/model/Part.kt | 103 ++++ .../kotlin/ai/koog/a2a/model/RequestId.kt | 17 - .../kotlin/ai/koog/a2a/model/Serialization.kt | 97 ++- .../kotlin/ai/koog/a2a/model/Task.kt | 110 ++++ .../kotlin/ai/koog/a2a/model/TaskEvents.kt | 51 ++ .../a2a/model/TaskPushNotificationConfig.kt | 43 ++ .../ai/koog/a2a/transport/Serialization.kt | 38 ++ .../kotlin/ai/koog/a2a/transport/Structs.kt | 43 ++ .../a2a/model/AgentCardSerializationTest.kt | 578 ++++++++++++++++++ .../TransportSerializationTest.kt} | 4 +- .../a2a/transport/jsonrpc/model/Messages.kt | 16 +- .../jsonrpc/model/JsonRpcSerializationTest.kt | 1 + 20 files changed, 1706 insertions(+), 57 deletions(-) delete mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/.gitkeep create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/dsl/A2ADsl.kt create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Artifact.kt create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Core.kt create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Message.kt create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Part.kt delete mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/RequestId.kt create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskEvents.kt create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskPushNotificationConfig.kt create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Serialization.kt create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Structs.kt create mode 100644 a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/model/AgentCardSerializationTest.kt rename a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/{model/ModelSerializationTest.kt => transport/TransportSerializationTest.kt} (93%) diff --git a/a2a/a2a-core/build.gradle.kts b/a2a/a2a-core/build.gradle.kts index 36e5aa832a..110474284b 100644 --- a/a2a/a2a-core/build.gradle.kts +++ b/a2a/a2a-core/build.gradle.kts @@ -20,6 +20,7 @@ kotlin { dependencies { api(libs.kotlinx.serialization.json) api(libs.kotlinx.coroutines.core) + api(libs.kotlinx.datetime) } } diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/.gitkeep b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/dsl/A2ADsl.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/dsl/A2ADsl.kt new file mode 100644 index 0000000000..4570d8824c --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/dsl/A2ADsl.kt @@ -0,0 +1,7 @@ +package ai.koog.a2a.dsl + +/** + * A2A DSL marker + */ +@DslMarker +public annotation class A2ADsl diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt new file mode 100644 index 0000000000..d2a4f1d87e --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt @@ -0,0 +1,110 @@ +package ai.koog.a2a.exceptions + +/** + * Base class for all A2A exceptions. + */ +public sealed class A2AException( + message: String, + public val errorCode: Int +) : Exception(message) + +/** + * Server received JSON that was not well-formed. + */ +public class ParseException( + message: String = "Invalid JSON payload", +) : A2AException(message, errorCode = -32700) + +/** + * The JSON payload was valid JSON, but not a valid JSON-RPC Request object. + */ +public class InvalidRequestException( + message: String = "Invalid JSON-RPC Request", +) : A2AException(message, errorCode = -32600) + +/** + * The requested A2A RPC method does not exist or is not supported. + */ +public class MethodNotFoundException( + message: String = "Method not found", +) : A2AException(message, errorCode = -32601) + +/** + * The params provided for the method are invalid. + */ +public class InvalidParamsException( + message: String = "Invalid method parameters", +) : A2AException(message, errorCode = -32602) + +/** + * An unexpected error occurred on the server during processing. + */ +public class InternalErrorException( + message: String = "Internal server error", +) : A2AException(message, errorCode = -32603) + +/** + * Reserved for implementation-defined server exceptions. A2A-specific exceptions use this range. + */ +public open class A2AServerException( + message: String, + errorCode: Int, +) : A2AException(message, errorCode) { + init { + require(errorCode in -32000..-32099) { "Server error code must be in -32000..-32099" } + } +} + +/** + * The specified task id does not correspond to an existing or active task. + * It might be invalid, expired, or already completed and purged. + */ +public class TaskNotFoundException( + message: String = "Task not found", +) : A2AServerException(message, errorCode = -32001) + +/** + * An attempt was made to cancel a task that is not in a cancelable state. + * The task has already reached a terminal state like completed, failed, or canceled. + */ +public class TaskNotCancelableException( + message: String = "Task cannot be canceled", +) : A2AServerException(message, errorCode = -32002) + +/** + * Client attempted to use push notification features but the server agent does not support them. + * The server's AgentCard.capabilities.pushNotifications is false. + */ +public class PushNotificationNotSupportedException( + message: String = "Push Notification is not supported", +) : A2AServerException(message, errorCode = -32003) + +/** + * The requested operation or a specific aspect of it is not supported by this server agent implementation. + * This is broader than just method not found. + */ +public class UnsupportedOperationException( + message: String = "This operation is not supported", +) : A2AServerException(message, errorCode = -32004) + +/** + * A Media Type provided in the request's message.parts or implied for an artifact is not supported + * by the agent or the specific skill being invoked. + */ +public class ContentTypeNotSupportedException( + message: String = "Incompatible content types", +) : A2AServerException(message, errorCode = -32005) + +/** + * Agent generated an invalid response for the requested method. + */ +public class InvalidAgentResponseException( + message: String = "Invalid agent response type", +) : A2AServerException(message, errorCode = -32006) + +/** + * The agent does not have an Authenticated Extended Card configured. + */ +public class AuthenticatedExtendedCardNotConfiguredException( + message: String = "Authenticated Extended Card not configured", +) : A2AServerException(message, errorCode = -32007) diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt new file mode 100644 index 0000000000..f939261586 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt @@ -0,0 +1,445 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.EncodeDefault +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonElement +import kotlin.jvm.JvmInline + +/** + * The [AgentCard] is a self-describing manifest for an agent. It provides essential metadata including the agent's + * identity, capabilities, skills, supported communication methods, and security requirements. + * + * @property protocolVersion The version of the A2A protocol this agent supports. Default: "0.3.0". + * + * @property name A human-readable name for the agent. Examples: ["Recipe Agent"]. + * + * @property description A human-readable description of the agent, assisting users and other agents in understanding its purpose. + * + * Examples: ["Agent that helps users with recipes and cooking."]. + * + * @property url The preferred endpoint URL for interacting with the agent. This URL MUST support the transport specified by 'preferredTransport'. + * + * Examples: ["https://api.example.com/a2a/v1"]. + * + * @property preferredTransport The transport protocol for the preferred endpoint (the main 'url' field). If not specified, defaults to 'JSONRPC'. + * + * IMPORTANT: The transport specified here MUST be available at the main 'url'. This creates a binding between the main URL and its supported + * transport protocol. Clients should prefer this transport and URL combination when both are supported. + * + * Examples: ["JSONRPC", "GRPC", "HTTP+JSON"]. + * + * @property additionalInterfaces A list of additional supported interfaces (transport and URL combinations). This allows agents to expose multiple + * transports, potentially at different URLs. + * + * Best practices: + * - SHOULD include all supported transports + * - SHOULD include an entry matching the main 'url' and 'preferredTransport' + * - MAY reuse URLs if multiple transports are available at the same endpoint + * - MUST accurately declare the transport available at each URL. + * + * Clients can select any interface from this list based on their transport capabilities and preferences, enabling transport + * negotiation and fallback scenarios. + * + * @property iconUrl An optional URL to an icon for the agent. + * + * @property provider Information about the agent's service provider. + * + * @property version The agent's own version number. The format is defined by the provider. + * + * Examples: ["1.0.0"]. + * + * @property documentationUrl An optional URL to the agent's documentation. + * + * @property capabilities A declaration of optional capabilities supported by the agent. + * + * @property securitySchemes A declaration of the security schemes available to authorize requests. The key is the scheme name. + * Follows the OpenAPI 3.0 Security Scheme Object. + * + * @property security A list of security requirement objects that apply to all agent interactions. Each object lists security schemes that can be used. + * Follows the OpenAPI 3.0 Security Requirement Object. This list can be seen as an OR of ANDs. Each object in the list describes one possible set + * of security requirements that must be present on a request. This allows specifying, for example, "callers must either use OAuth OR an API Key AND mTLS." + * + * Examples: [{"oauth": ["read"]}, {"api-key": [], "mtls": []}]. + * + * @property defaultInputModes Default set of supported input MIME types for all skills, which can be overridden on a per-skill basis. + * + * @property defaultOutputModes Default set of supported output MIME types for all skills, which can be overridden on a per-skill basis. + * + * @property skills The set of skills, or distinct capabilities, that the agent can perform. + * + * @property supportsAuthenticatedExtendedCard If true, the agent can provide an extended agent card with additional details to authenticated users. Defaults to false. + * + * @property signatures JSON Web Signatures computed for this [AgentCard]. + */ +@Serializable +public data class AgentCard( + @EncodeDefault + public val protocolVersion: String = "0.3.0", + public val name: String, + public val description: String, + public val url: String, + @EncodeDefault + public val preferredTransport: TransportProtocol = TransportProtocol.JSONRPC, + public val additionalInterfaces: List? = null, + public val iconUrl: String? = null, + public val provider: AgentProvider? = null, + public val version: String, + public val documentationUrl: String? = null, + public val capabilities: AgentCapabilities, + public val securitySchemes: Map? = null, + public val security: List>>? = null, + public val defaultInputModes: List, + public val defaultOutputModes: List, + public val skills: List, + @EncodeDefault + public val supportsAuthenticatedExtendedCard: Boolean = false, + public val signatures: List? = null +) { + init { + additionalInterfaces?.let { interfaces -> + requireNotNull(interfaces.find { it.url == url && it.transport == preferredTransport }) { + "If additionalInterfaces are specified, they must include an entry matching the main 'url' and 'preferredTransport'." + } + } + } +} + +/** + * The transport protocol for an agent. + */ +@JvmInline +@Serializable +public value class TransportProtocol(public val value: String) { + @Suppress("MissingKDocForPublicAPI") + public companion object { + /** + * JSON-RPC protocol. + */ + public val JSONRPC: TransportProtocol = TransportProtocol("JSONRPC") + + /** + * HTTP+JSON/REST protocol. + */ + public val HTTP_JSON_REST: TransportProtocol = TransportProtocol("HTTP+JSON/REST") + + /** + * GRPC protocol. + */ + public val GRPC: TransportProtocol = TransportProtocol("GRPC") + } +} + +/** + * Declares a combination of a target URL and a transport protocol for interacting with the agent. + * This allows agents to expose the same functionality over multiple transport mechanisms. + * + * @property url The URL where this interface is available. Must be a valid absolute HTTPS URL in production. + * + * Examples: ["https://api.example.com/a2a/v1", "https://grpc.example.com/a2a", "https://rest.example.com/v1"]. + * + * @property transport The transport protocol supported at this URL. + * + * Examples: ["JSONRPC", "GRPC", "HTTP+JSON"]. + */ +@Serializable +public data class AgentInterface( + public val url: String, + public val transport: TransportProtocol +) + +/** + * Represents the service provider of an agent. + * + * @property organization The name of the agent provider's organization. + * @property url A URL for the agent provider's website or relevant documentation. + */ +@Serializable +public data class AgentProvider( + public val organization: String, + public val url: String +) + +/** + * Defines optional capabilities supported by an agent. + * + * @property streaming Indicates if the agent supports Server-Sent Events (SSE) for streaming responses. + * @property pushNotifications Indicates if the agent supports sending push notifications for asynchronous task updates. + * @property stateTransitionHistory Indicates if the agent provides a history of state transitions for a task. + * @property extensions A list of protocol extensions supported by the agent. + */ +@Serializable +public data class AgentCapabilities( + @EncodeDefault + public val streaming: Boolean = false, + @EncodeDefault + public val pushNotifications: Boolean = false, + @EncodeDefault + public val stateTransitionHistory: Boolean = false, + public val extensions: List? = null +) + +/** + * A declaration of a protocol extension supported by an Agent. + * + * @property uri The unique URI identifying the extension. + * @property description A human-readable description of how this agent uses the extension. + * @property required If true, the client must understand and comply with the extension's requirements to interact with the agent. + * @property params Optional, extension-specific configuration parameters. + */ +@Serializable +public data class AgentExtension( + public val uri: String, + public val description: String? = null, + public val required: Boolean? = null, + public val params: Map? = null +) + +/** + * Defines a security scheme that can be used to secure an agent's endpoints. + * This is a discriminated union type based on the OpenAPI 3.0 Security Scheme Object. + * + * @see [https://swagger.io/specification/#security-scheme-object] + */ +@Serializable(with = SecuritySchemeSerializer::class) +public sealed interface SecurityScheme { + /** + * The type of the security scheme, used as discriminator. + */ + public val type: String +} + +/** + * Defines a security scheme using an API key. + * + * @property description An optional description for the security scheme. + * @property in The location of the API key. + * @property name The name of the header, query, or cookie parameter to be used. + */ +@Serializable +public data class APIKeySecurityScheme( + @SerialName("in") + public val `in`: In, + public val name: String, + public val description: String? = null, +) : SecurityScheme { + @EncodeDefault + override val type: String = "apiKey" +} + +/** + * The location of the API key. + */ +@Serializable +public enum class In { + @SerialName("cookie") + Cookie, + + @SerialName("header") + Header, + + @SerialName("query") + Query +} + +/** + * Defines a security scheme using HTTP authentication. + * + * @property scheme The name of the HTTP Authentication scheme to be used in the Authorization header, + * as defined in RFC7235 (e.g., "Bearer"). This value should be registered in the IANA Authentication Scheme registry. + * @property bearerFormat A hint to the client to identify how the bearer token is formatted (e.g., "JWT"). + * This is primarily for documentation purposes. + * @property description An optional description for the security scheme. + */ +@Serializable +public data class HTTPAuthSecurityScheme( + public val scheme: String, + public val bearerFormat: String? = null, + public val description: String? = null, +) : SecurityScheme { + @EncodeDefault + override val type: String = "http" +} + +/** + * Defines a security scheme using OAuth 2.0. + * + * @property flows An object containing configuration information for the supported OAuth 2.0 flows. + * @property oauth2MetadataUrl URL to the oauth2 authorization server metadata + * [RFC8414](https://datatracker.ietf.org/doc/html/rfc8414). TLS is required. + * @property description An optional description for the security scheme. + */ +@Serializable +public data class OAuth2SecurityScheme( + public val flows: OAuthFlows, + public val oauth2MetadataUrl: String? = null, + public val description: String? = null, +) : SecurityScheme { + @EncodeDefault + override val type: String = "oauth2" +} + +/** + * Defines the configuration for the supported OAuth 2.0 flows. + * + * @property authorizationCode Configuration for the OAuth Authorization Code flow. Previously called accessCode in OpenAPI 2.0. + * @property clientCredentials Configuration for the OAuth Client Credentials flow. Previously called application in OpenAPI 2.0. + * @property implicit Configuration for the OAuth Implicit flow. + * @property password Configuration for the OAuth Resource Owner Password flow. + */ +@Serializable +public data class OAuthFlows( + public val authorizationCode: AuthorizationCodeOAuthFlow? = null, + public val clientCredentials: ClientCredentialsOAuthFlow? = null, + public val implicit: ImplicitOAuthFlow? = null, + public val password: PasswordOAuthFlow? = null +) + +/** + * Common interface for OAuth 2.0 flows. + */ +@Serializable +public sealed interface OAuthFlow { + /** + * The available scopes for the OAuth2 security scheme. A map between the scope name and a short description for it. + */ + public val scopes: Map + + /** + * The URL to be used for obtaining refresh tokens. This MUST be a URL and use TLS. + */ + public val refreshUrl: String? +} + +/** + * Defines configuration details for the OAuth 2.0 Authorization Code flow. + * + * @property authorizationUrl The authorization URL to be used for this flow. This MUST be a URL and use TLS. + * @property tokenUrl The token URL to be used for this flow. This MUST be a URL and use TLS. + */ +@Serializable +public data class AuthorizationCodeOAuthFlow( + public val authorizationUrl: String, + public val tokenUrl: String, + override val scopes: Map, + override val refreshUrl: String? = null +) : OAuthFlow + +/** + * Defines configuration details for the OAuth 2.0 Client Credentials flow. + * + * @property tokenUrl The token URL to be used for this flow. This MUST be a URL. + */ +@Serializable +public data class ClientCredentialsOAuthFlow( + public val tokenUrl: String, + override val scopes: Map, + override val refreshUrl: String? = null +) : OAuthFlow + +/** + * Defines configuration details for the OAuth 2.0 Implicit flow. + * + * @property authorizationUrl The authorization URL to be used for this flow. This MUST be a URL. + */ +@Serializable +public data class ImplicitOAuthFlow( + public val authorizationUrl: String, + override val scopes: Map, + override val refreshUrl: String? = null +) : OAuthFlow + +/** + * Defines configuration details for the OAuth 2.0 Resource Owner Password flow. + * + * @property tokenUrl The token URL to be used for this flow. This MUST be a URL. + */ +@Serializable +public data class PasswordOAuthFlow( + public val tokenUrl: String, + override val scopes: Map, + override val refreshUrl: String? = null +) : OAuthFlow + +/** + * Defines a security scheme using OpenID Connect. + * + * @property openIdConnectUrl The OpenID Connect Discovery URL for the OIDC provider's metadata. + * @property description An optional description for the security scheme. + */ +@Serializable +public data class OpenIdConnectSecurityScheme( + public val openIdConnectUrl: String, + public val description: String? = null, +) : SecurityScheme { + @EncodeDefault + override val type: String = "openIdConnect" +} + +/** + * Defines a security scheme using mTLS authentication. + * + * @property description An optional description for the security scheme. + */ +@Serializable +public data class MutualTLSSecurityScheme( + public val description: String? = null, +) : SecurityScheme { + @EncodeDefault + override val type: String = "mutualTLS" +} + +/** + * Represents a distinct capability or function that an agent can perform. + * + * @property id A unique identifier for the agent's skill. + * + * @property name A human-readable name for the skill. + * + * @property description A detailed description of the skill, intended to help clients or users understand its purpose and functionality. + * + * @property tags A set of keywords describing the skill's capabilities. + * + * Examples: ["cooking", "customer support", "billing"]. + * + * @property examples Example prompts or scenarios that this skill can handle. Provides a hint to the client on how to use the skill. + * + * Examples: ["I need a recipe for bread"]. + * + * @property inputModes The set of supported input MIME types for this skill, overriding the agent's defaults. + * + * @property outputModes The set of supported output MIME types for this skill, overriding the agent's defaults. + * + * @property security Security schemes necessary for the agent to leverage this skill. As in the overall AgentCard.security, + * this list represents a logical OR of security requirement objects. Each object is a set of security schemes that must be + * used together (a logical AND). + * + * Examples: [{"google": ["oidc"]}]. + */ +@Serializable +public data class AgentSkill( + public val id: String, + public val name: String, + public val description: String, + public val tags: List, + public val examples: List? = null, + public val inputModes: List? = null, + public val outputModes: List? = null, + public val security: List>>? = null +) + +/** + * AgentCardSignature represents a JWS signature of an AgentCard. + * This follows the JSON format of an RFC 7515 JSON Web Signature (JWS). + * + * @property protected The protected JWS header for the signature. This is a Base64url-encoded JSON object, as per RFC 7515. + * @property signature The computed signature, Base64url-encoded. + * @property header The unprotected JWS header values. + */ +@Serializable +public data class AgentCardSignature( + @SerialName("protected") + public val `protected`: String, + public val signature: String, + public val header: Map? = null +) diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Artifact.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Artifact.kt new file mode 100644 index 0000000000..d00c893212 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Artifact.kt @@ -0,0 +1,29 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.EncodeDefault +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +/** + * Represents a file, data structure, or other resource generated by an agent during a task. + * + * @property artifactId A unique identifier (e.g. UUID) for the artifact within the scope of the task. + * @property name An optional, human-readable name for the artifact. + * @property description An optional, human-readable description of the artifact. + * @property parts A list of content parts that make up the artifact. + * @property extensions Optional URIs of extensions that are relevant to this artifact. + * @property metadata Optional metadata for extensions. The key is an extension-specific identifier. + */ +@OptIn(ExperimentalUuidApi::class) +@Serializable +public data class Artifact( + @EncodeDefault + public val artifactId: String = Uuid.random().toString(), + public val name: String? = null, + public val description: String? = null, + public val parts: List, + public val extensions: List? = null, + public val metadata: JsonObject? = null, +) diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Core.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Core.kt new file mode 100644 index 0000000000..9d4dd14391 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Core.kt @@ -0,0 +1,20 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.Serializable + +/** + * Base interface for events. + */ +@Serializable(with = EventSerializer::class) +public sealed interface Event { + /** + * The type used as discriminator. + */ + public val kind: String +} + +/** + * Base interface for communication units, such as messages or tasks. + */ +@Serializable(with = CommunicationSerializer::class) +public sealed interface Communication : Event diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Message.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Message.kt new file mode 100644 index 0000000000..0e3656d962 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Message.kt @@ -0,0 +1,50 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.EncodeDefault +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +/** + * Message role. + */ +@Serializable +public enum class Role { + @SerialName("user") + User, + + @SerialName("agent") + Agent +} + +/** + * Represents a single message in the conversation between a user and an agent. + * + * @property messageId A unique identifier for the message, typically a UUID, generated by the sender. + * @property role Identifies the sender of the message. `user` for the client, `agent` for the service. + * @property parts An array of content parts that form the message body. A message can be + * composed of multiple parts of different types (e.g., text and files). + * @property extensions The URIs of extensions that are relevant to this message. + * @property taskId The ID of the task this message is part of. Can be omitted for the first message of a new task. + * @property referenceTaskIds A list of other task IDs that this message references for additional context. + * @property contextId The context ID for this message, used to group related interactions. + * @property metadata Optional metadata for extensions. The key is an extension-specific identifier. + */ +@OptIn(ExperimentalUuidApi::class) +@Serializable +public data class Message( + @EncodeDefault + public val messageId: String = Uuid.random().toString(), + public val role: Role, + public val parts: List, + public val extensions: List? = null, + public val taskId: String? = null, + public val referenceTaskIds: List? = null, + public val contextId: String? = null, + public val metadata: JsonObject? = null, +) : Communication { + @EncodeDefault + override val kind: String = "message" +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Part.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Part.kt new file mode 100644 index 0000000000..abbe4ac047 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Part.kt @@ -0,0 +1,103 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.EncodeDefault +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject + +/** + * Represents a part of a message or artifact. + */ +@Serializable(with = PartSerializer::class) +public sealed interface Part { + /** + * The type of the part, used as discriminator. + */ + public val kind: String + + /** + * Optional metadata associated with this part. + */ + public val metadata: JsonObject? +} + +/** + * Represents a text part. + * + * @property text The string content of the text part. + */ +@Serializable +public data class TextPart( + public val text: String, + override val metadata: JsonObject? = null, +) : Part { + @EncodeDefault + override val kind: String = "text" +} + +/** + * Represents a file part. The file content can be provided either directly as bytes or as a URI. + * + * @property file The file content. + */ +@Serializable +public data class FilePart( + public val file: File, + override val metadata: JsonObject? = null, +) : Part { + @EncodeDefault + override val kind: String = "file" +} + +/** + * Represents a file within a part. + */ +@Serializable(with = FileSerializer::class) +public sealed interface File { + /** + * An optional name for the file (e.g., "document.pdf"). + */ + public val name: String? + + /** + * An optional MIME type of the file (e.g., "application/pdf"). + */ + public val mimeType: String? +} + +/** + * Represents a file with its content provided directly as a base64-encoded string. + * + * @property bytes The base64-encoded content of the file. + */ +@Serializable +public data class FileWithBytes( + public val bytes: String, + override val name: String? = null, + override val mimeType: String? = null, +) : File + +/** + * Represents a file with its content located at a specific URI. + * + * @property uri A URL pointing to the file's content. + */ +@Serializable +public data class FileWithUri( + public val uri: String, + override val name: String? = null, + override val mimeType: String? = null, +) : File + +/** + * Represents a structured data part (e.g., JSON). + * + * @property data The structured data content. + */ +@Serializable +public data class DataPart( + public val data: JsonObject, + override val metadata: JsonObject? = null, +) : Part { + @EncodeDefault + override val kind: String = "data" +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/RequestId.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/RequestId.kt deleted file mode 100644 index 7e472d169a..0000000000 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/RequestId.kt +++ /dev/null @@ -1,17 +0,0 @@ -@file:Suppress("MissingKDocForPublicAPI") - -package ai.koog.a2a.model - -import kotlinx.serialization.Serializable - -/** - * A uniquely identifying ID for a request. - */ -@Serializable(with = RequestIdSerializer::class) -public sealed interface RequestId { - @Serializable - public data class StringId(val value: String) : RequestId - - @Serializable - public data class NumberId(val value: Long) : RequestId -} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt index 554edabb28..5d9ccde594 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt @@ -1,38 +1,75 @@ package ai.koog.a2a.model -import kotlinx.serialization.KSerializer -import kotlinx.serialization.descriptors.SerialDescriptor -import kotlinx.serialization.descriptors.buildClassSerialDescriptor -import kotlinx.serialization.encoding.Decoder -import kotlinx.serialization.encoding.Encoder -import kotlinx.serialization.json.JsonDecoder -import kotlinx.serialization.json.JsonEncoder -import kotlinx.serialization.json.JsonPrimitive -import kotlinx.serialization.json.long -import kotlinx.serialization.json.longOrNull - -internal object RequestIdSerializer : KSerializer { - override val descriptor: SerialDescriptor = buildClassSerialDescriptor("RequestId") - - override fun deserialize(decoder: Decoder): RequestId { - val jsonDecoder = decoder as? JsonDecoder ?: error("Can only deserialize JSON") - - return when (val element = jsonDecoder.decodeJsonElement()) { - is JsonPrimitive -> when { - element.isString -> RequestId.StringId(element.content) - element.longOrNull != null -> RequestId.NumberId(element.long) - else -> error("Invalid RequestId type") - } - - else -> error("Invalid RequestId format") +import kotlinx.serialization.DeserializationStrategy +import kotlinx.serialization.json.JsonContentPolymorphicSerializer +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.jsonObject +import kotlinx.serialization.json.jsonPrimitive + +internal object SecuritySchemeSerializer : JsonContentPolymorphicSerializer(SecurityScheme::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + val jsonObject = element.jsonObject + val type = jsonObject["type"]?.jsonPrimitive?.content ?: error("Missing 'type' field in SecurityScheme") + + return when (type) { + "apiKey" -> APIKeySecurityScheme.serializer() + "http" -> HTTPAuthSecurityScheme.serializer() + "oauth2" -> OAuth2SecurityScheme.serializer() + "openIdConnect" -> OpenIdConnectSecurityScheme.serializer() + "mutualTLS" -> MutualTLSSecurityScheme.serializer() + else -> error("Unknown SecurityScheme type: $type") + } + } +} + +internal object PartSerializer : JsonContentPolymorphicSerializer(Part::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + val jsonObject = element.jsonObject + val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: error("Missing 'kind' field in Part") + + return when (kind) { + "text" -> TextPart.serializer() + "file" -> FilePart.serializer() + "data" -> DataPart.serializer() + else -> error("Unknown Part kind: $kind") } } +} + +internal object FileSerializer : JsonContentPolymorphicSerializer(File::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + val jsonObject = element.jsonObject + + return when { + "bytes" in jsonObject -> FileWithBytes.serializer() + "uri" in jsonObject -> FileWithUri.serializer() + else -> error("Unknown File type") + } + } +} + +internal object EventSerializer : JsonContentPolymorphicSerializer(Event::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + val jsonObject = element.jsonObject + val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: error("Missing 'kind' field in Event") + + return when (kind) { + "status-update" -> TaskStatusUpdateEvent.serializer() + "artifact-update" -> TaskArtifactUpdateEvent.serializer() + else -> CommunicationSerializer + } + } +} + +internal object CommunicationSerializer : JsonContentPolymorphicSerializer(Communication::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + val jsonObject = element.jsonObject + val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: error("Missing 'kind' field in Communication") - override fun serialize(encoder: Encoder, value: RequestId) { - val jsonEncoder = encoder as? JsonEncoder ?: error("Can only serialize JSON") - when (value) { - is RequestId.StringId -> jsonEncoder.encodeString(value.value) - is RequestId.NumberId -> jsonEncoder.encodeLong(value.value) + return when (kind) { + "task" -> Task.serializer() + "message" -> Message.serializer() + else -> error("Unknown kind: $kind") } } } diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt new file mode 100644 index 0000000000..fb8fff7702 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt @@ -0,0 +1,110 @@ +package ai.koog.a2a.model + +import kotlinx.datetime.Instant +import kotlinx.serialization.EncodeDefault +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +/** + * Represents a single, stateful operation or conversation between a client and an agent. + * + * @property id A unique identifier (e.g. UUID) for the task, generated by the server for a new task. + * @property contextId A server-generated unique identifier (e.g. UUID) for maintaining context across multiple related tasks or interactions. + * @property status The current status of the task, including its state and a descriptive message. + * @property history An array of messages exchanged during the task, representing the conversation history. + * @property artifacts A collection of artifacts generated by the agent during the execution of the task. + * @property metadata Optional metadata for extensions. The key is an extension-specific identifier. + */ +@OptIn(ExperimentalUuidApi::class) +@Serializable +public data class Task( + @EncodeDefault + public val id: String = Uuid.random().toString(), + public val contextId: String, + public val status: TaskStatus, + public val history: List? = null, + public val artifacts: List? = null, + public val metadata: JsonObject? = null, +) : Communication { + @EncodeDefault + override val kind: String = "task" +} + +/** + * Represents the status of a task at a specific point in time. + * + * @property state The current state of the task's lifecycle. + * @property message An optional, human-readable message providing more details about the current status. + * @property timestamp Timestamp indicating when this status was recorded. + */ +@Serializable +public data class TaskStatus( + public val state: TaskState, + public val message: Message? = null, + public val timestamp: Instant? = null, +) + +/** + * Defines the lifecycle states of a Task. + * + * @property terminal Indicates whether this is a terminal state. + */ +@Serializable +public enum class TaskState(public val terminal: Boolean) { + /** + * The task has been submitted and is awaiting execution. + */ + @SerialName("submitted") + Submitted(terminal = false), + + /** + * The agent is actively working on the task. + */ + @SerialName("working") + Working(terminal = false), + + /** + * The task is paused and waiting for input from the user. + */ + @SerialName("input-required") + InputRequired(terminal = false), + + /** + * The task has been successfully completed. + */ + @SerialName("completed") + Completed(terminal = true), + + /** + * The task has been canceled by the user. + */ + @SerialName("canceled") + Canceled(terminal = true), + + /** + * The task failed due to an error during execution. + */ + @SerialName("failed") + Failed(terminal = true), + + /** + * The task was rejected by the agent and was not started. + */ + @SerialName("rejected") + Rejected(terminal = true), + + /** + * The task requires authentication to proceed. + */ + @SerialName("auth-required") + AuthRequired(terminal = false), + + /* + * The task is in an unknown or indeterminate state. + */ + @SerialName("unknown") + Unknown(terminal = false), +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskEvents.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskEvents.kt new file mode 100644 index 0000000000..b589b01dd0 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskEvents.kt @@ -0,0 +1,51 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.EncodeDefault +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject + +/** + * An event sent by the agent to notify the client of a change in a task's status. + * This is typically used in streaming or subscription models. + * + * @property taskId The ID of the task that was updated. + * @property contextId The context ID associated with the task. + * @property status The new status of the task. + * @property final If true, this is the final event in the stream for this interaction. + * @property metadata Optional metadata for extensions. + */ +@Serializable +public data class TaskStatusUpdateEvent( + public val taskId: String, + public val contextId: String, + public val status: TaskStatus, + public val final: Boolean, + public val metadata: JsonObject? = null, +) : Event { + @EncodeDefault + override val kind: String = "status-update" +} + +/** + * An event sent by the agent to notify the client that an artifact has been + * generated or updated. This is typically used in streaming models. + * + * @property taskId The ID of the task this artifact belongs to. + * @property contextId The context ID associated with the task. + * @property artifact The artifact that was generated or updated. + * @property append If true, the content of this artifact should be appended to a previously sent artifact with the same ID. + * @property lastChunk If true, this is the final chunk of the artifact. + * @property metadata Optional metadata for extensions. + */ +@Serializable +public data class TaskArtifactUpdateEvent( + public val taskId: String, + public val contextId: String, + public val artifact: Artifact, + public val append: Boolean? = null, + public val lastChunk: Boolean? = null, + public val metadata: JsonObject? = null, +) : Event { + @EncodeDefault + override val kind: String = "artifact-update" +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskPushNotificationConfig.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskPushNotificationConfig.kt new file mode 100644 index 0000000000..f41cbbfe6d --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskPushNotificationConfig.kt @@ -0,0 +1,43 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.Serializable + +/** + * A container associating a push notification configuration with a specific task. + * + * @property taskId The unique identifier (e.g. UUID) of the task. + * @property pushNotificationConfig The push notification configuration for this task. + */ +@Serializable +public data class TaskPushNotificationConfig( + public val taskId: String, + public val pushNotificationConfig: PushNotificationConfig, +) + +/** + * Defines the configuration for setting up push notifications for task updates. + * + * @property id A unique identifier (e.g. UUID) for the push notification configuration, set by the client to support multiple notification callbacks. + * @property url The callback URL where the agent should send push notifications. + * @property token A unique token for this task or session to validate incoming push notifications. + * @property authentication Optional authentication details for the agent to use when calling the notification URL. + */ +@Serializable +public data class PushNotificationConfig( + public val id: String? = null, + public val url: String, + public val token: String? = null, + public val authentication: PushNotificationAuthenticationInfo? = null, +) + +/** + * Defines authentication details for a push notification endpoint. + * + * @property schemes A list of supported authentication schemes (e.g., 'Basic', 'Bearer'). + * @property credentials Optional credentials required by the push notification endpoint. + */ +@Serializable +public data class PushNotificationAuthenticationInfo( + public val schemes: List, + public val credentials: String? = null, +) diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Serialization.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Serialization.kt new file mode 100644 index 0000000000..ba8f1c7c0b --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Serialization.kt @@ -0,0 +1,38 @@ +package ai.koog.a2a.transport + +import kotlinx.serialization.KSerializer +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.buildClassSerialDescriptor +import kotlinx.serialization.encoding.Decoder +import kotlinx.serialization.encoding.Encoder +import kotlinx.serialization.json.JsonDecoder +import kotlinx.serialization.json.JsonEncoder +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.long +import kotlinx.serialization.json.longOrNull + +internal object RequestIdSerializer : KSerializer { + override val descriptor: SerialDescriptor = buildClassSerialDescriptor("RequestId") + + override fun deserialize(decoder: Decoder): RequestId { + val jsonDecoder = decoder as? JsonDecoder ?: error("Can only deserialize JSON") + + return when (val element = jsonDecoder.decodeJsonElement()) { + is JsonPrimitive -> when { + element.isString -> RequestId.StringId(element.content) + element.longOrNull != null -> RequestId.NumberId(element.long) + else -> error("Invalid RequestId type") + } + + else -> error("Invalid RequestId format") + } + } + + override fun serialize(encoder: Encoder, value: RequestId) { + val jsonEncoder = encoder as? JsonEncoder ?: error("Can only serialize JSON") + when (value) { + is RequestId.StringId -> jsonEncoder.encodeString(value.value) + is RequestId.NumberId -> jsonEncoder.encodeLong(value.value) + } + } +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Structs.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Structs.kt new file mode 100644 index 0000000000..655a57657d --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Structs.kt @@ -0,0 +1,43 @@ +package ai.koog.a2a.transport + +import kotlinx.serialization.Serializable + +/** + * A uniquely identifying ID for a request. + */ +@Serializable(with = RequestIdSerializer::class) +public sealed interface RequestId { + /** + * A string representation of the ID. + */ + @Serializable + public data class StringId(val value: String) : RequestId + + /** + * A numeric representation of the ID. + */ + @Serializable + public data class NumberId(val value: Long) : RequestId +} + +/** + * Represents a request containing a unique identifier. + * + * @property id The unique identifier for the request. + * @property data The data payload of the request. + */ +public class Request( + public val id: RequestId, + public val data: T, +) + +/** + * Represents a response associated with a request identifier. + * + * @property id The unique identifier for the request associated with this response. + * @property data The response data payload. + */ +public class Response( + public val id: RequestId, + public val data: T, +) diff --git a/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/model/AgentCardSerializationTest.kt b/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/model/AgentCardSerializationTest.kt new file mode 100644 index 0000000000..6b3271f5fa --- /dev/null +++ b/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/model/AgentCardSerializationTest.kt @@ -0,0 +1,578 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.json.Json +import kotlin.test.Test +import kotlin.test.assertEquals + +class AgentCardSerializationTest { + @Suppress("PrivatePropertyName") + private val TestJson = Json { + prettyPrint = true + encodeDefaults = false + } + + @Test + fun testMinimalAgentCardSerialization() { + val agentCard = AgentCard( + name = "Test Agent", + description = "A test agent", + url = "https://api.example.com/a2a", + version = "1.0.0", + capabilities = AgentCapabilities(), + defaultInputModes = listOf("text/plain"), + defaultOutputModes = listOf("text/plain"), + skills = listOf( + AgentSkill( + id = "test-skill", + name = "Test Skill", + description = "A test skill", + tags = listOf("test") + ) + ) + ) + + //language=JSON + val expectedJson = """ + { + "protocolVersion": "0.3.0", + "name": "Test Agent", + "description": "A test agent", + "url": "https://api.example.com/a2a", + "preferredTransport": "JSONRPC", + "version": "1.0.0", + "capabilities": { + "streaming": false, + "pushNotifications": false, + "stateTransitionHistory": false + }, + "defaultInputModes": [ + "text/plain" + ], + "defaultOutputModes": [ + "text/plain" + ], + "skills": [ + { + "id": "test-skill", + "name": "Test Skill", + "description": "A test skill", + "tags": [ + "test" + ] + } + ], + "supportsAuthenticatedExtendedCard": false + } + """.trimIndent() + + val actualJson = TestJson.encodeToString(agentCard) + assertEquals(expectedJson, actualJson) + + // Test deserialization + val deserializedCard = TestJson.decodeFromString(actualJson) + assertEquals(agentCard, deserializedCard) + } + + @Test + fun testFullAgentCardSerialization() { + val agentCard = AgentCard( + name = "GeoSpatial Route Planner Agent", + description = "Provides advanced route planning, traffic analysis, and custom map generation services. This agent can calculate optimal routes, estimate travel times considering real-time traffic, and create personalized maps with points of interest.", + url = "https://georoute-agent.example.com/a2a/v1", + additionalInterfaces = listOf( + AgentInterface( + url = "https://georoute-agent.example.com/a2a/v1", + transport = TransportProtocol.JSONRPC + ), + AgentInterface( + url = "https://georoute-agent.example.com/a2a/grpc", + transport = TransportProtocol.GRPC + ), + AgentInterface( + url = "https://georoute-agent.example.com/a2a/json", + transport = TransportProtocol.HTTP_JSON_REST + ) + ), + provider = AgentProvider( + organization = "Example Geo Services Inc.", + url = "https://www.examplegeoservices.com" + ), + iconUrl = "https://georoute-agent.example.com/icon.png", + version = "1.2.0", + documentationUrl = "https://docs.examplegeoservices.com/georoute-agent/api", + capabilities = AgentCapabilities( + streaming = true, + pushNotifications = true, + stateTransitionHistory = false + ), + securitySchemes = mapOf( + "google" to OpenIdConnectSecurityScheme( + openIdConnectUrl = "https://accounts.google.com/.well-known/openid-configuration" + ) + ), + security = listOf( + mapOf("google" to listOf("openid", "profile", "email")) + ), + defaultInputModes = listOf("application/json", "text/plain"), + defaultOutputModes = listOf("application/json", "image/png"), + skills = listOf( + AgentSkill( + id = "route-optimizer-traffic", + name = "Traffic-Aware Route Optimizer", + description = "Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).", + tags = listOf("maps", "routing", "navigation", "directions", "traffic"), + examples = listOf( + "Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", + "{\"origin\": {\"lat\": 37.422, \"lng\": -122.084}, \"destination\": {\"lat\": 37.7749, \"lng\": -122.4194}, \"preferences\": [\"avoid_ferries\"]}" + ), + inputModes = listOf("application/json", "text/plain"), + outputModes = listOf( + "application/json", + "application/vnd.geo+json", + "text/html" + ) + ), + AgentSkill( + id = "custom-map-generator", + name = "Personalized Map Generator", + description = "Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.", + tags = listOf("maps", "customization", "visualization", "cartography"), + examples = listOf( + "Generate a map of my upcoming road trip with all planned stops highlighted.", + "Show me a map visualizing all coffee shops within a 1-mile radius of my current location." + ), + inputModes = listOf("application/json"), + outputModes = listOf( + "image/png", + "image/jpeg", + "application/json", + "text/html" + ) + ) + ), + supportsAuthenticatedExtendedCard = true, + signatures = listOf( + AgentCardSignature( + `protected` = "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpPU0UiLCJraWQiOiJrZXktMSIsImprdSI6Imh0dHBzOi8vZXhhbXBsZS5jb20vYWdlbnQvandrcy5qc29uIn0", + signature = "QFdkNLNszlGj3z3u0YQGt_T9LixY3qtdQpZmsTdDHDe3fXV9y9-B3m2-XgCpzuhiLt8E0tV6HXoZKHv4GtHgKQ" + ) + ) + ) + + //language=JSON + val expectedJson = """ + { + "protocolVersion": "0.3.0", + "name": "GeoSpatial Route Planner Agent", + "description": "Provides advanced route planning, traffic analysis, and custom map generation services. This agent can calculate optimal routes, estimate travel times considering real-time traffic, and create personalized maps with points of interest.", + "url": "https://georoute-agent.example.com/a2a/v1", + "preferredTransport": "JSONRPC", + "additionalInterfaces": [ + { + "url": "https://georoute-agent.example.com/a2a/v1", + "transport": "JSONRPC" + }, + { + "url": "https://georoute-agent.example.com/a2a/grpc", + "transport": "GRPC" + }, + { + "url": "https://georoute-agent.example.com/a2a/json", + "transport": "HTTP+JSON/REST" + } + ], + "iconUrl": "https://georoute-agent.example.com/icon.png", + "provider": { + "organization": "Example Geo Services Inc.", + "url": "https://www.examplegeoservices.com" + }, + "version": "1.2.0", + "documentationUrl": "https://docs.examplegeoservices.com/georoute-agent/api", + "capabilities": { + "streaming": true, + "pushNotifications": true, + "stateTransitionHistory": false + }, + "securitySchemes": { + "google": { + "openIdConnectUrl": "https://accounts.google.com/.well-known/openid-configuration", + "type": "openIdConnect" + } + }, + "security": [ + { + "google": [ + "openid", + "profile", + "email" + ] + } + ], + "defaultInputModes": [ + "application/json", + "text/plain" + ], + "defaultOutputModes": [ + "application/json", + "image/png" + ], + "skills": [ + { + "id": "route-optimizer-traffic", + "name": "Traffic-Aware Route Optimizer", + "description": "Calculates the optimal driving route between two or more locations, taking into account real-time traffic conditions, road closures, and user preferences (e.g., avoid tolls, prefer highways).", + "tags": [ + "maps", + "routing", + "navigation", + "directions", + "traffic" + ], + "examples": [ + "Plan a route from '1600 Amphitheatre Parkway, Mountain View, CA' to 'San Francisco International Airport' avoiding tolls.", + "{\"origin\": {\"lat\": 37.422, \"lng\": -122.084}, \"destination\": {\"lat\": 37.7749, \"lng\": -122.4194}, \"preferences\": [\"avoid_ferries\"]}" + ], + "inputModes": [ + "application/json", + "text/plain" + ], + "outputModes": [ + "application/json", + "application/vnd.geo+json", + "text/html" + ] + }, + { + "id": "custom-map-generator", + "name": "Personalized Map Generator", + "description": "Creates custom map images or interactive map views based on user-defined points of interest, routes, and style preferences. Can overlay data layers.", + "tags": [ + "maps", + "customization", + "visualization", + "cartography" + ], + "examples": [ + "Generate a map of my upcoming road trip with all planned stops highlighted.", + "Show me a map visualizing all coffee shops within a 1-mile radius of my current location." + ], + "inputModes": [ + "application/json" + ], + "outputModes": [ + "image/png", + "image/jpeg", + "application/json", + "text/html" + ] + } + ], + "supportsAuthenticatedExtendedCard": true, + "signatures": [ + { + "protected": "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpPU0UiLCJraWQiOiJrZXktMSIsImprdSI6Imh0dHBzOi8vZXhhbXBsZS5jb20vYWdlbnQvandrcy5qc29uIn0", + "signature": "QFdkNLNszlGj3z3u0YQGt_T9LixY3qtdQpZmsTdDHDe3fXV9y9-B3m2-XgCpzuhiLt8E0tV6HXoZKHv4GtHgKQ" + } + ] + } + """.trimIndent() + + val actualJson = TestJson.encodeToString(agentCard) + + assertEquals(expectedJson, actualJson) + } + + @Test + fun testTransportProtocolSerialization() { + val jsonRpc = TransportProtocol.JSONRPC + val httpJson = TransportProtocol.HTTP_JSON_REST + val grpc = TransportProtocol.GRPC + val custom = TransportProtocol("CUSTOM") + + assertEquals("\"JSONRPC\"", TestJson.encodeToString(jsonRpc)) + assertEquals("\"HTTP+JSON/REST\"", TestJson.encodeToString(httpJson)) + assertEquals("\"GRPC\"", TestJson.encodeToString(grpc)) + assertEquals("\"CUSTOM\"", TestJson.encodeToString(custom)) + + // Test deserialization + assertEquals(jsonRpc, TestJson.decodeFromString("\"JSONRPC\"")) + assertEquals(httpJson, TestJson.decodeFromString("\"HTTP+JSON/REST\"")) + assertEquals(grpc, TestJson.decodeFromString("\"GRPC\"")) + assertEquals(custom, TestJson.decodeFromString("\"CUSTOM\"")) + } + + @Test + fun testSecuritySchemeSerialization() { + // API Key Security Scheme + val apiKeyScheme = APIKeySecurityScheme( + `in` = In.Header, + name = "Authorization", + description = "Bearer token" + ) + //language=JSON + val apiKeyJson = """ + { + "in": "header", + "name": "Authorization", + "description": "Bearer token", + "type": "apiKey" + } + """.trimIndent() + assertEquals(apiKeyJson, TestJson.encodeToString(apiKeyScheme)) + + // HTTP Auth Security Scheme + val httpScheme = HTTPAuthSecurityScheme( + scheme = "Bearer", + bearerFormat = "JWT", + description = "JWT Bearer token" + ) + //language=JSON + val httpJson = """ + { + "scheme": "Bearer", + "bearerFormat": "JWT", + "description": "JWT Bearer token", + "type": "http" + } + """.trimIndent() + assertEquals(httpJson, TestJson.encodeToString(httpScheme)) + + // OAuth2 Security Scheme + val oauth2Scheme = OAuth2SecurityScheme( + flows = OAuthFlows( + authorizationCode = AuthorizationCodeOAuthFlow( + authorizationUrl = "https://auth.example.com/oauth/authorize", + tokenUrl = "https://auth.example.com/oauth/token", + scopes = mapOf("read" to "Read access", "write" to "Write access") + ) + ), + description = "OAuth2 with authorization code flow" + ) + //language=JSON + val expectedOAuth2Json = """ + { + "flows": { + "authorizationCode": { + "authorizationUrl": "https://auth.example.com/oauth/authorize", + "tokenUrl": "https://auth.example.com/oauth/token", + "scopes": { + "read": "Read access", + "write": "Write access" + } + } + }, + "description": "OAuth2 with authorization code flow", + "type": "oauth2" + } + """.trimIndent() + + val oauth2Json = TestJson.encodeToString(oauth2Scheme) + assertEquals(expectedOAuth2Json, oauth2Json) + + // OpenID Connect Security Scheme + val oidcScheme = OpenIdConnectSecurityScheme( + openIdConnectUrl = "https://auth.example.com/.well-known/openid_configuration" + ) + //language=JSON + val oidcJson = """ + { + "openIdConnectUrl": "https://auth.example.com/.well-known/openid_configuration", + "type": "openIdConnect" + } + """.trimIndent() + assertEquals(oidcJson, TestJson.encodeToString(oidcScheme)) + + // Mutual TLS Security Scheme + val mtlsScheme = MutualTLSSecurityScheme( + description = "Client certificate authentication" + ) + //language=JSON + val mtlsJson = """ + { + "description": "Client certificate authentication", + "type": "mutualTLS" + } + """.trimIndent() + assertEquals(mtlsJson, TestJson.encodeToString(mtlsScheme)) + + // Test deserialization + assertEquals(apiKeyScheme, TestJson.decodeFromString(apiKeyJson)) + assertEquals(httpScheme, TestJson.decodeFromString(httpJson)) + assertEquals(oauth2Scheme, TestJson.decodeFromString(expectedOAuth2Json)) + assertEquals(oidcScheme, TestJson.decodeFromString(oidcJson)) + assertEquals(mtlsScheme, TestJson.decodeFromString(mtlsJson)) + } + + @Test + fun testOAuthFlowsSerialization() { + val flows = OAuthFlows( + authorizationCode = AuthorizationCodeOAuthFlow( + authorizationUrl = "https://auth.example.com/oauth/authorize", + tokenUrl = "https://auth.example.com/oauth/token", + scopes = mapOf("read" to "Read access"), + refreshUrl = "https://auth.example.com/oauth/refresh" + ), + clientCredentials = ClientCredentialsOAuthFlow( + tokenUrl = "https://auth.example.com/oauth/token", + scopes = mapOf("admin" to "Admin access") + ), + implicit = ImplicitOAuthFlow( + authorizationUrl = "https://auth.example.com/oauth/authorize", + scopes = mapOf("read" to "Read access") + ), + password = PasswordOAuthFlow( + tokenUrl = "https://auth.example.com/oauth/token", + scopes = mapOf("user" to "User access") + ) + ) + + val json = TestJson.encodeToString(flows) + val deserialized = TestJson.decodeFromString(json) + assertEquals(flows, deserialized) + } + + @Test + fun testAgentCapabilitiesSerialization() { + // Test default capabilities + val defaultCapabilities = AgentCapabilities() + //language=JSON + val defaultJson = """ + { + "streaming": false, + "pushNotifications": false, + "stateTransitionHistory": false + } + """.trimIndent() + assertEquals(defaultJson, TestJson.encodeToString(defaultCapabilities)) + + // Test full capabilities + val fullCapabilities = AgentCapabilities( + streaming = true, + pushNotifications = true, + stateTransitionHistory = true, + extensions = listOf( + AgentExtension( + uri = "https://example.com/ext/v1", + description = "Test extension", + required = true + ) + ) + ) + //language=JSON + val expectedFullJson = """ + { + "streaming": true, + "pushNotifications": true, + "stateTransitionHistory": true, + "extensions": [ + { + "uri": "https://example.com/ext/v1", + "description": "Test extension", + "required": true + } + ] + } + """.trimIndent() + + val fullJson = TestJson.encodeToString(fullCapabilities) + assertEquals(expectedFullJson, fullJson) + + // Test deserialization + assertEquals(fullCapabilities, TestJson.decodeFromString(expectedFullJson)) + } + + @Test + fun testAgentSkillSerialization() { + val skill = AgentSkill( + id = "test-skill", + name = "Test Skill", + description = "A comprehensive test skill", + tags = listOf("test", "demo"), + examples = listOf("How to test?", "Show me a demo"), + inputModes = listOf("text/plain", "application/json"), + outputModes = listOf("text/plain"), + security = listOf( + mapOf("oauth" to listOf("read", "write")), + mapOf("api-key" to emptyList()) + ) + ) + + //language=JSON + val expectedJson = """ + { + "id": "test-skill", + "name": "Test Skill", + "description": "A comprehensive test skill", + "tags": [ + "test", + "demo" + ], + "examples": [ + "How to test?", + "Show me a demo" + ], + "inputModes": [ + "text/plain", + "application/json" + ], + "outputModes": [ + "text/plain" + ], + "security": [ + { + "oauth": [ + "read", + "write" + ] + }, + { + "api-key": [] + } + ] + } + """.trimIndent() + + val json = TestJson.encodeToString(skill) + assertEquals(expectedJson, json) + + val deserialized = TestJson.decodeFromString(json) + assertEquals(skill, deserialized) + } + + @Test + fun testEnumSerialization() { + assertEquals("\"cookie\"", TestJson.encodeToString(In.Cookie)) + assertEquals("\"header\"", TestJson.encodeToString(In.Header)) + assertEquals("\"query\"", TestJson.encodeToString(In.Query)) + + assertEquals(In.Cookie, TestJson.decodeFromString("\"cookie\"")) + assertEquals(In.Header, TestJson.decodeFromString("\"header\"")) + assertEquals(In.Query, TestJson.decodeFromString("\"query\"")) + } + + @Test + fun testAgentCardSignatureSerialization() { + val signature = AgentCardSignature( + `protected` = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9", + signature = "signature-data-here", + header = mapOf("kid" to kotlinx.serialization.json.JsonPrimitive("key-id-123")) + ) + + //language=JSON + val expectedJson = """ + { + "protected": "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9", + "signature": "signature-data-here", + "header": { + "kid": "key-id-123" + } + } + """.trimIndent() + + val json = TestJson.encodeToString(signature) + assertEquals(expectedJson, json) + + val deserialized = TestJson.decodeFromString(json) + assertEquals(signature, deserialized) + } +} diff --git a/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/model/ModelSerializationTest.kt b/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/transport/TransportSerializationTest.kt similarity index 93% rename from a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/model/ModelSerializationTest.kt rename to a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/transport/TransportSerializationTest.kt index 3798c1af2f..903357f697 100644 --- a/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/model/ModelSerializationTest.kt +++ b/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/transport/TransportSerializationTest.kt @@ -1,10 +1,10 @@ -package ai.koog.a2a.model +package ai.koog.a2a.transport import kotlinx.serialization.json.Json import kotlin.test.Test import kotlin.test.assertEquals -class ModelSerializationTest { +class TransportSerializationTest { @Suppress("PrivatePropertyName") private val TestJson = Json diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt index 7a2153f68c..68c29c4b93 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt @@ -2,7 +2,7 @@ package ai.koog.a2a.transport.jsonrpc.model -import ai.koog.a2a.model.RequestId +import ai.koog.a2a.transport.RequestId import kotlinx.serialization.EncodeDefault import kotlinx.serialization.Serializable import kotlinx.serialization.json.JsonElement @@ -12,13 +12,6 @@ import kotlinx.serialization.json.JsonElement */ public const val JSONRPC_VERSION: String = "2.0" -@Serializable -public data class JSONRPCError( - val code: Int, - val message: String, - val data: JsonElement? = null, -) - @Serializable(with = JSONRPCMessageSerializer::class) public sealed interface JSONRPCMessage { public val jsonrpc: String @@ -52,6 +45,13 @@ public data class JSONRPCSuccessResponse( override val jsonrpc: String = JSONRPC_VERSION, ) : JSONRPCResponse +@Serializable +public data class JSONRPCError( + val code: Int, + val message: String, + val data: JsonElement? = null, +) + @Serializable public data class JSONRPCErrorResponse( public val id: RequestId?, diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonTest/kotlin/ai/koog/a2a/transport/jsonrpc/model/JsonRpcSerializationTest.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonTest/kotlin/ai/koog/a2a/transport/jsonrpc/model/JsonRpcSerializationTest.kt index 92ccb5317a..9d17c3e19c 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonTest/kotlin/ai/koog/a2a/transport/jsonrpc/model/JsonRpcSerializationTest.kt +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonTest/kotlin/ai/koog/a2a/transport/jsonrpc/model/JsonRpcSerializationTest.kt @@ -1,5 +1,6 @@ package ai.koog.a2a.transport.jsonrpc.model +import ai.koog.a2a.model.RequestId import kotlinx.serialization.json.JsonPrimitive import kotlin.test.Test import kotlin.test.assertEquals From 4c7d84f6220903e04699c6d73477e717d5cbf07d Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Sat, 30 Aug 2025 03:07:08 +0200 Subject: [PATCH 18/52] [a2a] Add client and server transport abstractions --- a2a/a2a-client/build.gradle.kts | 1 + .../kotlin/ai/koog/a2a/client/.gitkeep | 0 .../kotlin/ai/koog/a2a/client/A2AClient.kt | 10 ++ .../ai/koog/a2a/exceptions/Exceptions.kt | 52 +++++- .../kotlin/ai/koog/a2a/model/AgentCard.kt | 41 ++++- .../kotlin/ai/koog/a2a/model/Core.kt | 8 +- .../kotlin/ai/koog/a2a/model/Message.kt | 2 +- .../ai/koog/a2a/model/MessageSendParams.kt | 37 ++++ .../ai/koog/a2a/model/{Part.kt => Parts.kt} | 0 .../kotlin/ai/koog/a2a/model/Serialization.kt | 10 +- .../kotlin/ai/koog/a2a/model/Task.kt | 2 +- .../kotlin/ai/koog/a2a/model/TaskEvents.kt | 7 +- .../kotlin/ai/koog/a2a/model/TaskParams.kt | 44 +++++ .../ai/koog/a2a/transport/ClientTransport.kt | 131 ++++++++++++++ .../a2a/transport/{Structs.kt => Core.kt} | 0 .../ai/koog/a2a/transport/ServerTransport.kt | 142 +++++++++++++++ a2a/a2a-server/build.gradle.kts | 1 + .../kotlin/ai/koog/a2a/server/.gitkeep | 0 .../kotlin/ai/koog/a2a/server/A2AServer.kt | 84 +++++++++ .../koog/a2a/transport/jsonrpc/A2AMethod.kt | 16 ++ .../jsonrpc/JSONRCPServerTransport.kt | 106 +++++++++++ .../jsonrpc/JSONRPCClientTransport.kt | 168 ++++++++++++++++++ .../a2a/transport/jsonrpc/model/Messages.kt | 7 +- .../jsonrpc/model/JsonRpcSerializationTest.kt | 4 +- 24 files changed, 849 insertions(+), 24 deletions(-) delete mode 100644 a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/.gitkeep create mode 100644 a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/MessageSendParams.kt rename a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/{Part.kt => Parts.kt} (100%) create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskParams.kt create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt rename a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/{Structs.kt => Core.kt} (100%) create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt delete mode 100644 a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/.gitkeep create mode 100644 a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt create mode 100644 a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/A2AMethod.kt create mode 100644 a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRCPServerTransport.kt create mode 100644 a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt diff --git a/a2a/a2a-client/build.gradle.kts b/a2a/a2a-client/build.gradle.kts index 36e5aa832a..ca41b397e3 100644 --- a/a2a/a2a-client/build.gradle.kts +++ b/a2a/a2a-client/build.gradle.kts @@ -18,6 +18,7 @@ kotlin { sourceSets { commonMain { dependencies { + api(project(":a2a:a2a-core")) api(libs.kotlinx.serialization.json) api(libs.kotlinx.coroutines.core) } diff --git a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/.gitkeep b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt new file mode 100644 index 0000000000..82989d097c --- /dev/null +++ b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt @@ -0,0 +1,10 @@ +package ai.koog.a2a.client + +import ai.koog.a2a.transport.ClientTransport + +/** + * A2A client responsible for sending requests to A2A server. + */ +public class A2AClient( + private val transport: ClientTransport, +) diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt index d2a4f1d87e..40103ea52f 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt @@ -1,5 +1,23 @@ package ai.koog.a2a.exceptions +/** + * Enum containing all A2A error codes. + */ +public enum class A2AErrorCode(public val value: Int) { + PARSE_ERROR(-32700), + INVALID_REQUEST(-32600), + METHOD_NOT_FOUND(-32601), + INVALID_PARAMS(-32602), + INTERNAL_ERROR(-32603), + TASK_NOT_FOUND(-32001), + TASK_NOT_CANCELABLE(-32002), + PUSH_NOTIFICATION_NOT_SUPPORTED(-32003), + UNSUPPORTED_OPERATION(-32004), + CONTENT_TYPE_NOT_SUPPORTED(-32005), + INVALID_AGENT_RESPONSE(-32006), + AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED(-32007) +} + /** * Base class for all A2A exceptions. */ @@ -46,7 +64,7 @@ public class InternalErrorException( /** * Reserved for implementation-defined server exceptions. A2A-specific exceptions use this range. */ -public open class A2AServerException( +public sealed class A2AServerException( message: String, errorCode: Int, ) : A2AException(message, errorCode) { @@ -108,3 +126,35 @@ public class InvalidAgentResponseException( public class AuthenticatedExtendedCardNotConfiguredException( message: String = "Authenticated Extended Card not configured", ) : A2AServerException(message, errorCode = -32007) + +/** + * Server returned some unknown error code. + */ +public class UnknownA2AException( + message: String, + errorCode: Int, +) : A2AException(message, errorCode) + +/** + * Create appropriate [A2AException] based on the provided errorCode. + */ +public fun createA2AException( + message: String, + errorCode: Int, +): A2AException { + return when (errorCode) { + A2AErrorCode.PARSE_ERROR.value -> ParseException(message) + A2AErrorCode.INVALID_REQUEST.value -> InvalidRequestException(message) + A2AErrorCode.METHOD_NOT_FOUND.value -> MethodNotFoundException(message) + A2AErrorCode.INVALID_PARAMS.value -> InvalidParamsException(message) + A2AErrorCode.INTERNAL_ERROR.value -> InternalErrorException(message) + A2AErrorCode.TASK_NOT_FOUND.value -> TaskNotFoundException(message) + A2AErrorCode.TASK_NOT_CANCELABLE.value -> TaskNotCancelableException(message) + A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED.value -> PushNotificationNotSupportedException(message) + A2AErrorCode.UNSUPPORTED_OPERATION.value -> UnsupportedOperationException(message) + A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED.value -> ContentTypeNotSupportedException(message) + A2AErrorCode.INVALID_AGENT_RESPONSE.value -> InvalidAgentResponseException(message) + A2AErrorCode.AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED.value -> AuthenticatedExtendedCardNotConfiguredException(message) + else -> UnknownA2AException(message, errorCode) + } +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt index f939261586..bd41a84b4a 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt @@ -87,8 +87,8 @@ public data class AgentCard( public val version: String, public val documentationUrl: String? = null, public val capabilities: AgentCapabilities, - public val securitySchemes: Map? = null, - public val security: List>>? = null, + public val securitySchemes: SecuritySchemes? = null, + public val security: Security? = null, public val defaultInputModes: List, public val defaultOutputModes: List, public val skills: List, @@ -164,8 +164,16 @@ public data class AgentProvider( * Defines optional capabilities supported by an agent. * * @property streaming Indicates if the agent supports Server-Sent Events (SSE) for streaming responses. + * * @property pushNotifications Indicates if the agent supports sending push notifications for asynchronous task updates. + * * @property stateTransitionHistory Indicates if the agent provides a history of state transitions for a task. + * + * TODO: it's not clear from the specification and official Python SDK, what does this field control. + * It's not [Task.history], since it always should be present. + * There are no further mentions or usages of this field in the official sources. + * So currently in our implementation it does not control anything. + * * @property extensions A list of protocol extensions supported by the agent. */ @Serializable @@ -195,6 +203,33 @@ public data class AgentExtension( public val params: Map? = null ) +/** + * A declaration of the security schemes available to authorize requests. The key is the scheme name. The value is the + * declaration of the security scheme object, which follows the OpenAPI 3.0 Security Scheme Object. + */ +public typealias SecuritySchemes = Map + +/** + * A list of alternative security requirements (a logical OR). To authorize a request, a client must satisfy one of the + * [SecurityRequirement]s in this list. + * + * For example, `[{"oauth": ["read"]}, {"apiKey": [], "mtls": []}]` means a client can use either OAuth with the "read" scope + * or both an API key and mTLS. + * + * @see [https://swagger.io/specification/#security-requirement-object] + */ +public typealias Security = List + +/** + * A set of security schemes that must be satisfied together (a logical AND). The key is a security scheme name, and the + * value is a list of required scopes. + * + * For example, `{"apiKey": [], "mtls": []}` requires both an API key and mTLS. + * + * @see [https://swagger.io/specification/#security-requirement-object] + */ +public typealias SecurityRequirement = Map> + /** * Defines a security scheme that can be used to secure an agent's endpoints. * This is a discriminated union type based on the OpenAPI 3.0 Security Scheme Object. @@ -425,7 +460,7 @@ public data class AgentSkill( public val examples: List? = null, public val inputModes: List? = null, public val outputModes: List? = null, - public val security: List>>? = null + public val security: Security? = null ) /** diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Core.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Core.kt index 9d4dd14391..1f06cf4d07 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Core.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Core.kt @@ -5,8 +5,8 @@ import kotlinx.serialization.Serializable /** * Base interface for events. */ -@Serializable(with = EventSerializer::class) -public sealed interface Event { +@Serializable(with = UpdateEventSerializer::class) +public sealed interface UpdateEvent { /** * The type used as discriminator. */ @@ -16,5 +16,5 @@ public sealed interface Event { /** * Base interface for communication units, such as messages or tasks. */ -@Serializable(with = CommunicationSerializer::class) -public sealed interface Communication : Event +@Serializable(with = CommunicationEventSerializer::class) +public sealed interface CommunicationEvent : UpdateEvent diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Message.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Message.kt index 0e3656d962..4832145eb7 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Message.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Message.kt @@ -44,7 +44,7 @@ public data class Message( public val referenceTaskIds: List? = null, public val contextId: String? = null, public val metadata: JsonObject? = null, -) : Communication { +) : CommunicationEvent { @EncodeDefault override val kind: String = "message" } diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/MessageSendParams.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/MessageSendParams.kt new file mode 100644 index 0000000000..e92c170a4d --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/MessageSendParams.kt @@ -0,0 +1,37 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.EncodeDefault +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject + +/** + * Defines the parameters for a request to send a message to an agent. This can be used + * to create a new task, continue an existing one, or restart a task. + * + * @property message The message object being sent to the agent. + * @property configuration Optional configuration for the send request. + * @property metadata Optional metadata for extensions. + */ +@Serializable +public data class MessageSendParams( + public val message: Message, + public val configuration: MessageSendConfiguration? = null, + public val metadata: JsonObject? = null, +) + +/** + * Defines configuration options for a `message/send` or `message/stream` request. + * + * @property acceptedOutputModes A list of output MIME types the client is prepared to accept in the response. + * @property historyLength The number of most recent messages from the task's history to retrieve in the response. + * @property pushNotificationConfig Configuration for the agent to send push notifications for updates after the initial response. + * @property blocking If true, the client will wait for the task to complete. The server may reject this if the task is long-running. + */ +@Serializable +public data class MessageSendConfiguration( + @EncodeDefault + public val blocking: Boolean = false, + public val acceptedOutputModes: List? = null, + public val historyLength: Int? = null, + public val pushNotificationConfig: PushNotificationConfig? = null, +) diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Part.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Parts.kt similarity index 100% rename from a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Part.kt rename to a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Parts.kt diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt index 5d9ccde594..a3f69b6346 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt @@ -48,21 +48,21 @@ internal object FileSerializer : JsonContentPolymorphicSerializer(File::cl } } -internal object EventSerializer : JsonContentPolymorphicSerializer(Event::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { +internal object UpdateEventSerializer : JsonContentPolymorphicSerializer(UpdateEvent::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { val jsonObject = element.jsonObject val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: error("Missing 'kind' field in Event") return when (kind) { "status-update" -> TaskStatusUpdateEvent.serializer() "artifact-update" -> TaskArtifactUpdateEvent.serializer() - else -> CommunicationSerializer + else -> CommunicationEventSerializer } } } -internal object CommunicationSerializer : JsonContentPolymorphicSerializer(Communication::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { +internal object CommunicationEventSerializer : JsonContentPolymorphicSerializer(CommunicationEvent::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { val jsonObject = element.jsonObject val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: error("Missing 'kind' field in Communication") diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt index fb8fff7702..3c57061f05 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt @@ -28,7 +28,7 @@ public data class Task( public val history: List? = null, public val artifacts: List? = null, public val metadata: JsonObject? = null, -) : Communication { +) : CommunicationEvent { @EncodeDefault override val kind: String = "task" } diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskEvents.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskEvents.kt index b589b01dd0..e4cb3178aa 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskEvents.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskEvents.kt @@ -21,7 +21,7 @@ public data class TaskStatusUpdateEvent( public val status: TaskStatus, public val final: Boolean, public val metadata: JsonObject? = null, -) : Event { +) : UpdateEvent { @EncodeDefault override val kind: String = "status-update" } @@ -42,10 +42,11 @@ public data class TaskArtifactUpdateEvent( public val taskId: String, public val contextId: String, public val artifact: Artifact, - public val append: Boolean? = null, + @EncodeDefault + public val append: Boolean = false, public val lastChunk: Boolean? = null, public val metadata: JsonObject? = null, -) : Event { +) : UpdateEvent { @EncodeDefault override val kind: String = "artifact-update" } diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskParams.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskParams.kt new file mode 100644 index 0000000000..606d70806c --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskParams.kt @@ -0,0 +1,44 @@ +package ai.koog.a2a.model + +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject + +/** + * Defines parameters containing a task ID, used for simple task operations. + * + * @property id The unique identifier (e.g. UUID) of the task. + * @property metadata Optional metadata associated with this request. + */ +@Serializable +public data class TaskIdParams( + public val id: String, + public val metadata: JsonObject? = null, +) + +/** + * Defines parameters for querying a task, with an option to limit history length. + * + * @property id The unique identifier (e.g. UUID) of the task. + * @property historyLength The number of most recent messages from the task's history to retrieve. + * @property metadata Optional metadata associated with this request. + */ +@Serializable +public data class TaskQueryParams( + public val id: String, + public val historyLength: Int? = null, + public val metadata: JsonObject? = null, +) + +/** + * Defines parameters for fetching a specific push notification configuration for a task. + * + * @property id The unique identifier (e.g. UUID) of the task. + * @property pushNotificationConfigId The ID of the push notification configuration to retrieve. + * @property metadata Optional metadata associated with this request. + */ +@Serializable +public data class TaskPushNotificationConfigParams( + public val id: String, + public val pushNotificationConfigId: String? = null, + public val metadata: JsonObject? = null, +) diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt new file mode 100644 index 0000000000..96a110edbb --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt @@ -0,0 +1,131 @@ +package ai.koog.a2a.transport + +import ai.koog.a2a.exceptions.A2AException +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskPushNotificationConfig +import ai.koog.a2a.model.TaskPushNotificationConfigParams +import ai.koog.a2a.model.TaskQueryParams +import ai.koog.a2a.model.UpdateEvent +import kotlinx.coroutines.flow.Flow +import kotlinx.serialization.SerializationException + +/** + * Client transport making requests to [A2A protocol methods](https://a2a-protocol.org/latest/specification/#7-protocol-rpc-methods) + * and handling responses from the server. + * + * Client transport must handle error responses from the server and convert them to appropriate [A2AException] + * (e.g. parsing error response data format like JSON error object and throwing corresponding [A2AException] based on the error code). + * It must preserve the [A2AException.errorCode] received from the [ServerTransport]. + * + * Client transport may throw exceptions other than [A2AException] for any transport-level errors (e.g. network failures, invalid responses, timeout), + * e.g. [SerializationException] + */ +public interface ClientTransport : AutoCloseable { + /** + * Implements [agent/getAuthenticatedExtendedCard](https://a2a-protocol.org/latest/specification/#710-agentgetauthenticatedextendedcard) + * + * @throws A2AException if server returned an error. + */ + public suspend fun getAuthenticatedExtendedAgentCard( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response + + /** + * Implements [message/send](https://a2a-protocol.org/latest/specification/#71-messagesend). + * + * @throws A2AException if server returned an error. + */ + public suspend fun sendMessage( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response + + /** + * Implements [message/stream](https://a2a-protocol.org/latest/specification/#72-messagestream) + * + * @throws A2AException if server returned an error. + */ + public suspend fun sendMessageStreaming( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Flow> + + /** + * Implements [tasks/get](https://a2a-protocol.org/latest/specification/#73-tasksget) + * + * @throws A2AException if server returned an error. + */ + public suspend fun getTask( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response + + /** + * Implements [tasks/cancel](https://a2a-protocol.org/latest/specification/#74-taskscancel) + * + * @throws A2AException if server returned an error. + */ + public suspend fun cancelTask( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response + + /** + * Implements [tasks/pushNotificationConfig/set](https://a2a-protocol.org/latest/specification/#75-taskspushnotificationconfigset) + * + * @throws A2AException if server returned an error. + */ + public suspend fun setTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response + + /** + * Implements [tasks/pushNotificationConfig/get](https://a2a-protocol.org/latest/specification/#76-taskspushnotificationconfigget) + * + * @throws A2AException if server returned an error. + */ + public suspend fun getTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response + + /** + * Implements [tasks/pushNotificationConfig/list](https://a2a-protocol.org/latest/specification/#77-taskspushnotificationconfiglist) + * + * @throws A2AException if server returned an error. + */ + public suspend fun listTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response> + + /** + * Implements [tasks/pushNotificationConfig/delete](https://a2a-protocol.org/latest/specification/#78-taskspushnotificationconfigdelete) + * + * @throws A2AException if server returned an error. + */ + public suspend fun deleteTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response +} + +/** + * Represents the client context of a call. + * + * @property additionalHeaders Additional call-specific headers associated with the call. + */ +public class ClientCallContext( + public val additionalHeaders: Map = emptyMap(), +) { + @Suppress("MissingKDocForPublicAPI") + public companion object { + public val Default: ClientCallContext = ClientCallContext() + } +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Structs.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Core.kt similarity index 100% rename from a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Structs.kt rename to a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Core.kt diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt new file mode 100644 index 0000000000..10a8f21ddb --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt @@ -0,0 +1,142 @@ +package ai.koog.a2a.transport + +import ai.koog.a2a.exceptions.A2AException +import ai.koog.a2a.exceptions.InternalErrorException +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskPushNotificationConfig +import ai.koog.a2a.model.TaskPushNotificationConfigParams +import ai.koog.a2a.model.TaskQueryParams +import ai.koog.a2a.model.UpdateEvent +import kotlinx.coroutines.flow.Flow + +/** + * Server transport processing raw requests made to [A2A protocol methods](https://a2a-protocol.org/latest/specification/#7-protocol-rpc-methods) + * and delegating the processing to [RequestHandler]. + * + * Server transport must respond with appropriate [A2AException] in case of errors while processing the request + * (e.g. method not found or invalid method parameters). It must also handle [A2AException] thrown by the [RequestHandler] methods. + * In case non [A2AException] is thrown, it must be converted to [InternalErrorException] with appropriate message. + * + * Server transport must convert [A2AException] to appropriate response data format (e.g. JSON error object), + * preserving the [A2AException.errorCode] so that it can be properly handled by the [ClientTransport]. + */ +public interface ServerTransport { + /** + * Handler responsible for processing parsed A2A requests. + */ + public val requestHandler: RequestHandler +} + +/** + * Handler responsible for processing parsed A2A requests, implementing + * [A2A protocol methods](https://a2a-protocol.org/latest/specification/#7-protocol-rpc-methods). + */ +public interface RequestHandler { + /** + * Handles [agent/getAuthenticatedExtendedCard](https://a2a-protocol.org/latest/specification/#710-agentgetauthenticatedextendedcard) + * + * @throws A2AException if there is an error with processsing the request. + */ + public suspend fun onGetAuthenticatedExtendedAgentCard( + request: Request, + ctx: ServerCallContext + ): Response + + /** + * Handles [message/send](https://a2a-protocol.org/latest/specification/#71-messagesend). + * + * @throws A2AException if there is an error with processsing the request. + */ + public suspend fun onSendMessage( + request: Request, + ctx: ServerCallContext + ): Response + + /** + * Handles [message/stream](https://a2a-protocol.org/latest/specification/#72-messagestream) + * + * @throws A2AException if there is an error with processsing the request. + */ + public suspend fun onSendMessageStreaming( + request: Request, + ctx: ServerCallContext + ): Flow> + + /** + * Handles [tasks/get](https://a2a-protocol.org/latest/specification/#73-tasksget) + * + * @throws A2AException if there is an error with processsing the request. + */ + public suspend fun onGetTask( + request: Request, + ctx: ServerCallContext + ): Response + + /** + * Handles [tasks/cancel](https://a2a-protocol.org/latest/specification/#74-taskscancel) + * + * @throws A2AException if there is an error with processsing the request. + */ + public suspend fun onCancelTask( + request: Request, + ctx: ServerCallContext + ): Response + + /** + * Handles [tasks/pushNotificationConfig/set](https://a2a-protocol.org/latest/specification/#75-taskspushnotificationconfigset) + * + * @throws A2AException if there is an error with processsing the request. + */ + public suspend fun onSetTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response + + /** + * Handles [tasks/pushNotificationConfig/get](https://a2a-protocol.org/latest/specification/#76-taskspushnotificationconfigget) + * + * @throws A2AException if there is an error with processsing the request. + */ + public suspend fun onGetTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response + + /** + * Handles [tasks/pushNotificationConfig/list](https://a2a-protocol.org/latest/specification/#77-taskspushnotificationconfiglist) + * + * @throws A2AException if there is an error with processsing the request. + */ + public suspend fun onListTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response> + + /** + * Handles [tasks/pushNotificationConfig/delete](https://a2a-protocol.org/latest/specification/#78-taskspushnotificationconfigdelete) + * + * @throws A2AException if there is an error with processsing the request. + */ + public suspend fun onDeleteTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response +} + +/** + * Represents the server context of a call. + * + * @property headers Headers associated with the call. + */ +public class ServerCallContext( + public val headers: Map = emptyMap(), +) { + @Suppress("MissingKDocForPublicAPI") + public companion object { + public val Default: ServerCallContext = ServerCallContext() + } +} diff --git a/a2a/a2a-server/build.gradle.kts b/a2a/a2a-server/build.gradle.kts index 36e5aa832a..ca41b397e3 100644 --- a/a2a/a2a-server/build.gradle.kts +++ b/a2a/a2a-server/build.gradle.kts @@ -18,6 +18,7 @@ kotlin { sourceSets { commonMain { dependencies { + api(project(":a2a:a2a-core")) api(libs.kotlinx.serialization.json) api(libs.kotlinx.coroutines.core) } diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/.gitkeep b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt new file mode 100644 index 0000000000..067dc36c21 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt @@ -0,0 +1,84 @@ +package ai.koog.a2a.server + +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskPushNotificationConfig +import ai.koog.a2a.model.TaskPushNotificationConfigParams +import ai.koog.a2a.model.TaskQueryParams +import ai.koog.a2a.model.UpdateEvent +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.RequestHandler +import ai.koog.a2a.transport.Response +import ai.koog.a2a.transport.ServerCallContext +import kotlinx.coroutines.flow.Flow + +/** + * A2A server responsible for handling requests from A2A clients. + */ +public class A2AServer : RequestHandler { + override suspend fun onGetAuthenticatedExtendedAgentCard( + request: Request, + ctx: ServerCallContext + ): Response { + TODO("Not yet implemented") + } + + override suspend fun onSendMessage( + request: Request, + ctx: ServerCallContext + ): Response { + TODO("Not yet implemented") + } + + override suspend fun onSendMessageStreaming( + request: Request, + ctx: ServerCallContext + ): Flow> { + TODO("Not yet implemented") + } + + override suspend fun onGetTask( + request: Request, + ctx: ServerCallContext + ): Response { + TODO("Not yet implemented") + } + + override suspend fun onCancelTask( + request: Request, + ctx: ServerCallContext + ): Response { + TODO("Not yet implemented") + } + + override suspend fun onSetTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response { + TODO("Not yet implemented") + } + + override suspend fun onGetTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response { + TODO("Not yet implemented") + } + + override suspend fun onListTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response> { + TODO("Not yet implemented") + } + + override suspend fun onDeleteTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response { + TODO("Not yet implemented") + } +} diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/A2AMethod.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/A2AMethod.kt new file mode 100644 index 0000000000..0d0519abd8 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/A2AMethod.kt @@ -0,0 +1,16 @@ +package ai.koog.a2a.transport.jsonrpc + +/** + * A2A JSON-RPC methods. + */ +public enum class A2AMethod(public val value: String) { + GetAuthenticatedExtendedAgentCard("agent/getAuthenticatedExtendedCard"), + SendMessage("message/send"), + SendMessageStreaming("message/stream"), + GetTask("tasks/get"), + CancelTask("tasks/cancel"), + SetTaskPushNotificationConfig("tasks/pushNotificationConfig/set"), + GetTaskPushNotificationConfig("tasks/pushNotificationConfig/get"), + ListTaskPushNotificationConfig("tasks/pushNotificationConfig/list"), + DeleteTaskPushNotificationConfig("tasks/pushNotificationConfig/delete"), +} diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRCPServerTransport.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRCPServerTransport.kt new file mode 100644 index 0000000000..0d4d30a3a5 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRCPServerTransport.kt @@ -0,0 +1,106 @@ +package ai.koog.a2a.transport.jsonrpc + +import ai.koog.a2a.exceptions.A2AException +import ai.koog.a2a.exceptions.InvalidParamsException +import ai.koog.a2a.exceptions.MethodNotFoundException +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.Response +import ai.koog.a2a.transport.ServerCallContext +import ai.koog.a2a.transport.ServerTransport +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCJson +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCRequest +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCSuccessResponse +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map +import kotlinx.serialization.SerializationException +import kotlinx.serialization.json.decodeFromJsonElement +import kotlinx.serialization.json.encodeToJsonElement + +/** + * Abstract transport implementation for JSON-RPC-based server communication. + * Handles receiving JSON-RPC requests, processing them, and sending responses. + */ +public abstract class JSONRCPServerTransport : ServerTransport { + /** + * Handles a JSON-RPC request and returns the corresponding response. + * + * @throws A2AException if there's an error processing the request. + */ + public suspend fun onRequest( + request: JSONRPCRequest, + ctx: ServerCallContext, + ): JSONRPCSuccessResponse { + return when (request.method) { + A2AMethod.GetAuthenticatedExtendedAgentCard.value -> + requestHandler.onGetAuthenticatedExtendedAgentCard(request.toRequest(), ctx).toJSONRPCResponse() + + A2AMethod.SendMessage.value -> + requestHandler.onSendMessage(request.toRequest(), ctx).toJSONRPCResponse() + + A2AMethod.GetTask.value -> + requestHandler.onGetTask(request.toRequest(), ctx).toJSONRPCResponse() + + A2AMethod.CancelTask.value -> + requestHandler.onCancelTask(request.toRequest(), ctx).toJSONRPCResponse() + + A2AMethod.SetTaskPushNotificationConfig.value -> + requestHandler.onSetTaskPushNotificationConfig(request.toRequest(), ctx).toJSONRPCResponse() + + A2AMethod.GetTaskPushNotificationConfig.value -> + requestHandler.onGetTaskPushNotificationConfig(request.toRequest(), ctx).toJSONRPCResponse() + + A2AMethod.ListTaskPushNotificationConfig.value -> + requestHandler.onListTaskPushNotificationConfig(request.toRequest(), ctx).toJSONRPCResponse() + + A2AMethod.DeleteTaskPushNotificationConfig.value -> + requestHandler.onDeleteTaskPushNotificationConfig(request.toRequest(), ctx).toJSONRPCResponse() + + else -> throw MethodNotFoundException(request.method) + } + } + + /** + * Handles a JSON-RPC request and returns the corresponding response stream. + * + * @throws A2AException if there's an error processing the request. + */ + public suspend fun onRequestStreaming( + request: JSONRPCRequest, + ctx: ServerCallContext, + ): Flow { + return when (request.method) { + A2AMethod.SendMessageStreaming.value -> + requestHandler.onSendMessageStreaming(request.toRequest(), ctx).map { it.toJSONRPCResponse() } + + else -> throw MethodNotFoundException(request.method) + } + } + + /** + * Convert generic [JSONRPCRequest] to [Request]. + * + * @throws InvalidParamsException if request params cannot be parsed to [T]. + */ + protected inline fun JSONRPCRequest.toRequest(): Request { + val data = try { + JSONRPCJson.decodeFromJsonElement(params) + } catch (_: SerializationException) { + throw InvalidParamsException("Cannot parse request params to ${T::class}") + } + + return Request( + id = id, + data = data + ) + } + + /** + * Convert generic [Response] to [JSONRPCSuccessResponse]. + */ + protected inline fun Response.toJSONRPCResponse(): JSONRPCSuccessResponse { + return JSONRPCSuccessResponse( + id = id, + result = JSONRPCJson.encodeToJsonElement(data) + ) + } +} diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt new file mode 100644 index 0000000000..3fb00feb2f --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt @@ -0,0 +1,168 @@ +package ai.koog.a2a.transport.jsonrpc + +import ai.koog.a2a.exceptions.A2AException +import ai.koog.a2a.exceptions.createA2AException +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskPushNotificationConfig +import ai.koog.a2a.model.TaskPushNotificationConfigParams +import ai.koog.a2a.model.TaskQueryParams +import ai.koog.a2a.model.UpdateEvent +import ai.koog.a2a.transport.ClientCallContext +import ai.koog.a2a.transport.ClientTransport +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.Response +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCError +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCErrorResponse +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCJson +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCRequest +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCResponse +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCSuccessResponse +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.map +import kotlinx.serialization.json.decodeFromJsonElement +import kotlinx.serialization.json.encodeToJsonElement + +/** + * Abstract transport implementation for JSON-RPC-based client communication. + * Handles sending JSON-RPC requests, processing responses, and mapping them to expected types. + */ +public abstract class JSONRPCClientTransport : ClientTransport { + /** + * Sends a JSON-RPC request and returns the corresponding response. + * + * @throws A2AException if server returned an error. + */ + public abstract suspend fun request( + request: JSONRPCRequest, + ctx: ClientCallContext, + ): JSONRPCResponse + + /** + * Sends a JSON-RPC request and returns the corresponding response stream. + * + * @throws A2AException if server returned an error. + */ + public abstract suspend fun requestStreaming( + request: JSONRPCRequest, + ctx: ClientCallContext, + ): Flow + + /** + * Convert generic [Request] to [JSONRPCRequest]. + */ + protected inline fun Request.toJSONRPCRequest(method: A2AMethod): JSONRPCRequest { + return JSONRPCRequest( + id = id, + method = method.value, + params = JSONRPCJson.encodeToJsonElement(data) + ) + } + + /** + * Convert [JSONRPCResponse] to generic [Response]. + * + * @throws A2AException if server returned an error. + */ + protected inline fun JSONRPCResponse.toResponse(): Response { + return when (this) { + is JSONRPCSuccessResponse -> Response( + id = id, + data = JSONRPCJson.decodeFromJsonElement(result), + ) + + is JSONRPCErrorResponse -> { + throw error.toA2AException() + } + } + } + + protected fun JSONRPCError.toA2AException(): A2AException { + return createA2AException(message, code) + } + + /** + * Generic request processing. + */ + protected suspend inline fun request( + method: A2AMethod, + request: Request, + ctx: ClientCallContext + ): Response { + val jsonrpcRequest = request.toJSONRPCRequest(method) + val jsonrpcResponse = request(jsonrpcRequest, ctx) + + return jsonrpcResponse.toResponse() + } + + /** + * Generic streaming request processing. + */ + protected suspend inline fun requestStreaming( + method: A2AMethod, + request: Request, + ctx: ClientCallContext + ): Flow> { + val jsonrpcRequest = request.toJSONRPCRequest(method) + val jsonrpcResponse = requestStreaming(jsonrpcRequest, ctx) + + return jsonrpcResponse.map { it.toResponse() } + } + + override suspend fun getAuthenticatedExtendedAgentCard( + request: Request, + ctx: ClientCallContext + ): Response = + request(A2AMethod.GetAuthenticatedExtendedAgentCard, request, ctx) + + override suspend fun sendMessage( + request: Request, + ctx: ClientCallContext + ): Response = + request(A2AMethod.SendMessage, request, ctx) + + override suspend fun sendMessageStreaming( + request: Request, + ctx: ClientCallContext + ): Flow> = + requestStreaming(A2AMethod.SendMessageStreaming, request, ctx) + + override suspend fun getTask( + request: Request, + ctx: ClientCallContext + ): Response = + request(A2AMethod.GetTask, request, ctx) + + override suspend fun cancelTask( + request: Request, + ctx: ClientCallContext + ): Response = + request(A2AMethod.CancelTask, request, ctx) + + override suspend fun setTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext + ): Response = + request(A2AMethod.SetTaskPushNotificationConfig, request, ctx) + + override suspend fun getTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext + ): Response = + request(A2AMethod.GetTaskPushNotificationConfig, request, ctx) + + override suspend fun listTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext + ): Response> = + request(A2AMethod.ListTaskPushNotificationConfig, request, ctx) + + override suspend fun deleteTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext + ): Response = + request(A2AMethod.DeleteTaskPushNotificationConfig, request, ctx) +} diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt index 68c29c4b93..3f7707b3ad 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt @@ -6,6 +6,7 @@ import ai.koog.a2a.transport.RequestId import kotlinx.serialization.EncodeDefault import kotlinx.serialization.Serializable import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonNull /** * Default JSON-RPC version. @@ -24,7 +25,7 @@ public sealed interface JSONRPCResponse : JSONRPCMessage public data class JSONRPCRequest( public val id: RequestId, val method: String, - val params: JsonElement?, + val params: JsonElement = JsonNull, @EncodeDefault override val jsonrpc: String = JSONRPC_VERSION, ) : JSONRPCMessage @@ -32,7 +33,7 @@ public data class JSONRPCRequest( @Serializable public data class JSONRPCNotification( val method: String, - val params: JsonElement?, + val params: JsonElement = JsonNull, @EncodeDefault override val jsonrpc: String = JSONRPC_VERSION, ) : JSONRPCMessage @@ -49,7 +50,7 @@ public data class JSONRPCSuccessResponse( public data class JSONRPCError( val code: Int, val message: String, - val data: JsonElement? = null, + val data: JsonElement = JsonNull, ) @Serializable diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonTest/kotlin/ai/koog/a2a/transport/jsonrpc/model/JsonRpcSerializationTest.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonTest/kotlin/ai/koog/a2a/transport/jsonrpc/model/JsonRpcSerializationTest.kt index 9d17c3e19c..268432d441 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonTest/kotlin/ai/koog/a2a/transport/jsonrpc/model/JsonRpcSerializationTest.kt +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonTest/kotlin/ai/koog/a2a/transport/jsonrpc/model/JsonRpcSerializationTest.kt @@ -1,6 +1,6 @@ package ai.koog.a2a.transport.jsonrpc.model -import ai.koog.a2a.model.RequestId +import ai.koog.a2a.transport.RequestId import kotlinx.serialization.json.JsonPrimitive import kotlin.test.Test import kotlin.test.assertEquals @@ -24,7 +24,6 @@ class JsonRpcSerializationTest { val request: JSONRPCMessage = JSONRPCRequest( id = RequestId.NumberId(42), method = "add", - params = null ) //language=JSON @@ -58,7 +57,6 @@ class JsonRpcSerializationTest { fun testJSONRPCNotificationWithoutParams() { val request: JSONRPCMessage = JSONRPCNotification( method = "notify", - params = null ) //language=JSON From e044eae2449168f0b4eb8fcf2c6f8465ea1bb6a2 Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Sun, 31 Aug 2025 23:03:19 +0200 Subject: [PATCH 19/52] [a2a] Implement HTTP JSON-RPC client transport with tests --- .../ai/koog/a2a/exceptions/Exceptions.kt | 54 +-- .../ai/koog/a2a/transport/ClientTransport.kt | 4 +- .../ai/koog/a2a/transport/ServerTransport.kt | 8 +- .../kotlin/ai/koog/a2a/utils/ResultUtils.kt | 18 + .../kotlin/ai/koog/a2a/server/A2AServer.kt | 2 +- .../build.gradle.kts | 6 + .../transport/client/jsonrpc/http/.gitkeep | 0 .../http/HttpJSONRPCClientTransport.kt | 101 ++++ .../http/HttpJSONRPCClientTransportTest.kt | 440 ++++++++++++++++++ .../jsonrpc/JSONRCPServerTransport.kt | 106 ----- .../jsonrpc/JSONRPCClientTransport.kt | 12 +- .../jsonrpc/JSONRPCServerTransport.kt | 140 ++++++ .../build.gradle.kts | 10 + .../transport/server/jsonrpc/http/.gitkeep | 0 .../http/HttpJSONRPCServerTransport.kt | 108 +++++ gradle/libs.versions.toml | 1 + 16 files changed, 862 insertions(+), 148 deletions(-) create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/ResultUtils.kt delete mode 100644 a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/.gitkeep create mode 100644 a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransport.kt create mode 100644 a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportTest.kt delete mode 100644 a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRCPServerTransport.kt create mode 100644 a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt delete mode 100644 a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/.gitkeep create mode 100644 a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt index 40103ea52f..40aec678a7 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt @@ -22,42 +22,42 @@ public enum class A2AErrorCode(public val value: Int) { * Base class for all A2A exceptions. */ public sealed class A2AException( - message: String, + public override val message: String, public val errorCode: Int ) : Exception(message) /** * Server received JSON that was not well-formed. */ -public class ParseException( +public class A2AParseException( message: String = "Invalid JSON payload", ) : A2AException(message, errorCode = -32700) /** * The JSON payload was valid JSON, but not a valid JSON-RPC Request object. */ -public class InvalidRequestException( +public class A2AInvalidRequestException( message: String = "Invalid JSON-RPC Request", ) : A2AException(message, errorCode = -32600) /** * The requested A2A RPC method does not exist or is not supported. */ -public class MethodNotFoundException( +public class A2AMethodNotFoundException( message: String = "Method not found", ) : A2AException(message, errorCode = -32601) /** * The params provided for the method are invalid. */ -public class InvalidParamsException( +public class A2AInvalidParamsException( message: String = "Invalid method parameters", ) : A2AException(message, errorCode = -32602) /** * An unexpected error occurred on the server during processing. */ -public class InternalErrorException( +public class A2AInternalErrorException( message: String = "Internal server error", ) : A2AException(message, errorCode = -32603) @@ -77,7 +77,7 @@ public sealed class A2AServerException( * The specified task id does not correspond to an existing or active task. * It might be invalid, expired, or already completed and purged. */ -public class TaskNotFoundException( +public class A2ATaskNotFoundException( message: String = "Task not found", ) : A2AServerException(message, errorCode = -32001) @@ -85,7 +85,7 @@ public class TaskNotFoundException( * An attempt was made to cancel a task that is not in a cancelable state. * The task has already reached a terminal state like completed, failed, or canceled. */ -public class TaskNotCancelableException( +public class A2ATaskNotCancelableException( message: String = "Task cannot be canceled", ) : A2AServerException(message, errorCode = -32002) @@ -93,7 +93,7 @@ public class TaskNotCancelableException( * Client attempted to use push notification features but the server agent does not support them. * The server's AgentCard.capabilities.pushNotifications is false. */ -public class PushNotificationNotSupportedException( +public class A2APushNotificationNotSupportedException( message: String = "Push Notification is not supported", ) : A2AServerException(message, errorCode = -32003) @@ -101,7 +101,7 @@ public class PushNotificationNotSupportedException( * The requested operation or a specific aspect of it is not supported by this server agent implementation. * This is broader than just method not found. */ -public class UnsupportedOperationException( +public class A2AUnsupportedOperationException( message: String = "This operation is not supported", ) : A2AServerException(message, errorCode = -32004) @@ -109,28 +109,28 @@ public class UnsupportedOperationException( * A Media Type provided in the request's message.parts or implied for an artifact is not supported * by the agent or the specific skill being invoked. */ -public class ContentTypeNotSupportedException( +public class A2AContentTypeNotSupportedException( message: String = "Incompatible content types", ) : A2AServerException(message, errorCode = -32005) /** * Agent generated an invalid response for the requested method. */ -public class InvalidAgentResponseException( +public class A2AInvalidAgentResponseException( message: String = "Invalid agent response type", ) : A2AServerException(message, errorCode = -32006) /** * The agent does not have an Authenticated Extended Card configured. */ -public class AuthenticatedExtendedCardNotConfiguredException( +public class A2AAuthenticatedExtendedCardNotConfiguredException( message: String = "Authenticated Extended Card not configured", ) : A2AServerException(message, errorCode = -32007) /** * Server returned some unknown error code. */ -public class UnknownA2AException( +public class A2AUnknownException( message: String, errorCode: Int, ) : A2AException(message, errorCode) @@ -143,18 +143,18 @@ public fun createA2AException( errorCode: Int, ): A2AException { return when (errorCode) { - A2AErrorCode.PARSE_ERROR.value -> ParseException(message) - A2AErrorCode.INVALID_REQUEST.value -> InvalidRequestException(message) - A2AErrorCode.METHOD_NOT_FOUND.value -> MethodNotFoundException(message) - A2AErrorCode.INVALID_PARAMS.value -> InvalidParamsException(message) - A2AErrorCode.INTERNAL_ERROR.value -> InternalErrorException(message) - A2AErrorCode.TASK_NOT_FOUND.value -> TaskNotFoundException(message) - A2AErrorCode.TASK_NOT_CANCELABLE.value -> TaskNotCancelableException(message) - A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED.value -> PushNotificationNotSupportedException(message) - A2AErrorCode.UNSUPPORTED_OPERATION.value -> UnsupportedOperationException(message) - A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED.value -> ContentTypeNotSupportedException(message) - A2AErrorCode.INVALID_AGENT_RESPONSE.value -> InvalidAgentResponseException(message) - A2AErrorCode.AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED.value -> AuthenticatedExtendedCardNotConfiguredException(message) - else -> UnknownA2AException(message, errorCode) + A2AErrorCode.PARSE_ERROR.value -> A2AParseException(message) + A2AErrorCode.INVALID_REQUEST.value -> A2AInvalidRequestException(message) + A2AErrorCode.METHOD_NOT_FOUND.value -> A2AMethodNotFoundException(message) + A2AErrorCode.INVALID_PARAMS.value -> A2AInvalidParamsException(message) + A2AErrorCode.INTERNAL_ERROR.value -> A2AInternalErrorException(message) + A2AErrorCode.TASK_NOT_FOUND.value -> A2ATaskNotFoundException(message) + A2AErrorCode.TASK_NOT_CANCELABLE.value -> A2ATaskNotCancelableException(message) + A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED.value -> A2APushNotificationNotSupportedException(message) + A2AErrorCode.UNSUPPORTED_OPERATION.value -> A2AUnsupportedOperationException(message) + A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED.value -> A2AContentTypeNotSupportedException(message) + A2AErrorCode.INVALID_AGENT_RESPONSE.value -> A2AInvalidAgentResponseException(message) + A2AErrorCode.AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED.value -> A2AAuthenticatedExtendedCardNotConfiguredException(message) + else -> A2AUnknownException(message, errorCode) } } diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt index 96a110edbb..339346be60 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt @@ -50,7 +50,7 @@ public interface ClientTransport : AutoCloseable { * * @throws A2AException if server returned an error. */ - public suspend fun sendMessageStreaming( + public fun sendMessageStreaming( request: Request, ctx: ClientCallContext = ClientCallContext.Default ): Flow> @@ -122,7 +122,7 @@ public interface ClientTransport : AutoCloseable { * @property additionalHeaders Additional call-specific headers associated with the call. */ public class ClientCallContext( - public val additionalHeaders: Map = emptyMap(), + public val additionalHeaders: Map> = emptyMap(), ) { @Suppress("MissingKDocForPublicAPI") public companion object { diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt index 10a8f21ddb..0e31272b33 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt @@ -1,7 +1,7 @@ package ai.koog.a2a.transport import ai.koog.a2a.exceptions.A2AException -import ai.koog.a2a.exceptions.InternalErrorException +import ai.koog.a2a.exceptions.A2AInternalErrorException import ai.koog.a2a.model.AgentCard import ai.koog.a2a.model.CommunicationEvent import ai.koog.a2a.model.MessageSendParams @@ -19,7 +19,7 @@ import kotlinx.coroutines.flow.Flow * * Server transport must respond with appropriate [A2AException] in case of errors while processing the request * (e.g. method not found or invalid method parameters). It must also handle [A2AException] thrown by the [RequestHandler] methods. - * In case non [A2AException] is thrown, it must be converted to [InternalErrorException] with appropriate message. + * In case non [A2AException] is thrown, it must be converted to [A2AInternalErrorException] with appropriate message. * * Server transport must convert [A2AException] to appropriate response data format (e.g. JSON error object), * preserving the [A2AException.errorCode] so that it can be properly handled by the [ClientTransport]. @@ -61,7 +61,7 @@ public interface RequestHandler { * * @throws A2AException if there is an error with processsing the request. */ - public suspend fun onSendMessageStreaming( + public fun onSendMessageStreaming( request: Request, ctx: ServerCallContext ): Flow> @@ -133,7 +133,7 @@ public interface RequestHandler { * @property headers Headers associated with the call. */ public class ServerCallContext( - public val headers: Map = emptyMap(), + public val headers: Map> = emptyMap(), ) { @Suppress("MissingKDocForPublicAPI") public companion object { diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/ResultUtils.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/ResultUtils.kt new file mode 100644 index 0000000000..cdf5d21dd7 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/ResultUtils.kt @@ -0,0 +1,18 @@ +package ai.koog.a2a.utils + +import kotlinx.coroutines.CancellationException + +// FIXME copied from agents-core module, because a2a does not depend on other Koog modules. +// Do we want to make a global utils module for cases like this? +/** + * Same as [runCatching], but does not catch [CancellationException], throwing it instead, making it safe to use with coroutines. + */ +public inline fun runCatchingCancellable(block: () -> R): Result { + return try { + Result.success(block()) + } catch (ce: CancellationException) { + throw ce + } catch (e: Exception) { + Result.failure(e) + } +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt index 067dc36c21..6e94c5a391 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt @@ -33,7 +33,7 @@ public class A2AServer : RequestHandler { TODO("Not yet implemented") } - override suspend fun onSendMessageStreaming( + override fun onSendMessageStreaming( request: Request, ctx: ServerCallContext ): Flow> { diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts index 36e5aa832a..89286cc1c0 100644 --- a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts +++ b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts @@ -18,14 +18,20 @@ kotlin { sourceSets { commonMain { dependencies { + api(project(":a2a:a2a-transport:a2a-transport-core-jsonrpc")) api(libs.kotlinx.serialization.json) api(libs.kotlinx.coroutines.core) + api(libs.ktor.client.core) + api(libs.ktor.client.content.negotiation) + api(libs.ktor.serialization.kotlinx.json) } } commonTest { dependencies { implementation(kotlin("test")) + implementation(libs.kotlinx.coroutines.test) + implementation(libs.ktor.client.mock) } } diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/.gitkeep b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransport.kt b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransport.kt new file mode 100644 index 0000000000..77856488f7 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransport.kt @@ -0,0 +1,101 @@ +package ai.koog.a2a.transport.client.jsonrpc.http + +import ai.koog.a2a.transport.ClientCallContext +import ai.koog.a2a.transport.jsonrpc.JSONRPCClientTransport +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCJson +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCRequest +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCResponse +import io.ktor.client.HttpClient +import io.ktor.client.call.body +import io.ktor.client.plugins.contentnegotiation.ContentNegotiation +import io.ktor.client.plugins.defaultRequest +import io.ktor.client.plugins.sse.SSE +import io.ktor.client.plugins.sse.sse +import io.ktor.client.request.accept +import io.ktor.client.request.headers +import io.ktor.client.request.post +import io.ktor.client.request.setBody +import io.ktor.http.ContentType +import io.ktor.http.HttpMethod +import io.ktor.http.contentType +import io.ktor.serialization.kotlinx.json.json +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.flow + +/** + * Implementation of a JSON-RPC client transport using HTTP as the underlying communication protocol. + * + * This transport sends JSON-RPC requests over HTTP and processes the responses. It also supports + * both standard requests and Server-Sent Events (SSE) for streaming responses. + * + * @param url The URL of the JSON-RPC server endpoint. + * @param baseHttpClient The base [HttpClient] instance, which will be configured internally. + */ +public class HttpJSONRPCClientTransport( + url: String, + baseHttpClient: HttpClient +) : JSONRPCClientTransport() { + private val httpClient: HttpClient = baseHttpClient.config { + defaultRequest { + url(url) + contentType(ContentType.Application.Json) + } + + install(ContentNegotiation) { + json(JSONRPCJson) + } + + install(SSE) + + expectSuccess = true + } + + override suspend fun request( + request: JSONRPCRequest, + ctx: ClientCallContext + ): JSONRPCResponse { + val response = httpClient.post { + headers { + ctx.additionalHeaders.forEach { (key, values) -> + appendAll(key, values) + } + } + + setBody(request) + } + + return response.body() + } + + override fun requestStreaming( + request: JSONRPCRequest, + ctx: ClientCallContext + ): Flow = flow { + httpClient.sse( + request = { + method = HttpMethod.Post + accept(ContentType.Text.EventStream) + + headers { + ctx.additionalHeaders.forEach { (key, values) -> + appendAll(key, values) + } + } + + setBody(request) + } + ) { + incoming.collect { event -> + requireNotNull(event.data) { "SSE data must not be null" } + .let { data -> + val response = JSONRPCJson.decodeFromString(data) + emit(response) + } + } + } + } + + override fun close() { + httpClient.close() + } +} diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportTest.kt b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportTest.kt new file mode 100644 index 0000000000..8018111014 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportTest.kt @@ -0,0 +1,440 @@ +package ai.koog.a2a.transport.client.jsonrpc.http + +import ai.koog.a2a.exceptions.A2AErrorCode +import ai.koog.a2a.exceptions.A2AInvalidParamsException +import ai.koog.a2a.model.AgentCapabilities +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.AgentSkill +import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.PushNotificationConfig +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskPushNotificationConfig +import ai.koog.a2a.model.TaskPushNotificationConfigParams +import ai.koog.a2a.model.TaskQueryParams +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatus +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.transport.ClientTransport +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.RequestId +import ai.koog.a2a.transport.Response +import ai.koog.a2a.transport.jsonrpc.A2AMethod +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCError +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCErrorResponse +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCJson +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCRequest +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCSuccessResponse +import io.ktor.client.HttpClient +import io.ktor.client.engine.mock.MockEngine +import io.ktor.client.engine.mock.respond +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpMethod +import io.ktor.http.content.TextContent +import io.ktor.http.headersOf +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.json.decodeFromJsonElement +import kotlinx.serialization.json.encodeToJsonElement +import kotlin.test.Ignore +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.fail + +class HttpJSONRPCClientTransportTest { + + private val json = JSONRPCJson + + private suspend inline fun testAPIMethod( + method: A2AMethod, + request: Request, + expectedResponse: Response, + noinline invoke: suspend ClientTransport.(Request) -> Response, + ) { + val mockEngine = MockEngine { receivedRequest -> + assertEquals(HttpMethod.Post, receivedRequest.method) + assertEquals(ContentType.Application.Json, receivedRequest.body.contentType) + + val requestBodyText = (receivedRequest.body as TextContent).text + val jsonRpcRequest = json.decodeFromString(requestBodyText) + + assertEquals(method.value, jsonRpcRequest.method) + assertEquals(request.id, jsonRpcRequest.id) + assertEquals(request.data, json.decodeFromJsonElement(jsonRpcRequest.params)) + + val jsonRpcResponse = JSONRPCSuccessResponse( + id = expectedResponse.id, + result = json.encodeToJsonElement(expectedResponse.data) + ) + + respond( + content = json.encodeToString(jsonRpcResponse), + headers = headersOf(HttpHeaders.ContentType, ContentType.Application.Json.toString()), + ) + } + + val httpClient = HttpClient(mockEngine) + val transport = HttpJSONRPCClientTransport("https://api.example.com/a2a", httpClient) + + val actualResponse = transport.invoke(request) + + assertEquals(expectedResponse.id, actualResponse.id) + assertEquals(expectedResponse.data, actualResponse.data) + + transport.close() + } + + @Test + fun testGetAuthenticatedExtendedAgentCard() = runTest { + val id = RequestId.StringId("test-1") + + val request = Request( + id = id, + data = null, + ) + + val expectedResponse = Response( + id = id, + data = AgentCard( + name = "Test Agent", + description = "A test agent", + url = "https://api.example.com/a2a", + version = "1.0.0", + capabilities = AgentCapabilities(), + defaultInputModes = listOf("text/plain"), + defaultOutputModes = listOf("text/plain"), + skills = listOf( + AgentSkill( + id = "test-skill", + name = "Test Skill", + description = "A test skill", + tags = listOf("test") + ) + ) + ) + ) + + testAPIMethod( + method = A2AMethod.GetAuthenticatedExtendedAgentCard, + request = request, + expectedResponse = expectedResponse, + invoke = { getAuthenticatedExtendedAgentCard(it) } + ) + } + + @Test + fun testSendMessage() = runTest { + val id = RequestId.StringId("test-2") + + val testMessage = Message( + role = Role.User, + parts = listOf(TextPart("Hello, agent!")), + taskId = "task-123" + ) + + val messageSendParams = MessageSendParams( + message = testMessage + ) + + val request = Request( + id = id, + data = messageSendParams, + ) + + val expectedResponse: Response = Response( + id = id, + data = Message( + messageId = "msg-456", + role = Role.Agent, + parts = listOf(TextPart("Hello, user! How can I help you?")), + taskId = "task-123" + ) + ) + + testAPIMethod( + method = A2AMethod.SendMessage, + request = request, + expectedResponse = expectedResponse, + invoke = { sendMessage(it) } + ) + } + + @Ignore + @Test + fun testSendMessageStreaming() = runTest { + // FIXME Can't test it, MockEngine doesn't support SSE capability + } + + @Test + fun testGetTask() = runTest { + val id = RequestId.StringId("test-3") + + val taskQueryParams = TaskQueryParams( + id = "task-123", + historyLength = 10 + ) + + val request = Request( + id = id, + data = taskQueryParams, + ) + + val expectedResponse = Response( + id = id, + data = Task( + id = "task-123", + contextId = "context-456", + status = TaskStatus( + state = TaskState.Working, + message = Message( + role = Role.Agent, + parts = listOf(TextPart("Working on your request...")) + ) + ), + history = listOf( + Message( + role = Role.User, + parts = listOf(TextPart("Hello, agent!")), + taskId = "task-123" + ) + ) + ) + ) + + testAPIMethod( + method = A2AMethod.GetTask, + request = request, + expectedResponse = expectedResponse, + invoke = { getTask(it) } + ) + } + + @Test + fun testCancelTask() = runTest { + val id = RequestId.StringId("test-4") + + val taskIdParams = TaskIdParams(id = "task-123") + + val request = Request( + id = id, + data = taskIdParams, + ) + + val expectedResponse = Response( + id = id, + data = Task( + id = "task-123", + contextId = "context-456", + status = TaskStatus( + state = TaskState.Canceled, + message = Message( + role = Role.Agent, + parts = listOf(TextPart("Task has been canceled.")) + ) + ) + ) + ) + + testAPIMethod( + method = A2AMethod.CancelTask, + request = request, + expectedResponse = expectedResponse, + invoke = { cancelTask(it) } + ) + } + + @Test + fun testSetTaskPushNotificationConfig() = runTest { + val id = RequestId.StringId("test-5") + + val pushNotificationConfig = TaskPushNotificationConfig( + taskId = "task-123", + pushNotificationConfig = PushNotificationConfig( + id = "notification-config-1", + url = "https://webhook.example.com/notifications", + token = "webhook-token-123" + ) + ) + + val request = Request( + id = id, + data = pushNotificationConfig, + ) + + val expectedResponse = Response( + id = id, + data = pushNotificationConfig + ) + + testAPIMethod( + method = A2AMethod.SetTaskPushNotificationConfig, + request = request, + expectedResponse = expectedResponse, + invoke = { setTaskPushNotificationConfig(it) } + ) + } + + @Test + fun testGetTaskPushNotificationConfig() = runTest { + val id = RequestId.StringId("test-6") + + val configParams = TaskPushNotificationConfigParams( + id = "task-123", + pushNotificationConfigId = "notification-config-1" + ) + + val request = Request( + id = id, + data = configParams, + ) + + val expectedResponse = Response( + id = id, + data = TaskPushNotificationConfig( + taskId = "task-123", + pushNotificationConfig = PushNotificationConfig( + id = "notification-config-1", + url = "https://webhook.example.com/notifications", + token = "webhook-token-123" + ) + ) + ) + + testAPIMethod( + method = A2AMethod.GetTaskPushNotificationConfig, + request = request, + expectedResponse = expectedResponse, + invoke = { getTaskPushNotificationConfig(it) } + ) + } + + @Test + fun testListTaskPushNotificationConfig() = runTest { + val id = RequestId.StringId("test-7") + + val taskIdParams = TaskIdParams(id = "task-123") + + val request = Request( + id = id, + data = taskIdParams, + ) + + val expectedResponse = Response( + id = id, + data = listOf( + TaskPushNotificationConfig( + taskId = "task-123", + pushNotificationConfig = PushNotificationConfig( + id = "notification-config-1", + url = "https://webhook.example.com/notifications", + token = "webhook-token-123" + ) + ), + TaskPushNotificationConfig( + taskId = "task-123", + pushNotificationConfig = PushNotificationConfig( + id = "notification-config-2", + url = "https://webhook2.example.com/notifications", + token = "webhook-token-456" + ) + ) + ) + ) + + testAPIMethod( + method = A2AMethod.ListTaskPushNotificationConfig, + request = request, + expectedResponse = expectedResponse, + invoke = { listTaskPushNotificationConfig(it) } + ) + } + + @Test + fun testDeleteTaskPushNotificationConfig() = runTest { + val id = RequestId.StringId("test-8") + + val configParams = TaskPushNotificationConfigParams( + id = "task-123", + pushNotificationConfigId = "notification-config-1" + ) + + val request = Request( + id = id, + data = configParams, + ) + + val expectedResponse = Response( + id = id, + data = Unit + ) + + testAPIMethod( + method = A2AMethod.DeleteTaskPushNotificationConfig, + request = request, + expectedResponse = expectedResponse, + invoke = { deleteTaskPushNotificationConfig(it) } + ) + } + + @Test + fun testSendMessageError() = runTest { + val id = RequestId.StringId("test-error-1") + + val testMessage = Message( + role = Role.User, + parts = listOf(TextPart("Hello, agent!")), + taskId = "invalid-task-id" + ) + + val messageSendParams = MessageSendParams( + message = testMessage + ) + + val request = Request( + id = id, + data = messageSendParams, + ) + + val mockEngine = MockEngine { receivedRequest -> + assertEquals(HttpMethod.Post, receivedRequest.method) + assertEquals(ContentType.Application.Json, receivedRequest.body.contentType) + + val requestBodyText = (receivedRequest.body as TextContent).text + val jsonRpcRequest = json.decodeFromString(requestBodyText) + + assertEquals(A2AMethod.SendMessage.value, jsonRpcRequest.method) + assertEquals(request.id, jsonRpcRequest.id) + assertEquals(request.data, json.decodeFromJsonElement(jsonRpcRequest.params)) + + val jsonRpcErrorResponse = JSONRPCErrorResponse( + id = id, + error = JSONRPCError( + code = A2AErrorCode.INVALID_PARAMS.value, + message = "Invalid method parameters", + data = json.encodeToJsonElement("The message parameters are invalid") + ) + ) + + respond( + content = json.encodeToString(jsonRpcErrorResponse), + headers = headersOf(HttpHeaders.ContentType, ContentType.Application.Json.toString()), + ) + } + + val httpClient = HttpClient(mockEngine) + val transport = HttpJSONRPCClientTransport("https://api.example.com/a2a", httpClient) + + try { + transport.sendMessage(request) + fail("Expected A2AInvalidParamsException to be thrown") + } catch (e: A2AInvalidParamsException) { + assertEquals("Invalid method parameters", e.message) + assertEquals(-32602, e.errorCode) + } + + transport.close() + } +} diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRCPServerTransport.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRCPServerTransport.kt deleted file mode 100644 index 0d4d30a3a5..0000000000 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRCPServerTransport.kt +++ /dev/null @@ -1,106 +0,0 @@ -package ai.koog.a2a.transport.jsonrpc - -import ai.koog.a2a.exceptions.A2AException -import ai.koog.a2a.exceptions.InvalidParamsException -import ai.koog.a2a.exceptions.MethodNotFoundException -import ai.koog.a2a.transport.Request -import ai.koog.a2a.transport.Response -import ai.koog.a2a.transport.ServerCallContext -import ai.koog.a2a.transport.ServerTransport -import ai.koog.a2a.transport.jsonrpc.model.JSONRPCJson -import ai.koog.a2a.transport.jsonrpc.model.JSONRPCRequest -import ai.koog.a2a.transport.jsonrpc.model.JSONRPCSuccessResponse -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.map -import kotlinx.serialization.SerializationException -import kotlinx.serialization.json.decodeFromJsonElement -import kotlinx.serialization.json.encodeToJsonElement - -/** - * Abstract transport implementation for JSON-RPC-based server communication. - * Handles receiving JSON-RPC requests, processing them, and sending responses. - */ -public abstract class JSONRCPServerTransport : ServerTransport { - /** - * Handles a JSON-RPC request and returns the corresponding response. - * - * @throws A2AException if there's an error processing the request. - */ - public suspend fun onRequest( - request: JSONRPCRequest, - ctx: ServerCallContext, - ): JSONRPCSuccessResponse { - return when (request.method) { - A2AMethod.GetAuthenticatedExtendedAgentCard.value -> - requestHandler.onGetAuthenticatedExtendedAgentCard(request.toRequest(), ctx).toJSONRPCResponse() - - A2AMethod.SendMessage.value -> - requestHandler.onSendMessage(request.toRequest(), ctx).toJSONRPCResponse() - - A2AMethod.GetTask.value -> - requestHandler.onGetTask(request.toRequest(), ctx).toJSONRPCResponse() - - A2AMethod.CancelTask.value -> - requestHandler.onCancelTask(request.toRequest(), ctx).toJSONRPCResponse() - - A2AMethod.SetTaskPushNotificationConfig.value -> - requestHandler.onSetTaskPushNotificationConfig(request.toRequest(), ctx).toJSONRPCResponse() - - A2AMethod.GetTaskPushNotificationConfig.value -> - requestHandler.onGetTaskPushNotificationConfig(request.toRequest(), ctx).toJSONRPCResponse() - - A2AMethod.ListTaskPushNotificationConfig.value -> - requestHandler.onListTaskPushNotificationConfig(request.toRequest(), ctx).toJSONRPCResponse() - - A2AMethod.DeleteTaskPushNotificationConfig.value -> - requestHandler.onDeleteTaskPushNotificationConfig(request.toRequest(), ctx).toJSONRPCResponse() - - else -> throw MethodNotFoundException(request.method) - } - } - - /** - * Handles a JSON-RPC request and returns the corresponding response stream. - * - * @throws A2AException if there's an error processing the request. - */ - public suspend fun onRequestStreaming( - request: JSONRPCRequest, - ctx: ServerCallContext, - ): Flow { - return when (request.method) { - A2AMethod.SendMessageStreaming.value -> - requestHandler.onSendMessageStreaming(request.toRequest(), ctx).map { it.toJSONRPCResponse() } - - else -> throw MethodNotFoundException(request.method) - } - } - - /** - * Convert generic [JSONRPCRequest] to [Request]. - * - * @throws InvalidParamsException if request params cannot be parsed to [T]. - */ - protected inline fun JSONRPCRequest.toRequest(): Request { - val data = try { - JSONRPCJson.decodeFromJsonElement(params) - } catch (_: SerializationException) { - throw InvalidParamsException("Cannot parse request params to ${T::class}") - } - - return Request( - id = id, - data = data - ) - } - - /** - * Convert generic [Response] to [JSONRPCSuccessResponse]. - */ - protected inline fun Response.toJSONRPCResponse(): JSONRPCSuccessResponse { - return JSONRPCSuccessResponse( - id = id, - result = JSONRPCJson.encodeToJsonElement(data) - ) - } -} diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt index 3fb00feb2f..a0ef2867fc 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt @@ -33,20 +33,16 @@ import kotlinx.serialization.json.encodeToJsonElement public abstract class JSONRPCClientTransport : ClientTransport { /** * Sends a JSON-RPC request and returns the corresponding response. - * - * @throws A2AException if server returned an error. */ - public abstract suspend fun request( + protected abstract suspend fun request( request: JSONRPCRequest, ctx: ClientCallContext, ): JSONRPCResponse /** * Sends a JSON-RPC request and returns the corresponding response stream. - * - * @throws A2AException if server returned an error. */ - public abstract suspend fun requestStreaming( + protected abstract fun requestStreaming( request: JSONRPCRequest, ctx: ClientCallContext, ): Flow @@ -101,7 +97,7 @@ public abstract class JSONRPCClientTransport : ClientTransport { /** * Generic streaming request processing. */ - protected suspend inline fun requestStreaming( + protected inline fun requestStreaming( method: A2AMethod, request: Request, ctx: ClientCallContext @@ -124,7 +120,7 @@ public abstract class JSONRPCClientTransport : ClientTransport { ): Response = request(A2AMethod.SendMessage, request, ctx) - override suspend fun sendMessageStreaming( + override fun sendMessageStreaming( request: Request, ctx: ClientCallContext ): Flow> = diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt new file mode 100644 index 0000000000..be15b76549 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt @@ -0,0 +1,140 @@ +package ai.koog.a2a.transport.jsonrpc + +import ai.koog.a2a.exceptions.A2AException +import ai.koog.a2a.exceptions.A2AInternalErrorException +import ai.koog.a2a.exceptions.A2AInvalidParamsException +import ai.koog.a2a.exceptions.A2AMethodNotFoundException +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.RequestId +import ai.koog.a2a.transport.Response +import ai.koog.a2a.transport.ServerCallContext +import ai.koog.a2a.transport.ServerTransport +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCError +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCErrorResponse +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCJson +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCRequest +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCResponse +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCSuccessResponse +import ai.koog.a2a.utils.runCatchingCancellable +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.catch +import kotlinx.coroutines.flow.map +import kotlinx.serialization.SerializationException +import kotlinx.serialization.json.decodeFromJsonElement +import kotlinx.serialization.json.encodeToJsonElement + +/** + * Abstract transport implementation for JSON-RPC-based server communication. + * Handles receiving JSON-RPC requests, processing them, and sending responses. + */ +public abstract class JSONRPCServerTransport : ServerTransport { + /** + * Handles a JSON-RPC request and returns the corresponding response + * Handles exceptions, mapping all non [A2AException]s to [A2AInternalErrorException], and then converting them to [JSONRPCErrorResponse]. + */ + protected suspend fun onRequest( + request: JSONRPCRequest, + ctx: ServerCallContext, + ): JSONRPCResponse { + return runCatchingCancellable { + when (request.method) { + A2AMethod.GetAuthenticatedExtendedAgentCard.value -> + requestHandler.onGetAuthenticatedExtendedAgentCard(request.toRequest(), ctx).toJSONRPCSuccessResponse() + + A2AMethod.SendMessage.value -> + requestHandler.onSendMessage(request.toRequest(), ctx).toJSONRPCSuccessResponse() + + A2AMethod.GetTask.value -> + requestHandler.onGetTask(request.toRequest(), ctx).toJSONRPCSuccessResponse() + + A2AMethod.CancelTask.value -> + requestHandler.onCancelTask(request.toRequest(), ctx).toJSONRPCSuccessResponse() + + A2AMethod.SetTaskPushNotificationConfig.value -> + requestHandler.onSetTaskPushNotificationConfig(request.toRequest(), ctx).toJSONRPCSuccessResponse() + + A2AMethod.GetTaskPushNotificationConfig.value -> + requestHandler.onGetTaskPushNotificationConfig(request.toRequest(), ctx).toJSONRPCSuccessResponse() + + A2AMethod.ListTaskPushNotificationConfig.value -> + requestHandler.onListTaskPushNotificationConfig(request.toRequest(), ctx).toJSONRPCSuccessResponse() + + A2AMethod.DeleteTaskPushNotificationConfig.value -> + requestHandler.onDeleteTaskPushNotificationConfig(request.toRequest(), ctx).toJSONRPCSuccessResponse() + + else -> throw A2AMethodNotFoundException(request.method) + } + }.getOrElse { it.toJSONRPCErrorResponse(request.id) } + } + + /** + * Handles a JSON-RPC request and returns the corresponding response stream. + * Handles exceptions, mapping all non [A2AException]s to [A2AInternalErrorException], and then converting them to [JSONRPCErrorResponse]. + * Terminates the flow after the first exception. + */ + protected fun onRequestStreaming( + request: JSONRPCRequest, + ctx: ServerCallContext, + ): Flow { + return when (request.method) { + A2AMethod.SendMessageStreaming.value -> + requestHandler + .onSendMessageStreaming(request.toRequest(), ctx) + .map { it.toJSONRPCSuccessResponse() as JSONRPCResponse } + .catch { emit(it.toJSONRPCErrorResponse(request.id)) } + + else -> throw A2AMethodNotFoundException(request.method) + } + } + + /** + * Convert generic [JSONRPCRequest] to [Request]. + * + * @throws A2AInvalidParamsException if request params cannot be parsed to [T]. + */ + protected inline fun JSONRPCRequest.toRequest(): Request { + val data = try { + JSONRPCJson.decodeFromJsonElement(params) + } catch (_: SerializationException) { + throw A2AInvalidParamsException("Cannot parse request params to ${T::class}") + } + + return Request( + id = id, + data = data + ) + } + + /** + * Convert generic [Response] to [JSONRPCSuccessResponse]. + */ + protected inline fun Response.toJSONRPCSuccessResponse(): JSONRPCSuccessResponse { + return JSONRPCSuccessResponse( + id = id, + result = JSONRPCJson.encodeToJsonElement(data) + ) + } + + /** + * Handles exceptions, mapping all non [A2AException]s to [A2AInternalErrorException], and then converting them to [JSONRPCErrorResponse]. + */ + protected fun Throwable.toJSONRPCErrorResponse(requestId: RequestId? = null): JSONRPCErrorResponse { + val a2aException: A2AException = when (this) { + is A2AException -> this + is Exception -> A2AInternalErrorException("Internal error: ${this.message}") + else -> throw this // Non-exception throwable shouldn't be handled, rethrowing it + } + + return JSONRPCErrorResponse( + id = requestId, + error = a2aException.toJSONRPCError() + ) + } + + protected fun A2AException.toJSONRPCError(): JSONRPCError { + return JSONRPCError( + code = errorCode, + message = message + ) + } +} diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts index 36e5aa832a..4268bba174 100644 --- a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts @@ -18,11 +18,21 @@ kotlin { sourceSets { commonMain { dependencies { + api(project(":a2a:a2a-transport:a2a-transport-core-jsonrpc")) api(libs.kotlinx.serialization.json) api(libs.kotlinx.coroutines.core) } } + jvmMain { + dependencies { + api(libs.ktor.server.core) + api(libs.ktor.server.sse) + api(libs.ktor.server.content.negotiation) + api(libs.ktor.serialization.kotlinx.json) + } + } + commonTest { dependencies { implementation(kotlin("test")) diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/.gitkeep b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt new file mode 100644 index 0000000000..8256b1d3d8 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt @@ -0,0 +1,108 @@ +package ai.koog.a2a.transport.server.jsonrpc.http + +import ai.koog.a2a.exceptions.A2AInvalidRequestException +import ai.koog.a2a.exceptions.A2AParseException +import ai.koog.a2a.transport.RequestHandler +import ai.koog.a2a.transport.ServerCallContext +import ai.koog.a2a.transport.jsonrpc.JSONRPCServerTransport +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCJson +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCRequest +import ai.koog.a2a.utils.runCatchingCancellable +import io.ktor.serialization.kotlinx.json.json +import io.ktor.server.application.ApplicationCall +import io.ktor.server.application.pluginOrNull +import io.ktor.server.plugins.contentnegotiation.ContentNegotiation +import io.ktor.server.request.receive +import io.ktor.server.routing.Route +import io.ktor.server.routing.application +import io.ktor.server.routing.post +import io.ktor.server.routing.route +import io.ktor.server.sse.SSE +import io.ktor.server.sse.send +import io.ktor.server.sse.sse +import io.ktor.util.toMap +import kotlinx.serialization.SerializationException +import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.decodeFromJsonElement +import kotlinx.serialization.serializer + +/** + * Implements JSON-RPC server transport over HTTP, supporting standard JSON-RPC + * request/response and server-sent events for streaming responses. + * + * The transport integrates with Ktor routing to define and manage routes for + * JSON-RPC endpoints. It ensures compliance with the A2A specification for + * error handling and JSON-RPC request processing. + * + * @property requestHandler The handler responsible for processing JSON-RPC requests + * received by the transport. + */ +public class HttpJSONRPCServerTransport( + override val requestHandler: RequestHandler, +) : JSONRPCServerTransport() { + + /** + * Routes for handling JSON-RPC HTTP requests. + * Follows A2A specification in error handling. + * + * @param path JSON-RPC endpoint path. + */ + public fun Route.transportRoutes(path: String): Route = route(path) { + if (application.pluginOrNull(SSE) == null) { + throw IllegalStateException("SSE plugin must be installed in the application to add these routes.") + } + + install(ContentNegotiation) { + json(JSONRPCJson) + } + + post { + runCatchingCancellable { + onRequest( + request = call.receiveJSONRPCRequest(), + ctx = call.toServerCallContext() + ) + }.getOrElse { it.toJSONRPCErrorResponse() } + } + + sse( + serialize = { typeInfo, it -> + val kType = typeInfo.kotlinType ?: throw IllegalArgumentException("Null KType for value: $it") + val serializer = JSONRPCJson.serializersModule.serializer(kType) + JSONRPCJson.encodeToString(serializer, it) + } + ) { + runCatchingCancellable { + onRequestStreaming( + request = call.receiveJSONRPCRequest(), + ctx = call.toServerCallContext() + ).collect { response -> send(response) } + }.getOrElse { + send(it.toJSONRPCErrorResponse()) + } + } + } + + /** + * Converts raw request body to [JSONRPCRequest], following A2A specification for error handling. + */ + private suspend fun ApplicationCall.receiveJSONRPCRequest(): JSONRPCRequest { + val jsonBody = try { + receive() + } catch (e: SerializationException) { + throw A2AParseException("Cannot parse request body to JSON:\n${e.message}") + } + + return try { + JSONRPCJson.decodeFromJsonElement(jsonBody) + } catch (e: SerializationException) { + throw A2AInvalidRequestException("Cannot parse request params to JSON-RPC request:\n${e.message}") + } + } + + private fun ApplicationCall.toServerCallContext(): ServerCallContext { + return ServerCallContext( + headers = request.headers.toMap() + ) + } +} diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 920b3717a1..c5c31538df 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -69,6 +69,7 @@ ktor-server-core = { module = "io.ktor:ktor-server-core", version.ref = "ktor3" ktor-server-cio = { module = "io.ktor:ktor-server-cio", version.ref = "ktor3" } ktor-server-netty = { module = "io.ktor:ktor-server-netty-jvm", version.ref = "ktor3" } ktor-server-sse = { module = "io.ktor:ktor-server-sse", version.ref = "ktor3" } +ktor-server-content-negotiation = { module = "io.ktor:ktor-server-content-negotiation", version.ref = "ktor3" } lettuce-core = { module = "io.lettuce:lettuce-core", version.ref = "lettuce" } logback-classic = { module = "ch.qos.logback:logback-classic", version.ref = "logback" } oshai-kotlin-logging = { module = "io.github.oshai:kotlin-logging", version.ref = "oshai-logging" } From 52a46cfead263f505a274760a4ec8c1e51c31f02 Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Thu, 4 Sep 2025 03:21:17 +0200 Subject: [PATCH 20/52] [a2a] Implement HTTP JSON-RPC server transport with full protocol support --- .gitignore | 1 + .../ai/koog/a2a/exceptions/Exceptions.kt | 79 +-- .../ai/koog/a2a/transport/ClientTransport.kt | 12 +- .../ai/koog/a2a/transport/ServerTransport.kt | 12 +- .../kotlin/ai/koog/a2a/server/A2AServer.kt | 9 +- .../http/HttpJSONRPCClientTransportTest.kt | 12 +- .../koog/a2a/transport/jsonrpc/A2AMethod.kt | 1 + .../jsonrpc/JSONRPCClientTransport.kt | 8 +- .../jsonrpc/JSONRPCServerTransport.kt | 28 +- .../build.gradle.kts | 8 +- .../http/HttpJSONRPCServerTransport.kt | 143 +++- .../http/HttpJSONRPCServerTransportTest.kt | 631 ++++++++++++++++++ gradle/libs.versions.toml | 2 + 13 files changed, 872 insertions(+), 74 deletions(-) create mode 100644 a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt diff --git a/.gitignore b/.gitignore index e24ef2425c..da43f86bfa 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ local.properties docs/src/main/kotlin/*.kt **/.env .venv +.DS_Store diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt index 40aec678a7..10dec7caae 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt @@ -1,21 +1,22 @@ package ai.koog.a2a.exceptions /** - * Enum containing all A2A error codes. - */ -public enum class A2AErrorCode(public val value: Int) { - PARSE_ERROR(-32700), - INVALID_REQUEST(-32600), - METHOD_NOT_FOUND(-32601), - INVALID_PARAMS(-32602), - INTERNAL_ERROR(-32603), - TASK_NOT_FOUND(-32001), - TASK_NOT_CANCELABLE(-32002), - PUSH_NOTIFICATION_NOT_SUPPORTED(-32003), - UNSUPPORTED_OPERATION(-32004), - CONTENT_TYPE_NOT_SUPPORTED(-32005), - INVALID_AGENT_RESPONSE(-32006), - AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED(-32007) + * Object containing all A2A error codes. + */ +@Suppress("MissingKDocForPublicAPI") +public object A2AErrorCodes { + public const val PARSE_ERROR: Int = -32700 + public const val INVALID_REQUEST: Int = -32600 + public const val METHOD_NOT_FOUND: Int = -32601 + public const val INVALID_PARAMS: Int = -32602 + public const val INTERNAL_ERROR: Int = -32603 + public const val TASK_NOT_FOUND: Int = -32001 + public const val TASK_NOT_CANCELABLE: Int = -32002 + public const val PUSH_NOTIFICATION_NOT_SUPPORTED: Int = -32003 + public const val UNSUPPORTED_OPERATION: Int = -32004 + public const val CONTENT_TYPE_NOT_SUPPORTED: Int = -32005 + public const val INVALID_AGENT_RESPONSE: Int = -32006 + public const val AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED: Int = -32007 } /** @@ -31,35 +32,35 @@ public sealed class A2AException( */ public class A2AParseException( message: String = "Invalid JSON payload", -) : A2AException(message, errorCode = -32700) +) : A2AException(message, errorCode = A2AErrorCodes.PARSE_ERROR) /** * The JSON payload was valid JSON, but not a valid JSON-RPC Request object. */ public class A2AInvalidRequestException( message: String = "Invalid JSON-RPC Request", -) : A2AException(message, errorCode = -32600) +) : A2AException(message, errorCode = A2AErrorCodes.INVALID_REQUEST) /** * The requested A2A RPC method does not exist or is not supported. */ public class A2AMethodNotFoundException( message: String = "Method not found", -) : A2AException(message, errorCode = -32601) +) : A2AException(message, errorCode = A2AErrorCodes.METHOD_NOT_FOUND) /** * The params provided for the method are invalid. */ public class A2AInvalidParamsException( message: String = "Invalid method parameters", -) : A2AException(message, errorCode = -32602) +) : A2AException(message, errorCode = A2AErrorCodes.INVALID_PARAMS) /** * An unexpected error occurred on the server during processing. */ public class A2AInternalErrorException( message: String = "Internal server error", -) : A2AException(message, errorCode = -32603) +) : A2AException(message, errorCode = A2AErrorCodes.INTERNAL_ERROR) /** * Reserved for implementation-defined server exceptions. A2A-specific exceptions use this range. @@ -79,7 +80,7 @@ public sealed class A2AServerException( */ public class A2ATaskNotFoundException( message: String = "Task not found", -) : A2AServerException(message, errorCode = -32001) +) : A2AServerException(message, errorCode = A2AErrorCodes.TASK_NOT_FOUND) /** * An attempt was made to cancel a task that is not in a cancelable state. @@ -87,7 +88,7 @@ public class A2ATaskNotFoundException( */ public class A2ATaskNotCancelableException( message: String = "Task cannot be canceled", -) : A2AServerException(message, errorCode = -32002) +) : A2AServerException(message, errorCode = A2AErrorCodes.TASK_NOT_CANCELABLE) /** * Client attempted to use push notification features but the server agent does not support them. @@ -95,7 +96,7 @@ public class A2ATaskNotCancelableException( */ public class A2APushNotificationNotSupportedException( message: String = "Push Notification is not supported", -) : A2AServerException(message, errorCode = -32003) +) : A2AServerException(message, errorCode = A2AErrorCodes.PUSH_NOTIFICATION_NOT_SUPPORTED) /** * The requested operation or a specific aspect of it is not supported by this server agent implementation. @@ -103,7 +104,7 @@ public class A2APushNotificationNotSupportedException( */ public class A2AUnsupportedOperationException( message: String = "This operation is not supported", -) : A2AServerException(message, errorCode = -32004) +) : A2AServerException(message, errorCode = A2AErrorCodes.UNSUPPORTED_OPERATION) /** * A Media Type provided in the request's message.parts or implied for an artifact is not supported @@ -111,21 +112,21 @@ public class A2AUnsupportedOperationException( */ public class A2AContentTypeNotSupportedException( message: String = "Incompatible content types", -) : A2AServerException(message, errorCode = -32005) +) : A2AServerException(message, errorCode = A2AErrorCodes.CONTENT_TYPE_NOT_SUPPORTED) /** * Agent generated an invalid response for the requested method. */ public class A2AInvalidAgentResponseException( message: String = "Invalid agent response type", -) : A2AServerException(message, errorCode = -32006) +) : A2AServerException(message, errorCode = A2AErrorCodes.INVALID_AGENT_RESPONSE) /** * The agent does not have an Authenticated Extended Card configured. */ public class A2AAuthenticatedExtendedCardNotConfiguredException( message: String = "Authenticated Extended Card not configured", -) : A2AServerException(message, errorCode = -32007) +) : A2AServerException(message, errorCode = A2AErrorCodes.AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED) /** * Server returned some unknown error code. @@ -143,18 +144,18 @@ public fun createA2AException( errorCode: Int, ): A2AException { return when (errorCode) { - A2AErrorCode.PARSE_ERROR.value -> A2AParseException(message) - A2AErrorCode.INVALID_REQUEST.value -> A2AInvalidRequestException(message) - A2AErrorCode.METHOD_NOT_FOUND.value -> A2AMethodNotFoundException(message) - A2AErrorCode.INVALID_PARAMS.value -> A2AInvalidParamsException(message) - A2AErrorCode.INTERNAL_ERROR.value -> A2AInternalErrorException(message) - A2AErrorCode.TASK_NOT_FOUND.value -> A2ATaskNotFoundException(message) - A2AErrorCode.TASK_NOT_CANCELABLE.value -> A2ATaskNotCancelableException(message) - A2AErrorCode.PUSH_NOTIFICATION_NOT_SUPPORTED.value -> A2APushNotificationNotSupportedException(message) - A2AErrorCode.UNSUPPORTED_OPERATION.value -> A2AUnsupportedOperationException(message) - A2AErrorCode.CONTENT_TYPE_NOT_SUPPORTED.value -> A2AContentTypeNotSupportedException(message) - A2AErrorCode.INVALID_AGENT_RESPONSE.value -> A2AInvalidAgentResponseException(message) - A2AErrorCode.AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED.value -> A2AAuthenticatedExtendedCardNotConfiguredException(message) + A2AErrorCodes.PARSE_ERROR -> A2AParseException(message) + A2AErrorCodes.INVALID_REQUEST -> A2AInvalidRequestException(message) + A2AErrorCodes.METHOD_NOT_FOUND -> A2AMethodNotFoundException(message) + A2AErrorCodes.INVALID_PARAMS -> A2AInvalidParamsException(message) + A2AErrorCodes.INTERNAL_ERROR -> A2AInternalErrorException(message) + A2AErrorCodes.TASK_NOT_FOUND -> A2ATaskNotFoundException(message) + A2AErrorCodes.TASK_NOT_CANCELABLE -> A2ATaskNotCancelableException(message) + A2AErrorCodes.PUSH_NOTIFICATION_NOT_SUPPORTED -> A2APushNotificationNotSupportedException(message) + A2AErrorCodes.UNSUPPORTED_OPERATION -> A2AUnsupportedOperationException(message) + A2AErrorCodes.CONTENT_TYPE_NOT_SUPPORTED -> A2AContentTypeNotSupportedException(message) + A2AErrorCodes.INVALID_AGENT_RESPONSE -> A2AInvalidAgentResponseException(message) + A2AErrorCodes.AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED -> A2AAuthenticatedExtendedCardNotConfiguredException(message) else -> A2AUnknownException(message, errorCode) } } diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt index 339346be60..e5ab4ed312 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt @@ -75,6 +75,16 @@ public interface ClientTransport : AutoCloseable { ctx: ClientCallContext = ClientCallContext.Default ): Response + /** + * Implements [tasks/resubscribe](https://a2a-protocol.org/latest/specification/#79-tasksresubscribe) + * + * @throws A2AException if server returned an error. + */ + public fun resubscribeTask( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Flow> + /** * Implements [tasks/pushNotificationConfig/set](https://a2a-protocol.org/latest/specification/#75-taskspushnotificationconfigset) * @@ -113,7 +123,7 @@ public interface ClientTransport : AutoCloseable { public suspend fun deleteTaskPushNotificationConfig( request: Request, ctx: ClientCallContext = ClientCallContext.Default - ): Response + ): Response } /** diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt index 0e31272b33..fb09ac6a61 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt @@ -76,6 +76,16 @@ public interface RequestHandler { ctx: ServerCallContext ): Response + /** + * Handles [tasks/resubscribe](https://a2a-protocol.org/latest/specification/#79-tasksresubscribe) + * + * @throws A2AException if there is an error with processsing the request. + */ + public fun onResubscribeTask( + request: Request, + ctx: ServerCallContext + ): Flow> + /** * Handles [tasks/cancel](https://a2a-protocol.org/latest/specification/#74-taskscancel) * @@ -124,7 +134,7 @@ public interface RequestHandler { public suspend fun onDeleteTaskPushNotificationConfig( request: Request, ctx: ServerCallContext - ): Response + ): Response } /** diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt index 6e94c5a391..4c901c109d 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt @@ -54,6 +54,13 @@ public class A2AServer : RequestHandler { TODO("Not yet implemented") } + override fun onResubscribeTask( + request: Request, + ctx: ServerCallContext + ): Flow> { + TODO("Not yet implemented") + } + override suspend fun onSetTaskPushNotificationConfig( request: Request, ctx: ServerCallContext @@ -78,7 +85,7 @@ public class A2AServer : RequestHandler { override suspend fun onDeleteTaskPushNotificationConfig( request: Request, ctx: ServerCallContext - ): Response { + ): Response { TODO("Not yet implemented") } } diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportTest.kt b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportTest.kt index 8018111014..f274142806 100644 --- a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportTest.kt +++ b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportTest.kt @@ -1,6 +1,6 @@ package ai.koog.a2a.transport.client.jsonrpc.http -import ai.koog.a2a.exceptions.A2AErrorCode +import ai.koog.a2a.exceptions.A2AErrorCodes import ai.koog.a2a.exceptions.A2AInvalidParamsException import ai.koog.a2a.model.AgentCapabilities import ai.koog.a2a.model.AgentCard @@ -246,6 +246,12 @@ class HttpJSONRPCClientTransportTest { ) } + @Ignore + @Test + fun testResubscribeTask() = runTest { + // FIXME Can't test it, MockEngine doesn't support SSE capability + } + @Test fun testSetTaskPushNotificationConfig() = runTest { val id = RequestId.StringId("test-5") @@ -368,7 +374,7 @@ class HttpJSONRPCClientTransportTest { val expectedResponse = Response( id = id, - data = Unit + data = null ) testAPIMethod( @@ -412,7 +418,7 @@ class HttpJSONRPCClientTransportTest { val jsonRpcErrorResponse = JSONRPCErrorResponse( id = id, error = JSONRPCError( - code = A2AErrorCode.INVALID_PARAMS.value, + code = A2AErrorCodes.INVALID_PARAMS, message = "Invalid method parameters", data = json.encodeToJsonElement("The message parameters are invalid") ) diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/A2AMethod.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/A2AMethod.kt index 0d0519abd8..26f9e3ef28 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/A2AMethod.kt +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/A2AMethod.kt @@ -9,6 +9,7 @@ public enum class A2AMethod(public val value: String) { SendMessageStreaming("message/stream"), GetTask("tasks/get"), CancelTask("tasks/cancel"), + ResubscribeTask("tasks/resubscribe"), SetTaskPushNotificationConfig("tasks/pushNotificationConfig/set"), GetTaskPushNotificationConfig("tasks/pushNotificationConfig/get"), ListTaskPushNotificationConfig("tasks/pushNotificationConfig/list"), diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt index a0ef2867fc..dfc7f78904 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt @@ -138,6 +138,12 @@ public abstract class JSONRPCClientTransport : ClientTransport { ): Response = request(A2AMethod.CancelTask, request, ctx) + override fun resubscribeTask( + request: Request, + ctx: ClientCallContext + ): Flow> = + requestStreaming(A2AMethod.ResubscribeTask, request, ctx) + override suspend fun setTaskPushNotificationConfig( request: Request, ctx: ClientCallContext @@ -159,6 +165,6 @@ public abstract class JSONRPCClientTransport : ClientTransport { override suspend fun deleteTaskPushNotificationConfig( request: Request, ctx: ClientCallContext - ): Response = + ): Response = request(A2AMethod.DeleteTaskPushNotificationConfig, request, ctx) } diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt index be15b76549..181f0b3dbf 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt @@ -18,6 +18,7 @@ import ai.koog.a2a.transport.jsonrpc.model.JSONRPCSuccessResponse import ai.koog.a2a.utils.runCatchingCancellable import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.catch +import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map import kotlinx.serialization.SerializationException import kotlinx.serialization.json.decodeFromJsonElement @@ -39,7 +40,8 @@ public abstract class JSONRPCServerTransport : ServerTransport { return runCatchingCancellable { when (request.method) { A2AMethod.GetAuthenticatedExtendedAgentCard.value -> - requestHandler.onGetAuthenticatedExtendedAgentCard(request.toRequest(), ctx).toJSONRPCSuccessResponse() + requestHandler.onGetAuthenticatedExtendedAgentCard(request.toRequest(), ctx) + .toJSONRPCSuccessResponse() A2AMethod.SendMessage.value -> requestHandler.onSendMessage(request.toRequest(), ctx).toJSONRPCSuccessResponse() @@ -60,9 +62,11 @@ public abstract class JSONRPCServerTransport : ServerTransport { requestHandler.onListTaskPushNotificationConfig(request.toRequest(), ctx).toJSONRPCSuccessResponse() A2AMethod.DeleteTaskPushNotificationConfig.value -> - requestHandler.onDeleteTaskPushNotificationConfig(request.toRequest(), ctx).toJSONRPCSuccessResponse() + requestHandler.onDeleteTaskPushNotificationConfig(request.toRequest(), ctx) + .toJSONRPCSuccessResponse() - else -> throw A2AMethodNotFoundException(request.method) + else -> + throw A2AMethodNotFoundException(request.method) } }.getOrElse { it.toJSONRPCErrorResponse(request.id) } } @@ -78,13 +82,15 @@ public abstract class JSONRPCServerTransport : ServerTransport { ): Flow { return when (request.method) { A2AMethod.SendMessageStreaming.value -> - requestHandler - .onSendMessageStreaming(request.toRequest(), ctx) - .map { it.toJSONRPCSuccessResponse() as JSONRPCResponse } - .catch { emit(it.toJSONRPCErrorResponse(request.id)) } + requestHandler.onSendMessageStreaming(request.toRequest(), ctx) - else -> throw A2AMethodNotFoundException(request.method) - } + A2AMethod.ResubscribeTask.value -> + requestHandler.onResubscribeTask(request.toRequest(), ctx) + + else -> + flow { throw A2AMethodNotFoundException(request.method) } + }.map { it.toJSONRPCSuccessResponse() as JSONRPCResponse } + .catch { emit(it.toJSONRPCErrorResponse(request.id)) } } /** @@ -95,8 +101,8 @@ public abstract class JSONRPCServerTransport : ServerTransport { protected inline fun JSONRPCRequest.toRequest(): Request { val data = try { JSONRPCJson.decodeFromJsonElement(params) - } catch (_: SerializationException) { - throw A2AInvalidParamsException("Cannot parse request params to ${T::class}") + } catch (e: SerializationException) { + throw A2AInvalidParamsException("Cannot parse request params:\n${e.message}") } return Request( diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts index 4268bba174..2bcaef7fc4 100644 --- a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts @@ -27,9 +27,10 @@ kotlin { jvmMain { dependencies { api(libs.ktor.server.core) - api(libs.ktor.server.sse) - api(libs.ktor.server.content.negotiation) - api(libs.ktor.serialization.kotlinx.json) + implementation(libs.ktor.server.sse) + implementation(libs.ktor.server.content.negotiation) + implementation(libs.ktor.serialization.kotlinx.json) + implementation(libs.ktor.server.netty) } } @@ -42,6 +43,7 @@ kotlin { jvmTest { dependencies { implementation(kotlin("test-junit5")) + implementation(libs.ktor.server.test.host) } } diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt index 8256b1d3d8..297b063eec 100644 --- a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt @@ -10,44 +10,152 @@ import ai.koog.a2a.transport.jsonrpc.model.JSONRPCRequest import ai.koog.a2a.utils.runCatchingCancellable import io.ktor.serialization.kotlinx.json.json import io.ktor.server.application.ApplicationCall +import io.ktor.server.application.install import io.ktor.server.application.pluginOrNull +import io.ktor.server.engine.EmbeddedServer +import io.ktor.server.engine.embeddedServer +import io.ktor.server.netty.Netty import io.ktor.server.plugins.contentnegotiation.ContentNegotiation -import io.ktor.server.request.receive +import io.ktor.server.request.receiveText +import io.ktor.server.response.respond import io.ktor.server.routing.Route import io.ktor.server.routing.application import io.ktor.server.routing.post import io.ktor.server.routing.route +import io.ktor.server.routing.routing import io.ktor.server.sse.SSE import io.ktor.server.sse.send import io.ktor.server.sse.sse import io.ktor.util.toMap +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import kotlinx.serialization.SerializationException -import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.decodeFromJsonElement import kotlinx.serialization.serializer /** - * Implements JSON-RPC server transport over HTTP, supporting standard JSON-RPC - * request/response and server-sent events for streaming responses. + * Implements A2A JSON-RPC server transport over HTTP using Ktor server + * It ensures compliance with the A2A specification for error handling and JSON-RPC request processing. + * This transport can be used either as a standalone server or integrated into an existing Ktor application. * - * The transport integrates with Ktor routing to define and manage routes for - * JSON-RPC endpoints. It ensures compliance with the A2A specification for - * error handling and JSON-RPC request processing. + * Example usage as a standalone server: + * ```kotlin + * val transport = HttpJSONRPCServerTransport( + * requestHandler = A2AServer(...) + * ) * - * @property requestHandler The handler responsible for processing JSON-RPC requests - * received by the transport. + * transport.start(port = 8080, path = "/my-agent") + * transport.stop() + * ``` + * + * Example usage as an integration into an existing Ktor server. + * Can also be used to integrate multiple A2A server transports on the same server, to serve multiple A2A agents: + * ```kotlin + * val agentOneTransport = HttpJSONRPCServerTransport( + * requestHandler = A2AServer(...) + * ) + * val agentTwoTransport = HttpJSONRPCServerTransport( + * requestHandler = A2AServer(...) + * ) + * + * embeddedServer(Netty, port = 8080) { + * install(SSE) + * + * // Other configurations... + * + * routing { + * // Other routes... + * + * route("/a2a") { + * agentOneTransport.transportRoutes(this, "/agent-1") + * agentTwoTransport.transportRoutes(this, "/agent-2") + * } + * } + * }.startSuspend(wait = true) + * ``` + * + * @property requestHandler The handler responsible for processing A2A requests received by the transport. */ public class HttpJSONRPCServerTransport( override val requestHandler: RequestHandler, ) : JSONRPCServerTransport() { + /** + * Current running server instance if this transport is used as a standalone server. + */ + private var server: EmbeddedServer<*, *>? = null + private var serverMutex = Mutex() + + /** + * Starts Ktor embedded server with Netty engine to handle A2A JSON-RPC requests, using the specified port and endpoint path. + * Can be used to start a standalone server for quick prototyping or when no integration into the existing server is required. + * The routing consists only of [transportRoutes]. + * + * If you need to integrate A2A request handling logic into existing Ktor application, + * use [transportRoutes] method to mount the transport routes into existing [Route] configuration block. + * + * @param port The port on which the server will listen. + * @param path The JSON-RPC endpoint path to handle incoming requests. + * + * @throws IllegalStateException if the server is already running. + * + * @see [transportRoutes] + */ + public suspend fun start(port: Int, path: String): Unit = serverMutex.withLock { + check(server == null) { "Server is already configured and running. Stop it before starting a new one." } + + embeddedServer(Netty, port) { + install(SSE) + + routing { + transportRoutes(this, path) + } + }.startSuspend(wait = true) + } + + /** + * Stops the server gracefully within the specified time limits. + * + * @param gracePeriodMillis The time in milliseconds to allow ongoing requests to finish gracefully before shutting down. + * @param timeoutMillis The maximum time in milliseconds to wait for the server to stop. + * + * @throws IllegalStateException if the server is not configured or running. + */ + public suspend fun stop(gracePeriodMillis: Long = 1000, timeoutMillis: Long = 2000): Unit = serverMutex.withLock { + check(server != null) { "Server is not configured or running." } + + server?.stopSuspend(gracePeriodMillis, timeoutMillis) + server = null + } + /** * Routes for handling JSON-RPC HTTP requests. * Follows A2A specification in error handling. + * Allows mounting A2A requests handling into an existing Ktor server application. + * This can also be used to mount multiple A2A server transports on the same server, to serve multiple A2A agents. + * + * Example usage: + * ```kotlin + * embeddedServer(Netty, port = 8080) { + * install(SSE) + * + * // Other configurations... * - * @param path JSON-RPC endpoint path. + * routing { + * // Other routes... + * + * route("/a2a") { + * agentOneTransport.transportRoutes(this, "/agent-1") + * agentTwoTransport.transportRoutes(this, "/agent-2") + * } + * } + * }.startSuspend(wait = true) + * ``` + * + * @param route The base route to which the transport routes should be mounted. + * @param path JSON-RPC endpoint path that will be mounted under the base [route]. */ - public fun Route.transportRoutes(path: String): Route = route(path) { + public fun transportRoutes(route: Route, path: String): Route = route.route(path) { if (application.pluginOrNull(SSE) == null) { throw IllegalStateException("SSE plugin must be installed in the application to add these routes.") } @@ -56,15 +164,19 @@ public class HttpJSONRPCServerTransport( json(JSONRPCJson) } + // Regular JSON-RPC requests post { - runCatchingCancellable { + val response = runCatchingCancellable { onRequest( request = call.receiveJSONRPCRequest(), ctx = call.toServerCallContext() ) }.getOrElse { it.toJSONRPCErrorResponse() } + + call.respond(response) } + // Streaming JSON-RPC requests sse( serialize = { typeInfo, it -> val kType = typeInfo.kotlinType ?: throw IllegalArgumentException("Null KType for value: $it") @@ -76,7 +188,9 @@ public class HttpJSONRPCServerTransport( onRequestStreaming( request = call.receiveJSONRPCRequest(), ctx = call.toServerCallContext() - ).collect { response -> send(response) } + ).collect { response -> + send(response) + } }.getOrElse { send(it.toJSONRPCErrorResponse()) } @@ -88,7 +202,8 @@ public class HttpJSONRPCServerTransport( */ private suspend fun ApplicationCall.receiveJSONRPCRequest(): JSONRPCRequest { val jsonBody = try { - receive() + val rawBody = receiveText() + JSONRPCJson.parseToJsonElement(rawBody) } catch (e: SerializationException) { throw A2AParseException("Cannot parse request body to JSON:\n${e.message}") } diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt new file mode 100644 index 0000000000..e6e03a9210 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt @@ -0,0 +1,631 @@ +package ai.koog.a2a.transport.server.jsonrpc.http + +import ai.koog.a2a.exceptions.A2AErrorCodes +import ai.koog.a2a.model.AgentCapabilities +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.AgentSkill +import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.PushNotificationConfig +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskPushNotificationConfig +import ai.koog.a2a.model.TaskPushNotificationConfigParams +import ai.koog.a2a.model.TaskQueryParams +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatus +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.model.UpdateEvent +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.RequestHandler +import ai.koog.a2a.transport.RequestId +import ai.koog.a2a.transport.Response +import ai.koog.a2a.transport.ServerCallContext +import ai.koog.a2a.transport.jsonrpc.A2AMethod +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCErrorResponse +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCJson +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCRequest +import ai.koog.a2a.transport.jsonrpc.model.JSONRPCSuccessResponse +import io.ktor.client.plugins.sse.sse +import io.ktor.client.request.post +import io.ktor.client.request.setBody +import io.ktor.client.statement.bodyAsText +import io.ktor.http.ContentType +import io.ktor.http.HttpStatusCode +import io.ktor.http.contentType +import io.ktor.server.sse.SSE +import io.ktor.server.testing.testApplication +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.asFlow +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.decodeFromJsonElement +import kotlinx.serialization.json.encodeToJsonElement +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNull +import io.ktor.client.plugins.sse.SSE as SSEClient + +class HttpJSONRPCServerTransportTest { + private object MockRequestHandler : RequestHandler { + val agentCard = AgentCard( + name = "Test Agent", + description = "A test agent", + url = "https://api.example.com/a2a", + version = "1.0.0", + capabilities = AgentCapabilities(), + defaultInputModes = listOf("text/plain"), + defaultOutputModes = listOf("text/plain"), + skills = listOf( + AgentSkill( + id = "test-skill", + name = "Test Skill", + description = "A test skill", + tags = listOf("test") + ) + ) + ) + + val communicationEvent = Message( + messageId = "message-1", + role = Role.Agent, + parts = listOf(TextPart("Response message.")), + taskId = "task-1" + ) + + val updateEvents = listOf( + Message( + messageId = "message-stream-1", + role = Role.Agent, + parts = listOf(TextPart("Streaming response part 1")), + taskId = "task-1" + ), + Message( + messageId = "message-stream-2", + role = Role.Agent, + parts = listOf(TextPart("Streaming response part 2")), + taskId = "task-1" + ) + ) + + val taskGet = Task( + id = "task-1", + contextId = "test-context-1", + status = TaskStatus( + state = TaskState.Working + ) + ) + + val taskCancel = Task( + id = "task-1", + contextId = "test-context-1", + status = TaskStatus( + state = TaskState.Canceled + ) + ) + + val taskPushNotificationConfig = TaskPushNotificationConfig( + taskId = "task-1", + pushNotificationConfig = PushNotificationConfig( + id = "notification-config-1", + url = "https://webhook.example.com", + token = "webhook-token-123" + ) + ) + + val taskPushNotificationConfigList = listOf(taskPushNotificationConfig) + + override suspend fun onGetAuthenticatedExtendedAgentCard( + request: Request, + ctx: ServerCallContext + ): Response { + return Response( + id = request.id, + data = agentCard + ) + } + + override suspend fun onSendMessage( + request: Request, + ctx: ServerCallContext + ): Response { + return Response( + id = request.id, + data = communicationEvent + ) + } + + override fun onSendMessageStreaming( + request: Request, + ctx: ServerCallContext + ): Flow> { + return updateEvents + .asFlow() + .map { + Response( + id = request.id, + data = it + ) + } + } + + override suspend fun onGetTask( + request: Request, + ctx: ServerCallContext + ): Response { + return Response( + id = request.id, + data = taskGet + ) + } + + override suspend fun onCancelTask( + request: Request, + ctx: ServerCallContext + ): Response { + return Response( + id = request.id, + data = taskCancel + ) + } + + override fun onResubscribeTask( + request: Request, + ctx: ServerCallContext + ): Flow> { + return updateEvents + .asFlow() + .map { + Response( + id = request.id, + data = it + ) + } + } + + override suspend fun onSetTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response { + return Response( + id = request.id, + data = request.data + ) + } + + override suspend fun onGetTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response { + return Response( + id = request.id, + data = taskPushNotificationConfig + ) + } + + override suspend fun onListTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response> { + return Response( + id = request.id, + data = taskPushNotificationConfigList + ) + } + + override suspend fun onDeleteTaskPushNotificationConfig( + request: Request, + ctx: ServerCallContext + ): Response { + return Response( + id = request.id, + data = null + ) + } + } + + private val json = JSONRPCJson + + private inline fun testServerMethod( + method: A2AMethod, + request: Request, + expectedResponse: Response, + ) { + testApplication { + install(SSE) + + val transport = HttpJSONRPCServerTransport(MockRequestHandler) + + routing { + transport.transportRoutes(this, "/a2a") + } + + val jsonRpcRequest = JSONRPCRequest( + id = request.id, + method = method.value, + params = json.encodeToJsonElement(request.data) + ) + + val response = client.post("/a2a") { + contentType(ContentType.Application.Json) + setBody(json.encodeToString(jsonRpcRequest)) + } + + assertEquals(HttpStatusCode.OK, response.status) + + val jsonRpcResponse = json.decodeFromString(response.bodyAsText()) + val actualResponse = Response( + id = jsonRpcResponse.id, + data = json.decodeFromJsonElement(jsonRpcResponse.result) + ) + + assertEquals(expectedResponse.id, actualResponse.id) + assertEquals(expectedResponse.data, actualResponse.data) + } + } + + private inline fun testServerMethodStreaming( + method: A2AMethod, + request: Request, + expectedResponses: List>, + ) { + testApplication { + install(SSE) + + val client = createClient { + install(SSEClient) + } + + val transport = HttpJSONRPCServerTransport(MockRequestHandler) + + routing { + transport.transportRoutes(this, "/a2a") + } + + val jsonRpcRequest = JSONRPCRequest( + id = request.id, + method = method.value, + params = json.encodeToJsonElement(request.data) + ) + + val jsonrpcResponses = buildList { + client.sse( + urlString = "/a2a", + request = { + contentType(ContentType.Application.Json) + setBody(json.encodeToString(jsonRpcRequest)) + }, + ) { + assertEquals(HttpStatusCode.OK, call.response.status) + + incoming + .map { event -> JSONRPCJson.decodeFromString(event.data!!) } + .collect { add(it) } + } + } + + val actualResponses = jsonrpcResponses.map { + Response( + id = it.id, + data = json.decodeFromJsonElement(it.result) + ) + } + + assertEquals(expectedResponses.map { it.id }, actualResponses.map { it.id }) + assertEquals(expectedResponses.map { it.data }, actualResponses.map { it.data }) + } + } + + @Test + fun testGetAuthenticatedExtendedAgentCard() = runTest { + val requestId = RequestId.StringId("test-1") + + val request = Request( + id = requestId, + data = null, + ) + + val expectedResponse = Response( + id = requestId, + data = MockRequestHandler.agentCard, + ) + + testServerMethod( + method = A2AMethod.GetAuthenticatedExtendedAgentCard, + request = request, + expectedResponse = expectedResponse, + ) + } + + @Test + fun testSendMessage() = runTest { + val requestId = RequestId.StringId("test-2") + + val messageSendParams = MessageSendParams( + message = Message( + messageId = "msg-1", + role = Role.User, + parts = listOf(TextPart("Hello, agent!")), + taskId = "task-1" + ) + ) + + val request = Request( + id = requestId, + data = messageSendParams, + ) + + val expectedResponse = Response( + id = requestId, + data = MockRequestHandler.communicationEvent, + ) + + testServerMethod( + method = A2AMethod.SendMessage, + request = request, + expectedResponse = expectedResponse, + ) + } + + @Test + fun testSendMessageStreaming() = runTest { + val requestId = RequestId.StringId("test-2") + + val messageSendParams = MessageSendParams( + message = Message( + messageId = "msg-1", + role = Role.User, + parts = listOf(TextPart("Hello, agent!")), + taskId = "task-1" + ) + ) + + val request = Request( + id = requestId, + data = messageSendParams, + ) + + val expectedResponses = MockRequestHandler.updateEvents.map { + Response( + id = requestId, + data = it, + ) + } + + testServerMethodStreaming( + method = A2AMethod.SendMessageStreaming, + request = request, + expectedResponses = expectedResponses, + ) + } + + @Test + fun testGetTask() = runTest { + val requestId = RequestId.StringId("test-3") + val taskQueryParams = TaskQueryParams(id = "task-1") + + val request = Request( + id = requestId, + data = taskQueryParams, + ) + + val expectedResponse = Response( + id = requestId, + data = MockRequestHandler.taskGet, + ) + + testServerMethod( + method = A2AMethod.GetTask, + request = request, + expectedResponse = expectedResponse, + ) + } + + @Test + fun testCancelTask() = runTest { + val requestId = RequestId.StringId("test-4") + val taskIdParams = TaskIdParams(id = "task-1") + + val request = Request( + id = requestId, + data = taskIdParams, + ) + + val expectedResponse = Response( + id = requestId, + data = MockRequestHandler.taskCancel, + ) + + testServerMethod( + method = A2AMethod.CancelTask, + request = request, + expectedResponse = expectedResponse, + ) + } + + @Test + fun testResubscribeTask() = runTest { + val requestId = RequestId.StringId("test-7") + val taskIdParams = TaskIdParams(id = "task-1") + + val request = Request( + id = requestId, + data = taskIdParams, + ) + + val expectedResponses = MockRequestHandler.updateEvents.map { + Response( + id = requestId, + data = it, + ) + } + + testServerMethodStreaming( + method = A2AMethod.ResubscribeTask, + request = request, + expectedResponses = expectedResponses, + ) + } + + @Test + fun testSetTaskPushNotificationConfig() = runTest { + val requestId = RequestId.StringId("test-5") + + val pushNotificationConfig = TaskPushNotificationConfig( + taskId = "task-123", + pushNotificationConfig = PushNotificationConfig( + id = "notification-config-1", + url = "https://webhook.example.com/notifications", + token = "webhook-token-123" + ) + ) + + val request = Request( + id = requestId, + data = pushNotificationConfig, + ) + + val expectedResponse = Response( + id = requestId, + data = pushNotificationConfig, + ) + + testServerMethod( + method = A2AMethod.SetTaskPushNotificationConfig, + request = request, + expectedResponse = expectedResponse, + ) + } + + @Test + fun testGetTaskPushNotificationConfig() = runTest { + val requestId = RequestId.StringId("test-6") + + val configParams = TaskPushNotificationConfigParams( + id = "task-123", + pushNotificationConfigId = "notification-config-1" + ) + + val request = Request( + id = requestId, + data = configParams, + ) + + val expectedResponse = Response( + id = requestId, + data = MockRequestHandler.taskPushNotificationConfig, + ) + + testServerMethod( + method = A2AMethod.GetTaskPushNotificationConfig, + request = request, + expectedResponse = expectedResponse, + ) + } + + @Test + fun testListTaskPushNotificationConfig() = runTest { + val requestId = RequestId.StringId("test-7") + val taskIdParams = TaskIdParams(id = "task-1") + + val request = Request( + id = requestId, + data = taskIdParams, + ) + + val expectedResponse = Response( + id = requestId, + data = MockRequestHandler.taskPushNotificationConfigList, + ) + + testServerMethod( + method = A2AMethod.ListTaskPushNotificationConfig, + request = request, + expectedResponse = expectedResponse, + ) + } + + @Test + fun testDeleteTaskPushNotificationConfig() = runTest { + val requestId = RequestId.StringId("test-8") + + val configParams = TaskPushNotificationConfigParams( + id = "task-123", + pushNotificationConfigId = "notification-config-1" + ) + + val request = Request( + id = requestId, + data = configParams, + ) + + val expectedResponse = Response( + id = requestId, + data = null, + ) + + testServerMethod( + method = A2AMethod.DeleteTaskPushNotificationConfig, + request = request, + expectedResponse = expectedResponse, + ) + } + + @Test + fun testMethodNotFound() = runTest { + testApplication { + install(SSE) + + val transport = HttpJSONRPCServerTransport(MockRequestHandler) + + routing { + transport.transportRoutes(this, "/a2a") + } + + val requestId = RequestId.StringId("test-9") + val jsonRpcRequest = JSONRPCRequest( + id = requestId, + method = "unknown.method", + params = JsonNull + ) + + val response = client.post("/a2a") { + contentType(ContentType.Application.Json) + setBody(json.encodeToString(jsonRpcRequest)) + } + + assertEquals(HttpStatusCode.OK, response.status) + + val jsonRpcResponse = json.decodeFromString(response.bodyAsText()) + assertEquals(requestId, jsonRpcResponse.id) + assertEquals(A2AErrorCodes.METHOD_NOT_FOUND, jsonRpcResponse.error.code) + } + } + + @Test + fun testInvalidJsonRequest() = runTest { + testApplication { + install(SSE) + + val transport = HttpJSONRPCServerTransport(MockRequestHandler) + + routing { + transport.transportRoutes(this, "/a2a") + } + + val response = client.post("/a2a") { + contentType(ContentType.Application.Json) + setBody("invalid json") + } + + assertEquals(HttpStatusCode.OK, response.status) + + val jsonRpcResponse = json.decodeFromString(response.bodyAsText()) + assertNull(jsonRpcResponse.id) + assertEquals(A2AErrorCodes.PARSE_ERROR, jsonRpcResponse.error.code) + } + } +} diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index c5c31538df..380698c201 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -70,6 +70,8 @@ ktor-server-cio = { module = "io.ktor:ktor-server-cio", version.ref = "ktor3" } ktor-server-netty = { module = "io.ktor:ktor-server-netty-jvm", version.ref = "ktor3" } ktor-server-sse = { module = "io.ktor:ktor-server-sse", version.ref = "ktor3" } ktor-server-content-negotiation = { module = "io.ktor:ktor-server-content-negotiation", version.ref = "ktor3" } +ktor-server-netty = { module = "io.ktor:ktor-server-netty", version.ref = "ktor3" } +ktor-server-test-host = { module = "io.ktor:ktor-server-test-host", version.ref = "ktor3" } lettuce-core = { module = "io.lettuce:lettuce-core", version.ref = "lettuce" } logback-classic = { module = "ch.qos.logback:logback-classic", version.ref = "logback" } oshai-kotlin-logging = { module = "io.github.oshai:kotlin-logging", version.ref = "oshai-logging" } From 05a902fdef3db0327906589fe474c0cada70d5e0 Mon Sep 17 00:00:00 2001 From: Konstantin Pavlov <1517853+kpavlov@users.noreply.github.com> Date: Tue, 9 Sep 2025 12:06:05 +0300 Subject: [PATCH 21/52] Add integration tests for HttpJSONRPCClientTransport and update dependencies in build.gradle.kts --- .../build.gradle.kts | 3 + .../HttpJSONRPCClientTransportMokksyTest.kt | 202 ++++++++++++++++++ 2 files changed, 205 insertions(+) create mode 100644 a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportMokksyTest.kt diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts index 89286cc1c0..e6748d7b03 100644 --- a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts +++ b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts @@ -38,6 +38,9 @@ kotlin { jvmTest { dependencies { implementation(kotlin("test-junit5")) + implementation("me.kpavlov.aimocks:ai-mocks-a2a-jvm:0.5.0-Alpha1") + implementation(libs.ktor.client.cio) + runtimeOnly(libs.slf4j.simple) } } diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportMokksyTest.kt b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportMokksyTest.kt new file mode 100644 index 0000000000..e256992210 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportMokksyTest.kt @@ -0,0 +1,202 @@ +package ai.koog.a2a.transport.client.jsonrpc.http + +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskQueryParams +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.transport.ClientCallContext +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.RequestId +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.matchers.nulls.shouldNotBeNull +import io.kotest.matchers.shouldBe +import io.ktor.client.HttpClient +import kotlinx.coroutines.test.runTest +import kotlinx.serialization.MissingFieldException +import me.kpavlov.aimocks.a2a.MockAgentServer +import me.kpavlov.aimocks.a2a.model.AgentCard +import me.kpavlov.aimocks.a2a.model.Task +import me.kpavlov.aimocks.a2a.model.TaskStatus +import me.kpavlov.aimocks.a2a.model.create +import kotlin.test.Ignore +import kotlin.test.Test +import kotlin.time.Duration.Companion.milliseconds + +class HttpJSONRPCClientTransportMokksyTest { + + val a2aServer = MockAgentServer(verbose = true) + + val client = HttpJSONRPCClientTransport(a2aServer.baseUrl(), HttpClient()) + + @Test + @Ignore() + // todo: implement client.getCard(...) + fun `Should get Card`() = runTest { + val agentCard = AgentCard.create { + name = "test-agent" + description = "test-agent-description" + url = a2aServer.baseUrl() + documentationUrl = "https://example.com/documentation" + version = "0.0.1" + provider { + organization = "Acme, Inc." + url = "https://example.com/organization" + } + authentication { + schemes = listOf("none", "bearer") + credentials = "test-token" + } + capabilities { + streaming = true + pushNotifications = true + stateTransitionHistory = true + } + skills += skill { + id = "walk" + name = "Walk the walk" + } + skills += skill { + id = "talk" + name = "Talk the talk" + } + } + + // Configure the mock server to respond with the AgentCard + a2aServer.agentCard() responds { + delay = 1.milliseconds + card = agentCard + } + + TODO("client.getCard(...)") + } + + @Test + fun `Should sendMessage`() = runTest { + a2aServer.sendMessage() responds { + id = "req_1234" + result = Task( + id = "tid_12345", + sessionId = "de38c76d-d54c-436c-8b9f-4c2703648d64", + contextId = "ctx_12345", + status = TaskStatus("submitted") + ) + } + + val response = client.sendMessage( + request = Request( + id = RequestId.StringId("req_1234"), + data = MessageSendParams( + message = Message( + role = Role.User, + parts = listOf( + TextPart("Tell me a joke") + ) + ) + ) + ), + ctx = ClientCallContext() + ) + + response shouldNotBeNull { + id shouldBe RequestId.StringId("req_1234") + (data as? ai.koog.a2a.model.Task) shouldNotBeNull { + id shouldBe "tid_12345" + contextId shouldBe "ctx_12345" + status shouldBe ai.koog.a2a.model.TaskStatus(TaskState.Submitted) + } + } + } + + @Test + fun `Should getTask`() = runTest { + a2aServer.getTask() responds { + id = 1 + result = Task( + id = "tid_12345", + contextId = "ctx_12345", + sessionId = "de38c76d-d54c-436c-8b9f-4c2703648d64", + status = TaskStatus("canceled") + ) + } + + val response = client.getTask( + request = Request( + id = RequestId.StringId("req_1234"), + data = TaskQueryParams(id = "tid_12345") + ), + ctx = ClientCallContext() + ) + + response shouldNotBeNull { + id shouldBe RequestId.NumberId(1) + data shouldNotBeNull { + id shouldBe "tid_12345" + contextId shouldBe "ctx_12345" + status shouldBe ai.koog.a2a.model.TaskStatus(TaskState.Canceled) + } + } + } + + @Test + fun `Should handle getTask with missing contextId`() = runTest { + a2aServer.getTask() responds { + id = 1 + result { + id = "tid_12345" + sessionId = "de38c76d-d54c-436c-8b9f-4c2703648d64" + status { + state = "completed" + } + artifacts += artifact { + name = "joke" + parts += textPart { + text = "This is a joke" + } + } + } + } + + shouldThrow { + client.getTask( + request = Request( + id = RequestId.StringId("req_1234"), + data = TaskQueryParams(id = "tid_12345") + ), + ctx = ClientCallContext() + ) + }.missingFields shouldBe listOf("contextId") + } + + @Test + fun `Should cancelTask`() = runTest { + a2aServer.cancelTask() responds { + id = "req_123" + result = Task( + id = "tid_12345", + contextId = "ctx_12345", + sessionId = "de38c76d-d54c-436c-8b9f-4c2703648d64", + status = TaskStatus("canceled") + ) + } + + val response = client.cancelTask( + request = Request( + id = RequestId.StringId("req_1233"), + data = TaskIdParams("tid_12345") + ), + ctx = ClientCallContext() + ) + + response shouldNotBeNull { + id shouldBe RequestId.StringId("req_123") + data shouldNotBeNull { + id shouldBe "tid_12345" + contextId shouldBe "ctx_12345" + status shouldBe ai.koog.a2a.model.TaskStatus(TaskState.Canceled) + } + } + } +} From 9f073d8d35477c7511d8d6f6df3a2e7e7304ad90 Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Mon, 29 Sep 2025 05:16:27 +0200 Subject: [PATCH 22/52] [a2a] Implement A2AClient foundation with AgentCardResolver --- a2a/a2a-client/build.gradle.kts | 10 ++ .../kotlin/ai/koog/a2a/client/A2AClient.kt | 32 ++++- .../ai/koog/a2a/client/AgentCardResolver.kt | 54 ++++++++ .../ai/koog/a2a/client/A2AClientMokksyTest.kt | 115 ++++++++++++++++++ .../ai/koog/a2a/annotations/InternalA2AApi.kt | 10 ++ .../kotlin/ai/koog/a2a/model/AgentCard.kt | 2 +- .../kotlin/ai/koog/a2a/utils/RWLock.kt | 47 +++++++ .../build.gradle.kts | 3 +- .../HttpJSONRPCClientTransportMokksyTest.kt | 47 ------- .../http/HttpJSONRPCServerTransportTest.kt | 4 + gradle/libs.versions.toml | 3 +- 11 files changed, 275 insertions(+), 52 deletions(-) create mode 100644 a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/AgentCardResolver.kt create mode 100644 a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientMokksyTest.kt create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/annotations/InternalA2AApi.kt create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/RWLock.kt diff --git a/a2a/a2a-client/build.gradle.kts b/a2a/a2a-client/build.gradle.kts index ca41b397e3..44297855e1 100644 --- a/a2a/a2a-client/build.gradle.kts +++ b/a2a/a2a-client/build.gradle.kts @@ -21,18 +21,28 @@ kotlin { api(project(":a2a:a2a-core")) api(libs.kotlinx.serialization.json) api(libs.kotlinx.coroutines.core) + api(libs.ktor.client.core) + api(libs.ktor.client.content.negotiation) + api(libs.ktor.serialization.kotlinx.json) } } commonTest { dependencies { implementation(kotlin("test")) + implementation(libs.kotlinx.coroutines.test) } } jvmTest { dependencies { implementation(kotlin("test-junit5")) + implementation(project(":a2a:a2a-transport:a2a-transport-client-jsonrpc-http")) + + implementation(libs.mokksy.a2a) + implementation(libs.ktor.client.cio) + implementation(libs.ktor.client.logging) + runtimeOnly(libs.slf4j.simple) } } diff --git a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt index 82989d097c..b6d092af2f 100644 --- a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt +++ b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt @@ -1,10 +1,38 @@ +@file:Suppress("MissingKDocForPublicAPI") + package ai.koog.a2a.client +import ai.koog.a2a.model.AgentCard import ai.koog.a2a.transport.ClientTransport +import ai.koog.a2a.utils.RWLock +import kotlin.concurrent.Volatile /** * A2A client responsible for sending requests to A2A server. */ -public class A2AClient( +public open class A2AClient( private val transport: ClientTransport, -) + private val agentCardResolver: AgentCardResolver, +) { + @Volatile + public lateinit var agentCard: AgentCard + private set + + private val cardLock = RWLock() + + public suspend fun connect(): AgentCard = cardLock.withWriteLock { + agentCard = agentCardResolver.resolve() + agentCard + } +} + +public fun A2AClient( + transport: ClientTransport, + baseUrl: String, + cardPath: String = UrlAgentCardResolver.wellKnownPath, +): A2AClient { + return A2AClient( + transport = transport, + agentCardResolver = UrlAgentCardResolver(baseUrl, cardPath), + ) +} diff --git a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/AgentCardResolver.kt b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/AgentCardResolver.kt new file mode 100644 index 0000000000..c0ced94b8f --- /dev/null +++ b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/AgentCardResolver.kt @@ -0,0 +1,54 @@ +@file:Suppress("MissingKDocForPublicAPI") + +package ai.koog.a2a.client + +import ai.koog.a2a.model.AgentCard +import io.ktor.client.HttpClient +import io.ktor.client.call.body +import io.ktor.client.plugins.contentnegotiation.ContentNegotiation +import io.ktor.client.plugins.defaultRequest +import io.ktor.client.request.get +import io.ktor.http.ContentType +import io.ktor.http.contentType +import io.ktor.serialization.kotlinx.json.json +import kotlinx.serialization.json.Json + +public interface AgentCardResolver { + public suspend fun resolve(): AgentCard +} + +public class ExplicitAgentCardResolver(public val agentCard: AgentCard) : AgentCardResolver { + override suspend fun resolve(): AgentCard = agentCard +} + +public class UrlAgentCardResolver( + public val baseUrl: String, + public val path: String = wellKnownPath, + baseHttpClient: HttpClient = HttpClient(), +) : AgentCardResolver { + public companion object { + @Suppress("ConstPropertyName") + public const val wellKnownPath: String = "/.well-known/agent-card.json" + } + + private val httpClient: HttpClient = baseHttpClient.config { + defaultRequest { + url(baseUrl) + contentType(ContentType.Application.Json) + } + + install(ContentNegotiation) { + json( + Json { + ignoreUnknownKeys = true + } + ) + } + + expectSuccess = true + } + + override suspend fun resolve(): AgentCard { + return httpClient.get(path).body() + } +} diff --git a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientMokksyTest.kt b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientMokksyTest.kt new file mode 100644 index 0000000000..30b75e8866 --- /dev/null +++ b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientMokksyTest.kt @@ -0,0 +1,115 @@ +package ai.koog.a2a.client + +import ai.koog.a2a.transport.client.jsonrpc.http.HttpJSONRPCClientTransport +import io.kotest.matchers.collections.shouldHaveSize +import io.kotest.matchers.nulls.shouldNotBeNull +import io.kotest.matchers.shouldBe +import io.ktor.client.HttpClient +import io.ktor.client.plugins.logging.LogLevel +import io.ktor.client.plugins.logging.Logging +import kotlinx.coroutines.test.runTest +import me.kpavlov.aimocks.a2a.MockAgentServer +import me.kpavlov.aimocks.a2a.model.AgentCard +import me.kpavlov.aimocks.a2a.model.create +import kotlin.test.Test +import kotlin.time.Duration.Companion.milliseconds + +class A2AClientMokksyTest { + val a2aServer = MockAgentServer(verbose = false) + + private val httpClient = HttpClient { + install(Logging) { + this.level = LogLevel.BODY + } + } + + private val transport = HttpJSONRPCClientTransport( + url = a2aServer.baseUrl(), + baseHttpClient = httpClient + ) + + val client = A2AClient( + transport = transport, + agentCardResolver = UrlAgentCardResolver( + baseUrl = a2aServer.baseUrl(), + baseHttpClient = httpClient, + ), + ) + + @Test + fun `Should get Card`() = runTest { + // given + val agentCard = AgentCard.create { + name = "test-agent" + description = "test-agent-description" + url = a2aServer.baseUrl() + documentationUrl = "https://example.com/documentation" + version = "1.0.1" + security = listOf( + mapOf("oauth" to listOf("read")), + mapOf("api-key" to listOf("mtls")), + ) + provider { + organization = "Acme, Inc." + url = "https://example.com/organization" + } + capabilities { + streaming = true + pushNotifications = true + stateTransitionHistory = true + } + skills += skill { + id = "walk" + name = "Walk the walk" + description = "Walk the walk description" + tags = listOf("walk", "tag") + } + skills += skill { + id = "talk" + name = "Talk the talk" + description = "Talk the talk description" + tags = listOf("walk", "tag") + } + } + + // Configure the mock server to respond with the AgentCard + a2aServer.agentCard() responds { + delay = 1.milliseconds + card = agentCard + } + + // when + val actualAgentCard = client.connect() + + // then + actualAgentCard shouldNotBeNull { + name shouldBe agentCard.name + description shouldBe agentCard.description + url shouldBe agentCard.url + documentationUrl shouldBe agentCard.documentationUrl + version shouldBe agentCard.version + security shouldBe agentCard.security + + provider shouldNotBeNull { + organization shouldBe agentCard.provider?.organization + url shouldBe agentCard.provider?.url + } + + capabilities shouldNotBeNull { + streaming shouldBe agentCard.capabilities.streaming + pushNotifications shouldBe agentCard.capabilities.pushNotifications + stateTransitionHistory shouldBe agentCard.capabilities.stateTransitionHistory + } + + skills shouldHaveSize agentCard.skills.size + skills.zip(agentCard.skills).forEach { (actualSkill, skill) -> + actualSkill shouldNotBeNull { + id shouldBe skill.id + name shouldBe skill.name + description shouldBe skill.description + tags shouldBe skill.tags + } + } + } + } +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/annotations/InternalA2AApi.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/annotations/InternalA2AApi.kt new file mode 100644 index 0000000000..d8aee42afd --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/annotations/InternalA2AApi.kt @@ -0,0 +1,10 @@ +package ai.koog.a2a.annotations + +/** + * Marks an API as internal to the a2a module. This annotation indicates that the + * marked API is not intended for public use and is subject to change or removal + * without prior notice. It should be used with caution and only within the intended + * internal scope. + */ +@RequiresOptIn("This API is internal in a2a and should not be used. It could be removed or changed without notice.") +public annotation class InternalA2AApi diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt index bd41a84b4a..6235b29abe 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt @@ -93,7 +93,7 @@ public data class AgentCard( public val defaultOutputModes: List, public val skills: List, @EncodeDefault - public val supportsAuthenticatedExtendedCard: Boolean = false, + public val supportsAuthenticatedExtendedCard: Boolean? = false, public val signatures: List? = null ) { init { diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/RWLock.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/RWLock.kt new file mode 100644 index 0000000000..25c3813ce8 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/RWLock.kt @@ -0,0 +1,47 @@ +package ai.koog.a2a.utils + +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock + +// FIXME copied from agents-core module, because a2a does not depend on other Koog modules. +// Do we want to make a global utils module for cases like this? +/** + * A KMP read-write lock implementation that allows concurrent read access but ensures exclusive write access. + * + * This implementation uses `kotlinx.coroutines.sync.Mutex` to coordinate access for both readers and writers. + */ +public class RWLock { + private val writeMutex = Mutex() + private var readersCount = 0 + private val readersCountMutex = Mutex() + + /** + * Run the given [block] of code while holding the read lock. + */ + public suspend fun withReadLock(block: suspend () -> T): T { + readersCountMutex.withLock { + if (++readersCount == 1) { + writeMutex.lock() + } + } + + return try { + block() + } finally { + readersCountMutex.withLock { + if (--readersCount == 0) { + writeMutex.unlock() + } + } + } + } + + /** + * Run the given [block] of code while holding the write lock. + */ + public suspend fun withWriteLock(block: suspend () -> T): T { + writeMutex.withLock { + return block() + } + } +} diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts index e6748d7b03..dc383847a5 100644 --- a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts +++ b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts @@ -24,6 +24,7 @@ kotlin { api(libs.ktor.client.core) api(libs.ktor.client.content.negotiation) api(libs.ktor.serialization.kotlinx.json) + implementation(libs.oshai.kotlin.logging) } } @@ -38,7 +39,7 @@ kotlin { jvmTest { dependencies { implementation(kotlin("test-junit5")) - implementation("me.kpavlov.aimocks:ai-mocks-a2a-jvm:0.5.0-Alpha1") + implementation(libs.mokksy.a2a) implementation(libs.ktor.client.cio) runtimeOnly(libs.slf4j.simple) } diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportMokksyTest.kt b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportMokksyTest.kt index e256992210..6ef91f3dba 100644 --- a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportMokksyTest.kt +++ b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportMokksyTest.kt @@ -17,62 +17,15 @@ import io.ktor.client.HttpClient import kotlinx.coroutines.test.runTest import kotlinx.serialization.MissingFieldException import me.kpavlov.aimocks.a2a.MockAgentServer -import me.kpavlov.aimocks.a2a.model.AgentCard import me.kpavlov.aimocks.a2a.model.Task import me.kpavlov.aimocks.a2a.model.TaskStatus -import me.kpavlov.aimocks.a2a.model.create -import kotlin.test.Ignore import kotlin.test.Test -import kotlin.time.Duration.Companion.milliseconds class HttpJSONRPCClientTransportMokksyTest { - val a2aServer = MockAgentServer(verbose = true) val client = HttpJSONRPCClientTransport(a2aServer.baseUrl(), HttpClient()) - @Test - @Ignore() - // todo: implement client.getCard(...) - fun `Should get Card`() = runTest { - val agentCard = AgentCard.create { - name = "test-agent" - description = "test-agent-description" - url = a2aServer.baseUrl() - documentationUrl = "https://example.com/documentation" - version = "0.0.1" - provider { - organization = "Acme, Inc." - url = "https://example.com/organization" - } - authentication { - schemes = listOf("none", "bearer") - credentials = "test-token" - } - capabilities { - streaming = true - pushNotifications = true - stateTransitionHistory = true - } - skills += skill { - id = "walk" - name = "Walk the walk" - } - skills += skill { - id = "talk" - name = "Talk the talk" - } - } - - // Configure the mock server to respond with the AgentCard - a2aServer.agentCard() responds { - delay = 1.milliseconds - card = agentCard - } - - TODO("client.getCard(...)") - } - @Test fun `Should sendMessage`() = runTest { a2aServer.sendMessage() responds { diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt index e6e03a9210..4d8bbe71a9 100644 --- a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt @@ -59,6 +59,10 @@ class HttpJSONRPCServerTransportTest { capabilities = AgentCapabilities(), defaultInputModes = listOf("text/plain"), defaultOutputModes = listOf("text/plain"), + security = listOf( + mapOf("oauth" to listOf("read")), + mapOf("api-key" to listOf("mtls")), + ), skills = listOf( AgentSkill( id = "test-skill", diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 380698c201..f36c121977 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -35,6 +35,7 @@ spring-boot = "3.5.3" spring-management = "1.1.7" sqlite = "3.46.1.3" testcontainers = "1.19.7" +mokksy = "0.5.0-Alpha3" [libraries] jetbrains-annotations = { module = "org.jetbrains:annotations", version.ref = "annotations" } @@ -70,12 +71,12 @@ ktor-server-cio = { module = "io.ktor:ktor-server-cio", version.ref = "ktor3" } ktor-server-netty = { module = "io.ktor:ktor-server-netty-jvm", version.ref = "ktor3" } ktor-server-sse = { module = "io.ktor:ktor-server-sse", version.ref = "ktor3" } ktor-server-content-negotiation = { module = "io.ktor:ktor-server-content-negotiation", version.ref = "ktor3" } -ktor-server-netty = { module = "io.ktor:ktor-server-netty", version.ref = "ktor3" } ktor-server-test-host = { module = "io.ktor:ktor-server-test-host", version.ref = "ktor3" } lettuce-core = { module = "io.lettuce:lettuce-core", version.ref = "lettuce" } logback-classic = { module = "ch.qos.logback:logback-classic", version.ref = "logback" } oshai-kotlin-logging = { module = "io.github.oshai:kotlin-logging", version.ref = "oshai-logging" } mockk = { module = "io.mockk:mockk", version.ref = "mockk" } +mokksy-a2a = { module = "me.kpavlov.aimocks:ai-mocks-a2a", version.ref = "mokksy" } dokka-gradle-plugin = { module = "org.jetbrains.dokka:dokka-gradle-plugin", version.ref = "dokka" } mcp-client = { module = "io.modelcontextprotocol:kotlin-sdk-client", version.ref = "mcp" } mcp-server = { module = "io.modelcontextprotocol:kotlin-sdk-server", version.ref = "mcp" } From 1f523573cd12c9b2332e70952394cfa5080dbdca Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Wed, 10 Sep 2025 23:32:32 +0200 Subject: [PATCH 23/52] [a2a] Update A2A models for spec compliance and add client methods --- .../kotlin/ai/koog/a2a/client/A2AClient.kt | 193 ++++++++++++++++-- .../ai/koog/a2a/client/A2AClientMokksyTest.kt | 4 +- .../kotlin/ai/koog/a2a/model/AgentCard.kt | 10 +- .../ai/koog/a2a/model/MessageSendParams.kt | 4 +- .../kotlin/ai/koog/a2a/model/TaskEvents.kt | 3 +- .../ai/koog/a2a/transport/ClientTransport.kt | 20 +- .../a2a/model/AgentCardSerializationTest.kt | 15 +- 7 files changed, 194 insertions(+), 55 deletions(-) diff --git a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt index b6d092af2f..2c9db020c3 100644 --- a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt +++ b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt @@ -1,10 +1,20 @@ -@file:Suppress("MissingKDocForPublicAPI") - package ai.koog.a2a.client +import ai.koog.a2a.exceptions.A2AException import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskPushNotificationConfig +import ai.koog.a2a.model.TaskPushNotificationConfigParams +import ai.koog.a2a.model.TaskQueryParams +import ai.koog.a2a.model.UpdateEvent +import ai.koog.a2a.transport.ClientCallContext import ai.koog.a2a.transport.ClientTransport -import ai.koog.a2a.utils.RWLock +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.Response +import kotlinx.coroutines.flow.Flow import kotlin.concurrent.Volatile /** @@ -14,25 +24,170 @@ public open class A2AClient( private val transport: ClientTransport, private val agentCardResolver: AgentCardResolver, ) { + /** + * Currently cached version of the agent card. + * Shouldn't be used directly to read values from it, since it can be updated by the [loadAgentCard] method. + * Always use [getAgentCard] instead. + */ @Volatile - public lateinit var agentCard: AgentCard - private set + protected open lateinit var agentCard: AgentCard + + /** + * Resolve agent card from the provided [agentCardResolver] and cache it. + * Can be called multiple times, to update cached version of the agent card. + */ + public open suspend fun loadAgentCard(): AgentCard { + return agentCardResolver.resolve().also { + agentCard = it + } + } + + /** + * Get current cached version of agent card. + */ + public open fun getAgentCard(): AgentCard = agentCard - private val cardLock = RWLock() + /** + * Calls [agent/getAuthenticatedExtendedCard](https://a2a-protocol.org/latest/specification/#710-agentgetauthenticatedextendedcard) + * + * @throws A2AException if server returned an error. + */ + public suspend fun getAuthenticatedExtendedAgentCard( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response { + check(getAgentCard().supportsAuthenticatedExtendedCard == true) { + "Agent card reports that authenticated extended agent card is not supported." + } - public suspend fun connect(): AgentCard = cardLock.withWriteLock { - agentCard = agentCardResolver.resolve() - agentCard + return transport.getAuthenticatedExtendedAgentCard(request, ctx).also { + agentCard = it.data + } } -} -public fun A2AClient( - transport: ClientTransport, - baseUrl: String, - cardPath: String = UrlAgentCardResolver.wellKnownPath, -): A2AClient { - return A2AClient( - transport = transport, - agentCardResolver = UrlAgentCardResolver(baseUrl, cardPath), - ) + /** + * Calls [message/send](https://a2a-protocol.org/latest/specification/#71-messagesend). + * + * @throws A2AException if server returned an error. + */ + public suspend fun sendMessage( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response { + return transport.sendMessage(request, ctx) + } + + /** + * Calls [message/stream](https://a2a-protocol.org/latest/specification/#72-messagestream) + * + * @throws A2AException if server returned an error. + */ + public fun sendMessageStreaming( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Flow> { + check(getAgentCard().capabilities.streaming == true) { + "Agent card reports that streaming is not supported." + } + + return transport.sendMessageStreaming(request, ctx) + } + + /** + * Calls [tasks/get](https://a2a-protocol.org/latest/specification/#73-tasksget) + * + * @throws A2AException if server returned an error. + */ + public suspend fun getTask( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response { + return transport.getTask(request, ctx) + } + + /** + * Calls [tasks/cancel](https://a2a-protocol.org/latest/specification/#74-taskscancel) + * + * @throws A2AException if server returned an error. + */ + public suspend fun cancelTask( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response { + return transport.cancelTask(request, ctx) + } + + /** + * Calls [tasks/resubscribe](https://a2a-protocol.org/latest/specification/#79-tasksresubscribe) + * + * @throws A2AException if server returned an error. + */ + public fun resubscribeTask( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Flow> { + return transport.resubscribeTask(request, ctx) + } + + /** + * Calls [tasks/pushNotificationConfig/set](https://a2a-protocol.org/latest/specification/#75-taskspushnotificationconfigset) + * + * @throws A2AException if server returned an error. + */ + public suspend fun setTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response { + checkPushNotificationsSupported() + + return transport.setTaskPushNotificationConfig(request, ctx) + } + + /** + * Calls [tasks/pushNotificationConfig/get](https://a2a-protocol.org/latest/specification/#76-taskspushnotificationconfigget) + * + * @throws A2AException if server returned an error. + */ + public suspend fun getTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response { + checkPushNotificationsSupported() + + return transport.getTaskPushNotificationConfig(request, ctx) + } + + /** + * Calls [tasks/pushNotificationConfig/list](https://a2a-protocol.org/latest/specification/#77-taskspushnotificationconfiglist) + * + * @throws A2AException if server returned an error. + */ + public suspend fun listTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response> { + checkPushNotificationsSupported() + + return transport.listTaskPushNotificationConfig(request, ctx) + } + + /** + * Calls [tasks/pushNotificationConfig/delete](https://a2a-protocol.org/latest/specification/#78-taskspushnotificationconfigdelete) + * + * @throws A2AException if server returned an error. + */ + public suspend fun deleteTaskPushNotificationConfig( + request: Request, + ctx: ClientCallContext = ClientCallContext.Default + ): Response { + checkPushNotificationsSupported() + + return transport.deleteTaskPushNotificationConfig(request, ctx) + } + + private fun checkPushNotificationsSupported() { + check(getAgentCard().capabilities.pushNotifications == true) { + "Agent card reports that push notifications are not supported." + } + } } diff --git a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientMokksyTest.kt b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientMokksyTest.kt index 30b75e8866..4ac690618f 100644 --- a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientMokksyTest.kt +++ b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientMokksyTest.kt @@ -37,7 +37,7 @@ class A2AClientMokksyTest { ) @Test - fun `Should get Card`() = runTest { + fun `should get card`() = runTest { // given val agentCard = AgentCard.create { name = "test-agent" @@ -79,7 +79,7 @@ class A2AClientMokksyTest { } // when - val actualAgentCard = client.connect() + val actualAgentCard = client.loadAgentCard() // then actualAgentCard shouldNotBeNull { diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt index 6235b29abe..0f36f87973 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt @@ -92,7 +92,6 @@ public data class AgentCard( public val defaultInputModes: List, public val defaultOutputModes: List, public val skills: List, - @EncodeDefault public val supportsAuthenticatedExtendedCard: Boolean? = false, public val signatures: List? = null ) { @@ -178,12 +177,9 @@ public data class AgentProvider( */ @Serializable public data class AgentCapabilities( - @EncodeDefault - public val streaming: Boolean = false, - @EncodeDefault - public val pushNotifications: Boolean = false, - @EncodeDefault - public val stateTransitionHistory: Boolean = false, + public val streaming: Boolean? = null, + public val pushNotifications: Boolean? = null, + public val stateTransitionHistory: Boolean? = null, public val extensions: List? = null ) diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/MessageSendParams.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/MessageSendParams.kt index e92c170a4d..2d72fc5e45 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/MessageSendParams.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/MessageSendParams.kt @@ -1,6 +1,5 @@ package ai.koog.a2a.model -import kotlinx.serialization.EncodeDefault import kotlinx.serialization.Serializable import kotlinx.serialization.json.JsonObject @@ -29,8 +28,7 @@ public data class MessageSendParams( */ @Serializable public data class MessageSendConfiguration( - @EncodeDefault - public val blocking: Boolean = false, + public val blocking: Boolean? = null, public val acceptedOutputModes: List? = null, public val historyLength: Int? = null, public val pushNotificationConfig: PushNotificationConfig? = null, diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskEvents.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskEvents.kt index e4cb3178aa..ef28cc7533 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskEvents.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskEvents.kt @@ -42,8 +42,7 @@ public data class TaskArtifactUpdateEvent( public val taskId: String, public val contextId: String, public val artifact: Artifact, - @EncodeDefault - public val append: Boolean = false, + public val append: Boolean? = null, public val lastChunk: Boolean? = null, public val metadata: JsonObject? = null, ) : UpdateEvent { diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt index e5ab4ed312..ff548e7a5e 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt @@ -26,7 +26,7 @@ import kotlinx.serialization.SerializationException */ public interface ClientTransport : AutoCloseable { /** - * Implements [agent/getAuthenticatedExtendedCard](https://a2a-protocol.org/latest/specification/#710-agentgetauthenticatedextendedcard) + * Calls [agent/getAuthenticatedExtendedCard](https://a2a-protocol.org/latest/specification/#710-agentgetauthenticatedextendedcard) * * @throws A2AException if server returned an error. */ @@ -36,7 +36,7 @@ public interface ClientTransport : AutoCloseable { ): Response /** - * Implements [message/send](https://a2a-protocol.org/latest/specification/#71-messagesend). + * Calls [message/send](https://a2a-protocol.org/latest/specification/#71-messagesend). * * @throws A2AException if server returned an error. */ @@ -46,7 +46,7 @@ public interface ClientTransport : AutoCloseable { ): Response /** - * Implements [message/stream](https://a2a-protocol.org/latest/specification/#72-messagestream) + * Calls [message/stream](https://a2a-protocol.org/latest/specification/#72-messagestream) * * @throws A2AException if server returned an error. */ @@ -56,7 +56,7 @@ public interface ClientTransport : AutoCloseable { ): Flow> /** - * Implements [tasks/get](https://a2a-protocol.org/latest/specification/#73-tasksget) + * Calls [tasks/get](https://a2a-protocol.org/latest/specification/#73-tasksget) * * @throws A2AException if server returned an error. */ @@ -66,7 +66,7 @@ public interface ClientTransport : AutoCloseable { ): Response /** - * Implements [tasks/cancel](https://a2a-protocol.org/latest/specification/#74-taskscancel) + * Calls [tasks/cancel](https://a2a-protocol.org/latest/specification/#74-taskscancel) * * @throws A2AException if server returned an error. */ @@ -76,7 +76,7 @@ public interface ClientTransport : AutoCloseable { ): Response /** - * Implements [tasks/resubscribe](https://a2a-protocol.org/latest/specification/#79-tasksresubscribe) + * Calls [tasks/resubscribe](https://a2a-protocol.org/latest/specification/#79-tasksresubscribe) * * @throws A2AException if server returned an error. */ @@ -86,7 +86,7 @@ public interface ClientTransport : AutoCloseable { ): Flow> /** - * Implements [tasks/pushNotificationConfig/set](https://a2a-protocol.org/latest/specification/#75-taskspushnotificationconfigset) + * Calls [tasks/pushNotificationConfig/set](https://a2a-protocol.org/latest/specification/#75-taskspushnotificationconfigset) * * @throws A2AException if server returned an error. */ @@ -96,7 +96,7 @@ public interface ClientTransport : AutoCloseable { ): Response /** - * Implements [tasks/pushNotificationConfig/get](https://a2a-protocol.org/latest/specification/#76-taskspushnotificationconfigget) + * Calls [tasks/pushNotificationConfig/get](https://a2a-protocol.org/latest/specification/#76-taskspushnotificationconfigget) * * @throws A2AException if server returned an error. */ @@ -106,7 +106,7 @@ public interface ClientTransport : AutoCloseable { ): Response /** - * Implements [tasks/pushNotificationConfig/list](https://a2a-protocol.org/latest/specification/#77-taskspushnotificationconfiglist) + * Calls [tasks/pushNotificationConfig/list](https://a2a-protocol.org/latest/specification/#77-taskspushnotificationconfiglist) * * @throws A2AException if server returned an error. */ @@ -116,7 +116,7 @@ public interface ClientTransport : AutoCloseable { ): Response> /** - * Implements [tasks/pushNotificationConfig/delete](https://a2a-protocol.org/latest/specification/#78-taskspushnotificationconfigdelete) + * Calls [tasks/pushNotificationConfig/delete](https://a2a-protocol.org/latest/specification/#78-taskspushnotificationconfigdelete) * * @throws A2AException if server returned an error. */ diff --git a/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/model/AgentCardSerializationTest.kt b/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/model/AgentCardSerializationTest.kt index 6b3271f5fa..a0e410b2cc 100644 --- a/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/model/AgentCardSerializationTest.kt +++ b/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/model/AgentCardSerializationTest.kt @@ -40,11 +40,7 @@ class AgentCardSerializationTest { "url": "https://api.example.com/a2a", "preferredTransport": "JSONRPC", "version": "1.0.0", - "capabilities": { - "streaming": false, - "pushNotifications": false, - "stateTransitionHistory": false - }, + "capabilities": {}, "defaultInputModes": [ "text/plain" ], @@ -60,8 +56,7 @@ class AgentCardSerializationTest { "test" ] } - ], - "supportsAuthenticatedExtendedCard": false + ] } """.trimIndent() @@ -437,11 +432,7 @@ class AgentCardSerializationTest { val defaultCapabilities = AgentCapabilities() //language=JSON val defaultJson = """ - { - "streaming": false, - "pushNotifications": false, - "stateTransitionHistory": false - } + {} """.trimIndent() assertEquals(defaultJson, TestJson.encodeToString(defaultCapabilities)) From b071e85423f3bb0eb67bb9425558eb227b0335a9 Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Thu, 11 Sep 2025 01:48:49 +0200 Subject: [PATCH 24/52] [a2a] Add comprehensive A2A client test suite with Python test server --- a2a/a2a-client/build.gradle.kts | 14 +- .../kotlin/ai/koog/a2a/client/A2AClient.kt | 9 +- .../a2a/client/A2AClientIntegrationTest.kt | 469 ++++++++++++++++++ .../ai/koog/a2a/client/A2AClientMokksyTest.kt | 115 ----- .../ai/koog/a2a/exceptions/Exceptions.kt | 2 +- .../kotlin/ai/koog/a2a/transport/Core.kt | 8 +- .../a2a/transport/jsonrpc/model/Messages.kt | 2 +- .../transport/jsonrpc/model/Serialization.kt | 3 +- a2a/test-python-a2a-server/.gitignore | 2 + a2a/test-python-a2a-server/.python-version | 1 + a2a/test-python-a2a-server/Dockerfile | 7 + a2a/test-python-a2a-server/pyproject.toml | 9 + .../src/agent_executor.py | 165 ++++++ a2a/test-python-a2a-server/src/main.py | 79 +++ a2a/test-python-a2a-server/uv.lock | 468 +++++++++++++++++ gradle/libs.versions.toml | 3 + 16 files changed, 1230 insertions(+), 126 deletions(-) create mode 100644 a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientIntegrationTest.kt delete mode 100644 a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientMokksyTest.kt create mode 100644 a2a/test-python-a2a-server/.gitignore create mode 100644 a2a/test-python-a2a-server/.python-version create mode 100644 a2a/test-python-a2a-server/Dockerfile create mode 100644 a2a/test-python-a2a-server/pyproject.toml create mode 100644 a2a/test-python-a2a-server/src/agent_executor.py create mode 100644 a2a/test-python-a2a-server/src/main.py create mode 100644 a2a/test-python-a2a-server/uv.lock diff --git a/a2a/a2a-client/build.gradle.kts b/a2a/a2a-client/build.gradle.kts index 44297855e1..217c0385e3 100644 --- a/a2a/a2a-client/build.gradle.kts +++ b/a2a/a2a-client/build.gradle.kts @@ -30,6 +30,7 @@ kotlin { commonTest { dependencies { implementation(kotlin("test")) + implementation(libs.kotest.assertions) implementation(libs.kotlinx.coroutines.test) } } @@ -39,9 +40,9 @@ kotlin { implementation(kotlin("test-junit5")) implementation(project(":a2a:a2a-transport:a2a-transport-client-jsonrpc-http")) - implementation(libs.mokksy.a2a) implementation(libs.ktor.client.cio) implementation(libs.ktor.client.logging) + implementation(libs.testcontainers.junit) runtimeOnly(libs.slf4j.simple) } } @@ -57,3 +58,14 @@ kotlin { } publishToMaven() + +tasks.register("dockerBuildTestPythonA2AServer") { + group = "docker" + description = "Build Python A2A test server image" + workingDir = file("../test-python-a2a-server") + commandLine = listOf("docker", "build", "-t", "test-python-a2a-server", ".") +} + +tasks.named("jvmTest") { + dependsOn("dockerBuildTestPythonA2AServer") +} diff --git a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt index 2c9db020c3..ae6a0decf4 100644 --- a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt +++ b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt @@ -29,8 +29,9 @@ public open class A2AClient( * Shouldn't be used directly to read values from it, since it can be updated by the [loadAgentCard] method. * Always use [getAgentCard] instead. */ + @Suppress("PropertyName") @Volatile - protected open lateinit var agentCard: AgentCard + protected open lateinit var _agentCard: AgentCard /** * Resolve agent card from the provided [agentCardResolver] and cache it. @@ -38,14 +39,14 @@ public open class A2AClient( */ public open suspend fun loadAgentCard(): AgentCard { return agentCardResolver.resolve().also { - agentCard = it + _agentCard = it } } /** * Get current cached version of agent card. */ - public open fun getAgentCard(): AgentCard = agentCard + public open fun getAgentCard(): AgentCard = _agentCard /** * Calls [agent/getAuthenticatedExtendedCard](https://a2a-protocol.org/latest/specification/#710-agentgetauthenticatedextendedcard) @@ -61,7 +62,7 @@ public open class A2AClient( } return transport.getAuthenticatedExtendedAgentCard(request, ctx).also { - agentCard = it.data + _agentCard = it.data } } diff --git a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientIntegrationTest.kt b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientIntegrationTest.kt new file mode 100644 index 0000000000..bf4cd468fa --- /dev/null +++ b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientIntegrationTest.kt @@ -0,0 +1,469 @@ +package ai.koog.a2a.client + +import ai.koog.a2a.exceptions.A2AInternalErrorException +import ai.koog.a2a.model.AgentCapabilities +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.AgentSkill +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendConfiguration +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.PushNotificationAuthenticationInfo +import ai.koog.a2a.model.PushNotificationConfig +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskPushNotificationConfig +import ai.koog.a2a.model.TaskPushNotificationConfigParams +import ai.koog.a2a.model.TaskQueryParams +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.model.TransportProtocol +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.client.jsonrpc.http.HttpJSONRPCClientTransport +import io.kotest.assertions.throwables.shouldThrowExactly +import io.kotest.inspectors.shouldForAll +import io.kotest.matchers.collections.shouldHaveSize +import io.kotest.matchers.collections.shouldNotBeEmpty +import io.kotest.matchers.nulls.shouldNotBeNull +import io.kotest.matchers.should +import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldStartWith +import io.kotest.matchers.types.shouldBeInstanceOf +import io.ktor.client.HttpClient +import io.ktor.client.plugins.logging.LogLevel +import io.ktor.client.plugins.logging.Logging +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.test.runTest +import org.testcontainers.containers.GenericContainer +import org.testcontainers.containers.wait.strategy.Wait +import org.testcontainers.junit.jupiter.Container +import org.testcontainers.junit.jupiter.Testcontainers +import kotlin.test.BeforeTest +import kotlin.test.Test + +@Testcontainers +class A2AClientIntegrationTest { + companion object { + @Container + val testA2AServer: GenericContainer<*> = + GenericContainer("test-python-a2a-server") + .withExposedPorts(9999) + .waitingFor(Wait.forListeningPort()) + } + + private val httpClient = HttpClient { + install(Logging) { + level = LogLevel.BODY + } + } + + @Suppress("HttpUrlsUsage") + private val agentUrl by lazy { "http://${testA2AServer.host}:${testA2AServer.getMappedPort(9999)}" } + + private val transport by lazy { + HttpJSONRPCClientTransport( + url = agentUrl, + baseHttpClient = httpClient + ) + } + + private val client by lazy { + A2AClient( + transport = transport, + agentCardResolver = UrlAgentCardResolver( + baseUrl = agentUrl, + baseHttpClient = httpClient, + ), + ) + } + + @BeforeTest + fun initClient() = runTest { + client.loadAgentCard() + } + + @Test + fun `test get agent card`() = runTest { + val agentCard = client.getAgentCard() + + // Assert on the full AgentCard structure + val expectedAgentCard = AgentCard( + protocolVersion = "0.3.0", + name = "Hello World Agent", + description = "Just a hello world agent", + url = "http://localhost:9999/", + preferredTransport = TransportProtocol.JSONRPC, + additionalInterfaces = null, + iconUrl = null, + provider = null, + version = "1.0.0", + documentationUrl = null, + capabilities = AgentCapabilities( + streaming = true, + pushNotifications = true, + stateTransitionHistory = null, + extensions = null + ), + securitySchemes = null, + security = null, + defaultInputModes = listOf("text"), + defaultOutputModes = listOf("text"), + skills = listOf( + AgentSkill( + id = "hello_world", + name = "Returns hello world", + description = "just returns hello world", + tags = listOf("hello world"), + examples = listOf("hi", "hello world"), + inputModes = null, + outputModes = null, + security = null + ) + ), + supportsAuthenticatedExtendedCard = true, + signatures = null + ) + + agentCard shouldBe expectedAgentCard + } + + @Test + fun `test get authenticated extended agent card`() = runTest { + val request = Request(data = null) + + val response = client.getAuthenticatedExtendedAgentCard(request) + + // Assert on the extended agent card structure + val expectedExtendedAgentCard = AgentCard( + protocolVersion = "0.3.0", + name = "Hello World Agent - Extended Edition", + description = "The full-featured hello world agent for authenticated users.", + url = "http://localhost:9999/", + preferredTransport = TransportProtocol.JSONRPC, + additionalInterfaces = null, + iconUrl = null, + provider = null, + version = "1.0.1", + documentationUrl = null, + capabilities = AgentCapabilities( + streaming = true, + pushNotifications = true, + stateTransitionHistory = null, + extensions = null + ), + securitySchemes = null, + security = null, + defaultInputModes = listOf("text"), + defaultOutputModes = listOf("text"), + skills = listOf( + AgentSkill( + id = "hello_world", + name = "Returns hello world", + description = "just returns hello world", + tags = listOf("hello world"), + examples = listOf("hi", "hello world"), + inputModes = null, + outputModes = null, + security = null + ), + AgentSkill( + id = "super_hello_world", + name = "Returns a SUPER Hello World", + description = "A more enthusiastic greeting, only for authenticated users.", + tags = listOf("hello world", "super", "extended"), + examples = listOf("super hi", "give me a super hello"), + inputModes = null, + outputModes = null, + security = null + ) + ), + supportsAuthenticatedExtendedCard = true, + signatures = null + ) + + response.data shouldBe expectedExtendedAgentCard + } + + @Test + fun `test send message`() = runTest { + val request = Request( + data = MessageSendParams( + message = Message( + role = Role.User, + parts = listOf( + TextPart("hello world"), + TextPart("How are you doing?"), + ), + contextId = "test-context" + ), + ) + ) + + val response = client.sendMessage(request) + + response should { + it.id shouldBe request.id + + it.data.shouldBeInstanceOf { + it.role shouldBe Role.Agent + it.parts shouldBe listOf(TextPart("Hello World")) + it.contextId shouldBe "test-context" + } + } + } + + @Test + fun `test send message streaming`() = runTest { + val createTaskRequest = Request( + data = MessageSendParams( + message = Message( + role = Role.User, + parts = listOf( + TextPart("do task"), + ), + contextId = "test-context" + ), + ), + ) + + val events = client + .sendMessageStreaming(createTaskRequest) + .toList() + .map { it.data } + + events shouldHaveSize 3 + events[0].shouldBeInstanceOf { + it.contextId shouldBe "test-context" + it.status should { + it.state shouldBe TaskState.Submitted + } + + it.history shouldNotBeNull { + this shouldHaveSize 1 + + this[0] should { + it.role shouldBe Role.User + it.parts shouldBe listOf(TextPart("do task")) + } + } + } + + events[1].shouldBeInstanceOf { + it.contextId shouldBe "test-context" + + it.status should { + it.state shouldBe TaskState.Working + it.message shouldNotBeNull { + role shouldBe Role.Agent + parts shouldBe listOf(TextPart("Working on task")) + } + } + } + + events[2].shouldBeInstanceOf { + it.contextId shouldBe "test-context" + + it.status should { + it.state shouldBe TaskState.Completed + it.message shouldNotBeNull { + role shouldBe Role.Agent + parts shouldBe listOf(TextPart("Task completed")) + } + } + } + } + + @Test + fun `test get task`() = runTest { + val createTaskRequest = Request( + data = MessageSendParams( + message = Message( + role = Role.User, + parts = listOf( + TextPart("do task"), + ), + contextId = "test-context" + ), + ), + ) + + val taskId = (client.sendMessage(createTaskRequest).data as Task).id + + val getTaskRequest = Request( + data = TaskQueryParams( + id = taskId, + historyLength = 1 + ) + ) + + val response = client.getTask(getTaskRequest) + + response.data should { + it.id shouldBe taskId + it.contextId shouldBe "test-context" + it.status should { + it.state shouldBe TaskState.Completed + } + } + } + + @Test + fun `test cancel task`() = runTest { + val createTaskRequest = Request( + data = MessageSendParams( + message = Message( + role = Role.User, + parts = listOf( + TextPart("do cancelable task"), + ), + contextId = "test-context" + ), + ), + ) + + val taskId = (client.sendMessage(createTaskRequest).data as Task).id + + val cancelTaskRequest = Request( + data = TaskIdParams( + id = taskId, + ) + ) + + val response = client.cancelTask(cancelTaskRequest) + + response.data should { + it.id shouldBe taskId + it.contextId shouldBe "test-context" + it.status should { + it.state shouldBe TaskState.Canceled + it.message shouldNotBeNull { + role shouldBe Role.Agent + parts shouldBe listOf(TextPart("Task canceled")) + } + } + } + } + + @Test + fun `test resubscribe task`() = runTest { + val createTaskRequest = Request( + data = MessageSendParams( + message = Message( + role = Role.User, + parts = listOf( + TextPart("do long-running task"), + ), + contextId = "test-context" + ), + configuration = MessageSendConfiguration( + blocking = false + ) + ), + ) + + val taskId = (client.sendMessage(createTaskRequest).data as Task).id + + val resubscribeTaskRequest = Request( + data = TaskIdParams( + id = taskId, + ) + ) + + val events = client + .resubscribeTask(resubscribeTaskRequest) + .toList() + .map { it.data } + + events.shouldNotBeEmpty() + + events.shouldForAll { + it.shouldBeInstanceOf { + it.taskId shouldBe taskId + it.contextId shouldBe "test-context" + + it.status should { + it.state shouldBe TaskState.Working + it.message shouldNotBeNull { + role shouldBe Role.Agent + + parts.shouldForAll { + it.shouldBeInstanceOf { + it.text shouldStartWith "Still working" + } + } + } + } + } + } + } + + @Test + fun `test push notification configs`() = runTest { + val createTaskRequest = Request( + data = MessageSendParams( + message = Message( + role = Role.User, + parts = listOf( + TextPart("do long-running task"), + ), + contextId = "test-context" + ), + ), + ) + + val taskId = (client.sendMessage(createTaskRequest).data as Task).id + + val pushConfig = TaskPushNotificationConfig( + taskId = taskId, + pushNotificationConfig = PushNotificationConfig( + id = "push-id", + url = "https://localhost:3000", + token = "push-token", + authentication = PushNotificationAuthenticationInfo( + schemes = listOf("bearer"), + credentials = "very-secret-credential" + ) + ) + ) + + val request = Request( + data = pushConfig + ) + + val setPushConfigResponse = client.setTaskPushNotificationConfig(request) + setPushConfigResponse.data shouldBe pushConfig + + val getPushConfigRequest = Request( + data = TaskPushNotificationConfigParams( + id = taskId, + pushNotificationConfigId = pushConfig.pushNotificationConfig.id, + ) + ) + + val getPushConfigResponse = client.getTaskPushNotificationConfig(getPushConfigRequest) + getPushConfigResponse.data shouldBe pushConfig + + val listPushConfigRequest = Request( + data = TaskIdParams( + id = taskId, + ) + ) + + val listPushConfigResponse = client.listTaskPushNotificationConfig(listPushConfigRequest) + listPushConfigResponse.data shouldBe listOf(pushConfig) + + val deletePushConfigRequest = Request( + data = TaskPushNotificationConfigParams( + id = taskId, + pushNotificationConfigId = pushConfig.pushNotificationConfig.id, + ) + ) + + client.deleteTaskPushNotificationConfig(deletePushConfigRequest) + + shouldThrowExactly { + client.getTaskPushNotificationConfig(getPushConfigRequest) + } + } +} diff --git a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientMokksyTest.kt b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientMokksyTest.kt deleted file mode 100644 index 4ac690618f..0000000000 --- a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientMokksyTest.kt +++ /dev/null @@ -1,115 +0,0 @@ -package ai.koog.a2a.client - -import ai.koog.a2a.transport.client.jsonrpc.http.HttpJSONRPCClientTransport -import io.kotest.matchers.collections.shouldHaveSize -import io.kotest.matchers.nulls.shouldNotBeNull -import io.kotest.matchers.shouldBe -import io.ktor.client.HttpClient -import io.ktor.client.plugins.logging.LogLevel -import io.ktor.client.plugins.logging.Logging -import kotlinx.coroutines.test.runTest -import me.kpavlov.aimocks.a2a.MockAgentServer -import me.kpavlov.aimocks.a2a.model.AgentCard -import me.kpavlov.aimocks.a2a.model.create -import kotlin.test.Test -import kotlin.time.Duration.Companion.milliseconds - -class A2AClientMokksyTest { - val a2aServer = MockAgentServer(verbose = false) - - private val httpClient = HttpClient { - install(Logging) { - this.level = LogLevel.BODY - } - } - - private val transport = HttpJSONRPCClientTransport( - url = a2aServer.baseUrl(), - baseHttpClient = httpClient - ) - - val client = A2AClient( - transport = transport, - agentCardResolver = UrlAgentCardResolver( - baseUrl = a2aServer.baseUrl(), - baseHttpClient = httpClient, - ), - ) - - @Test - fun `should get card`() = runTest { - // given - val agentCard = AgentCard.create { - name = "test-agent" - description = "test-agent-description" - url = a2aServer.baseUrl() - documentationUrl = "https://example.com/documentation" - version = "1.0.1" - security = listOf( - mapOf("oauth" to listOf("read")), - mapOf("api-key" to listOf("mtls")), - ) - provider { - organization = "Acme, Inc." - url = "https://example.com/organization" - } - capabilities { - streaming = true - pushNotifications = true - stateTransitionHistory = true - } - skills += skill { - id = "walk" - name = "Walk the walk" - description = "Walk the walk description" - tags = listOf("walk", "tag") - } - skills += skill { - id = "talk" - name = "Talk the talk" - description = "Talk the talk description" - tags = listOf("walk", "tag") - } - } - - // Configure the mock server to respond with the AgentCard - a2aServer.agentCard() responds { - delay = 1.milliseconds - card = agentCard - } - - // when - val actualAgentCard = client.loadAgentCard() - - // then - actualAgentCard shouldNotBeNull { - name shouldBe agentCard.name - description shouldBe agentCard.description - url shouldBe agentCard.url - documentationUrl shouldBe agentCard.documentationUrl - version shouldBe agentCard.version - security shouldBe agentCard.security - - provider shouldNotBeNull { - organization shouldBe agentCard.provider?.organization - url shouldBe agentCard.provider?.url - } - - capabilities shouldNotBeNull { - streaming shouldBe agentCard.capabilities.streaming - pushNotifications shouldBe agentCard.capabilities.pushNotifications - stateTransitionHistory shouldBe agentCard.capabilities.stateTransitionHistory - } - - skills shouldHaveSize agentCard.skills.size - skills.zip(agentCard.skills).forEach { (actualSkill, skill) -> - actualSkill shouldNotBeNull { - id shouldBe skill.id - name shouldBe skill.name - description shouldBe skill.description - tags shouldBe skill.tags - } - } - } - } -} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt index 10dec7caae..798ed3766a 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt @@ -70,7 +70,7 @@ public sealed class A2AServerException( errorCode: Int, ) : A2AException(message, errorCode) { init { - require(errorCode in -32000..-32099) { "Server error code must be in -32000..-32099" } + require(errorCode in -32099..-32000) { "Server error code must be in -32099..-32000" } } } diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Core.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Core.kt index 655a57657d..ad4c2e9437 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Core.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Core.kt @@ -1,6 +1,8 @@ package ai.koog.a2a.transport import kotlinx.serialization.Serializable +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid /** * A uniquely identifying ID for a request. @@ -26,9 +28,10 @@ public sealed interface RequestId { * @property id The unique identifier for the request. * @property data The data payload of the request. */ +@OptIn(ExperimentalUuidApi::class) public class Request( - public val id: RequestId, public val data: T, + public val id: RequestId = RequestId.StringId(Uuid.random().toString()), ) /** @@ -37,7 +40,8 @@ public class Request( * @property id The unique identifier for the request associated with this response. * @property data The response data payload. */ +@OptIn(ExperimentalUuidApi::class) public class Response( - public val id: RequestId, public val data: T, + public val id: RequestId = RequestId.StringId(Uuid.random().toString()), ) diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt index 3f7707b3ad..8238c2c565 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt @@ -41,7 +41,7 @@ public data class JSONRPCNotification( @Serializable public data class JSONRPCSuccessResponse( public val id: RequestId, - public val result: JsonElement, + public val result: JsonElement = JsonNull, @EncodeDefault override val jsonrpc: String = JSONRPC_VERSION, ) : JSONRPCResponse diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Serialization.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Serialization.kt index a10332d408..9763919e93 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Serialization.kt +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Serialization.kt @@ -35,9 +35,8 @@ internal object JSONRPCResponseSerializer : JsonContentPolymorphicSerializer JSONRPCSuccessResponse.serializer() "error" in jsonObject -> JSONRPCErrorResponse.serializer() - else -> error("Invalid JSON format") + else -> JSONRPCSuccessResponse.serializer() } } } diff --git a/a2a/test-python-a2a-server/.gitignore b/a2a/test-python-a2a-server/.gitignore new file mode 100644 index 0000000000..605d1308d1 --- /dev/null +++ b/a2a/test-python-a2a-server/.gitignore @@ -0,0 +1,2 @@ +.venv +*.iml diff --git a/a2a/test-python-a2a-server/.python-version b/a2a/test-python-a2a-server/.python-version new file mode 100644 index 0000000000..e4fba21835 --- /dev/null +++ b/a2a/test-python-a2a-server/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/a2a/test-python-a2a-server/Dockerfile b/a2a/test-python-a2a-server/Dockerfile new file mode 100644 index 0000000000..87f268e7e2 --- /dev/null +++ b/a2a/test-python-a2a-server/Dockerfile @@ -0,0 +1,7 @@ +FROM ghcr.io/astral-sh/uv:python3.12-alpine +WORKDIR /app +COPY pyproject.toml uv.lock ./ +RUN uv sync --frozen +COPY src/ ./src/ +EXPOSE 9999 +CMD ["uv", "run", "--no-sync", "src/main.py"] diff --git a/a2a/test-python-a2a-server/pyproject.toml b/a2a/test-python-a2a-server/pyproject.toml new file mode 100644 index 0000000000..2ce64bc7b0 --- /dev/null +++ b/a2a/test-python-a2a-server/pyproject.toml @@ -0,0 +1,9 @@ +[project] +name = "test-python-a2a-server" +version = "0.1.0" +description = "Python A2A server for integration tests." +requires-python = ">=3.12" +dependencies = [ + "a2a-sdk[http-server]==0.3.5", + "uvicorn==0.35.0", +] diff --git a/a2a/test-python-a2a-server/src/agent_executor.py b/a2a/test-python-a2a-server/src/agent_executor.py new file mode 100644 index 0000000000..906eb4cef0 --- /dev/null +++ b/a2a/test-python-a2a-server/src/agent_executor.py @@ -0,0 +1,165 @@ +import asyncio + +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.events import EventQueue +from a2a.types import ( + Message, + TaskStatusUpdateEvent, + TaskStatus, + TaskState, + Task, +) +from a2a.utils import ( + new_agent_text_message, + new_task +) + + +async def say_hello( + event_queue: EventQueue, + message: Message +) -> None: + await event_queue.enqueue_event( + new_agent_text_message( + text="Hello World", + context_id=message.context_id, + task_id=message.task_id + ) + ) + + +async def do_task( + event_queue: EventQueue, + message: Message +) -> None: + task = new_task(message) + + # noinspection PyTypeChecker + events = [ + task, + + TaskStatusUpdateEvent( + context_id=task.context_id, + task_id=task.id, + status=TaskStatus( + state=TaskState.working, + message=new_agent_text_message( + text="Working on task", + context_id=task.context_id, + task_id=task.id + ) + ), + final=False + ), + + TaskStatusUpdateEvent( + context_id=task.context_id, + task_id=task.id, + status=TaskStatus( + state=TaskState.completed, + message=new_agent_text_message( + text="Task completed", + context_id=task.context_id, + task_id=task.id + ) + ), + final=True + ) + ] + + for event in events: + await event_queue.enqueue_event(event) + + +async def do_cancelable_task( + event_queue: EventQueue, + message: Message, +): + await event_queue.enqueue_event( + new_task(message), + ) + +async def do_long_running_task( + event_queue: EventQueue, + message: Message +): + task = Task( + id=message.task_id, + context_id=message.context_id, + status=TaskStatus( + state=TaskState.working, + message=message + ) + ) + + await event_queue.enqueue_event(task) + + # Simulate long-running task + for i in range(4): + await asyncio.sleep(0.2) + + # noinspection PyTypeChecker + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + task_id=task.id, + context_id=task.context_id, + status=TaskStatus( + state=TaskState.working, + message=new_agent_text_message( + text=f"Still working {i}", + context_id=task.context_id, + task_id=task.id + ) + ), + final=False + ) + ) + + +class HelloWorldAgentExecutor(AgentExecutor): + """Test AgentProxy Implementation.""" + + async def execute( + self, + context: RequestContext, + event_queue: EventQueue, + ) -> None: + # Test scenarios to test various aspects of A2A + if "hello world" in context.get_user_input(): + await say_hello(event_queue, context.message) + + elif "do task" in context.get_user_input(): + await do_task(event_queue, context.message) + + elif "do cancelable task" in context.get_user_input(): + await do_cancelable_task(event_queue, context.message) + + elif "do long-running task" in context.get_user_input(): + await do_long_running_task(event_queue, context.message) + + else: + await event_queue.enqueue_event( + new_agent_text_message("Sorry, I don't understand you") + ) + + async def cancel( + self, + context: RequestContext, + event_queue: EventQueue + ) -> None: + # noinspection PyTypeChecker + await event_queue.enqueue_event( + TaskStatusUpdateEvent( + context_id=context.context_id, + task_id=context.task_id, + status=TaskStatus( + state=TaskState.canceled, + message=new_agent_text_message( + text="Task canceled", + context_id=context.context_id, + task_id=context.task_id + ) + ), + final=True, + ) + ) diff --git a/a2a/test-python-a2a-server/src/main.py b/a2a/test-python-a2a-server/src/main.py new file mode 100644 index 0000000000..516ce0fa34 --- /dev/null +++ b/a2a/test-python-a2a-server/src/main.py @@ -0,0 +1,79 @@ +import uvicorn + +from a2a.server.apps import A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import ( + InMemoryTaskStore, + InMemoryPushNotificationConfigStore +) +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentSkill, +) +from agent_executor import ( + HelloWorldAgentExecutor, +) + + +if __name__ == '__main__': + skill = AgentSkill( + id='hello_world', + name='Returns hello world', + description='just returns hello world', + tags=['hello world'], + examples=['hi', 'hello world'], + ) + + extended_skill = AgentSkill( + id='super_hello_world', + name='Returns a SUPER Hello World', + description='A more enthusiastic greeting, only for authenticated users.', + tags=['hello world', 'super', 'extended'], + examples=['super hi', 'give me a super hello'], + ) + + public_agent_card = AgentCard( + name='Hello World Agent', + description='Just a hello world agent', + url='http://localhost:9999/', + version='1.0.0', + default_input_modes=['text'], + default_output_modes=['text'], + capabilities=AgentCapabilities( + streaming=True, + push_notifications=True, + ), + skills=[skill], # Only the basic skill for the public card + supports_authenticated_extended_card=True, + ) + + # This will be the authenticated extended agent card + # It includes the additional 'extended_skill' + specific_extended_agent_card = public_agent_card.model_copy( + update={ + 'name': 'Hello World Agent - Extended Edition', # Different name for clarity + 'description': 'The full-featured hello world agent for authenticated users.', + 'version': '1.0.1', # Could even be a different version + # Capabilities and other fields like url, default_input_modes, default_output_modes, + # supports_authenticated_extended_card are inherited from public_agent_card unless specified here. + 'skills': [ + skill, + extended_skill, + ], # Both skills for the extended card + } + ) + + request_handler = DefaultRequestHandler( + agent_executor=HelloWorldAgentExecutor(), + task_store=InMemoryTaskStore(), + push_config_store=InMemoryPushNotificationConfigStore() + ) + + server = A2AStarletteApplication( + agent_card=public_agent_card, + http_handler=request_handler, + extended_agent_card=specific_extended_agent_card, + ) + + uvicorn.run(server.build(), host='0.0.0.0', port=9999) diff --git a/a2a/test-python-a2a-server/uv.lock b/a2a/test-python-a2a-server/uv.lock new file mode 100644 index 0000000000..6bbaa6a180 --- /dev/null +++ b/a2a/test-python-a2a-server/uv.lock @@ -0,0 +1,468 @@ +version = 1 +revision = 2 +requires-python = ">=3.12" +resolution-markers = [ + "python_full_version >= '3.13'", + "python_full_version < '3.13'", +] + +[[package]] +name = "a2a-sdk" +version = "0.3.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-api-core" }, + { name = "httpx" }, + { name = "httpx-sse" }, + { name = "protobuf" }, + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cf/0d/12ebef081b096ca5fafd1ec8cc589739abba07b46ae7899c7420e599f2a6/a2a_sdk-0.3.5.tar.gz", hash = "sha256:48cf37dedeb63cf0a072512221a12ed4b3950df695c9d65eadb839a99392c3e5", size = 222064, upload-time = "2025-09-08T17:30:35.826Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4c/96/c33802d929b0f884cb6e509195d69914632536256d273bd7127e900d79ea/a2a_sdk-0.3.5-py3-none-any.whl", hash = "sha256:fd85b1e4e7be18a89b5d723e4013171510150a235275876f98de9e1ba869457e", size = 136911, upload-time = "2025-09-08T17:30:34.091Z" }, +] + +[package.optional-dependencies] +http-server = [ + { name = "fastapi" }, + { name = "sse-starlette" }, + { name = "starlette" }, +] + +[[package]] +name = "annotated-types" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081, upload-time = "2024-05-20T21:33:25.928Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, +] + +[[package]] +name = "anyio" +version = "4.10.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna" }, + { name = "sniffio" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f1/b4/636b3b65173d3ce9a38ef5f0522789614e590dab6a8d505340a4efe4c567/anyio-4.10.0.tar.gz", hash = "sha256:3f3fae35c96039744587aa5b8371e7e8e603c0702999535961dd336026973ba6", size = 213252, upload-time = "2025-08-04T08:54:26.451Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6f/12/e5e0282d673bb9746bacfb6e2dba8719989d3660cdb2ea79aee9a9651afb/anyio-4.10.0-py3-none-any.whl", hash = "sha256:60e474ac86736bbfd6f210f7a61218939c318f43f9972497381f1c5e930ed3d1", size = 107213, upload-time = "2025-08-04T08:54:24.882Z" }, +] + +[[package]] +name = "cachetools" +version = "5.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/81/3747dad6b14fa2cf53fcf10548cf5aea6913e96fab41a3c198676f8948a5/cachetools-5.5.2.tar.gz", hash = "sha256:1a661caa9175d26759571b2e19580f9d6393969e5dfca11fdb1f947a23e640d4", size = 28380, upload-time = "2025-02-20T21:01:19.524Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/76/20fa66124dbe6be5cafeb312ece67de6b61dd91a0247d1ea13db4ebb33c2/cachetools-5.5.2-py3-none-any.whl", hash = "sha256:d26a22bcc62eb95c3beabd9f1ee5e820d3d2704fe2967cbe350e20c8ffcd3f0a", size = 10080, upload-time = "2025-02-20T21:01:16.647Z" }, +] + +[[package]] +name = "certifi" +version = "2025.8.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/dc/67/960ebe6bf230a96cda2e0abcf73af550ec4f090005363542f0765df162e0/certifi-2025.8.3.tar.gz", hash = "sha256:e564105f78ded564e3ae7c923924435e1daa7463faeab5bb932bc53ffae63407", size = 162386, upload-time = "2025-08-03T03:07:47.08Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/48/1549795ba7742c948d2ad169c1c8cdbae65bc450d6cd753d124b17c8cd32/certifi-2025.8.3-py3-none-any.whl", hash = "sha256:f6c12493cfb1b06ba2ff328595af9350c65d6644968e5d3a2ffd78699af217a5", size = 161216, upload-time = "2025-08-03T03:07:45.777Z" }, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/83/2d/5fd176ceb9b2fc619e63405525573493ca23441330fcdaee6bef9460e924/charset_normalizer-3.4.3.tar.gz", hash = "sha256:6fce4b8500244f6fcb71465d4a4930d132ba9ab8e71a7859e6a5d59851068d14", size = 122371, upload-time = "2025-08-09T07:57:28.46Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/5e/14c94999e418d9b87682734589404a25854d5f5d0408df68bc15b6ff54bb/charset_normalizer-3.4.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:e28e334d3ff134e88989d90ba04b47d84382a828c061d0d1027b1b12a62b39b1", size = 205655, upload-time = "2025-08-09T07:56:08.475Z" }, + { url = "https://files.pythonhosted.org/packages/7d/a8/c6ec5d389672521f644505a257f50544c074cf5fc292d5390331cd6fc9c3/charset_normalizer-3.4.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0cacf8f7297b0c4fcb74227692ca46b4a5852f8f4f24b3c766dd94a1075c4884", size = 146223, upload-time = "2025-08-09T07:56:09.708Z" }, + { url = "https://files.pythonhosted.org/packages/fc/eb/a2ffb08547f4e1e5415fb69eb7db25932c52a52bed371429648db4d84fb1/charset_normalizer-3.4.3-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c6fd51128a41297f5409deab284fecbe5305ebd7e5a1f959bee1c054622b7018", size = 159366, upload-time = "2025-08-09T07:56:11.326Z" }, + { url = "https://files.pythonhosted.org/packages/82/10/0fd19f20c624b278dddaf83b8464dcddc2456cb4b02bb902a6da126b87a1/charset_normalizer-3.4.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:3cfb2aad70f2c6debfbcb717f23b7eb55febc0bb23dcffc0f076009da10c6392", size = 157104, upload-time = "2025-08-09T07:56:13.014Z" }, + { url = "https://files.pythonhosted.org/packages/16/ab/0233c3231af734f5dfcf0844aa9582d5a1466c985bbed6cedab85af9bfe3/charset_normalizer-3.4.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1606f4a55c0fd363d754049cdf400175ee96c992b1f8018b993941f221221c5f", size = 151830, upload-time = "2025-08-09T07:56:14.428Z" }, + { url = "https://files.pythonhosted.org/packages/ae/02/e29e22b4e02839a0e4a06557b1999d0a47db3567e82989b5bb21f3fbbd9f/charset_normalizer-3.4.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:027b776c26d38b7f15b26a5da1044f376455fb3766df8fc38563b4efbc515154", size = 148854, upload-time = "2025-08-09T07:56:16.051Z" }, + { url = "https://files.pythonhosted.org/packages/05/6b/e2539a0a4be302b481e8cafb5af8792da8093b486885a1ae4d15d452bcec/charset_normalizer-3.4.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:42e5088973e56e31e4fa58eb6bd709e42fc03799c11c42929592889a2e54c491", size = 160670, upload-time = "2025-08-09T07:56:17.314Z" }, + { url = "https://files.pythonhosted.org/packages/31/e7/883ee5676a2ef217a40ce0bffcc3d0dfbf9e64cbcfbdf822c52981c3304b/charset_normalizer-3.4.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:cc34f233c9e71701040d772aa7490318673aa7164a0efe3172b2981218c26d93", size = 158501, upload-time = "2025-08-09T07:56:18.641Z" }, + { url = "https://files.pythonhosted.org/packages/c1/35/6525b21aa0db614cf8b5792d232021dca3df7f90a1944db934efa5d20bb1/charset_normalizer-3.4.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:320e8e66157cc4e247d9ddca8e21f427efc7a04bbd0ac8a9faf56583fa543f9f", size = 153173, upload-time = "2025-08-09T07:56:20.289Z" }, + { url = "https://files.pythonhosted.org/packages/50/ee/f4704bad8201de513fdc8aac1cabc87e38c5818c93857140e06e772b5892/charset_normalizer-3.4.3-cp312-cp312-win32.whl", hash = "sha256:fb6fecfd65564f208cbf0fba07f107fb661bcd1a7c389edbced3f7a493f70e37", size = 99822, upload-time = "2025-08-09T07:56:21.551Z" }, + { url = "https://files.pythonhosted.org/packages/39/f5/3b3836ca6064d0992c58c7561c6b6eee1b3892e9665d650c803bd5614522/charset_normalizer-3.4.3-cp312-cp312-win_amd64.whl", hash = "sha256:86df271bf921c2ee3818f0522e9a5b8092ca2ad8b065ece5d7d9d0e9f4849bcc", size = 107543, upload-time = "2025-08-09T07:56:23.115Z" }, + { url = "https://files.pythonhosted.org/packages/65/ca/2135ac97709b400c7654b4b764daf5c5567c2da45a30cdd20f9eefe2d658/charset_normalizer-3.4.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:14c2a87c65b351109f6abfc424cab3927b3bdece6f706e4d12faaf3d52ee5efe", size = 205326, upload-time = "2025-08-09T07:56:24.721Z" }, + { url = "https://files.pythonhosted.org/packages/71/11/98a04c3c97dd34e49c7d247083af03645ca3730809a5509443f3c37f7c99/charset_normalizer-3.4.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:41d1fc408ff5fdfb910200ec0e74abc40387bccb3252f3f27c0676731df2b2c8", size = 146008, upload-time = "2025-08-09T07:56:26.004Z" }, + { url = "https://files.pythonhosted.org/packages/60/f5/4659a4cb3c4ec146bec80c32d8bb16033752574c20b1252ee842a95d1a1e/charset_normalizer-3.4.3-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:1bb60174149316da1c35fa5233681f7c0f9f514509b8e399ab70fea5f17e45c9", size = 159196, upload-time = "2025-08-09T07:56:27.25Z" }, + { url = "https://files.pythonhosted.org/packages/86/9e/f552f7a00611f168b9a5865a1414179b2c6de8235a4fa40189f6f79a1753/charset_normalizer-3.4.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:30d006f98569de3459c2fc1f2acde170b7b2bd265dc1943e87e1a4efe1b67c31", size = 156819, upload-time = "2025-08-09T07:56:28.515Z" }, + { url = "https://files.pythonhosted.org/packages/7e/95/42aa2156235cbc8fa61208aded06ef46111c4d3f0de233107b3f38631803/charset_normalizer-3.4.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:416175faf02e4b0810f1f38bcb54682878a4af94059a1cd63b8747244420801f", size = 151350, upload-time = "2025-08-09T07:56:29.716Z" }, + { url = "https://files.pythonhosted.org/packages/c2/a9/3865b02c56f300a6f94fc631ef54f0a8a29da74fb45a773dfd3dcd380af7/charset_normalizer-3.4.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6aab0f181c486f973bc7262a97f5aca3ee7e1437011ef0c2ec04b5a11d16c927", size = 148644, upload-time = "2025-08-09T07:56:30.984Z" }, + { url = "https://files.pythonhosted.org/packages/77/d9/cbcf1a2a5c7d7856f11e7ac2d782aec12bdfea60d104e60e0aa1c97849dc/charset_normalizer-3.4.3-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:fdabf8315679312cfa71302f9bd509ded4f2f263fb5b765cf1433b39106c3cc9", size = 160468, upload-time = "2025-08-09T07:56:32.252Z" }, + { url = "https://files.pythonhosted.org/packages/f6/42/6f45efee8697b89fda4d50580f292b8f7f9306cb2971d4b53f8914e4d890/charset_normalizer-3.4.3-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:bd28b817ea8c70215401f657edef3a8aa83c29d447fb0b622c35403780ba11d5", size = 158187, upload-time = "2025-08-09T07:56:33.481Z" }, + { url = "https://files.pythonhosted.org/packages/70/99/f1c3bdcfaa9c45b3ce96f70b14f070411366fa19549c1d4832c935d8e2c3/charset_normalizer-3.4.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:18343b2d246dc6761a249ba1fb13f9ee9a2bcd95decc767319506056ea4ad4dc", size = 152699, upload-time = "2025-08-09T07:56:34.739Z" }, + { url = "https://files.pythonhosted.org/packages/a3/ad/b0081f2f99a4b194bcbb1934ef3b12aa4d9702ced80a37026b7607c72e58/charset_normalizer-3.4.3-cp313-cp313-win32.whl", hash = "sha256:6fb70de56f1859a3f71261cbe41005f56a7842cc348d3aeb26237560bfa5e0ce", size = 99580, upload-time = "2025-08-09T07:56:35.981Z" }, + { url = "https://files.pythonhosted.org/packages/9a/8f/ae790790c7b64f925e5c953b924aaa42a243fb778fed9e41f147b2a5715a/charset_normalizer-3.4.3-cp313-cp313-win_amd64.whl", hash = "sha256:cf1ebb7d78e1ad8ec2a8c4732c7be2e736f6e5123a4146c5b89c9d1f585f8cef", size = 107366, upload-time = "2025-08-09T07:56:37.339Z" }, + { url = "https://files.pythonhosted.org/packages/8e/91/b5a06ad970ddc7a0e513112d40113e834638f4ca1120eb727a249fb2715e/charset_normalizer-3.4.3-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:3cd35b7e8aedeb9e34c41385fda4f73ba609e561faedfae0a9e75e44ac558a15", size = 204342, upload-time = "2025-08-09T07:56:38.687Z" }, + { url = "https://files.pythonhosted.org/packages/ce/ec/1edc30a377f0a02689342f214455c3f6c2fbedd896a1d2f856c002fc3062/charset_normalizer-3.4.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b89bc04de1d83006373429975f8ef9e7932534b8cc9ca582e4db7d20d91816db", size = 145995, upload-time = "2025-08-09T07:56:40.048Z" }, + { url = "https://files.pythonhosted.org/packages/17/e5/5e67ab85e6d22b04641acb5399c8684f4d37caf7558a53859f0283a650e9/charset_normalizer-3.4.3-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2001a39612b241dae17b4687898843f254f8748b796a2e16f1051a17078d991d", size = 158640, upload-time = "2025-08-09T07:56:41.311Z" }, + { url = "https://files.pythonhosted.org/packages/f1/e5/38421987f6c697ee3722981289d554957c4be652f963d71c5e46a262e135/charset_normalizer-3.4.3-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8dcfc373f888e4fb39a7bc57e93e3b845e7f462dacc008d9749568b1c4ece096", size = 156636, upload-time = "2025-08-09T07:56:43.195Z" }, + { url = "https://files.pythonhosted.org/packages/a0/e4/5a075de8daa3ec0745a9a3b54467e0c2967daaaf2cec04c845f73493e9a1/charset_normalizer-3.4.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:18b97b8404387b96cdbd30ad660f6407799126d26a39ca65729162fd810a99aa", size = 150939, upload-time = "2025-08-09T07:56:44.819Z" }, + { url = "https://files.pythonhosted.org/packages/02/f7/3611b32318b30974131db62b4043f335861d4d9b49adc6d57c1149cc49d4/charset_normalizer-3.4.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:ccf600859c183d70eb47e05a44cd80a4ce77394d1ac0f79dbd2dd90a69a3a049", size = 148580, upload-time = "2025-08-09T07:56:46.684Z" }, + { url = "https://files.pythonhosted.org/packages/7e/61/19b36f4bd67f2793ab6a99b979b4e4f3d8fc754cbdffb805335df4337126/charset_normalizer-3.4.3-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:53cd68b185d98dde4ad8990e56a58dea83a4162161b1ea9272e5c9182ce415e0", size = 159870, upload-time = "2025-08-09T07:56:47.941Z" }, + { url = "https://files.pythonhosted.org/packages/06/57/84722eefdd338c04cf3030ada66889298eaedf3e7a30a624201e0cbe424a/charset_normalizer-3.4.3-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:30a96e1e1f865f78b030d65241c1ee850cdf422d869e9028e2fc1d5e4db73b92", size = 157797, upload-time = "2025-08-09T07:56:49.756Z" }, + { url = "https://files.pythonhosted.org/packages/72/2a/aff5dd112b2f14bcc3462c312dce5445806bfc8ab3a7328555da95330e4b/charset_normalizer-3.4.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d716a916938e03231e86e43782ca7878fb602a125a91e7acb8b5112e2e96ac16", size = 152224, upload-time = "2025-08-09T07:56:51.369Z" }, + { url = "https://files.pythonhosted.org/packages/b7/8c/9839225320046ed279c6e839d51f028342eb77c91c89b8ef2549f951f3ec/charset_normalizer-3.4.3-cp314-cp314-win32.whl", hash = "sha256:c6dbd0ccdda3a2ba7c2ecd9d77b37f3b5831687d8dc1b6ca5f56a4880cc7b7ce", size = 100086, upload-time = "2025-08-09T07:56:52.722Z" }, + { url = "https://files.pythonhosted.org/packages/ee/7a/36fbcf646e41f710ce0a563c1c9a343c6edf9be80786edeb15b6f62e17db/charset_normalizer-3.4.3-cp314-cp314-win_amd64.whl", hash = "sha256:73dc19b562516fc9bcf6e5d6e596df0b4eb98d87e4f79f3ae71840e6ed21361c", size = 107400, upload-time = "2025-08-09T07:56:55.172Z" }, + { url = "https://files.pythonhosted.org/packages/8a/1f/f041989e93b001bc4e44bb1669ccdcf54d3f00e628229a85b08d330615c5/charset_normalizer-3.4.3-py3-none-any.whl", hash = "sha256:ce571ab16d890d23b5c278547ba694193a45011ff86a9162a71307ed9f86759a", size = 53175, upload-time = "2025-08-09T07:57:26.864Z" }, +] + +[[package]] +name = "click" +version = "8.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/60/6c/8ca2efa64cf75a977a0d7fac081354553ebe483345c734fb6b6515d96bbc/click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202", size = 286342, upload-time = "2025-05-20T23:19:49.832Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/32/10bb5764d90a8eee674e9dc6f4db6a0ab47c8c4d0d83c27f7c39ac415a4d/click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b", size = 102215, upload-time = "2025-05-20T23:19:47.796Z" }, +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, +] + +[[package]] +name = "fastapi" +version = "0.116.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/78/d7/6c8b3bfe33eeffa208183ec037fee0cce9f7f024089ab1c5d12ef04bd27c/fastapi-0.116.1.tar.gz", hash = "sha256:ed52cbf946abfd70c5a0dccb24673f0670deeb517a88b3544d03c2a6bf283143", size = 296485, upload-time = "2025-07-11T16:22:32.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/47/d63c60f59a59467fda0f93f46335c9d18526d7071f025cb5b89d5353ea42/fastapi-0.116.1-py3-none-any.whl", hash = "sha256:c46ac7c312df840f0c9e220f7964bada936781bc4e2e6eb71f1c4d7553786565", size = 95631, upload-time = "2025-07-11T16:22:30.485Z" }, +] + +[[package]] +name = "google-api-core" +version = "2.25.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "google-auth" }, + { name = "googleapis-common-protos" }, + { name = "proto-plus" }, + { name = "protobuf" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/dc/21/e9d043e88222317afdbdb567165fdbc3b0aad90064c7e0c9eb0ad9955ad8/google_api_core-2.25.1.tar.gz", hash = "sha256:d2aaa0b13c78c61cb3f4282c464c046e45fbd75755683c9c525e6e8f7ed0a5e8", size = 165443, upload-time = "2025-06-12T20:52:20.439Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/4b/ead00905132820b623732b175d66354e9d3e69fcf2a5dcdab780664e7896/google_api_core-2.25.1-py3-none-any.whl", hash = "sha256:8a2a56c1fef82987a524371f99f3bd0143702fecc670c72e600c1cda6bf8dbb7", size = 160807, upload-time = "2025-06-12T20:52:19.334Z" }, +] + +[[package]] +name = "google-auth" +version = "2.40.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cachetools" }, + { name = "pyasn1-modules" }, + { name = "rsa" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9e/9b/e92ef23b84fa10a64ce4831390b7a4c2e53c0132568d99d4ae61d04c8855/google_auth-2.40.3.tar.gz", hash = "sha256:500c3a29adedeb36ea9cf24b8d10858e152f2412e3ca37829b3fa18e33d63b77", size = 281029, upload-time = "2025-06-04T18:04:57.577Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/63/b19553b658a1692443c62bd07e5868adaa0ad746a0751ba62c59568cd45b/google_auth-2.40.3-py2.py3-none-any.whl", hash = "sha256:1370d4593e86213563547f97a92752fc658456fe4514c809544f330fed45a7ca", size = 216137, upload-time = "2025-06-04T18:04:55.573Z" }, +] + +[[package]] +name = "googleapis-common-protos" +version = "1.70.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/39/24/33db22342cf4a2ea27c9955e6713140fedd51e8b141b5ce5260897020f1a/googleapis_common_protos-1.70.0.tar.gz", hash = "sha256:0e1b44e0ea153e6594f9f394fef15193a68aaaea2d843f83e2742717ca753257", size = 145903, upload-time = "2025-04-14T10:17:02.924Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/86/f1/62a193f0227cf15a920390abe675f386dec35f7ae3ffe6da582d3ade42c7/googleapis_common_protos-1.70.0-py3-none-any.whl", hash = "sha256:b8bfcca8c25a2bb253e0e0b0adaf8c00773e5e6af6fd92397576680b807e0fd8", size = 294530, upload-time = "2025-04-14T10:17:01.271Z" }, +] + +[[package]] +name = "h11" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, +] + +[[package]] +name = "httpcore" +version = "1.0.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" }, +] + +[[package]] +name = "httpx" +version = "0.28.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "certifi" }, + { name = "httpcore" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, +] + +[[package]] +name = "httpx-sse" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6e/fa/66bd985dd0b7c109a3bcb89272ee0bfb7e2b4d06309ad7b38ff866734b2a/httpx_sse-0.4.1.tar.gz", hash = "sha256:8f44d34414bc7b21bf3602713005c5df4917884f76072479b21f68befa4ea26e", size = 12998, upload-time = "2025-06-24T13:21:05.71Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/25/0a/6269e3473b09aed2dab8aa1a600c70f31f00ae1349bee30658f7e358a159/httpx_sse-0.4.1-py3-none-any.whl", hash = "sha256:cba42174344c3a5b06f255ce65b350880f962d99ead85e776f23c6618a377a37", size = 8054, upload-time = "2025-06-24T13:21:04.772Z" }, +] + +[[package]] +name = "idna" +version = "3.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490, upload-time = "2024-09-15T18:07:39.745Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, +] + +[[package]] +name = "proto-plus" +version = "1.26.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f4/ac/87285f15f7cce6d4a008f33f1757fb5a13611ea8914eb58c3d0d26243468/proto_plus-1.26.1.tar.gz", hash = "sha256:21a515a4c4c0088a773899e23c7bbade3d18f9c66c73edd4c7ee3816bc96a012", size = 56142, upload-time = "2025-03-10T15:54:38.843Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/6d/280c4c2ce28b1593a19ad5239c8b826871fc6ec275c21afc8e1820108039/proto_plus-1.26.1-py3-none-any.whl", hash = "sha256:13285478c2dcf2abb829db158e1047e2f1e8d63a077d94263c2b88b043c75a66", size = 50163, upload-time = "2025-03-10T15:54:37.335Z" }, +] + +[[package]] +name = "protobuf" +version = "6.32.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c0/df/fb4a8eeea482eca989b51cffd274aac2ee24e825f0bf3cbce5281fa1567b/protobuf-6.32.0.tar.gz", hash = "sha256:a81439049127067fc49ec1d36e25c6ee1d1a2b7be930675f919258d03c04e7d2", size = 440614, upload-time = "2025-08-14T21:21:25.015Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/18/df8c87da2e47f4f1dcc5153a81cd6bca4e429803f4069a299e236e4dd510/protobuf-6.32.0-cp310-abi3-win32.whl", hash = "sha256:84f9e3c1ff6fb0308dbacb0950d8aa90694b0d0ee68e75719cb044b7078fe741", size = 424409, upload-time = "2025-08-14T21:21:12.366Z" }, + { url = "https://files.pythonhosted.org/packages/e1/59/0a820b7310f8139bd8d5a9388e6a38e1786d179d6f33998448609296c229/protobuf-6.32.0-cp310-abi3-win_amd64.whl", hash = "sha256:a8bdbb2f009cfc22a36d031f22a625a38b615b5e19e558a7b756b3279723e68e", size = 435735, upload-time = "2025-08-14T21:21:15.046Z" }, + { url = "https://files.pythonhosted.org/packages/cc/5b/0d421533c59c789e9c9894683efac582c06246bf24bb26b753b149bd88e4/protobuf-6.32.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:d52691e5bee6c860fff9a1c86ad26a13afbeb4b168cd4445c922b7e2cf85aaf0", size = 426449, upload-time = "2025-08-14T21:21:16.687Z" }, + { url = "https://files.pythonhosted.org/packages/ec/7b/607764ebe6c7a23dcee06e054fd1de3d5841b7648a90fd6def9a3bb58c5e/protobuf-6.32.0-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:501fe6372fd1c8ea2a30b4d9be8f87955a64d6be9c88a973996cef5ef6f0abf1", size = 322869, upload-time = "2025-08-14T21:21:18.282Z" }, + { url = "https://files.pythonhosted.org/packages/40/01/2e730bd1c25392fc32e3268e02446f0d77cb51a2c3a8486b1798e34d5805/protobuf-6.32.0-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:75a2aab2bd1aeb1f5dc7c5f33bcb11d82ea8c055c9becbb41c26a8c43fd7092c", size = 322009, upload-time = "2025-08-14T21:21:19.893Z" }, + { url = "https://files.pythonhosted.org/packages/9c/f2/80ffc4677aac1bc3519b26bc7f7f5de7fce0ee2f7e36e59e27d8beb32dd1/protobuf-6.32.0-py3-none-any.whl", hash = "sha256:ba377e5b67b908c8f3072a57b63e2c6a4cbd18aea4ed98d2584350dbf46f2783", size = 169287, upload-time = "2025-08-14T21:21:23.515Z" }, +] + +[[package]] +name = "pyasn1" +version = "0.6.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/e9/01f1a64245b89f039897cb0130016d79f77d52669aae6ee7b159a6c4c018/pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034", size = 145322, upload-time = "2024-09-10T22:41:42.55Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/f1/d6a797abb14f6283c0ddff96bbdd46937f64122b8c925cab503dd37f8214/pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629", size = 83135, upload-time = "2024-09-11T16:00:36.122Z" }, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/e6/78ebbb10a8c8e4b61a59249394a4a594c1a7af95593dc933a349c8d00964/pyasn1_modules-0.4.2.tar.gz", hash = "sha256:677091de870a80aae844b1ca6134f54652fa2c8c5a52aa396440ac3106e941e6", size = 307892, upload-time = "2025-03-28T02:41:22.17Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/8d/d529b5d697919ba8c11ad626e835d4039be708a35b0d22de83a269a6682c/pyasn1_modules-0.4.2-py3-none-any.whl", hash = "sha256:29253a9207ce32b64c3ac6600edc75368f98473906e8fd1043bd6b5b1de2c14a", size = 181259, upload-time = "2025-03-28T02:41:19.028Z" }, +] + +[[package]] +name = "pydantic" +version = "2.11.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-types" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/00/dd/4325abf92c39ba8623b5af936ddb36ffcfe0beae70405d456ab1fb2f5b8c/pydantic-2.11.7.tar.gz", hash = "sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db", size = 788350, upload-time = "2025-06-14T08:33:17.137Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/c0/ec2b1c8712ca690e5d61979dee872603e92b8a32f94cc1b72d53beab008a/pydantic-2.11.7-py3-none-any.whl", hash = "sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b", size = 444782, upload-time = "2025-06-14T08:33:14.905Z" }, +] + +[[package]] +name = "pydantic-core" +version = "2.33.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ad/88/5f2260bdfae97aabf98f1778d43f69574390ad787afb646292a638c923d4/pydantic_core-2.33.2.tar.gz", hash = "sha256:7cb8bc3605c29176e1b105350d2e6474142d7c1bd1d9327c4a9bdb46bf827acc", size = 435195, upload-time = "2025-04-23T18:33:52.104Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/8a/2b41c97f554ec8c71f2a8a5f85cb56a8b0956addfe8b0efb5b3d77e8bdc3/pydantic_core-2.33.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a7ec89dc587667f22b6a0b6579c249fca9026ce7c333fc142ba42411fa243cdc", size = 2009000, upload-time = "2025-04-23T18:31:25.863Z" }, + { url = "https://files.pythonhosted.org/packages/a1/02/6224312aacb3c8ecbaa959897af57181fb6cf3a3d7917fd44d0f2917e6f2/pydantic_core-2.33.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3c6db6e52c6d70aa0d00d45cdb9b40f0433b96380071ea80b09277dba021ddf7", size = 1847996, upload-time = "2025-04-23T18:31:27.341Z" }, + { url = "https://files.pythonhosted.org/packages/d6/46/6dcdf084a523dbe0a0be59d054734b86a981726f221f4562aed313dbcb49/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e61206137cbc65e6d5256e1166f88331d3b6238e082d9f74613b9b765fb9025", size = 1880957, upload-time = "2025-04-23T18:31:28.956Z" }, + { url = "https://files.pythonhosted.org/packages/ec/6b/1ec2c03837ac00886ba8160ce041ce4e325b41d06a034adbef11339ae422/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb8c529b2819c37140eb51b914153063d27ed88e3bdc31b71198a198e921e011", size = 1964199, upload-time = "2025-04-23T18:31:31.025Z" }, + { url = "https://files.pythonhosted.org/packages/2d/1d/6bf34d6adb9debd9136bd197ca72642203ce9aaaa85cfcbfcf20f9696e83/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c52b02ad8b4e2cf14ca7b3d918f3eb0ee91e63b3167c32591e57c4317e134f8f", size = 2120296, upload-time = "2025-04-23T18:31:32.514Z" }, + { url = "https://files.pythonhosted.org/packages/e0/94/2bd0aaf5a591e974b32a9f7123f16637776c304471a0ab33cf263cf5591a/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:96081f1605125ba0855dfda83f6f3df5ec90c61195421ba72223de35ccfb2f88", size = 2676109, upload-time = "2025-04-23T18:31:33.958Z" }, + { url = "https://files.pythonhosted.org/packages/f9/41/4b043778cf9c4285d59742281a769eac371b9e47e35f98ad321349cc5d61/pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f57a69461af2a5fa6e6bbd7a5f60d3b7e6cebb687f55106933188e79ad155c1", size = 2002028, upload-time = "2025-04-23T18:31:39.095Z" }, + { url = "https://files.pythonhosted.org/packages/cb/d5/7bb781bf2748ce3d03af04d5c969fa1308880e1dca35a9bd94e1a96a922e/pydantic_core-2.33.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:572c7e6c8bb4774d2ac88929e3d1f12bc45714ae5ee6d9a788a9fb35e60bb04b", size = 2100044, upload-time = "2025-04-23T18:31:41.034Z" }, + { url = "https://files.pythonhosted.org/packages/fe/36/def5e53e1eb0ad896785702a5bbfd25eed546cdcf4087ad285021a90ed53/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:db4b41f9bd95fbe5acd76d89920336ba96f03e149097365afe1cb092fceb89a1", size = 2058881, upload-time = "2025-04-23T18:31:42.757Z" }, + { url = "https://files.pythonhosted.org/packages/01/6c/57f8d70b2ee57fc3dc8b9610315949837fa8c11d86927b9bb044f8705419/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:fa854f5cf7e33842a892e5c73f45327760bc7bc516339fda888c75ae60edaeb6", size = 2227034, upload-time = "2025-04-23T18:31:44.304Z" }, + { url = "https://files.pythonhosted.org/packages/27/b9/9c17f0396a82b3d5cbea4c24d742083422639e7bb1d5bf600e12cb176a13/pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5f483cfb75ff703095c59e365360cb73e00185e01aaea067cd19acffd2ab20ea", size = 2234187, upload-time = "2025-04-23T18:31:45.891Z" }, + { url = "https://files.pythonhosted.org/packages/b0/6a/adf5734ffd52bf86d865093ad70b2ce543415e0e356f6cacabbc0d9ad910/pydantic_core-2.33.2-cp312-cp312-win32.whl", hash = "sha256:9cb1da0f5a471435a7bc7e439b8a728e8b61e59784b2af70d7c169f8dd8ae290", size = 1892628, upload-time = "2025-04-23T18:31:47.819Z" }, + { url = "https://files.pythonhosted.org/packages/43/e4/5479fecb3606c1368d496a825d8411e126133c41224c1e7238be58b87d7e/pydantic_core-2.33.2-cp312-cp312-win_amd64.whl", hash = "sha256:f941635f2a3d96b2973e867144fde513665c87f13fe0e193c158ac51bfaaa7b2", size = 1955866, upload-time = "2025-04-23T18:31:49.635Z" }, + { url = "https://files.pythonhosted.org/packages/0d/24/8b11e8b3e2be9dd82df4b11408a67c61bb4dc4f8e11b5b0fc888b38118b5/pydantic_core-2.33.2-cp312-cp312-win_arm64.whl", hash = "sha256:cca3868ddfaccfbc4bfb1d608e2ccaaebe0ae628e1416aeb9c4d88c001bb45ab", size = 1888894, upload-time = "2025-04-23T18:31:51.609Z" }, + { url = "https://files.pythonhosted.org/packages/46/8c/99040727b41f56616573a28771b1bfa08a3d3fe74d3d513f01251f79f172/pydantic_core-2.33.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:1082dd3e2d7109ad8b7da48e1d4710c8d06c253cbc4a27c1cff4fbcaa97a9e3f", size = 2015688, upload-time = "2025-04-23T18:31:53.175Z" }, + { url = "https://files.pythonhosted.org/packages/3a/cc/5999d1eb705a6cefc31f0b4a90e9f7fc400539b1a1030529700cc1b51838/pydantic_core-2.33.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f517ca031dfc037a9c07e748cefd8d96235088b83b4f4ba8939105d20fa1dcd6", size = 1844808, upload-time = "2025-04-23T18:31:54.79Z" }, + { url = "https://files.pythonhosted.org/packages/6f/5e/a0a7b8885c98889a18b6e376f344da1ef323d270b44edf8174d6bce4d622/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a9f2c9dd19656823cb8250b0724ee9c60a82f3cdf68a080979d13092a3b0fef", size = 1885580, upload-time = "2025-04-23T18:31:57.393Z" }, + { url = "https://files.pythonhosted.org/packages/3b/2a/953581f343c7d11a304581156618c3f592435523dd9d79865903272c256a/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2b0a451c263b01acebe51895bfb0e1cc842a5c666efe06cdf13846c7418caa9a", size = 1973859, upload-time = "2025-04-23T18:31:59.065Z" }, + { url = "https://files.pythonhosted.org/packages/e6/55/f1a813904771c03a3f97f676c62cca0c0a4138654107c1b61f19c644868b/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ea40a64d23faa25e62a70ad163571c0b342b8bf66d5fa612ac0dec4f069d916", size = 2120810, upload-time = "2025-04-23T18:32:00.78Z" }, + { url = "https://files.pythonhosted.org/packages/aa/c3/053389835a996e18853ba107a63caae0b9deb4a276c6b472931ea9ae6e48/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fb2d542b4d66f9470e8065c5469ec676978d625a8b7a363f07d9a501a9cb36a", size = 2676498, upload-time = "2025-04-23T18:32:02.418Z" }, + { url = "https://files.pythonhosted.org/packages/eb/3c/f4abd740877a35abade05e437245b192f9d0ffb48bbbbd708df33d3cda37/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fdac5d6ffa1b5a83bca06ffe7583f5576555e6c8b3a91fbd25ea7780f825f7d", size = 2000611, upload-time = "2025-04-23T18:32:04.152Z" }, + { url = "https://files.pythonhosted.org/packages/59/a7/63ef2fed1837d1121a894d0ce88439fe3e3b3e48c7543b2a4479eb99c2bd/pydantic_core-2.33.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04a1a413977ab517154eebb2d326da71638271477d6ad87a769102f7c2488c56", size = 2107924, upload-time = "2025-04-23T18:32:06.129Z" }, + { url = "https://files.pythonhosted.org/packages/04/8f/2551964ef045669801675f1cfc3b0d74147f4901c3ffa42be2ddb1f0efc4/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:c8e7af2f4e0194c22b5b37205bfb293d166a7344a5b0d0eaccebc376546d77d5", size = 2063196, upload-time = "2025-04-23T18:32:08.178Z" }, + { url = "https://files.pythonhosted.org/packages/26/bd/d9602777e77fc6dbb0c7db9ad356e9a985825547dce5ad1d30ee04903918/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:5c92edd15cd58b3c2d34873597a1e20f13094f59cf88068adb18947df5455b4e", size = 2236389, upload-time = "2025-04-23T18:32:10.242Z" }, + { url = "https://files.pythonhosted.org/packages/42/db/0e950daa7e2230423ab342ae918a794964b053bec24ba8af013fc7c94846/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:65132b7b4a1c0beded5e057324b7e16e10910c106d43675d9bd87d4f38dde162", size = 2239223, upload-time = "2025-04-23T18:32:12.382Z" }, + { url = "https://files.pythonhosted.org/packages/58/4d/4f937099c545a8a17eb52cb67fe0447fd9a373b348ccfa9a87f141eeb00f/pydantic_core-2.33.2-cp313-cp313-win32.whl", hash = "sha256:52fb90784e0a242bb96ec53f42196a17278855b0f31ac7c3cc6f5c1ec4811849", size = 1900473, upload-time = "2025-04-23T18:32:14.034Z" }, + { url = "https://files.pythonhosted.org/packages/a0/75/4a0a9bac998d78d889def5e4ef2b065acba8cae8c93696906c3a91f310ca/pydantic_core-2.33.2-cp313-cp313-win_amd64.whl", hash = "sha256:c083a3bdd5a93dfe480f1125926afcdbf2917ae714bdb80b36d34318b2bec5d9", size = 1955269, upload-time = "2025-04-23T18:32:15.783Z" }, + { url = "https://files.pythonhosted.org/packages/f9/86/1beda0576969592f1497b4ce8e7bc8cbdf614c352426271b1b10d5f0aa64/pydantic_core-2.33.2-cp313-cp313-win_arm64.whl", hash = "sha256:e80b087132752f6b3d714f041ccf74403799d3b23a72722ea2e6ba2e892555b9", size = 1893921, upload-time = "2025-04-23T18:32:18.473Z" }, + { url = "https://files.pythonhosted.org/packages/a4/7d/e09391c2eebeab681df2b74bfe6c43422fffede8dc74187b2b0bf6fd7571/pydantic_core-2.33.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:61c18fba8e5e9db3ab908620af374db0ac1baa69f0f32df4f61ae23f15e586ac", size = 1806162, upload-time = "2025-04-23T18:32:20.188Z" }, + { url = "https://files.pythonhosted.org/packages/f1/3d/847b6b1fed9f8ed3bb95a9ad04fbd0b212e832d4f0f50ff4d9ee5a9f15cf/pydantic_core-2.33.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95237e53bb015f67b63c91af7518a62a8660376a6a0db19b89acc77a4d6199f5", size = 1981560, upload-time = "2025-04-23T18:32:22.354Z" }, + { url = "https://files.pythonhosted.org/packages/6f/9a/e73262f6c6656262b5fdd723ad90f518f579b7bc8622e43a942eec53c938/pydantic_core-2.33.2-cp313-cp313t-win_amd64.whl", hash = "sha256:c2fc0a768ef76c15ab9238afa6da7f69895bb5d1ee83aeea2e3509af4472d0b9", size = 1935777, upload-time = "2025-04-23T18:32:25.088Z" }, +] + +[[package]] +name = "requests" +version = "2.32.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "idna" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, +] + +[[package]] +name = "rsa" +version = "4.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyasn1" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/da/8a/22b7beea3ee0d44b1916c0c1cb0ee3af23b700b6da9f04991899d0c555d4/rsa-4.9.1.tar.gz", hash = "sha256:e7bdbfdb5497da4c07dfd35530e1a902659db6ff241e39d9953cad06ebd0ae75", size = 29034, upload-time = "2025-04-16T09:51:18.218Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/64/8d/0133e4eb4beed9e425d9a98ed6e081a55d195481b7632472be1af08d2f6b/rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762", size = 34696, upload-time = "2025-04-16T09:51:17.142Z" }, +] + +[[package]] +name = "sniffio" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372, upload-time = "2024-02-25T23:20:04.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, +] + +[[package]] +name = "sse-starlette" +version = "3.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/6f/22ed6e33f8a9e76ca0a412405f31abb844b779d52c5f96660766edcd737c/sse_starlette-3.0.2.tar.gz", hash = "sha256:ccd60b5765ebb3584d0de2d7a6e4f745672581de4f5005ab31c3a25d10b52b3a", size = 20985, upload-time = "2025-07-27T09:07:44.565Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/10/c78f463b4ef22eef8491f218f692be838282cd65480f6e423d7730dfd1fb/sse_starlette-3.0.2-py3-none-any.whl", hash = "sha256:16b7cbfddbcd4eaca11f7b586f3b8a080f1afe952c15813455b162edea619e5a", size = 11297, upload-time = "2025-07-27T09:07:43.268Z" }, +] + +[[package]] +name = "starlette" +version = "0.47.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/15/b9/cc3017f9a9c9b6e27c5106cc10cc7904653c3eec0729793aec10479dd669/starlette-0.47.3.tar.gz", hash = "sha256:6bc94f839cc176c4858894f1f8908f0ab79dfec1a6b8402f6da9be26ebea52e9", size = 2584144, upload-time = "2025-08-24T13:36:42.122Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ce/fd/901cfa59aaa5b30a99e16876f11abe38b59a1a2c51ffb3d7142bb6089069/starlette-0.47.3-py3-none-any.whl", hash = "sha256:89c0778ca62a76b826101e7c709e70680a1699ca7da6b44d38eb0a7e61fe4b51", size = 72991, upload-time = "2025-08-24T13:36:40.887Z" }, +] + +[[package]] +name = "test-python-a2a-server" +version = "0.1.0" +source = { virtual = "." } +dependencies = [ + { name = "a2a-sdk", extra = ["http-server"] }, + { name = "uvicorn" }, +] + +[package.metadata] +requires-dist = [ + { name = "a2a-sdk", extras = ["http-server"], specifier = "==0.3.5" }, + { name = "uvicorn", specifier = "==0.35.0" }, +] + +[[package]] +name = "typing-extensions" +version = "4.15.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, +] + +[[package]] +name = "typing-inspection" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f8/b1/0c11f5058406b3af7609f121aaa6b609744687f1d158b3c3a5bf4cc94238/typing_inspection-0.4.1.tar.gz", hash = "sha256:6ae134cc0203c33377d43188d4064e9b357dba58cff3185f22924610e70a9d28", size = 75726, upload-time = "2025-05-21T18:55:23.885Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/69/cd203477f944c353c31bade965f880aa1061fd6bf05ded0726ca845b6ff7/typing_inspection-0.4.1-py3-none-any.whl", hash = "sha256:389055682238f53b04f7badcb49b989835495a96700ced5dab2d8feae4b26f51", size = 14552, upload-time = "2025-05-21T18:55:22.152Z" }, +] + +[[package]] +name = "urllib3" +version = "2.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/15/22/9ee70a2574a4f4599c47dd506532914ce044817c7752a79b6a51286319bc/urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760", size = 393185, upload-time = "2025-06-18T14:07:41.644Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, +] + +[[package]] +name = "uvicorn" +version = "0.35.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5e/42/e0e305207bb88c6b8d3061399c6a961ffe5fbb7e2aa63c9234df7259e9cd/uvicorn-0.35.0.tar.gz", hash = "sha256:bc662f087f7cf2ce11a1d7fd70b90c9f98ef2e2831556dd078d131b96cc94a01", size = 78473, upload-time = "2025-06-28T16:15:46.058Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/e2/dc81b1bd1dcfe91735810265e9d26bc8ec5da45b4c0f6237e286819194c3/uvicorn-0.35.0-py3-none-any.whl", hash = "sha256:197535216b25ff9b785e29a0b79199f55222193d47f820816e7da751e9bc8d4a", size = 66406, upload-time = "2025-06-28T16:15:44.816Z" }, +] diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index f36c121977..59517d9ebf 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -36,6 +36,7 @@ spring-management = "1.1.7" sqlite = "3.46.1.3" testcontainers = "1.19.7" mokksy = "0.5.0-Alpha3" +kotest = "6.0.3" [libraries] jetbrains-annotations = { module = "org.jetbrains:annotations", version.ref = "annotations" } @@ -82,7 +83,9 @@ mcp-client = { module = "io.modelcontextprotocol:kotlin-sdk-client", version.ref mcp-server = { module = "io.modelcontextprotocol:kotlin-sdk-server", version.ref = "mcp" } slf4j-simple = { module = "org.slf4j:slf4j-simple", version.ref = "slf4j" } jetsign-gradle-plugin = { module = "com.jetbrains:jet-sign", version.ref = "jetsign" } +kotest-assertions = { module = "io.kotest:kotest-assertions-core", version.ref = "kotest" } testcontainers = { module = "org.testcontainers:testcontainers", version.ref = "testcontainers" } +testcontainers-junit = { module = "org.testcontainers:junit-jupiter", version.ref = "testcontainers" } testcontainers-postgresql = { module = "org.testcontainers:postgresql", version.ref = "testcontainers" } testcontainers-mysql = { module = "org.testcontainers:mysql", version.ref = "testcontainers" } exposed-core = { module = "org.jetbrains.exposed:exposed-core", version.ref = "exposed" } From dca4e11c080beceea189291369dd1fbd7d57ef07 Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Fri, 12 Sep 2025 20:34:59 +0200 Subject: [PATCH 25/52] [a2a] Improve agent card loading and resolution --- .../kotlin/ai/koog/a2a/client/A2AClient.kt | 41 +++++++++++-------- .../ai/koog/a2a/client/AgentCardResolver.kt | 17 +++++++- .../a2a/client/A2AClientIntegrationTest.kt | 2 +- 3 files changed, 40 insertions(+), 20 deletions(-) diff --git a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt index ae6a0decf4..1c320e7eb4 100644 --- a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt +++ b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt @@ -24,32 +24,39 @@ public open class A2AClient( private val transport: ClientTransport, private val agentCardResolver: AgentCardResolver, ) { + @Volatile + protected var agentCard: AgentCard? = null + /** - * Currently cached version of the agent card. - * Shouldn't be used directly to read values from it, since it can be updated by the [loadAgentCard] method. - * Always use [getAgentCard] instead. + * Performs initialization logic. + * Currently only retrieves the [AgentCard]. */ - @Suppress("PropertyName") - @Volatile - protected open lateinit var _agentCard: AgentCard + public open suspend fun connect() { + getAgentCard() + } /** - * Resolve agent card from the provided [agentCardResolver] and cache it. - * Can be called multiple times, to update cached version of the agent card. + * Retrieves [AgentCard] by calling [AgentCardResolver.resolve]. + * Saves it to the cache. */ - public open suspend fun loadAgentCard(): AgentCard { + public open suspend fun getAgentCard(): AgentCard { return agentCardResolver.resolve().also { - _agentCard = it + agentCard = it } } /** - * Get current cached version of agent card. + * Retrieves currently cached [AgentCard] + * + * @throws [IllegalStateException] if it's not initialized */ - public open fun getAgentCard(): AgentCard = _agentCard + public open fun cachedAgentCard(): AgentCard { + return checkNotNull(agentCard) { "Agent card is not initialized." } + } /** - * Calls [agent/getAuthenticatedExtendedCard](https://a2a-protocol.org/latest/specification/#710-agentgetauthenticatedextendedcard) + * Calls [agent/getAuthenticatedExtendedCard](https://a2a-protocol.org/latest/specification/#710-agentgetauthenticatedextendedcard). + * Updates cached [AgentCard]. * * @throws A2AException if server returned an error. */ @@ -62,7 +69,7 @@ public open class A2AClient( } return transport.getAuthenticatedExtendedAgentCard(request, ctx).also { - _agentCard = it.data + agentCard = it.data } } @@ -87,7 +94,7 @@ public open class A2AClient( request: Request, ctx: ClientCallContext = ClientCallContext.Default ): Flow> { - check(getAgentCard().capabilities.streaming == true) { + check(cachedAgentCard().capabilities.streaming == true) { "Agent card reports that streaming is not supported." } @@ -186,8 +193,8 @@ public open class A2AClient( return transport.deleteTaskPushNotificationConfig(request, ctx) } - private fun checkPushNotificationsSupported() { - check(getAgentCard().capabilities.pushNotifications == true) { + protected fun checkPushNotificationsSupported() { + check(cachedAgentCard().capabilities.pushNotifications == true) { "Agent card reports that push notifications are not supported." } } diff --git a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/AgentCardResolver.kt b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/AgentCardResolver.kt index c0ced94b8f..7eddaf49cd 100644 --- a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/AgentCardResolver.kt +++ b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/AgentCardResolver.kt @@ -1,5 +1,3 @@ -@file:Suppress("MissingKDocForPublicAPI") - package ai.koog.a2a.client import ai.koog.a2a.model.AgentCard @@ -13,14 +11,29 @@ import io.ktor.http.contentType import io.ktor.serialization.kotlinx.json.json import kotlinx.serialization.json.Json +/** + * Represents a resolver capable of fetching an [AgentCard]. + * + * Implementations of this interface are responsible for providing the mechanism to retrieve + * the [AgentCard], which may include network requests, local lookups, or other means of resolution. + */ public interface AgentCardResolver { + /** + * Resolves and retrieves an [AgentCard]. + */ public suspend fun resolve(): AgentCard } +/** + * An [AgentCardResolver] that always returns the provided [agentCard]. + */ public class ExplicitAgentCardResolver(public val agentCard: AgentCard) : AgentCardResolver { override suspend fun resolve(): AgentCard = agentCard } +/** + * An [AgentCardResolver] that fetches the [AgentCard] from the provided [baseUrl] at [path]. + */ public class UrlAgentCardResolver( public val baseUrl: String, public val path: String = wellKnownPath, diff --git a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientIntegrationTest.kt b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientIntegrationTest.kt index bf4cd468fa..fe37a3aa2a 100644 --- a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientIntegrationTest.kt +++ b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientIntegrationTest.kt @@ -80,7 +80,7 @@ class A2AClientIntegrationTest { @BeforeTest fun initClient() = runTest { - client.loadAgentCard() + client.connect() } @Test From f1a8c52c28d9d60eb66dfc7cc5b2e99d94d7ee26 Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Sat, 13 Sep 2025 15:41:21 +0200 Subject: [PATCH 26/52] [a2a] Implement A2AServer core with session management and storage --- .../kotlin/ai/koog/a2a/client/A2AClient.kt | 6 +- .../ai/koog/a2a/client/AgentCardResolver.kt | 8 +- .../kotlin/ai/koog/a2a/consts/A2AConsts.kt | 13 + .../kotlin/ai/koog/a2a/model/Core.kt | 24 +- .../kotlin/ai/koog/a2a/model/Serialization.kt | 24 +- .../kotlin/ai/koog/a2a/model/Task.kt | 6 +- .../kotlin/ai/koog/a2a/model/TaskEvents.kt | 12 +- .../ai/koog/a2a/transport/ClientTransport.kt | 6 +- .../ai/koog/a2a/transport/ServerTransport.kt | 53 +++- .../kotlin/ai/koog/a2a/utils/RWLock.kt | 4 +- .../kotlin/ai/koog/a2a/utils/ResultUtils.kt | 2 + a2a/a2a-server/build.gradle.kts | 1 + .../kotlin/ai/koog/a2a/server/A2AServer.kt | 217 ++++++++++++++- .../ai/koog/a2a/server/agent/AgentExecutor.kt | 60 +++++ .../koog/a2a/server/exceptions/Exceptions.kt | 18 ++ .../server/messages/InMemoryMessageStorage.kt | 47 ++++ .../a2a/server/messages/MessageStorage.kt | 101 +++++++ .../koog/a2a/server/session/RequestContext.kt | 23 ++ .../server/session/SessionEventProcessor.kt | 212 +++++++++++++++ .../koog/a2a/server/session/SessionManager.kt | 94 +++++++ .../a2a/server/tasks/InMemoryTaskStorage.kt | 148 ++++++++++ .../ai/koog/a2a/server/tasks/TaskStorage.kt | 154 +++++++++++ .../messages/InMemoryMessageStorageTest.kt | 99 +++++++ .../server/tasks/InMemoryTaskStorageTest.kt | 253 ++++++++++++++++++ .../HttpJSONRPCClientTransportMokksyTest.kt | 155 ----------- .../jsonrpc/JSONRPCClientTransport.kt | 6 +- .../jsonrpc/JSONRPCServerTransport.kt | 2 + .../http/HttpJSONRPCServerTransport.kt | 33 ++- .../http/HttpJSONRPCServerTransportTest.kt | 6 +- .../src/agent_executor.py | 11 +- 30 files changed, 1582 insertions(+), 216 deletions(-) create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/consts/A2AConsts.kt create mode 100644 a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt create mode 100644 a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/exceptions/Exceptions.kt create mode 100644 a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/messages/InMemoryMessageStorage.kt create mode 100644 a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/messages/MessageStorage.kt create mode 100644 a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/RequestContext.kt create mode 100644 a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt create mode 100644 a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt create mode 100644 a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorage.kt create mode 100644 a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/TaskStorage.kt create mode 100644 a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/messages/InMemoryMessageStorageTest.kt create mode 100644 a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorageTest.kt delete mode 100644 a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportMokksyTest.kt diff --git a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt index 1c320e7eb4..ce991559b8 100644 --- a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt +++ b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt @@ -3,13 +3,13 @@ package ai.koog.a2a.client import ai.koog.a2a.exceptions.A2AException import ai.koog.a2a.model.AgentCard import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.Event import ai.koog.a2a.model.MessageSendParams import ai.koog.a2a.model.Task import ai.koog.a2a.model.TaskIdParams import ai.koog.a2a.model.TaskPushNotificationConfig import ai.koog.a2a.model.TaskPushNotificationConfigParams import ai.koog.a2a.model.TaskQueryParams -import ai.koog.a2a.model.UpdateEvent import ai.koog.a2a.transport.ClientCallContext import ai.koog.a2a.transport.ClientTransport import ai.koog.a2a.transport.Request @@ -93,7 +93,7 @@ public open class A2AClient( public fun sendMessageStreaming( request: Request, ctx: ClientCallContext = ClientCallContext.Default - ): Flow> { + ): Flow> { check(cachedAgentCard().capabilities.streaming == true) { "Agent card reports that streaming is not supported." } @@ -133,7 +133,7 @@ public open class A2AClient( public fun resubscribeTask( request: Request, ctx: ClientCallContext = ClientCallContext.Default - ): Flow> { + ): Flow> { return transport.resubscribeTask(request, ctx) } diff --git a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/AgentCardResolver.kt b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/AgentCardResolver.kt index 7eddaf49cd..7a90b89a5a 100644 --- a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/AgentCardResolver.kt +++ b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/AgentCardResolver.kt @@ -1,5 +1,6 @@ package ai.koog.a2a.client +import ai.koog.a2a.consts.A2AConsts import ai.koog.a2a.model.AgentCard import io.ktor.client.HttpClient import io.ktor.client.call.body @@ -36,14 +37,9 @@ public class ExplicitAgentCardResolver(public val agentCard: AgentCard) : AgentC */ public class UrlAgentCardResolver( public val baseUrl: String, - public val path: String = wellKnownPath, + public val path: String = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH, baseHttpClient: HttpClient = HttpClient(), ) : AgentCardResolver { - public companion object { - @Suppress("ConstPropertyName") - public const val wellKnownPath: String = "/.well-known/agent-card.json" - } - private val httpClient: HttpClient = baseHttpClient.config { defaultRequest { url(baseUrl) diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/consts/A2AConsts.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/consts/A2AConsts.kt new file mode 100644 index 0000000000..4863d3b801 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/consts/A2AConsts.kt @@ -0,0 +1,13 @@ +package ai.koog.a2a.consts + +/** + * Some common global A2A constants. + */ +public object A2AConsts { + /** + * The well-known agent card URL following RFC 8615. + * + * More info: https://a2a-protocol.org/latest/specification/#53-recommended-location + */ + public const val AGENT_CARD_WELL_KNOWN_PATH: String = "/.well-known/agent-card.json" +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Core.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Core.kt index 1f06cf4d07..ab3de34149 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Core.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Core.kt @@ -5,8 +5,8 @@ import kotlinx.serialization.Serializable /** * Base interface for events. */ -@Serializable(with = UpdateEventSerializer::class) -public sealed interface UpdateEvent { +@Serializable(with = EventSerializer::class) +public sealed interface Event { /** * The type used as discriminator. */ @@ -14,7 +14,23 @@ public sealed interface UpdateEvent { } /** - * Base interface for communication units, such as messages or tasks. + * Base interface for communication events, such as messages or tasks. */ @Serializable(with = CommunicationEventSerializer::class) -public sealed interface CommunicationEvent : UpdateEvent +public sealed interface CommunicationEvent : Event + +/** + * Base interface for task events. + */ +@Serializable(with = TaskEventSerializer::class) +public sealed interface TaskEvent : Event { + /** + * The ID of the task associated with this event. + */ + public val taskId: String + + /** + * The ID of the context associated with this event. + */ + public val contextId: String +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt index a3f69b6346..bdf6bde1a4 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt @@ -48,15 +48,17 @@ internal object FileSerializer : JsonContentPolymorphicSerializer(File::cl } } -internal object UpdateEventSerializer : JsonContentPolymorphicSerializer(UpdateEvent::class) { - override fun selectDeserializer(element: JsonElement): DeserializationStrategy { +internal object EventSerializer : JsonContentPolymorphicSerializer(Event::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { val jsonObject = element.jsonObject val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: error("Missing 'kind' field in Event") return when (kind) { "status-update" -> TaskStatusUpdateEvent.serializer() "artifact-update" -> TaskArtifactUpdateEvent.serializer() - else -> CommunicationEventSerializer + "task" -> Task.serializer() + "message" -> Message.serializer() + else -> error("Unknown kind: $kind") } } } @@ -64,7 +66,7 @@ internal object UpdateEventSerializer : JsonContentPolymorphicSerializer(CommunicationEvent::class) { override fun selectDeserializer(element: JsonElement): DeserializationStrategy { val jsonObject = element.jsonObject - val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: error("Missing 'kind' field in Communication") + val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: error("Missing 'kind' field in CommunicationEvent") return when (kind) { "task" -> Task.serializer() @@ -73,3 +75,17 @@ internal object CommunicationEventSerializer : JsonContentPolymorphicSerializer< } } } + +internal object TaskEventSerializer : JsonContentPolymorphicSerializer(TaskEvent::class) { + override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + val jsonObject = element.jsonObject + val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: error("Missing 'kind' field in TaskEvent") + + return when (kind) { + "task" -> Task.serializer() + "status-update" -> TaskStatusUpdateEvent.serializer() + "artifact-update" -> TaskArtifactUpdateEvent.serializer() + else -> error("Unknown kind: $kind") + } + } +} diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt index 3c57061f05..1bde51d937 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt @@ -23,14 +23,16 @@ import kotlin.uuid.Uuid public data class Task( @EncodeDefault public val id: String = Uuid.random().toString(), - public val contextId: String, + override val contextId: String, public val status: TaskStatus, public val history: List? = null, public val artifacts: List? = null, public val metadata: JsonObject? = null, -) : CommunicationEvent { +) : CommunicationEvent, TaskEvent { @EncodeDefault override val kind: String = "task" + + override val taskId: String get() = id } /** diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskEvents.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskEvents.kt index ef28cc7533..2187b90145 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskEvents.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/TaskEvents.kt @@ -16,12 +16,12 @@ import kotlinx.serialization.json.JsonObject */ @Serializable public data class TaskStatusUpdateEvent( - public val taskId: String, - public val contextId: String, + override val taskId: String, + override val contextId: String, public val status: TaskStatus, public val final: Boolean, public val metadata: JsonObject? = null, -) : UpdateEvent { +) : TaskEvent { @EncodeDefault override val kind: String = "status-update" } @@ -39,13 +39,13 @@ public data class TaskStatusUpdateEvent( */ @Serializable public data class TaskArtifactUpdateEvent( - public val taskId: String, - public val contextId: String, + override val taskId: String, + override val contextId: String, public val artifact: Artifact, public val append: Boolean? = null, public val lastChunk: Boolean? = null, public val metadata: JsonObject? = null, -) : UpdateEvent { +) : TaskEvent { @EncodeDefault override val kind: String = "artifact-update" } diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt index ff548e7a5e..6d1b5d1a89 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt @@ -3,13 +3,13 @@ package ai.koog.a2a.transport import ai.koog.a2a.exceptions.A2AException import ai.koog.a2a.model.AgentCard import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.Event import ai.koog.a2a.model.MessageSendParams import ai.koog.a2a.model.Task import ai.koog.a2a.model.TaskIdParams import ai.koog.a2a.model.TaskPushNotificationConfig import ai.koog.a2a.model.TaskPushNotificationConfigParams import ai.koog.a2a.model.TaskQueryParams -import ai.koog.a2a.model.UpdateEvent import kotlinx.coroutines.flow.Flow import kotlinx.serialization.SerializationException @@ -53,7 +53,7 @@ public interface ClientTransport : AutoCloseable { public fun sendMessageStreaming( request: Request, ctx: ClientCallContext = ClientCallContext.Default - ): Flow> + ): Flow> /** * Calls [tasks/get](https://a2a-protocol.org/latest/specification/#73-tasksget) @@ -83,7 +83,7 @@ public interface ClientTransport : AutoCloseable { public fun resubscribeTask( request: Request, ctx: ClientCallContext = ClientCallContext.Default - ): Flow> + ): Flow> /** * Calls [tasks/pushNotificationConfig/set](https://a2a-protocol.org/latest/specification/#75-taskspushnotificationconfigset) diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt index fb09ac6a61..4de35fa21a 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt @@ -4,13 +4,13 @@ import ai.koog.a2a.exceptions.A2AException import ai.koog.a2a.exceptions.A2AInternalErrorException import ai.koog.a2a.model.AgentCard import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.Event import ai.koog.a2a.model.MessageSendParams import ai.koog.a2a.model.Task import ai.koog.a2a.model.TaskIdParams import ai.koog.a2a.model.TaskPushNotificationConfig import ai.koog.a2a.model.TaskPushNotificationConfigParams import ai.koog.a2a.model.TaskQueryParams -import ai.koog.a2a.model.UpdateEvent import kotlinx.coroutines.flow.Flow /** @@ -64,7 +64,7 @@ public interface RequestHandler { public fun onSendMessageStreaming( request: Request, ctx: ServerCallContext - ): Flow> + ): Flow> /** * Handles [tasks/get](https://a2a-protocol.org/latest/specification/#73-tasksget) @@ -84,7 +84,7 @@ public interface RequestHandler { public fun onResubscribeTask( request: Request, ctx: ServerCallContext - ): Flow> + ): Flow> /** * Handles [tasks/cancel](https://a2a-protocol.org/latest/specification/#74-taskscancel) @@ -141,12 +141,53 @@ public interface RequestHandler { * Represents the server context of a call. * * @property headers Headers associated with the call. + * @property state State associated with the call, allows storing arbitrary values. To get typed value from the state, + * use [getFromState] or [getFromStateOrNull] with appropriate [StateKey]. */ public class ServerCallContext( public val headers: Map> = emptyMap(), + public val state: Map, Any> = emptyMap() ) { - @Suppress("MissingKDocForPublicAPI") - public companion object { - public val Default: ServerCallContext = ServerCallContext() + /** + * Retrieves a value of type [T] associated with the specified [key] from the [state] map. + * If the [key] is not found in the state, returns `null`. + * + * Performs unsafe cast under the hood, so make sure the value is of the expected type. + * + * @param key The state key for which the associated value needs to be retrieved. + */ + public fun getFromStateOrNull(key: StateKey): T? { + return state[key]?.let { + @Suppress("UNCHECKED_CAST") + it as T + } } + + /** + * Retrieves a value of type [T] associated with the specified [key] from the [state] map. + * + * Performs unsafe cast under the hood, so make sure the value is of the expected type. + * + * @param key The state key for which the associated value needs to be retrieved. + * @throws NoSuchElementException if the [key] is not found in the state. + */ + public fun getFromState(key: StateKey): T { + return getFromStateOrNull(key) ?: throw NoSuchElementException("State key $key not found") + } + + /** + * Creates a copy of this [ServerCallContext]. + */ + public fun copy( + headers: Map> = this.headers, + state: Map, Any> = this.state, + ): ServerCallContext = ServerCallContext(headers, state) +} + +/** + * Helper class to be used with [ServerCallContext.state] to store and retrieve values associated with a key in a typed + * manner. + */ +public class StateKey<@Suppress("unused") T>(public val name: String) { + override fun toString(): String = "${super.toString()}(name=$name)" } diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/RWLock.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/RWLock.kt index 25c3813ce8..6705868ffa 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/RWLock.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/RWLock.kt @@ -1,5 +1,6 @@ package ai.koog.a2a.utils +import ai.koog.a2a.annotations.InternalA2AApi import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock @@ -8,8 +9,9 @@ import kotlinx.coroutines.sync.withLock /** * A KMP read-write lock implementation that allows concurrent read access but ensures exclusive write access. * - * This implementation uses `kotlinx.coroutines.sync.Mutex` to coordinate access for both readers and writers. + * This implementation uses [Mutex] to coordinate access for both readers and writers. */ +@InternalA2AApi public class RWLock { private val writeMutex = Mutex() private var readersCount = 0 diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/ResultUtils.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/ResultUtils.kt index cdf5d21dd7..ff73526be5 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/ResultUtils.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/ResultUtils.kt @@ -1,5 +1,6 @@ package ai.koog.a2a.utils +import ai.koog.a2a.annotations.InternalA2AApi import kotlinx.coroutines.CancellationException // FIXME copied from agents-core module, because a2a does not depend on other Koog modules. @@ -7,6 +8,7 @@ import kotlinx.coroutines.CancellationException /** * Same as [runCatching], but does not catch [CancellationException], throwing it instead, making it safe to use with coroutines. */ +@InternalA2AApi public inline fun runCatchingCancellable(block: () -> R): Result { return try { Result.success(block()) diff --git a/a2a/a2a-server/build.gradle.kts b/a2a/a2a-server/build.gradle.kts index ca41b397e3..29059c9caf 100644 --- a/a2a/a2a-server/build.gradle.kts +++ b/a2a/a2a-server/build.gradle.kts @@ -27,6 +27,7 @@ kotlin { commonTest { dependencies { implementation(kotlin("test")) + implementation(libs.kotlinx.coroutines.test) } } diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt index 4c901c109d..47a83e6921 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt @@ -1,64 +1,261 @@ package ai.koog.a2a.server +import ai.koog.a2a.exceptions.A2AInternalErrorException +import ai.koog.a2a.exceptions.A2AInvalidParamsException +import ai.koog.a2a.exceptions.A2ATaskNotFoundException +import ai.koog.a2a.exceptions.A2AUnsupportedOperationException import ai.koog.a2a.model.AgentCard import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.Event +import ai.koog.a2a.model.Message import ai.koog.a2a.model.MessageSendParams import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskEvent import ai.koog.a2a.model.TaskIdParams import ai.koog.a2a.model.TaskPushNotificationConfig import ai.koog.a2a.model.TaskPushNotificationConfigParams import ai.koog.a2a.model.TaskQueryParams -import ai.koog.a2a.model.UpdateEvent +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatus +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.server.agent.AgentExecutor +import ai.koog.a2a.server.messages.ContextMessageStorage +import ai.koog.a2a.server.messages.InMemoryMessageStorage +import ai.koog.a2a.server.messages.MessageStorage +import ai.koog.a2a.server.session.RequestContext +import ai.koog.a2a.server.session.SessionEventProcessor +import ai.koog.a2a.server.session.SessionManager +import ai.koog.a2a.server.tasks.ContextTaskStorage +import ai.koog.a2a.server.tasks.InMemoryTaskStorage +import ai.koog.a2a.server.tasks.TaskStorage import ai.koog.a2a.transport.Request import ai.koog.a2a.transport.RequestHandler import ai.koog.a2a.transport.Response import ai.koog.a2a.transport.ServerCallContext +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.channelFlow +import kotlinx.coroutines.flow.emitAll +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.last +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.launch +import kotlinx.datetime.Clock +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid /** * A2A server responsible for handling requests from A2A clients. */ -public class A2AServer : RequestHandler { +public open class A2AServer( + protected val agentExecutor: AgentExecutor, + protected val agentCard: AgentCard, + protected val agentCardExtended: AgentCard? = null, + protected val taskStorage: TaskStorage = InMemoryTaskStorage(), + protected val messageStorage: MessageStorage = InMemoryMessageStorage(), + protected val coroutineScope: CoroutineScope = CoroutineScope(SupervisorJob()), + protected val clock: Clock = Clock.System, +) : RequestHandler { + protected val sessionManager: SessionManager = SessionManager(coroutineScope) + override suspend fun onGetAuthenticatedExtendedAgentCard( request: Request, ctx: ServerCallContext ): Response { - TODO("Not yet implemented") + // Default server implementation does not provide authorization, return extended card directly if present + return Response( + data = agentCardExtended ?: agentCard, + id = request.id + ) } override suspend fun onSendMessage( request: Request, ctx: ServerCallContext ): Response { - TODO("Not yet implemented") + val messageConfiguration = request.data.configuration + // Reusing streaming logic here, because it's essentially the same, only we need some particular event from the stream + val eventStream = onSendMessageStreaming(request, ctx) + + return if (messageConfiguration?.blocking == true) { + // If blocking is requested, attempt to wait for the last event, until the current turn of the agent execution is finished. + val lastEventResponse = eventStream.last() + + when (val eventData = lastEventResponse.data) { + is Message -> Response(data = eventData, id = lastEventResponse.id) + is TaskEvent -> + taskStorage + .get( + eventData.taskId, + historyLength = messageConfiguration.historyLength, + includeArtifacts = true + ) + ?.let { Response(data = it, id = lastEventResponse.id) } + ?: throw A2ATaskNotFoundException("Task '${eventData.taskId}' not found after the agent execution") + } + } else { + // Else read the first event from the stream, check that it's a proper communication event and return it. + val firstEventResponse = eventStream.first() + + when (val eventData = firstEventResponse.data) { + is Message -> Response(data = eventData, id = firstEventResponse.id) + is Task -> Response(data = eventData, id = firstEventResponse.id) + else -> throw A2AInternalErrorException("Got unexpected event type from the agent '${eventData::class.simpleName}'") + } + } } override fun onSendMessageStreaming( request: Request, ctx: ServerCallContext - ): Flow> { - TODO("Not yet implemented") + ): Flow> = channelFlow { + val message = request.data.message + val taskId = message.taskId + + // Check if message links to a task. + val eventProcessor = if (taskId != null) { + // Check if the task is still in progress, no message can be sent. + if (sessionManager.processorForTask(taskId) != null) { + throw A2AUnsupportedOperationException("Task '$taskId' is still running, can't send messages to the task that has not yielded control") + } + + // Check if the specified task exists and message context id matches the task context id. + val task = taskStorage.get(taskId) ?: throw A2ATaskNotFoundException("Task '$taskId' not found") + if (message.contextId != task.contextId) { + throw A2AInvalidParamsException("Message context id '${message.contextId}' doesn't match task context id '${task.contextId}'") + } + + // Create new event processor for the task. + SessionEventProcessor( + contextId = task.contextId, + taskStorage = taskStorage, + coroutineScope = coroutineScope, + currentTask = task + ) + } else { + // Create new event processor without task specified. + @OptIn(ExperimentalUuidApi::class) + SessionEventProcessor( + contextId = message.contextId ?: Uuid.random().toString(), + taskStorage = taskStorage, + // Use specified context id or generate a new random one. + coroutineScope = coroutineScope, + ) + }.also { + sessionManager.addProcessor(it) + } + + // Create request context based on the request information. + val requestContext = RequestContext( + contextId = eventProcessor.contextId, + callContext = ctx, + params = request.data, + taskStorage = ContextTaskStorage(eventProcessor.contextId, taskStorage), + messageStorage = ContextMessageStorage(eventProcessor.contextId, messageStorage), + ) + + // Subscribe to events stream and start emitting them. + val collectionJob = launch { + eventProcessor.events + .collect { event -> + send(Response(data = event, id = request.id)) + } + } + + // Execute the agent. + agentExecutor.execute(requestContext, eventProcessor) + + // Close event processor session and collecting job + eventProcessor.close() + collectionJob.cancel() } override suspend fun onGetTask( request: Request, ctx: ServerCallContext ): Response { - TODO("Not yet implemented") + val taskParams = request.data + + return Response( + data = taskStorage.get(taskParams.id, historyLength = taskParams.historyLength) + ?: throw A2ATaskNotFoundException("Task '${taskParams.id}' not found"), + id = request.id, + ) } override suspend fun onCancelTask( request: Request, ctx: ServerCallContext ): Response { - TODO("Not yet implemented") + val taskParams = request.data + val eventProcessor = sessionManager.processorForTask(taskParams.id) + + // Task is not running, check if it exists in the storage. + if (eventProcessor == null) { + val task = taskStorage.get(taskParams.id, historyLength = 0, includeArtifacts = true) + ?: throw A2ATaskNotFoundException("Task '${taskParams.id}' not found") + + // Task exists but not running - check if it is already canceled + if (task.status.state == TaskState.Canceled) { + return Response(data = task, id = request.id) + } + + // If the task is not canceled and in the terminal state, throw. + if (task.status.state.terminal) { + throw A2AUnsupportedOperationException("Task '${taskParams.id}' is already in terminal state ${task.status.state}") + } + + // Proceed to mark the task as canceled. + taskStorage.update( + TaskStatusUpdateEvent( + taskId = task.id, + contextId = task.contextId, + status = TaskStatus( + state = TaskState.Canceled, + timestamp = clock.now() + ), + final = true + ) + ) + } else { + // Create request context based on the request information. + val requestContext = RequestContext( + contextId = taskParams.id, + callContext = ctx, + params = request.data, + taskStorage = ContextTaskStorage(eventProcessor.contextId, taskStorage), + messageStorage = ContextMessageStorage(eventProcessor.contextId, messageStorage), + ) + + // Attempt to cancel the agent execution and wait until it's finished. + agentExecutor.cancel(requestContext, eventProcessor) + + // If `cancel` finished without exceptions, assume the cancellation was successful and close event processor session too. + eventProcessor.close() + } + + // Return the final task state. + return Response( + data = taskStorage.get(taskParams.id, historyLength = 0, includeArtifacts = true) + ?: throw A2ATaskNotFoundException("Task '${taskParams.id}' not found"), + id = request.id, + ) } override fun onResubscribeTask( request: Request, ctx: ServerCallContext - ): Flow> { - TODO("Not yet implemented") + ): Flow> = flow { + val taskParams = request.data + val eventProcessor = sessionManager.processorForTask(taskParams.id) + ?: throw A2AUnsupportedOperationException("Task '${taskParams.id}' is not currently running or does not exist") + + emitAll( + eventProcessor.events + .map { event -> Response(data = event, id = request.id) } + ) } override suspend fun onSetTaskPushNotificationConfig( diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt new file mode 100644 index 0000000000..6b60dbb07e --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt @@ -0,0 +1,60 @@ +package ai.koog.a2a.server.agent + +import ai.koog.a2a.exceptions.A2AContentTypeNotSupportedException +import ai.koog.a2a.exceptions.A2AUnsupportedOperationException +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.TaskEvent +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.server.session.RequestContext +import ai.koog.a2a.server.session.SessionEventProcessor +import kotlin.jvm.JvmName + +/** + * Implementations of this interface contain the core logic of the agent, + * executing actions based on requests and publishing updates to an event processor. + */ +public interface AgentExecutor { + /** + * Execute the agent's logic for a given request context. + * + * The agent should read necessary information from the [context] and publish [TaskEvent] or [Message] events to + * the [eventProcessor]. This method should return once the agent's execution for this request is complete or + * yields control (e.g., enters an [TaskState.InputRequired] state). + * + * Can throw an exception if the input is invalid or the agent fails to execute the request. + * + * @throws Exception if something goes wrong during execution. Should prefer more specific exceptions when available, + * e.g., [A2AContentTypeNotSupportedException], [A2AUnsupportedOperationException], etc. See full list of available + * A2A exceptions in [ai.koog.a2a.exceptions]. + */ + public suspend fun execute(context: RequestContext, eventProcessor: SessionEventProcessor) + + /** + * Request the agent to cancel an ongoing task. + * + * The agent should attempt to stop the task identified by the task id in the context and publish a [TaskStatusUpdateEvent] with state + * [TaskState.Canceled] to the [eventProcessor]. + * + * Can throw an exception if the agent fails to cancel the task. + * + * @throws Exception if something goes wrong during execution. Should prefer more specific exceptions when available, + * e.g., [A2AContentTypeNotSupportedException], [A2AUnsupportedOperationException], etc. See full list of available + * A2A exceptions in [ai.koog.a2a.exceptions]. + */ + public suspend fun cancel(context: RequestContext, eventProcessor: SessionEventProcessor) +} + +/** + * Returns the task id from the [MessageSendParams] in the [RequestContext]. + */ +@get:JvmName("getMessageTaskId") +public val RequestContext.taskId: String? get() = params.message.taskId + +/** + * Returns the task id from the [TaskIdParams] in the [RequestContext]. + */ +@get:JvmName("getTaskIdParamsTaskId") +public val RequestContext.taskId: String get() = params.id diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/exceptions/Exceptions.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/exceptions/Exceptions.kt new file mode 100644 index 0000000000..957ba80034 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/exceptions/Exceptions.kt @@ -0,0 +1,18 @@ +package ai.koog.a2a.server.exceptions + +import ai.koog.a2a.server.session.SessionEventProcessor + +/** + * Indicates an error with task-related operations. + */ +public class TaskOperationException(message: String, cause: Throwable? = null) : Exception(message, cause) + +/** + * Indicates an error with message-related operations. + */ +public class MessageOperationException(message: String, cause: Throwable? = null) : Exception(message, cause) + +/** + * Indicates a failure in sending an event through the [SessionEventProcessor] because of invalid event. + */ +public class InvalidEventException(message: String, cause: Throwable? = null) : Exception(message, cause) diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/messages/InMemoryMessageStorage.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/messages/InMemoryMessageStorage.kt new file mode 100644 index 0000000000..6960c19e15 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/messages/InMemoryMessageStorage.kt @@ -0,0 +1,47 @@ +package ai.koog.a2a.server.messages + +import ai.koog.a2a.annotations.InternalA2AApi +import ai.koog.a2a.model.Message +import ai.koog.a2a.server.exceptions.MessageOperationException +import ai.koog.a2a.utils.RWLock + +/** + * In-memory implementation of [MessageStorage] using a thread-safe map. + * + * This implementation stores messages in memory grouped by context ID and provides + * concurrency safety through mutex locks. + */ +@OptIn(InternalA2AApi::class) +public class InMemoryMessageStorage : MessageStorage { + private val messagesByContext = mutableMapOf>() + private val rwLock = RWLock() + + override suspend fun save(message: Message): Unit = rwLock.withWriteLock { + val contextId = message.contextId + ?: throw MessageOperationException("Message must have a contextId to be saved") + + messagesByContext.getOrPut(contextId) { mutableListOf() }.add(message) + } + + override suspend fun getByContext(contextId: String): List = rwLock.withReadLock { + messagesByContext[contextId]?.toList() ?: emptyList() + } + + override suspend fun deleteByContext(contextId: String): Unit = rwLock.withReadLock { + messagesByContext -= contextId + } + + override suspend fun replaceByContext(contextId: String, messages: List): Unit = rwLock.withWriteLock { + // Validate that all messages have the correct contextId + val invalidMessages = messages.filter { it.contextId != contextId } + if (invalidMessages.isNotEmpty()) { + throw MessageOperationException( + "All messages must have contextId '$contextId', but found messages with different contextIds: " + + invalidMessages.map { it.contextId }.distinct().joinToString() + ) + } + + // Replace all messages for the context + messagesByContext[contextId] = messages.toMutableList() + } +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/messages/MessageStorage.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/messages/MessageStorage.kt new file mode 100644 index 0000000000..d722f13b38 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/messages/MessageStorage.kt @@ -0,0 +1,101 @@ +package ai.koog.a2a.server.messages + +import ai.koog.a2a.model.Message +import ai.koog.a2a.server.exceptions.MessageOperationException + +/** + * Storage interface for messages associated with a particular context. + * Can be used to keep track of a conversation history. + * + * Implementations must ensure concurrency safety. + */ +public interface MessageStorage { + /** + * Saves a message to the storage. + * + * @param message the message to save + * @throws MessageOperationException if the message cannot be saved + */ + public suspend fun save(message: Message) + + /** + * Retrieves all messages associated with the given context. + * + * @param contextId the context identifier + */ + public suspend fun getByContext(contextId: String): List + + /** + * Deletes all messages associated with the given context. + * + * @param contextId the context identifier + * @throws MessageOperationException if some messages cannot be deleted + */ + public suspend fun deleteByContext(contextId: String) + + /** + * Replaces all messages associated with the given context. + * + * @param contextId the context identifier + * @throws MessageOperationException if context cannot be replaced + */ + public suspend fun replaceByContext(contextId: String, messages: List) +} + +/** + * Wrapper class around [MessageStorage] for interacting with a particular context. + * Provides convenience methods and verification for context ID. + * + * @param contextId the context identifier + * @param messageStorage the underlying [MessageStorage] implementation + */ +public class ContextMessageStorage( + private val contextId: String, + private val messageStorage: MessageStorage, +) { + /** + * Saves a message to the storage. + * + * @param message the message to save + * @see [MessageStorage.save] + */ + public suspend fun save(message: Message) { + require(message.contextId == contextId) { + "contextId of message must be same as current contextId" + } + + messageStorage.save(message) + } + + /** + * Retrieves all messages associated with the current context. + * + * @see [MessageStorage.getByContext] + */ + public suspend fun getAll(): List { + return messageStorage.getByContext(contextId) + } + + /** + * Deletes all messages associated with the current context. + * + * @see [MessageStorage.deleteByContext] + */ + public suspend fun deleteAll() { + messageStorage.deleteByContext(contextId) + } + + /** + * Replaces all messages associated with the current context. + * + * @param messages the list of messages to replace + * @see [MessageStorage.replaceByContext] + */ + public suspend fun replaceAll(messages: List) { + require(messages.all { it.contextId == contextId }) { + "contextId of messages must be same as current contextId" + } + + messageStorage.replaceByContext(contextId, messages) + } +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/RequestContext.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/RequestContext.kt new file mode 100644 index 0000000000..8672c7b128 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/RequestContext.kt @@ -0,0 +1,23 @@ +package ai.koog.a2a.server.session + +import ai.koog.a2a.server.messages.ContextMessageStorage +import ai.koog.a2a.server.tasks.ContextTaskStorage +import ai.koog.a2a.transport.ServerCallContext + +/** + * Request context associated with each A2A agent-related request, providing essential information and repositories to + * the agent executor. + * + * @param contextId Context ID associated with this request. + * @param callContext [ServerCallContext] associated with the request. + * @param params Parameters associated with the request. + * @param taskStorage [ContextTaskStorage] associated with the request. + * @param messageStorage [ContextMessageStorage] associated with the request. + */ +public class RequestContext( + public val contextId: String, + public val callContext: ServerCallContext, + public val params: T, + public val taskStorage: ContextTaskStorage, + public val messageStorage: ContextMessageStorage, +) diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt new file mode 100644 index 0000000000..f419ab17fe --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt @@ -0,0 +1,212 @@ +package ai.koog.a2a.server.session + +import ai.koog.a2a.model.Event +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskArtifactUpdateEvent +import ai.koog.a2a.model.TaskEvent +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.server.exceptions.InvalidEventException +import ai.koog.a2a.server.tasks.TaskStorage +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.flow.SharedFlow +import kotlinx.coroutines.flow.SharingStarted +import kotlinx.coroutines.flow.receiveAsFlow +import kotlinx.coroutines.flow.shareIn +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock + +/** + * A session processor responsible for handling session events. + * It validates the events, emits them to the subscribers via [events] and updates session state. + * All valid [TaskEvent] events that are sent using [sendTaskEvent] will also be saved to the [taskStorage] provided. + * + * Validation logic attempts to verify that the number, type and order of events comply to what is expected from a proper + * A2A server implementation. + * These are the main rules: + * + * - **Session type exclusivity**: A session can only handle either [Message] events or [TaskEvent] events, never both + * - **Context ID validation**: All events must have the same contextId as the session + * - **Single message limit**: Only one [Message] can be sent per session, after which the session becomes terminal + * - **Task initialization order**: For new tasks, the first [TaskEvent] must be of type [Task] to create the task + * - **Task ID consistency**: Once a task session is initialized, only [TaskEvent]s with the same taskId are allowed + * - **Final event enforcement**: After a [TaskStatusUpdateEvent] with `final=true` is sent, no more events are permitted + * - **Terminal state blocking**: No events can be sent when the task is already in a terminal state + * - **Final flag requirement**: [TaskStatusUpdateEvent]s that set the task to a terminal state must have `final=true` + * + * @property contextId The contextId of the session. + * @param taskStorage The storage for tasks where task events will be saved. + * @param coroutineScope The scope in which the event flow will be shared + * @param currentTask The current task associated with the session, if it is a continuation of a previous task session. + * + * @property events A shared flow of session events that can be subscribed. The flow will be closed when the session is closed. + */ +public class SessionEventProcessor( + public val contextId: String, + private val taskStorage: TaskStorage, + coroutineScope: CoroutineScope, + currentTask: Task? = null, +) : AutoCloseable { + private companion object { + private const val MESSAGE_SENT = + "Message has already been sent in this session. Sending message is a terminal operation and no more events " + + "are allowed to be sent, the session must terminate ASAP" + + private const val TASK_INITIALIZED = + "Task has already been initialized in this sessions, only TaskEvent's with the same taskId can be sent from now on." + + private const val TASK_EVENT_FINAL_SENT = + "Final TaskEvent has already been sent in this session. Sending final event is a terminal operation " + + "and no more events are allowed to be sent, the session must terminate ASAP" + + private const val TASK_EVENT_TERMINAL_STATE = + "TaskEvent's cannot be sent when the task transitioned to the terminal state." + + private const val TASK_EVENT_FINAL_REQUIRED = + "TaskEvent final parameter is required to be set to 'true' when setting task state to the terminal state" + + private const val TASK_DOES_NOT_EXIST = + "Task associated with the taskId in TaskEvent does not exist yet and the event was not Task. Creating new " + + "task should always start with Task event." + + private const val INVALID_CONTEXT_ID = "Event contextId must be same as current contextId" + } + + private sealed interface SessionType { + object MessageSession : SessionType + + class TaskSession( + val taskId: String, + var taskState: TaskState? = null, + var finalEventReceived: Boolean = false, + ) : SessionType + } + + private val _events = Channel() + public val events: SharedFlow = _events + .receiveAsFlow() + .shareIn(scope = coroutineScope, started = SharingStarted.Eagerly) + + private val sessionMutex = Mutex() + private var sessionType: SessionType? = currentTask?.let { + SessionType.TaskSession( + taskId = it.id, + taskState = it.status.state + ) + } + + /** + * Sends a [Message] to the session event processor. Validates the message against the session context and updates + * the session state accordingly. + * + * @param message The message to be sent. Contains details such as message content, context ID, and metadata. + * @throws [InvalidEventException] for invalid events. + * Check [SessionEventProcessor] docs from info about valid events. + */ + public suspend fun sendMessage(message: Message): Unit = sessionMutex.withLock { + if (message.contextId != contextId) { + throw InvalidEventException(INVALID_CONTEXT_ID) + } + + when (sessionType) { + is SessionType.MessageSession -> throw InvalidEventException(MESSAGE_SENT) + + is SessionType.TaskSession -> throw InvalidEventException(TASK_INITIALIZED) + + null -> { + _events.send(message) + sessionType = SessionType.MessageSession + } + } + } + + /** + * Sends a [TaskEvent] to the session event processor. Validates the event against the session context and updates + * the session state and [taskStorage] accordingly. + * + * @param event The event to be sent. Contains details such as task ID, context ID, and metadata. + * @throws [InvalidEventException] for invalid events. + * Check [SessionEventProcessor] docs from info about valid events. + */ + public suspend fun sendTaskEvent(event: TaskEvent): Unit = sessionMutex.withLock { + if (event.contextId != contextId) { + throw InvalidEventException(INVALID_CONTEXT_ID) + } + /* + The first set of checks, to get initial task session type if it is allowed here. + */ + val taskSessionType: SessionType.TaskSession = when (sessionType) { + is SessionType.MessageSession -> throw InvalidEventException(MESSAGE_SENT) + + is SessionType.TaskSession -> sessionType as SessionType.TaskSession + + null -> { + val savedTask = taskStorage.get(event.taskId) + + SessionType.TaskSession( + taskId = event.taskId, + taskState = savedTask?.status?.state, // null - new task + finalEventReceived = false + ).also { + sessionType = it + } + } + } + + /* + The second set of checks to check various aspects of the current task and session state and guide the user to emit + only allowed events. + */ + when { + /** + * If the task does not exist yet, the first [TaskEvent] should be only of type Task, to create the task itself + */ + taskSessionType.taskState == null && event !is Task -> + throw InvalidEventException(TASK_DOES_NOT_EXIST) + + /** + * If there was already a [TaskStatusUpdateEvent] with [TaskStatusUpdateEvent.final] set to true, no more events are expected + */ + taskSessionType.finalEventReceived -> + throw InvalidEventException(TASK_EVENT_FINAL_SENT) + + /** + * If the task is already in a terminal state, no more events are expected + */ + taskSessionType.taskState?.terminal == true -> + throw InvalidEventException(TASK_EVENT_TERMINAL_STATE) + + /** + * If the event is a [TaskStatusUpdateEvent] attempting to set a task to a terminal state, + * then [TaskStatusUpdateEvent.final] must be set to true + */ + event is TaskStatusUpdateEvent && event.status.state.terminal && !event.final -> + throw InvalidEventException(TASK_EVENT_FINAL_REQUIRED) + } + + // Only if all checks passed, attempt to update and emit the event + taskStorage.update(event) + _events.send(event) + + when (event) { + is TaskStatusUpdateEvent -> taskSessionType.apply { + taskState = event.status.state + finalEventReceived = event.final + } + + is Task -> taskSessionType.apply { + taskState = event.status.state + } + + is TaskArtifactUpdateEvent -> { + // do nothing, condition is left here for clarity + } + } + } + + override fun close() { + _events.close() + } +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt new file mode 100644 index 0000000000..b7c3cf84df --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt @@ -0,0 +1,94 @@ +package ai.koog.a2a.server.session + +import ai.koog.a2a.annotations.InternalA2AApi +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.TaskEvent +import ai.koog.a2a.utils.RWLock +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.flow.collect +import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.onCompletion +import kotlinx.coroutines.launch + +/** + * Manages a set of active instances of [SessionEventProcessor]. + * + * Each added processor is monitored for task id associated with this session, if any, i.e., the session is processing a task, + * and if it is a task-related session, the processor is added to the task sessions map. + * + * Automatically removes the processor when the session is closed. + * + * @param coroutineScope The scope in which the monitoring jobs will be launched. + */ +@OptIn(InternalA2AApi::class) +public class SessionManager( + private val coroutineScope: CoroutineScope, +) { + private val allProcessors = mutableSetOf() + private val taskProcessors = mutableMapOf() + private val rwLock = RWLock() + + /** + * Adds a session event processor to a set of active processors. + * If the first event in the processor is of type [TaskEvent], the processor is added to the task sessions map too. + * Handles cleanup by removing the processor when the session is closed. + * + * @param eventProcessor The session event processor to be added. + */ + public fun addProcessor(eventProcessor: SessionEventProcessor) { + coroutineScope.launch { + // Check if the first event in the session processor is task related and add this processor to the task sessions map. + when (val firstEvent = eventProcessor.events.first()) { + is TaskEvent -> { + val taskId = firstEvent.taskId + + rwLock.withWriteLock { + check(taskId !in taskProcessors) { + "SessionEventProcessor for taskId '${firstEvent.taskId}' already exists." + } + + allProcessors += eventProcessor + taskProcessors[firstEvent.taskId] = eventProcessor + } + + // Wait for the session to close and remove the processor from collections. + eventProcessor.events + .onCompletion { + rwLock.withWriteLock { + allProcessors -= eventProcessor + taskProcessors -= taskId + } + } + .collect() + } + + is Message -> { + allProcessors += eventProcessor + + // Wait for the session to close and remove the processor from collections. + eventProcessor.events + .onCompletion { + rwLock.withWriteLock { + allProcessors -= eventProcessor + } + } + .collect() + } + } + } + } + + /** + * Returns the session event processor for the given task id, if any. + */ + public suspend fun processorForTask(taskId: String): SessionEventProcessor? = rwLock.withReadLock { + taskProcessors[taskId] + } + + /** + * Returns the number of active session event processors. + */ + public suspend fun activeProcessors(): Int = rwLock.withReadLock { + allProcessors.size + } +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorage.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorage.kt new file mode 100644 index 0000000000..14ddd44e81 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorage.kt @@ -0,0 +1,148 @@ +package ai.koog.a2a.server.tasks + +import ai.koog.a2a.annotations.InternalA2AApi +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskArtifactUpdateEvent +import ai.koog.a2a.model.TaskEvent +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.server.exceptions.TaskOperationException +import ai.koog.a2a.utils.RWLock + +/** + * In-memory implementation of [TaskStorage] using a thread-safe map. + * + * This implementation stores tasks in memory and provides concurrency safety through mutex locks. + * Tasks are indexed by both ID and context ID for efficient retrieval. + */ +@OptIn(InternalA2AApi::class) +public class InMemoryTaskStorage : TaskStorage { + private val tasks = mutableMapOf() + private val tasksByContext = mutableMapOf>() + private val rwLock = RWLock() + + override suspend fun get( + taskId: String, + historyLength: Int?, + includeArtifacts: Boolean + ): Task? = rwLock.withReadLock { + historyLength?.let { + require(it >= 0) { "historyLength must be non-negative" } + } + + val task = tasks[taskId] ?: return@withReadLock null + // Need to modify the original full task object to remove some information + val isModificationNeeded = historyLength != null || !includeArtifacts + + if (isModificationNeeded) { + task.copy( + history = if (historyLength != null) { + task.history?.takeLast(historyLength) + } else { + task.history + }, + artifacts = if (includeArtifacts) { + task.artifacts + } else { + null + } + ) + } else { + task + } + } + + override suspend fun getAll( + taskIds: List, + historyLength: Int?, + includeArtefacts: Boolean + ): List = rwLock.withReadLock { + taskIds.mapNotNull { taskId -> + get(taskId, historyLength, includeArtefacts) + } + } + + override suspend fun getByContext( + contextId: String, + historyLength: Int?, + includeArtefacts: Boolean + ): List = rwLock.withReadLock { + val contextTaskIds = tasksByContext[contextId] ?: emptySet() + contextTaskIds.mapNotNull { taskId -> + get(taskId, historyLength, includeArtefacts) + } + } + + override suspend fun update(event: TaskEvent): Unit = rwLock.withWriteLock { + when (event) { + is Task -> { + // Store or replace the task + val oldTask = tasks[event.id] + tasks[event.id] = event + + // Update context index + tasksByContext.getOrPut(event.contextId) { mutableSetOf() }.add(event.id) + + // Remove from old context if it changed + if (oldTask != null && oldTask.contextId != event.contextId) { + tasksByContext[oldTask.contextId]?.remove(event.id) + if (tasksByContext[oldTask.contextId]?.isEmpty() == true) { + tasksByContext.remove(oldTask.contextId) + } + } + } + + is TaskStatusUpdateEvent -> { + val existingTask = tasks[event.taskId] + ?: throw TaskOperationException("Cannot update status for non-existing task: ${event.taskId}") + + val updatedTask = existingTask.copy(status = event.status) + tasks[event.taskId] = updatedTask + } + + is TaskArtifactUpdateEvent -> { + val existingTask = tasks[event.taskId] + ?: throw TaskOperationException("Cannot update artifact for non-existing task: ${event.taskId}") + + val currentArtifacts = existingTask.artifacts?.toMutableList() ?: mutableListOf() + val existingArtifactIndex = currentArtifacts.indexOfFirst { it.artifactId == event.artifact.artifactId } + + if (existingArtifactIndex >= 0) { + val existingArtifact = currentArtifacts[existingArtifactIndex] + + currentArtifacts[existingArtifactIndex] = if (event.append == true) { + existingArtifact.copy(parts = existingArtifact.parts + event.artifact.parts) + } else { + event.artifact + } + } else { + currentArtifacts.add(event.artifact) + } + + val updatedTask = existingTask.copy(artifacts = currentArtifacts) + tasks[event.taskId] = updatedTask + } + } + } + + override suspend fun delete(taskId: String): Unit = rwLock.withWriteLock { + tasks.remove(taskId)?.let { task -> + // Remove from context index + tasksByContext[task.contextId]?.remove(taskId) + if (tasksByContext[task.contextId]?.isEmpty() == true) { + tasksByContext.remove(task.contextId) + } + } + } + + override suspend fun deleteAll(taskIds: List): Unit = rwLock.withWriteLock { + taskIds.forEach { taskId -> + tasks.remove(taskId)?.let { task -> + // Remove from context index + tasksByContext[task.contextId]?.remove(taskId) + if (tasksByContext[task.contextId]?.isEmpty() == true) { + tasksByContext.remove(task.contextId) + } + } + } + } +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/TaskStorage.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/TaskStorage.kt new file mode 100644 index 0000000000..0f76c9353f --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/TaskStorage.kt @@ -0,0 +1,154 @@ +package ai.koog.a2a.server.tasks + +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskEvent +import ai.koog.a2a.server.exceptions.TaskOperationException + +/** + * Storage interface for managing tasks and their lifecycle events. + * + * Implementations must ensure concurrency safety. + */ +public interface TaskStorage { + /** + * Retrieves a task by ID. + * + * @param taskId the unique task identifier + * @param historyLength the maximum number of messages in conversation history to include in the response + * Set to `null` to include all messages. Defaults to `0`. + * @param includeArtifacts whether to include artifacts in the response. Default is `false`. + * + * @return the task or null if not found + */ + public suspend fun get( + taskId: String, + historyLength: Int? = 0, + includeArtifacts: Boolean = false + ): Task? + + /** + * Retrieves multiple tasks by their IDs. + * + * @param taskIds list of task identifiers + * @param historyLength the maximum number of messages in conversation history to include in the response + * Set to `null` to include all messages. Defaults to `0`. + * @param includeArtifacts whether to include artifacts in the response. Default is `false`. + * @return list of found tasks (may be fewer than requested) + */ + public suspend fun getAll( + taskIds: List, + historyLength: Int? = 0, + includeArtifacts: Boolean = false + ): List + + /** + * Retrieves all tasks associated with a specific context ID. + * + * @param contextId context identifier + * @param historyLength the maximum number of messages in conversation history to include in the response + * Set to `null` to include all messages. Defaults to `0`. + * @param includeArtifacts whether to include artifacts in the response. Default is `false`. + * @return A list of tasks that match the specified context ID. + */ + public suspend fun getByContext( + contextId: String, + historyLength: Int? = 0, + includeArtifacts: Boolean = false + ): List + + /** + * Updates task state based on an event, or creates the task if it doesn't exist. + * + * When the event is a [Task], it will be stored as a new task or replace an existing one. + * When the event is a status/artifact update, it modifies the existing task state. + * All tasks must be created before they can be updated with the [Task] event first. + * Attempts to send a non-[Task] [TaskEvent] for a non-existing task will result in error. + * + * @param event the update event to apply (creation or modification) + * @throws TaskOperationException if the task cannot be created or updated, e.g. [TaskEvent] that is not [Task] is sent for non-existing task id + */ + public suspend fun update(event: TaskEvent) + + /** + * Deletes a task by ID. + * + * @param taskId the task identifier to delete + * @throws TaskOperationException if the task cannot be deleted, e.g. it doesn't exist + */ + public suspend fun delete(taskId: String) + + /** + * Deletes multiple tasks by their IDs. + * + * @param taskIds list of task identifiers to delete + * @throws TaskOperationException if some tasks cannot be deleted, e.g. they don't exist + */ + public suspend fun deleteAll(taskIds: List) +} + +public class ContextTaskStorage( + private val contextId: String, + private val taskStorage: TaskStorage, +) { + /** + * Retrieves a task by ID. + * + * @see [TaskStorage.get] + */ + public suspend fun get( + taskId: String, + historyLength: Int? = 0, + includeArtifacts: Boolean = false + ): Task? = taskStorage.get(taskId, historyLength, includeArtifacts) + + /** + * Retrieves multiple tasks by their IDs. + * + * @see [TaskStorage.getAll] + */ + public suspend fun getAll( + taskIds: List, + historyLength: Int? = 0, + includeArtifacts: Boolean = false + ): List = taskStorage.getAll(taskIds, historyLength, includeArtifacts) + + /** + * Retrieves all tasks associated with the current context ID. + * + * @see [TaskStorage.getByContext] + */ + public suspend fun getByContext( + historyLength: Int? = 0, + includeArtifacts: Boolean = false + ): List = taskStorage.getByContext(contextId, historyLength, includeArtifacts) + + /** + * Deletes a task by ID, checking that it belongs to the current context ID. + * + * @see [TaskStorage.delete] + */ + public suspend fun delete(taskId: String) { + get(taskId, historyLength = 0, includeArtifacts = false)?.let { task -> + require(task.contextId == contextId) { + "contextId of the task requested to be deleted must be same as current contextId" + } + + taskStorage.delete(taskId) + } + } + + /** + * Deletes multiple tasks by their IDs, checking that they belong to the current context ID. + * + * @see [TaskStorage.deleteAll] + */ + public suspend fun deleteAll(taskIds: List) { + getAll(taskIds, historyLength = 0, includeArtifacts = false).let { tasks -> + require(tasks.all { it.contextId == contextId }) { + "contextId of the tasks requested to be deleted must be same as current contextId" + } + + taskStorage.deleteAll(taskIds) + } + } +} diff --git a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/messages/InMemoryMessageStorageTest.kt b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/messages/InMemoryMessageStorageTest.kt new file mode 100644 index 0000000000..5395731022 --- /dev/null +++ b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/messages/InMemoryMessageStorageTest.kt @@ -0,0 +1,99 @@ +package ai.koog.a2a.server.messages + +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.server.exceptions.MessageOperationException +import kotlinx.coroutines.test.runTest +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertTrue + +class InMemoryMessageStorageTest { + private lateinit var storage: InMemoryMessageStorage + + @BeforeTest + fun setUp() { + storage = InMemoryMessageStorage() + } + + @Test + fun testSaveMessage() = runTest { + val message = createMessage("msg-1", "context-1", "Hello, world!") + + storage.save(message) + + val messages = storage.getByContext("context-1") + + assertEquals(1, messages.size) + assertEquals(message, messages[0]) + } + + @Test + fun testSaveMessageWithoutContextId() = runTest { + val message = createMessage("msg-1", null, "Hello, world!") + + assertFailsWith { + storage.save(message) + } + } + + @Test + fun testSaveMultipleMessages() = runTest { + val message1 = createMessage("msg-1", "context-1", "First message") + val message2 = createMessage("msg-2", "context-1", "Second message") + val message3 = createMessage("msg-3", "context-2", "Different context") + + storage.save(message1) + storage.save(message2) + storage.save(message3) + + val context1Messages = storage.getByContext("context-1") + assertEquals(2, context1Messages.size) + assertEquals(message1, context1Messages[0]) + assertEquals(message2, context1Messages[1]) + + val context2Messages = storage.getByContext("context-2") + assertEquals(1, context2Messages.size) + assertEquals(message3, context2Messages[0]) + } + + @Test + fun testGetByNonExistentContext() = runTest { + val messages = storage.getByContext("non-existent-context") + assertTrue(messages.isEmpty()) + } + + @Test + fun testDeleteByContext() = runTest { + val message1 = createMessage("msg-1", "context-1", "Message 1") + val message2 = createMessage("msg-2", "context-1", "Message 2") + val message3 = createMessage("msg-3", "context-2", "Different context") + + storage.save(message1) + storage.save(message2) + storage.save(message3) + + storage.deleteByContext("context-1") + + val context1Messages = storage.getByContext("context-1") + assertTrue(context1Messages.isEmpty()) + + val context2Messages = storage.getByContext("context-2") + assertEquals(1, context2Messages.size) + assertEquals(message3, context2Messages[0]) + } + + private fun createMessage( + messageId: String, + contextId: String?, + content: String = "test content" + ) = Message( + messageId = messageId, + role = Role.User, + parts = listOf(TextPart(content)), + contextId = contextId + ) +} diff --git a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorageTest.kt b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorageTest.kt new file mode 100644 index 0000000000..5cdf815282 --- /dev/null +++ b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorageTest.kt @@ -0,0 +1,253 @@ +package ai.koog.a2a.server.tasks + +import ai.koog.a2a.model.Artifact +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskArtifactUpdateEvent +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatus +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.server.exceptions.TaskOperationException +import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Instant +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue + +class InMemoryTaskStorageTest { + private lateinit var storage: InMemoryTaskStorage + + @BeforeTest + fun setUp() { + storage = InMemoryTaskStorage() + } + + @Test + fun testGetNonExistentTask() = runTest { + val result = storage.get("non-existent-id") + assertNull(result) + } + + @Test + fun testStoreAndRetrieveTask() = runTest { + val task = createTask(id = "task-1", contextId = "context-1") + + storage.update(task) + + val retrieved = storage.get("task-1") + + assertNotNull(retrieved) + assertEquals(task.id, retrieved.id) + assertEquals(task.contextId, retrieved.contextId) + } + + @Test + fun testDeleteTask() = runTest { + val task = createTask(id = "task-1", contextId = "context-1") + storage.update(task) + + storage.delete("task-1") + + val retrieved = storage.get("task-1") + assertNull(retrieved) + } + + @Test + fun testGetByContext() = runTest { + val task1 = createTask(id = "task-1", contextId = "context-1") + val task2 = createTask(id = "task-2", contextId = "context-1") + val task3 = createTask(id = "task-3", contextId = "context-2") + + storage.update(task1) + storage.update(task2) + storage.update(task3) + + val result = storage.getByContext("context-1") + + assertEquals(2, result.size) + assertTrue(result.all { it.contextId == "context-1" }) + assertTrue(result.any { it.id == "task-1" }) + assertTrue(result.any { it.id == "task-2" }) + } + + @Test + fun testTaskStatusUpdateEvent() = runTest { + // Create and store initial task + val task = createTask(id = "task-1", contextId = "context-1") + storage.update(task) + + // Create a status update event + val newMessage = createUserMessage("status-msg", "context-1", "Task completed successfully") + val newStatus = TaskStatus( + state = TaskState.Completed, + message = newMessage, + timestamp = Instant.parse("2023-01-01T12:00:00Z") + ) + val statusUpdateEvent = TaskStatusUpdateEvent( + taskId = "task-1", + contextId = "context-1", + status = newStatus, + final = true + ) + + // Update task status + storage.update(statusUpdateEvent) + + // Verify the status was updated + val retrieved = storage.get("task-1") + assertEquals(newStatus, retrieved?.status) + } + + @Test + fun testTaskStatusUpdateEventForNonExistentTask() = runTest { + val statusUpdateEvent = TaskStatusUpdateEvent( + taskId = "non-existent", + contextId = "context-1", + status = TaskStatus(state = TaskState.Completed), + final = true + ) + + assertFailsWith { + storage.update(statusUpdateEvent) + } + } + + @Test + fun testTaskArtifactUpdateEventNewArtifact() = runTest { + // Create and store initial task + val task = createTask(id = "task-1", contextId = "context-1") + storage.update(task) + + // Create artifact update event with new artifact + val artifact = Artifact( + artifactId = "artifact-1", + parts = listOf(TextPart("Initial content")) + ) + val artifactUpdateEvent = TaskArtifactUpdateEvent( + taskId = "task-1", + contextId = "context-1", + artifact = artifact, + append = false + ) + + // Update task with artifact + storage.update(artifactUpdateEvent) + + // Verify the artifact was added + val retrieved = storage.get("task-1", includeArtifacts = true) + assertEquals(listOf(artifact), retrieved?.artifacts) + } + + @Test + fun testTaskArtifactUpdateEventReplaceExisting() = runTest { + // Create and store initial task with artifact + val initialArtifact = Artifact( + artifactId = "artifact-1", + parts = listOf(TextPart("Initial content")) + ) + val task = createTask(id = "task-1", contextId = "context-1", artifacts = listOf(initialArtifact)) + storage.update(task) + + // Create artifact update event to replace existing artifact + val newArtifact = Artifact( + artifactId = "artifact-1", + parts = listOf(TextPart("Replaced content")) + ) + val artifactUpdateEvent = TaskArtifactUpdateEvent( + taskId = "task-1", + contextId = "context-1", + artifact = newArtifact, + append = false + ) + + // Update task with new artifact + storage.update(artifactUpdateEvent) + + // Verify the artifact was replaced + val retrieved = storage.get("task-1", includeArtifacts = true) + assertEquals(listOf(newArtifact), retrieved?.artifacts) + } + + @Test + fun testTaskArtifactUpdateEventAppendToExisting() = runTest { + // Create and store initial task with artifact + val initialArtifact = Artifact( + artifactId = "artifact-1", + parts = listOf(TextPart("Initial content")) + ) + val task = createTask(id = "task-1", contextId = "context-1", artifacts = listOf(initialArtifact)) + storage.update(task) + + // Create artifact update event to append to existing artifact + val appendArtifact = Artifact( + artifactId = "artifact-1", + parts = listOf(TextPart(" Additional content")) + ) + val artifactUpdateEvent = TaskArtifactUpdateEvent( + taskId = "task-1", + contextId = "context-1", + artifact = appendArtifact, + append = true + ) + + val resultingArtifact = initialArtifact.copy(parts = initialArtifact.parts + appendArtifact.parts) + + // Update task with appended artifact + storage.update(artifactUpdateEvent) + + // Verify the content was appended + val retrieved = storage.get("task-1", includeArtifacts = true) + assertEquals(listOf(resultingArtifact), retrieved?.artifacts) + } + + @Test + fun testTaskArtifactUpdateEventForNonExistentTask() = runTest { + val artifact = Artifact( + artifactId = "artifact-1", + parts = listOf(TextPart("Content")) + ) + val artifactUpdateEvent = TaskArtifactUpdateEvent( + taskId = "non-existent", + contextId = "context-1", + artifact = artifact, + append = false + ) + + assertFailsWith { + storage.update(artifactUpdateEvent) + } + } + + private fun createUserMessage( + messageId: String, + contextId: String, + content: String + ) = Message( + messageId = messageId, + role = Role.User, + parts = listOf(TextPart(content)), + contextId = contextId + ) + + private fun createTask( + id: String, + contextId: String, + history: List? = null, + artifacts: List? = null + ) = Task( + id = id, + contextId = contextId, + status = TaskStatus( + state = TaskState.Submitted, + timestamp = Instant.parse("2023-01-01T10:00:00Z") + ), + history = history, + artifacts = artifacts + ) +} diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportMokksyTest.kt b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportMokksyTest.kt deleted file mode 100644 index 6ef91f3dba..0000000000 --- a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportMokksyTest.kt +++ /dev/null @@ -1,155 +0,0 @@ -package ai.koog.a2a.transport.client.jsonrpc.http - -import ai.koog.a2a.model.Message -import ai.koog.a2a.model.MessageSendParams -import ai.koog.a2a.model.Role -import ai.koog.a2a.model.TaskIdParams -import ai.koog.a2a.model.TaskQueryParams -import ai.koog.a2a.model.TaskState -import ai.koog.a2a.model.TextPart -import ai.koog.a2a.transport.ClientCallContext -import ai.koog.a2a.transport.Request -import ai.koog.a2a.transport.RequestId -import io.kotest.assertions.throwables.shouldThrow -import io.kotest.matchers.nulls.shouldNotBeNull -import io.kotest.matchers.shouldBe -import io.ktor.client.HttpClient -import kotlinx.coroutines.test.runTest -import kotlinx.serialization.MissingFieldException -import me.kpavlov.aimocks.a2a.MockAgentServer -import me.kpavlov.aimocks.a2a.model.Task -import me.kpavlov.aimocks.a2a.model.TaskStatus -import kotlin.test.Test - -class HttpJSONRPCClientTransportMokksyTest { - val a2aServer = MockAgentServer(verbose = true) - - val client = HttpJSONRPCClientTransport(a2aServer.baseUrl(), HttpClient()) - - @Test - fun `Should sendMessage`() = runTest { - a2aServer.sendMessage() responds { - id = "req_1234" - result = Task( - id = "tid_12345", - sessionId = "de38c76d-d54c-436c-8b9f-4c2703648d64", - contextId = "ctx_12345", - status = TaskStatus("submitted") - ) - } - - val response = client.sendMessage( - request = Request( - id = RequestId.StringId("req_1234"), - data = MessageSendParams( - message = Message( - role = Role.User, - parts = listOf( - TextPart("Tell me a joke") - ) - ) - ) - ), - ctx = ClientCallContext() - ) - - response shouldNotBeNull { - id shouldBe RequestId.StringId("req_1234") - (data as? ai.koog.a2a.model.Task) shouldNotBeNull { - id shouldBe "tid_12345" - contextId shouldBe "ctx_12345" - status shouldBe ai.koog.a2a.model.TaskStatus(TaskState.Submitted) - } - } - } - - @Test - fun `Should getTask`() = runTest { - a2aServer.getTask() responds { - id = 1 - result = Task( - id = "tid_12345", - contextId = "ctx_12345", - sessionId = "de38c76d-d54c-436c-8b9f-4c2703648d64", - status = TaskStatus("canceled") - ) - } - - val response = client.getTask( - request = Request( - id = RequestId.StringId("req_1234"), - data = TaskQueryParams(id = "tid_12345") - ), - ctx = ClientCallContext() - ) - - response shouldNotBeNull { - id shouldBe RequestId.NumberId(1) - data shouldNotBeNull { - id shouldBe "tid_12345" - contextId shouldBe "ctx_12345" - status shouldBe ai.koog.a2a.model.TaskStatus(TaskState.Canceled) - } - } - } - - @Test - fun `Should handle getTask with missing contextId`() = runTest { - a2aServer.getTask() responds { - id = 1 - result { - id = "tid_12345" - sessionId = "de38c76d-d54c-436c-8b9f-4c2703648d64" - status { - state = "completed" - } - artifacts += artifact { - name = "joke" - parts += textPart { - text = "This is a joke" - } - } - } - } - - shouldThrow { - client.getTask( - request = Request( - id = RequestId.StringId("req_1234"), - data = TaskQueryParams(id = "tid_12345") - ), - ctx = ClientCallContext() - ) - }.missingFields shouldBe listOf("contextId") - } - - @Test - fun `Should cancelTask`() = runTest { - a2aServer.cancelTask() responds { - id = "req_123" - result = Task( - id = "tid_12345", - contextId = "ctx_12345", - sessionId = "de38c76d-d54c-436c-8b9f-4c2703648d64", - status = TaskStatus("canceled") - ) - } - - val response = client.cancelTask( - request = Request( - id = RequestId.StringId("req_1233"), - data = TaskIdParams("tid_12345") - ), - ctx = ClientCallContext() - ) - - response shouldNotBeNull { - id shouldBe RequestId.StringId("req_123") - data shouldNotBeNull { - id shouldBe "tid_12345" - contextId shouldBe "ctx_12345" - status shouldBe ai.koog.a2a.model.TaskStatus(TaskState.Canceled) - } - } - } -} diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt index dfc7f78904..9d9f330d37 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt @@ -4,13 +4,13 @@ import ai.koog.a2a.exceptions.A2AException import ai.koog.a2a.exceptions.createA2AException import ai.koog.a2a.model.AgentCard import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.Event import ai.koog.a2a.model.MessageSendParams import ai.koog.a2a.model.Task import ai.koog.a2a.model.TaskIdParams import ai.koog.a2a.model.TaskPushNotificationConfig import ai.koog.a2a.model.TaskPushNotificationConfigParams import ai.koog.a2a.model.TaskQueryParams -import ai.koog.a2a.model.UpdateEvent import ai.koog.a2a.transport.ClientCallContext import ai.koog.a2a.transport.ClientTransport import ai.koog.a2a.transport.Request @@ -123,7 +123,7 @@ public abstract class JSONRPCClientTransport : ClientTransport { override fun sendMessageStreaming( request: Request, ctx: ClientCallContext - ): Flow> = + ): Flow> = requestStreaming(A2AMethod.SendMessageStreaming, request, ctx) override suspend fun getTask( @@ -141,7 +141,7 @@ public abstract class JSONRPCClientTransport : ClientTransport { override fun resubscribeTask( request: Request, ctx: ClientCallContext - ): Flow> = + ): Flow> = requestStreaming(A2AMethod.ResubscribeTask, request, ctx) override suspend fun setTaskPushNotificationConfig( diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt index 181f0b3dbf..e68efbc003 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt @@ -1,5 +1,6 @@ package ai.koog.a2a.transport.jsonrpc +import ai.koog.a2a.annotations.InternalA2AApi import ai.koog.a2a.exceptions.A2AException import ai.koog.a2a.exceptions.A2AInternalErrorException import ai.koog.a2a.exceptions.A2AInvalidParamsException @@ -33,6 +34,7 @@ public abstract class JSONRPCServerTransport : ServerTransport { * Handles a JSON-RPC request and returns the corresponding response * Handles exceptions, mapping all non [A2AException]s to [A2AInternalErrorException], and then converting them to [JSONRPCErrorResponse]. */ + @OptIn(InternalA2AApi::class) protected suspend fun onRequest( request: JSONRPCRequest, ctx: ServerCallContext, diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt index 297b063eec..00004fff26 100644 --- a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt @@ -1,7 +1,10 @@ package ai.koog.a2a.transport.server.jsonrpc.http +import ai.koog.a2a.annotations.InternalA2AApi +import ai.koog.a2a.consts.A2AConsts import ai.koog.a2a.exceptions.A2AInvalidRequestException import ai.koog.a2a.exceptions.A2AParseException +import ai.koog.a2a.model.AgentCard import ai.koog.a2a.transport.RequestHandler import ai.koog.a2a.transport.ServerCallContext import ai.koog.a2a.transport.jsonrpc.JSONRPCServerTransport @@ -20,6 +23,7 @@ import io.ktor.server.request.receiveText import io.ktor.server.response.respond import io.ktor.server.routing.Route import io.ktor.server.routing.application +import io.ktor.server.routing.get import io.ktor.server.routing.post import io.ktor.server.routing.route import io.ktor.server.routing.routing @@ -44,7 +48,7 @@ import kotlinx.serialization.serializer * requestHandler = A2AServer(...) * ) * - * transport.start(port = 8080, path = "/my-agent") + * transport.start(port = 8080, path = "/my-agent", agentCard = AgentCard(...), agentCardPath = "/my-agent-card.json") * transport.stop() * ``` * @@ -91,26 +95,42 @@ public class HttpJSONRPCServerTransport( * Can be used to start a standalone server for quick prototyping or when no integration into the existing server is required. * The routing consists only of [transportRoutes]. * + * Can also optionally serve [AgentCard] at the specified [agentCardPath]. + * * If you need to integrate A2A request handling logic into existing Ktor application, * use [transportRoutes] method to mount the transport routes into existing [Route] configuration block. * - * @param port The port on which the server will listen. - * @param path The JSON-RPC endpoint path to handle incoming requests. + * @param port A port on which the server will listen. + * @param path A JSON-RPC endpoint path to handle incoming requests. + * @param agentCard An optional [AgentCard] that will be served at the specified [agentCardPath]. + * @param agentCardPath The path at which the [agentCard] will be served, if specified. + * Defaults to [A2AConsts.AGENT_CARD_WELL_KNOWN_PATH]. * * @throws IllegalStateException if the server is already running. * * @see [transportRoutes] */ - public suspend fun start(port: Int, path: String): Unit = serverMutex.withLock { + public suspend fun start( + port: Int, + path: String, + agentCard: AgentCard? = null, + agentCardPath: String = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH, + ): Unit = serverMutex.withLock { check(server == null) { "Server is already configured and running. Stop it before starting a new one." } - embeddedServer(Netty, port) { + server = embeddedServer(Netty, port) { install(SSE) routing { transportRoutes(this, path) + + if (agentCard != null) { + get(agentCardPath) { + call.respond(agentCard) + } + } } - }.startSuspend(wait = true) + }.startSuspend(wait = false) } /** @@ -155,6 +175,7 @@ public class HttpJSONRPCServerTransport( * @param route The base route to which the transport routes should be mounted. * @param path JSON-RPC endpoint path that will be mounted under the base [route]. */ + @OptIn(InternalA2AApi::class) public fun transportRoutes(route: Route, path: String): Route = route.route(path) { if (application.pluginOrNull(SSE) == null) { throw IllegalStateException("SSE plugin must be installed in the application to add these routes.") diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt index 4d8bbe71a9..4b492bbc3a 100644 --- a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt @@ -5,6 +5,7 @@ import ai.koog.a2a.model.AgentCapabilities import ai.koog.a2a.model.AgentCard import ai.koog.a2a.model.AgentSkill import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.Event import ai.koog.a2a.model.Message import ai.koog.a2a.model.MessageSendParams import ai.koog.a2a.model.PushNotificationConfig @@ -17,7 +18,6 @@ import ai.koog.a2a.model.TaskQueryParams import ai.koog.a2a.model.TaskState import ai.koog.a2a.model.TaskStatus import ai.koog.a2a.model.TextPart -import ai.koog.a2a.model.UpdateEvent import ai.koog.a2a.transport.Request import ai.koog.a2a.transport.RequestHandler import ai.koog.a2a.transport.RequestId @@ -145,7 +145,7 @@ class HttpJSONRPCServerTransportTest { override fun onSendMessageStreaming( request: Request, ctx: ServerCallContext - ): Flow> { + ): Flow> { return updateEvents .asFlow() .map { @@ -179,7 +179,7 @@ class HttpJSONRPCServerTransportTest { override fun onResubscribeTask( request: Request, ctx: ServerCallContext - ): Flow> { + ): Flow> { return updateEvents .asFlow() .map { diff --git a/a2a/test-python-a2a-server/src/agent_executor.py b/a2a/test-python-a2a-server/src/agent_executor.py index 906eb4cef0..7ebb0f3f13 100644 --- a/a2a/test-python-a2a-server/src/agent_executor.py +++ b/a2a/test-python-a2a-server/src/agent_executor.py @@ -79,6 +79,7 @@ async def do_cancelable_task( new_task(message), ) + async def do_long_running_task( event_queue: EventQueue, message: Message @@ -124,17 +125,19 @@ async def execute( context: RequestContext, event_queue: EventQueue, ) -> None: + user_input = context.get_user_input() + # Test scenarios to test various aspects of A2A - if "hello world" in context.get_user_input(): + if "hello world" in user_input: await say_hello(event_queue, context.message) - elif "do task" in context.get_user_input(): + elif "do task" in user_input: await do_task(event_queue, context.message) - elif "do cancelable task" in context.get_user_input(): + elif "do cancelable task" in user_input: await do_cancelable_task(event_queue, context.message) - elif "do long-running task" in context.get_user_input(): + elif "do long-running task" in user_input: await do_long_running_task(event_queue, context.message) else: From 2c4c65753bc16154ae4bd1e28bb693aa42e8da3c Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Fri, 19 Sep 2025 23:38:16 +0200 Subject: [PATCH 27/52] [a2a] Implement session management with push notifications --- a2a/CLAUDE.md | 253 ++++++++++++++++++ .../kotlin/ai/koog/a2a/model/Task.kt | 3 +- a2a/a2a-server/build.gradle.kts | 4 + .../kotlin/ai/koog/a2a/server/A2AServer.kt | 221 ++++++++++----- .../ai/koog/a2a/server/agent/AgentExecutor.kt | 68 ++++- .../koog/a2a/server/exceptions/Exceptions.kt | 5 + .../InMemoryPushNotificationConfigStorage.kt | 50 ++++ .../PushNotificationConfigStorage.kt | 43 +++ .../notifications/PushNotificationSender.kt | 26 ++ .../SimplePushNotificationSender.kt | 67 +++++ .../ai/koog/a2a/server/session/Session.kt | 38 +++ .../server/session/SessionEventProcessor.kt | 2 +- .../koog/a2a/server/session/SessionManager.kt | 102 ++++--- .../a2a/server/tasks/InMemoryTaskStorage.kt | 40 +-- ...MemoryPushNotificationConfigStorageTest.kt | 92 +++++++ .../server/tasks/InMemoryTaskStorageTest.kt | 34 ++- .../build.gradle.kts | 12 +- .../http/HttpJSONRPCServerTransport.kt | 11 +- 18 files changed, 920 insertions(+), 151 deletions(-) create mode 100644 a2a/CLAUDE.md create mode 100644 a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/InMemoryPushNotificationConfigStorage.kt create mode 100644 a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/PushNotificationConfigStorage.kt create mode 100644 a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/PushNotificationSender.kt create mode 100644 a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/SimplePushNotificationSender.kt create mode 100644 a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt create mode 100644 a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/notifications/InMemoryPushNotificationConfigStorageTest.kt rename a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/{jvmMain => commonMain}/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt (94%) diff --git a/a2a/CLAUDE.md b/a2a/CLAUDE.md new file mode 100644 index 0000000000..154c7311c0 --- /dev/null +++ b/a2a/CLAUDE.md @@ -0,0 +1,253 @@ +# A2A Module Development Guidelines + +## Module Overview + +The A2A (Agent-to-Agent) module is a **meta-module** within the larger Koog project that implements a comprehensive client and server library for the A2A protocol, a standardized communication protocol for AI agents based on the specification at https://a2a-protocol.org/latest/specification/. + +**Important**: Since this is a meta-module inside a bigger root project, the Gradle wrapper executable is located one directory up. All Gradle commands must use `../gradlew` when running from the a2a directory. + +### Purpose +- **Agent Discovery**: Through AgentCard manifests describing agent capabilities and interfaces +- **Message Exchange**: Standardized message format with support for different content types +- **Task Management**: Long-running operations with state tracking and lifecycle management +- **Push Notifications**: Asynchronous notifications for task updates +- **Multiple Transport Protocols**: JSON-RPC, HTTP+JSON/REST with extensibility for gRPC +- **Authentication & Security**: OpenAPI 3.0 compatible security schemes + +### Submodules Architecture + +1. **a2a-core**: Core abstractions, data models, and transport interfaces + - AgentCard, Message, Task, Event hierarchy + - Transport interfaces: ClientTransport, ServerTransport, RequestHandler + - No external dependencies beyond Kotlin stdlib and serialization + +2. **a2a-client**: High-level client library for communicating with A2A servers + - A2AClient wrapper, AgentCardResolver + - Capability validation, request/response handling + - Depends on: a2a-core, Ktor HTTP client + +3. **a2a-server**: Server-side implementation for hosting A2A agents + - A2AServer, AgentExecutor interface, session management + - Storage abstractions with in-memory implementations + - Depends on: a2a-core, coroutines, logging + +4. **a2a-transport**: Multiple transport protocol implementations + - a2a-transport-core-jsonrpc: JSON-RPC protocol base classes + - a2a-transport-client-jsonrpc-http: HTTP JSON-RPC client transport + - a2a-transport-server-jsonrpc-http: HTTP JSON-RPC server transport + - a2a-transport-*-rest: HTTP+JSON/REST protocol implementations + +## Technologies & Libraries + +### Core Dependencies (from gradle/libs.versions.toml) +- **Kotlin Multiplatform**: JVM + JS (IR) support +- **kotlinx-serialization**: JSON serialization for protocol messages +- **kotlinx-coroutines**: Async/concurrent programming, Flow APIs +- **kotlinx-datetime**: Timestamp handling in protocol messages +- **ktor3**: HTTP client/server for transport implementations +- **oshai-kotlin-logging**: Structured logging + +### Testing Libraries +- **kotlin-test**: Multiplatform test framework +- **kotest-assertions**: Rich assertion library for complex objects +- **kotlinx-coroutines-test**: Testing coroutines with runTest +- **testcontainers**: Docker-based integration testing +- **slf4j-simple**: Runtime logging for tests + +### Platform Support +- **JVM**: Full server and client support +- **JS (Browser)**: Client-only support +- **Future platforms**: Architecture ready for native support + +## Development Guidelines + +### Architecture Decisions +⚠️ **CRITICAL**: Never make design and architecture decisions independently. Always ask the user for confirmation before: +- Adding new transport protocols +- Changing storage interfaces +- Modifying core protocol message formats +- Adding new authentication schemes +- Changing session management behavior +- Doing other sorts of major architectural changes + +### Code Style Conventions + +#### API Visibility +- Use `explicitApi()` - all public APIs must have explicit visibility modifiers +- Prefer `public` declarations for APIs, `internal` for implementation details +- Use `@InternalA2AApi` annotation for APIs that are public but internal to the module + +#### Naming Conventions +- Classes: PascalCase (`A2AServer`, `SessionManager`, `AgentExecutor`) +- Interfaces: Same as classes, often ending in -or/-er for behaviors (`AgentExecutor`) +- Functions: camelCase with descriptive names (`onSendMessage`, `getByContext`) +- Properties: camelCase (`agentCard`, `sessionMutex`) + +#### Async Patterns +- Use `suspend` functions for all async operations +- Prefer `Flow` for streaming APIs (task events, message streams) +- Use `RWLock` pattern: `rwLock.withReadLock/withWriteLock` for concurrent access +- Use `Mutex` for single-threaded critical sections: `mutex.withLock` + +#### Error Handling +- Use domain-specific exceptions (`A2AInternalErrorException`, `TaskOperationException`) +- Propagate errors properly in async contexts +- Include contextual information in exception messages + +## Testing Requirements + +### Mandatory Test Execution +⚠️ **ALWAYS** run the specific test suite using Gradle's jvmTest task to ensure changes work correctly: + +```bash +# Run all tests in a specific module +../gradlew :a2a:a2a-server:jvmTest + +# Run a specific test class +../gradlew :a2a:a2a-server:jvmTest --tests "ai.koog.a2a.server.tasks.InMemoryTaskStorageTest" + +# Run a specific test method +../gradlew :a2a:a2a-client:jvmTest --tests "ai.koog.a2a.client.A2AClientIntegrationTest.test get agent card" +``` + +### Testing Approach: Crucial Minimum +- **Focus on core functionality** - test essential behavior, not implementation details +- **Separate dedicated tests** - one test method per core scenario/edge case +- **Avoid verbose, unnecessary tests** - don't test trivial getters/setters +- **Test error conditions** - verify proper exception handling and error states + +#### Preferred Assertion Pattern +**✅ DO**: Assert on whole objects when possible (especially data classes): +```kotlin +assertEquals(expectedTask, actualTask) // Compares all properties +assertEquals(listOf(expectedArtifact), retrieved?.artifacts) +``` + +**❌ AVOID**: Bunch of individual field assertions unless necessary: +```kotlin +// Avoid this pattern: +assertEquals(expected.id, actual.id) +assertEquals(expected.contextId, actual.contextId) +assertEquals(expected.status, actual.status) +// ... many more individual assertions +``` + +### Integration Tests +- **Docker-based testing**: a2a-client uses testcontainers with Python A2A server +- **Docker build dependency**: Integration tests automatically build required containers +- **Test isolation**: Each test should be independent and clean up after itself + +### Test Structure Examples + +**Unit Tests** (InMemoryTaskStorageTest pattern): +```kotlin +class InMemoryTaskStorageTest { + private lateinit var storage: InMemoryTaskStorage + + @BeforeTest + fun setUp() { + storage = InMemoryTaskStorage() + } + + @Test + fun testStoreAndRetrieveTask() = runTest { + val task = createTask(id = "task-1", contextId = "context-1") + storage.update(task) + val retrieved = storage.get("task-1") + + assertNotNull(retrieved) + assertEquals(task, retrieved) // Whole object assertion + } +} +``` + +**Integration Tests** (A2AClientIntegrationTest pattern): +```kotlin +@Testcontainers +class A2AClientIntegrationTest { + @Container + val testA2AServer = GenericContainer("test-python-a2a-server") + .withExposedPorts(9999) + .waitingFor(Wait.forListeningPort()) + + @Test + fun `test get agent card`() = runTest { + val agentCard = client.getAgentCard() + val expectedAgentCard = AgentCard(...) + assertEquals(expectedAgentCard, agentCard) // Full object comparison + } +} +``` + +## Implementation Patterns + +### Transport Layer Abstraction +- Implement `ClientTransport` interface for new client protocols +- Implement `ServerTransport` and `RequestHandler` for new server protocols +- Keep protocol logic separate from transport mechanism +- Use suspend functions for all transport operations + +### Storage Interface Pattern +```kotlin +public interface SomeStorage { + public suspend fun save(item: Item) + public suspend fun get(id: String): Item? + public suspend fun delete(id: String) +} + +// In-memory implementation for testing/development +@OptIn(InternalA2AApi::class) +public class InMemorySomeStorage : SomeStorage { + private val rwLock = RWLock() + private val items = mutableMapOf() + + override suspend fun save(item: Item) = rwLock.withWriteLock { + items[item.id] = item + } +} +``` + +### Session Management Pattern +- Use `SessionEventProcessor` for event-driven session handling +- Implement proper cleanup with `Closeable` interface +- Use structured concurrency with `CoroutineScope` +- Handle session lifecycle properly (start → active → closed) + +### Agent Executor Implementation +```kotlin +public class MyAgentExecutor : AgentExecutor { + override suspend fun execute( + context: RequestContext, + eventProcessor: SessionEventProcessor + ) { + // Process request and send events through eventProcessor + eventProcessor.sendTaskEvent(/* task event */) + } +} +``` + +## Some Common Commands + +```bash +# Build specific modules (a2a is a meta-module, build individual modules) +../gradlew :a2a:a2a-core:assemble +../gradlew :a2a:a2a-client:assemble +../gradlew :a2a:a2a-server:assemble + +# Run JVM tests for specific modules +../gradlew :a2a:a2a-core:jvmTest +../gradlew :a2a:a2a-client:jvmTest +../gradlew :a2a:a2a-server:jvmTest +../gradlew :a2a:a2a-transport:a2a-transport-core-jsonrpc:jvmTest + +# Run JS tests for specific modules +../gradlew :a2a:a2a-core:jsTest +../gradlew :a2a:a2a-client:jsTest +../gradlew :a2a:a2a-server:jsTest + +# Build all non-transport a2a modules at once +../gradlew :a2a:a2a-core:assemble :a2a:a2a-client:assemble :a2a:a2a-server:assemble + +# Run all JVM tests across a2a modules +../gradlew :a2a:a2a-core:jvmTest :a2a:a2a-client:jvmTest :a2a:a2a-server:jvmTest +``` diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt index 1bde51d937..3a861c24c7 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt @@ -1,5 +1,6 @@ package ai.koog.a2a.model +import kotlinx.datetime.Clock import kotlinx.datetime.Instant import kotlinx.serialization.EncodeDefault import kotlinx.serialization.SerialName @@ -46,7 +47,7 @@ public data class Task( public data class TaskStatus( public val state: TaskState, public val message: Message? = null, - public val timestamp: Instant? = null, + public val timestamp: Instant? = Clock.System.now(), ) /** diff --git a/a2a/a2a-server/build.gradle.kts b/a2a/a2a-server/build.gradle.kts index 29059c9caf..5d156ddd6d 100644 --- a/a2a/a2a-server/build.gradle.kts +++ b/a2a/a2a-server/build.gradle.kts @@ -21,6 +21,10 @@ kotlin { api(project(":a2a:a2a-core")) api(libs.kotlinx.serialization.json) api(libs.kotlinx.coroutines.core) + api(libs.ktor.client.core) + api(libs.ktor.client.content.negotiation) + api(libs.ktor.serialization.kotlinx.json) + implementation(libs.oshai.kotlin.logging) } } diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt index 47a83e6921..5c04be2341 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt @@ -1,7 +1,9 @@ package ai.koog.a2a.server +import ai.koog.a2a.exceptions.A2AAuthenticatedExtendedCardNotConfiguredException import ai.koog.a2a.exceptions.A2AInternalErrorException import ai.koog.a2a.exceptions.A2AInvalidParamsException +import ai.koog.a2a.exceptions.A2APushNotificationNotSupportedException import ai.koog.a2a.exceptions.A2ATaskNotFoundException import ai.koog.a2a.exceptions.A2AUnsupportedOperationException import ai.koog.a2a.model.AgentCard @@ -22,7 +24,10 @@ import ai.koog.a2a.server.agent.AgentExecutor import ai.koog.a2a.server.messages.ContextMessageStorage import ai.koog.a2a.server.messages.InMemoryMessageStorage import ai.koog.a2a.server.messages.MessageStorage +import ai.koog.a2a.server.notifications.PushNotificationConfigStorage +import ai.koog.a2a.server.notifications.PushNotificationSender import ai.koog.a2a.server.session.RequestContext +import ai.koog.a2a.server.session.Session import ai.koog.a2a.server.session.SessionEventProcessor import ai.koog.a2a.server.session.SessionManager import ai.koog.a2a.server.tasks.ContextTaskStorage @@ -32,8 +37,10 @@ import ai.koog.a2a.transport.Request import ai.koog.a2a.transport.RequestHandler import ai.koog.a2a.transport.Response import ai.koog.a2a.transport.ServerCallContext +import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.channelFlow import kotlinx.coroutines.flow.emitAll @@ -55,59 +62,41 @@ public open class A2AServer( protected val agentCardExtended: AgentCard? = null, protected val taskStorage: TaskStorage = InMemoryTaskStorage(), protected val messageStorage: MessageStorage = InMemoryMessageStorage(), + protected val pushConfigStorage: PushNotificationConfigStorage? = null, + protected val pushSender: PushNotificationSender? = null, protected val coroutineScope: CoroutineScope = CoroutineScope(SupervisorJob()), protected val clock: Clock = Clock.System, ) : RequestHandler { - protected val sessionManager: SessionManager = SessionManager(coroutineScope) + protected val sessionManager: SessionManager = SessionManager( + coroutineScope = coroutineScope, + taskStorage = taskStorage, + pushConfigStorage = pushConfigStorage, + pushSender = pushSender, + ) override suspend fun onGetAuthenticatedExtendedAgentCard( request: Request, ctx: ServerCallContext ): Response { + if (agentCard.supportsAuthenticatedExtendedCard != true) { + throw A2AAuthenticatedExtendedCardNotConfiguredException("Extended agent card is not supported") + } + // Default server implementation does not provide authorization, return extended card directly if present return Response( - data = agentCardExtended ?: agentCard, + data = agentCardExtended + ?: throw A2AAuthenticatedExtendedCardNotConfiguredException("Extended agent card is supported but not configured on the server"), id = request.id ) } - override suspend fun onSendMessage( - request: Request, - ctx: ServerCallContext - ): Response { - val messageConfiguration = request.data.configuration - // Reusing streaming logic here, because it's essentially the same, only we need some particular event from the stream - val eventStream = onSendMessageStreaming(request, ctx) - - return if (messageConfiguration?.blocking == true) { - // If blocking is requested, attempt to wait for the last event, until the current turn of the agent execution is finished. - val lastEventResponse = eventStream.last() - - when (val eventData = lastEventResponse.data) { - is Message -> Response(data = eventData, id = lastEventResponse.id) - is TaskEvent -> - taskStorage - .get( - eventData.taskId, - historyLength = messageConfiguration.historyLength, - includeArtifacts = true - ) - ?.let { Response(data = it, id = lastEventResponse.id) } - ?: throw A2ATaskNotFoundException("Task '${eventData.taskId}' not found after the agent execution") - } - } else { - // Else read the first event from the stream, check that it's a proper communication event and return it. - val firstEventResponse = eventStream.first() - - when (val eventData = firstEventResponse.data) { - is Message -> Response(data = eventData, id = firstEventResponse.id) - is Task -> Response(data = eventData, id = firstEventResponse.id) - else -> throw A2AInternalErrorException("Got unexpected event type from the agent '${eventData::class.simpleName}'") - } - } - } - - override fun onSendMessageStreaming( + /** + * Common logic for handling incoming messages and starting the agent execution. + * Does all the setup and validation, creates event stream. + * + * @return A stream of events from the agent + */ + protected fun onSendMessageCommon( request: Request, ctx: ServerCallContext ): Flow> = channelFlow { @@ -117,12 +106,14 @@ public open class A2AServer( // Check if message links to a task. val eventProcessor = if (taskId != null) { // Check if the task is still in progress, no message can be sent. - if (sessionManager.processorForTask(taskId) != null) { + if (sessionManager.sessionForTask(taskId) != null) { throw A2AUnsupportedOperationException("Task '$taskId' is still running, can't send messages to the task that has not yielded control") } // Check if the specified task exists and message context id matches the task context id. - val task = taskStorage.get(taskId) ?: throw A2ATaskNotFoundException("Task '$taskId' not found") + val task = taskStorage.get(taskId, historyLength = 0, includeArtifacts = false) + ?: throw A2ATaskNotFoundException("Task '$taskId' not found") + if (message.contextId != task.contextId) { throw A2AInvalidParamsException("Message context id '${message.contextId}' doesn't match task context id '${task.contextId}'") } @@ -143,8 +134,6 @@ public open class A2AServer( // Use specified context id or generate a new random one. coroutineScope = coroutineScope, ) - }.also { - sessionManager.addProcessor(it) } // Create request context based on the request information. @@ -156,20 +145,68 @@ public open class A2AServer( messageStorage = ContextMessageStorage(eventProcessor.contextId, messageStorage), ) + // Create agent execution session + val session = Session(coroutineScope, eventProcessor) { + agentExecutor.execute(requestContext, eventProcessor) + } + // Subscribe to events stream and start emitting them. - val collectionJob = launch { - eventProcessor.events + launch { + session.events .collect { event -> send(Response(data = event, id = request.id)) } } - // Execute the agent. - agentExecutor.execute(requestContext, eventProcessor) + // Add to session manager, it will handle monitoring and closing once the session is completed (successfully or not). + sessionManager.addSession(session) + + // Start the session to execute the agent and wait for it to finish. + session.join() + } + + override suspend fun onSendMessage( + request: Request, + ctx: ServerCallContext + ): Response { + val messageConfiguration = request.data.configuration + // Reusing streaming logic here, because it's essentially the same, only we need some particular event from the stream + val eventStream = onSendMessageCommon(request, ctx) + + return if (messageConfiguration?.blocking == true) { + // If blocking is requested, attempt to wait for the last event, until the current turn of the agent execution is finished. + val lastEventResponse = eventStream.last() + + when (val eventData = lastEventResponse.data) { + is Message -> Response(data = eventData, id = lastEventResponse.id) + is TaskEvent -> + taskStorage + .get( + eventData.taskId, + historyLength = messageConfiguration.historyLength, + includeArtifacts = true + ) + ?.let { Response(data = it, id = lastEventResponse.id) } + ?: throw A2ATaskNotFoundException("Task '${eventData.taskId}' not found after the agent execution") + } + } else { + // Else read the first event from the stream, check that it's a proper communication event and return it. + val firstEventResponse = eventStream.first() - // Close event processor session and collecting job - eventProcessor.close() - collectionJob.cancel() + when (val eventData = firstEventResponse.data) { + is Message -> Response(data = eventData, id = firstEventResponse.id) + is Task -> Response(data = eventData, id = firstEventResponse.id) + else -> throw A2AInternalErrorException("Got unexpected event type from the agent '${eventData::class.simpleName}'") + } + } + } + + override fun onSendMessageStreaming( + request: Request, + ctx: ServerCallContext + ): Flow> = flow { + checkStreamingSupport() + emitAll(onSendMessageCommon(request, ctx)) } override suspend fun onGetTask( @@ -179,7 +216,7 @@ public open class A2AServer( val taskParams = request.data return Response( - data = taskStorage.get(taskParams.id, historyLength = taskParams.historyLength) + data = taskStorage.get(taskParams.id, historyLength = taskParams.historyLength, includeArtifacts = false) ?: throw A2ATaskNotFoundException("Task '${taskParams.id}' not found"), id = request.id, ) @@ -190,14 +227,14 @@ public open class A2AServer( ctx: ServerCallContext ): Response { val taskParams = request.data - val eventProcessor = sessionManager.processorForTask(taskParams.id) + val session = sessionManager.sessionForTask(taskParams.id) // Task is not running, check if it exists in the storage. - if (eventProcessor == null) { + if (session == null) { val task = taskStorage.get(taskParams.id, historyLength = 0, includeArtifacts = true) ?: throw A2ATaskNotFoundException("Task '${taskParams.id}' not found") - // Task exists but not running - check if it is already canceled + // Task exists but not running - check if it is already canceled. if (task.status.state == TaskState.Canceled) { return Response(data = task, id = request.id) } @@ -225,15 +262,15 @@ public open class A2AServer( contextId = taskParams.id, callContext = ctx, params = request.data, - taskStorage = ContextTaskStorage(eventProcessor.contextId, taskStorage), - messageStorage = ContextMessageStorage(eventProcessor.contextId, messageStorage), + taskStorage = ContextTaskStorage(session.contextId, taskStorage), + messageStorage = ContextMessageStorage(session.contextId, messageStorage), ) // Attempt to cancel the agent execution and wait until it's finished. - agentExecutor.cancel(requestContext, eventProcessor) + agentExecutor.cancel(requestContext, session) - // If `cancel` finished without exceptions, assume the cancellation was successful and close event processor session too. - eventProcessor.close() + // If cancel finished without exception, assume the cancellation was successful and close the session explicitly. + session.close() } // Return the final task state. @@ -248,12 +285,14 @@ public open class A2AServer( request: Request, ctx: ServerCallContext ): Flow> = flow { + checkStreamingSupport() + val taskParams = request.data - val eventProcessor = sessionManager.processorForTask(taskParams.id) + val session = sessionManager.sessionForTask(taskParams.id) ?: throw A2AUnsupportedOperationException("Task '${taskParams.id}' is not currently running or does not exist") emitAll( - eventProcessor.events + session.events .map { event -> Response(data = event, id = request.id) } ) } @@ -262,27 +301,79 @@ public open class A2AServer( request: Request, ctx: ServerCallContext ): Response { - TODO("Not yet implemented") + val pushStorage = storageIfPushNotificationSupported() + val taskPushConfig = request.data + + pushStorage.save(taskPushConfig.taskId, taskPushConfig.pushNotificationConfig) + + return Response(data = taskPushConfig, id = request.id) } override suspend fun onGetTaskPushNotificationConfig( request: Request, ctx: ServerCallContext ): Response { - TODO("Not yet implemented") + val pushStorage = storageIfPushNotificationSupported() + val pushConfigParams = request.data + + val pushConfig = pushStorage.get(pushConfigParams.id, pushConfigParams.pushNotificationConfigId) + ?: throw NoSuchElementException("Can't find push notification config with id '${pushConfigParams.pushNotificationConfigId}' for task '${pushConfigParams.id}'") + + return Response( + data = TaskPushNotificationConfig( + taskId = pushConfigParams.id, + pushNotificationConfig = pushConfig + ), + id = request.id + ) } override suspend fun onListTaskPushNotificationConfig( request: Request, ctx: ServerCallContext ): Response> { - TODO("Not yet implemented") + val pushStorage = storageIfPushNotificationSupported() + val taskParams = request.data + + return Response( + data = pushStorage + .getAll(taskParams.id) + .map { TaskPushNotificationConfig(taskId = taskParams.id, pushNotificationConfig = it) }, + id = request.id + ) } override suspend fun onDeleteTaskPushNotificationConfig( request: Request, ctx: ServerCallContext ): Response { - TODO("Not yet implemented") + val pushStorage = storageIfPushNotificationSupported() + val taskPushConfigParams = request.data + + pushStorage.delete(taskPushConfigParams.id, taskPushConfigParams.pushNotificationConfigId) + + return Response(data = null, id = request.id) + } + + protected fun checkStreamingSupport() { + if (agentCard.capabilities.streaming != true) { + throw A2AUnsupportedOperationException("Streaming is not supported by the server") + } + } + + protected fun storageIfPushNotificationSupported(): PushNotificationConfigStorage { + if (agentCard.capabilities.pushNotifications != true) { + throw A2APushNotificationNotSupportedException("Push notifications are not supported by the server") + } + + if (pushConfigStorage == null) { + throw A2APushNotificationNotSupportedException("Push notifications are supported, but not configured on the server") + } + + return pushConfigStorage + } + + public fun cancel(cause: CancellationException? = null) { + coroutineScope.cancel(cause) } } diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt index 6b60dbb07e..18b66ef4f9 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt @@ -1,14 +1,15 @@ package ai.koog.a2a.server.agent import ai.koog.a2a.exceptions.A2AContentTypeNotSupportedException +import ai.koog.a2a.exceptions.A2ATaskNotCancelableException import ai.koog.a2a.exceptions.A2AUnsupportedOperationException import ai.koog.a2a.model.Message import ai.koog.a2a.model.MessageSendParams import ai.koog.a2a.model.TaskEvent import ai.koog.a2a.model.TaskIdParams import ai.koog.a2a.model.TaskState -import ai.koog.a2a.model.TaskStatusUpdateEvent import ai.koog.a2a.server.session.RequestContext +import ai.koog.a2a.server.session.Session import ai.koog.a2a.server.session.SessionEventProcessor import kotlin.jvm.JvmName @@ -26,25 +27,72 @@ public interface AgentExecutor { * * Can throw an exception if the input is invalid or the agent fails to execute the request. * - * @throws Exception if something goes wrong during execution. Should prefer more specific exceptions when available, + * @param context The context containing the necessary information and accessors for executing the agent. + * @param eventProcessor The event processor to publish events to. + * @throws Exception if something goes wrong during execution. Should prefer more specific exceptions when possible, * e.g., [A2AContentTypeNotSupportedException], [A2AUnsupportedOperationException], etc. See full list of available * A2A exceptions in [ai.koog.a2a.exceptions]. */ public suspend fun execute(context: RequestContext, eventProcessor: SessionEventProcessor) /** - * Request the agent to cancel an ongoing task. + * Request to cancel an ongoing task in the running [session]. * - * The agent should attempt to stop the task identified by the task id in the context and publish a [TaskStatusUpdateEvent] with state - * [TaskState.Canceled] to the [eventProcessor]. + * The executor should attempt to stop the task identified by the task id in the [context] or throw an exception if + * cancellation is not supported or not possible, e.g. [A2ATaskNotCancelableException]. * - * Can throw an exception if the agent fails to cancel the task. + * If this method finishes normally, it will be considered successful cancellation and the [session] will be explicitly closed. + * This means the agent execution job (the code running in the [execute]) will be canceled, and + * [SessionEventProcessor] associated with this session will be closed. * - * @throws Exception if something goes wrong during execution. Should prefer more specific exceptions when available, - * e.g., [A2AContentTypeNotSupportedException], [A2AUnsupportedOperationException], etc. See full list of available - * A2A exceptions in [ai.koog.a2a.exceptions]. + * Implementations can call [Session.close] explicitly themselves if they want to stop the agent execution first and + * then perform some cleanup afterwards, e.g., closing connection to external resources. + * + * Must throw an exception if the cancellation fails or is impossible. + * + * Default implementation does nothing, meaning cancellations will always be successful and the [session] will be closed + * immediately. + * + * Example simple implementation: + * ```kotlin + * // Explicitly close the session to stop the agent execution job and event processor + * session.close() + * // Log the fact that the task was canceled + * log.info("Task '${context.taskId}' canceled") + * ``` + * + * Example more advanced implementation: + * ```kotlin + * // Cancel only the agent execution job to terminate the agent run, but keep event processor running. + * session.agentJob.cancel() + * // Send task cancellation event with custom message to event processor + * session.eventProcessor.sendTaskEvent( + * TaskStatusUpdateEvent( + * taskId = context.taskId, + * contextId = context.contextId, + * status = TaskStatus( + * state = TaskState.Canceled, + * message = Message( + * role = Role.Agent, + * taskId = context.taskId, + * contextId = context.contextId, + * parts = listOf( + * TextPart("Task was canceled by the user") + * ) + * ) + * ), + * final = true, + * ) + * ) + * // Close the session completely + * session.close() + * ``` + * + * @throws Exception if something goes wrong during execution or the cancellation is impossible. Should prefer more + * specific exceptions when available, e.g., [A2ATaskNotCancelableException], [A2AUnsupportedOperationException], etc. + * See full list of available A2A exceptions in [ai.koog.a2a.exceptions]. */ - public suspend fun cancel(context: RequestContext, eventProcessor: SessionEventProcessor) + public suspend fun cancel(context: RequestContext, session: Session) {} } /** diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/exceptions/Exceptions.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/exceptions/Exceptions.kt index 957ba80034..8d1ad19e46 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/exceptions/Exceptions.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/exceptions/Exceptions.kt @@ -16,3 +16,8 @@ public class MessageOperationException(message: String, cause: Throwable? = null * Indicates a failure in sending an event through the [SessionEventProcessor] because of invalid event. */ public class InvalidEventException(message: String, cause: Throwable? = null) : Exception(message, cause) + +/** + * An exception that is thrown to indicate errors occurring during push notification operations. + */ +public class PushNotificationException(message: String, cause: Throwable? = null) : Exception(message, cause) diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/InMemoryPushNotificationConfigStorage.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/InMemoryPushNotificationConfigStorage.kt new file mode 100644 index 0000000000..ca197fc5e3 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/InMemoryPushNotificationConfigStorage.kt @@ -0,0 +1,50 @@ +package ai.koog.a2a.server.notifications + +import ai.koog.a2a.annotations.InternalA2AApi +import ai.koog.a2a.model.PushNotificationConfig +import ai.koog.a2a.utils.RWLock + +/** + * In-memory implementation of [PushNotificationConfigStorage] using a thread-safe map. + * + * This implementation stores push notification configurations in memory grouped by task ID + * and provides concurrency safety through read-write locks. + */ +@OptIn(InternalA2AApi::class) +public class InMemoryPushNotificationConfigStorage : PushNotificationConfigStorage { + private val configsByTaskId = mutableMapOf>() + private val rwLock = RWLock() + + override suspend fun save(taskId: String, pushNotificationConfig: PushNotificationConfig): Unit = + rwLock.withWriteLock { + val configId = pushNotificationConfig.id + val taskConfigs = configsByTaskId.getOrPut(taskId) { mutableMapOf() } + + taskConfigs[configId] = pushNotificationConfig + } + + override suspend fun getAll(taskId: String): List = rwLock.withReadLock { + configsByTaskId[taskId]?.values?.toList() ?: emptyList() + } + + override suspend fun get(taskId: String, configId: String?): PushNotificationConfig? = rwLock.withReadLock { + configsByTaskId[taskId]?.get(configId) + } + + override suspend fun delete(taskId: String, configId: String?): Unit = rwLock.withWriteLock { + if (configId == null) { + // Delete all configurations for the task + configsByTaskId.remove(taskId) + } else { + configsByTaskId[taskId]?.let { taskConfigs -> + // Delete specific configuration + taskConfigs.remove(configId) + + // Clean up empty task entry + if (taskConfigs.isEmpty()) { + configsByTaskId.remove(taskId) + } + } + } + } +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/PushNotificationConfigStorage.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/PushNotificationConfigStorage.kt new file mode 100644 index 0000000000..4f5ea18053 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/PushNotificationConfigStorage.kt @@ -0,0 +1,43 @@ +package ai.koog.a2a.server.notifications + +import ai.koog.a2a.model.PushNotificationConfig +import ai.koog.a2a.server.exceptions.PushNotificationException + +/** + * Interface for managing the storage of push notification configurations associated with task updates. + * + * Implementations must ensure concurrency safety. + */ +public interface PushNotificationConfigStorage { + /** + * Saves a push notification configuration for a specified task ID. + * + * @param taskId Task ID for which to save the configration. + * @param pushNotificationConfig Config instance containing the details of the notification setup. + * @throws PushNotificationException if the configuration cannot be saved. + */ + public suspend fun save(taskId: String, pushNotificationConfig: PushNotificationConfig) + + /** + * Retrieves a push notification configuration for a specified task ID and configuration ID. + * @param taskId Task ID for which to retrieve the configuration. + * @param configId Configuration ID for which to retrieve the configuration. + */ + public suspend fun get(taskId: String, configId: String?): PushNotificationConfig? + + /** + * Retrieves all push notification configurations associated with the given task ID. + * + * @param taskId Task ID for which to retrieve the configurations. + */ + public suspend fun getAll(taskId: String): List + + /** + * Deletes all push notification configurations for a specified task ID, optionally deleting a specific configuration instead. + * + * @param taskId Task ID for which to delete the configurations. + * @param configId Optional configuration ID to delete. Defaults to `null`, meaning all configurations for the task will be deleted. + * @throws PushNotificationException if the configuration cannot be deleted. + */ + public suspend fun delete(taskId: String, configId: String? = null) +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/PushNotificationSender.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/PushNotificationSender.kt new file mode 100644 index 0000000000..2be5fcbd74 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/PushNotificationSender.kt @@ -0,0 +1,26 @@ +package ai.koog.a2a.server.notifications + +import ai.koog.a2a.model.PushNotificationConfig +import ai.koog.a2a.model.Task + +/** + * Interface for sending push notifications. + * + * [More info on push notifications in specification](https://a2a-protocol.org/latest/specification/#95-push-notification-setup-and-usage) + */ +public interface PushNotificationSender { + public companion object { + /** + * Represents a custom optional HTTP header used to include a token for authenticating A2A notifications. + */ + public const val A2A_NOTIFICATION_TOKEN_HEADER: String = "X-A2A-Notification-Token" + } + + /** + * Sends a push notification. + * + * @param config Push notification configuration. + * @param task Task object to send in the notification. + */ + public suspend fun send(config: PushNotificationConfig, task: Task) +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/SimplePushNotificationSender.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/SimplePushNotificationSender.kt new file mode 100644 index 0000000000..2593b453c9 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/notifications/SimplePushNotificationSender.kt @@ -0,0 +1,67 @@ +package ai.koog.a2a.server.notifications + +import ai.koog.a2a.model.PushNotificationConfig +import ai.koog.a2a.model.Task +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.client.HttpClient +import io.ktor.client.plugins.contentnegotiation.ContentNegotiation +import io.ktor.client.request.post +import io.ktor.client.request.setBody +import io.ktor.http.HttpHeaders +import io.ktor.serialization.kotlinx.json.json +import io.ktor.utils.io.core.Closeable +import kotlinx.serialization.json.Json + +/** + * Simple implementation of a notification sender. + * Doesn't perform any configuration validation. + * Always takes the first authentication scheme provided in [PushNotificationConfig.authentication] + */ +public class SimplePushNotificationSender( + baseHttpClient: HttpClient, + json: Json = Json, +) : PushNotificationSender, Closeable { + private companion object { + private val logger = KotlinLogging.logger {} + } + + private val httpClient = baseHttpClient.config { + install(ContentNegotiation) { + json(json) + } + + expectSuccess = true + } + + override suspend fun send(config: PushNotificationConfig, task: Task) { + try { + logger.debug { "Sending push notification configId='${config.id} for taskId='${task.id}'" } + + httpClient.post(config.url) { + config.authentication?.let { auth -> + // Simple sender always takes the first scheme from the list + val schema = auth.schemes.firstOrNull() + val credentials = auth.credentials + + if (schema != null && credentials != null) { + headers[HttpHeaders.Authorization] = "$schema $credentials" + } + } + + config.token?.let { token -> + headers[PushNotificationSender.A2A_NOTIFICATION_TOKEN_HEADER] = token + } + + setBody(task) + } + + logger.debug { "Sent push notification successfully configId='${config.id} for taskId='${task.id}'" } + } catch (e: Exception) { + logger.warn(e) { "Failed to send push notification configId='${config.id} for taskId='${task.id}'" } + } + } + + override fun close() { + httpClient.close() + } +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt new file mode 100644 index 0000000000..7766e5f837 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt @@ -0,0 +1,38 @@ +package ai.koog.a2a.server.session + +import ai.koog.a2a.model.Event +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.CoroutineStart +import kotlinx.coroutines.Job +import kotlinx.coroutines.flow.SharedFlow +import kotlinx.coroutines.launch + +public class Session( + public val eventProcessor: SessionEventProcessor, + public val agentJob: Job +) { + public val events: SharedFlow get() = eventProcessor.events + public val contextId: String get() = eventProcessor.contextId + + public fun start() { + agentJob.start() + } + + public suspend fun join() { + agentJob.join() + } + + public fun close() { + agentJob.cancel() + eventProcessor.close() + } +} + +public fun Session( + coroutineScope: CoroutineScope, + eventProcessor: SessionEventProcessor, + agentAction: suspend CoroutineScope.() -> Unit +): Session { + val agentJob = coroutineScope.launch(start = CoroutineStart.LAZY, block = agentAction) + return Session(eventProcessor, agentJob) +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt index f419ab17fe..4352c8ca57 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt @@ -143,7 +143,7 @@ public class SessionEventProcessor( is SessionType.TaskSession -> sessionType as SessionType.TaskSession null -> { - val savedTask = taskStorage.get(event.taskId) + val savedTask = taskStorage.get(event.taskId, historyLength = 0, includeArtifacts = false) SessionType.TaskSession( taskId = event.taskId, diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt index b7c3cf84df..94aed0e3bc 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt @@ -3,92 +3,114 @@ package ai.koog.a2a.server.session import ai.koog.a2a.annotations.InternalA2AApi import ai.koog.a2a.model.Message import ai.koog.a2a.model.TaskEvent +import ai.koog.a2a.server.notifications.PushNotificationConfigStorage +import ai.koog.a2a.server.notifications.PushNotificationSender +import ai.koog.a2a.server.tasks.TaskStorage import ai.koog.a2a.utils.RWLock import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.flow.collect import kotlinx.coroutines.flow.first -import kotlinx.coroutines.flow.onCompletion import kotlinx.coroutines.launch /** - * Manages a set of active instances of [SessionEventProcessor]. + * Manages a set of active instances of [Session], sends push notifications if configured after each session completes. * - * Each added processor is monitored for task id associated with this session, if any, i.e., the session is processing a task, - * and if it is a task-related session, the processor is added to the task sessions map. + * Each session's event stream is monitored for task id associated with this session, if any, i.e., the session is processing a task, + * and if it is a task-related session, it is added to the task sessions map. * - * Automatically removes the processor when the session is closed. + * Automatically closes and removes the session when it is completed (whether successfully or not). + * + * Additionally, if push notifications are configured, after each task session completes, push notifications are sent with + * the current task state. * * @param coroutineScope The scope in which the monitoring jobs will be launched. + * @param taskStorage The storage for tasks. + * @param pushConfigStorage The storage for push notification configurations. + * @param pushSender The push notification sender. */ @OptIn(InternalA2AApi::class) public class SessionManager( private val coroutineScope: CoroutineScope, + private val taskStorage: TaskStorage, + private val pushConfigStorage: PushNotificationConfigStorage? = null, + private val pushSender: PushNotificationSender? = null, ) { - private val allProcessors = mutableSetOf() - private val taskProcessors = mutableMapOf() + private val allSessions = mutableSetOf() + private val taskSessions = mutableMapOf() private val rwLock = RWLock() /** - * Adds a session event processor to a set of active processors. - * If the first event in the processor is of type [TaskEvent], the processor is added to the task sessions map too. - * Handles cleanup by removing the processor when the session is closed. + * Adds a session to a set of active sessions. + * If the first event in the session events stream is of type [TaskEvent], the session is added to the task sessions map too. + * + * Handles cleanup by closing and removing the session when it is completed (whether successfully or not). * - * @param eventProcessor The session event processor to be added. + * @param session The session to add. */ - public fun addProcessor(eventProcessor: SessionEventProcessor) { + public fun addSession(session: Session) { coroutineScope.launch { - // Check if the first event in the session processor is task related and add this processor to the task sessions map. - when (val firstEvent = eventProcessor.events.first()) { + // Check if the first event is task related and add this session to the task sessions map. + when (val firstEvent = session.events.first()) { is TaskEvent -> { val taskId = firstEvent.taskId rwLock.withWriteLock { - check(taskId !in taskProcessors) { + check(taskId !in taskSessions) { "SessionEventProcessor for taskId '${firstEvent.taskId}' already exists." } - allProcessors += eventProcessor - taskProcessors[firstEvent.taskId] = eventProcessor + allSessions += session + taskSessions[firstEvent.taskId] = session + } + + // Wait for the session to complete, then close and remove it from collections. + session.join() + + rwLock.withWriteLock { + session.close() + allSessions -= session + taskSessions -= taskId } - // Wait for the session to close and remove the processor from collections. - eventProcessor.events - .onCompletion { - rwLock.withWriteLock { - allProcessors -= eventProcessor - taskProcessors -= taskId + // Send push notifications with the current state of the task, after the session completion, if configured. + if (pushSender != null && pushConfigStorage != null) { + val task = taskStorage.get(taskId, historyLength = 0) + + if (task != null) { + pushConfigStorage.getAll(taskId).forEach { config -> + pushSender.send(config, task) } } - .collect() + } } is Message -> { - allProcessors += eventProcessor + rwLock.withWriteLock { + allSessions += session + } - // Wait for the session to close and remove the processor from collections. - eventProcessor.events - .onCompletion { - rwLock.withWriteLock { - allProcessors -= eventProcessor - } - } - .collect() + // Wait for the session to complete, then close and remove it from collection. + session.join() + + rwLock.withWriteLock { + session.close() + allSessions -= session + } } } } } /** - * Returns the session event processor for the given task id, if any. + * Returns the session for the given task id, if any. */ - public suspend fun processorForTask(taskId: String): SessionEventProcessor? = rwLock.withReadLock { - taskProcessors[taskId] + public suspend fun sessionForTask(taskId: String): Session? = rwLock.withReadLock { + taskSessions[taskId] } /** - * Returns the number of active session event processors. + * Returns the number of active sessions. */ - public suspend fun activeProcessors(): Int = rwLock.withReadLock { - allProcessors.size + public suspend fun activeSessions(): Int = rwLock.withReadLock { + allSessions.size } } diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorage.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorage.kt index 14ddd44e81..52a4ced6cb 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorage.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorage.kt @@ -7,6 +7,7 @@ import ai.koog.a2a.model.TaskEvent import ai.koog.a2a.model.TaskStatusUpdateEvent import ai.koog.a2a.server.exceptions.TaskOperationException import ai.koog.a2a.utils.RWLock +import kotlinx.serialization.json.JsonObject /** * In-memory implementation of [TaskStorage] using a thread-safe map. @@ -54,48 +55,51 @@ public class InMemoryTaskStorage : TaskStorage { override suspend fun getAll( taskIds: List, historyLength: Int?, - includeArtefacts: Boolean + includeArtifacts: Boolean ): List = rwLock.withReadLock { taskIds.mapNotNull { taskId -> - get(taskId, historyLength, includeArtefacts) + get(taskId, historyLength, includeArtifacts) } } override suspend fun getByContext( contextId: String, historyLength: Int?, - includeArtefacts: Boolean + includeArtifacts: Boolean ): List = rwLock.withReadLock { val contextTaskIds = tasksByContext[contextId] ?: emptySet() contextTaskIds.mapNotNull { taskId -> - get(taskId, historyLength, includeArtefacts) + get(taskId, historyLength, includeArtifacts) } } override suspend fun update(event: TaskEvent): Unit = rwLock.withWriteLock { when (event) { is Task -> { - // Store or replace the task val oldTask = tasks[event.id] + + if (oldTask != null && event.contextId != oldTask.contextId) { + throw TaskOperationException("Cannot change context for existing task: ${event.id}") + } + + // Store or replace the task tasks[event.id] = event // Update context index tasksByContext.getOrPut(event.contextId) { mutableSetOf() }.add(event.id) - - // Remove from old context if it changed - if (oldTask != null && oldTask.contextId != event.contextId) { - tasksByContext[oldTask.contextId]?.remove(event.id) - if (tasksByContext[oldTask.contextId]?.isEmpty() == true) { - tasksByContext.remove(oldTask.contextId) - } - } } is TaskStatusUpdateEvent -> { val existingTask = tasks[event.taskId] ?: throw TaskOperationException("Cannot update status for non-existing task: ${event.taskId}") - val updatedTask = existingTask.copy(status = event.status) + val updatedTask = existingTask.copy( + status = event.status, + metadata = existingTask.metadata + ?.let { JsonObject(it + event.metadata.orEmpty()) } + ?: event.metadata + ) + tasks[event.taskId] = updatedTask } @@ -118,7 +122,13 @@ public class InMemoryTaskStorage : TaskStorage { currentArtifacts.add(event.artifact) } - val updatedTask = existingTask.copy(artifacts = currentArtifacts) + val updatedTask = existingTask.copy( + artifacts = currentArtifacts, + metadata = existingTask.metadata + ?.let { JsonObject(it + event.metadata.orEmpty()) } + ?: event.metadata + ) + tasks[event.taskId] = updatedTask } } diff --git a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/notifications/InMemoryPushNotificationConfigStorageTest.kt b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/notifications/InMemoryPushNotificationConfigStorageTest.kt new file mode 100644 index 0000000000..38e3ac8abf --- /dev/null +++ b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/notifications/InMemoryPushNotificationConfigStorageTest.kt @@ -0,0 +1,92 @@ +package ai.koog.a2a.server.notifications + +import ai.koog.a2a.model.PushNotificationConfig +import kotlinx.coroutines.test.runTest +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class InMemoryPushNotificationConfigStorageTest { + private lateinit var storage: InMemoryPushNotificationConfigStorage + + @BeforeTest + fun setUp() { + storage = InMemoryPushNotificationConfigStorage() + } + + @Test + fun testSaveMultipleConfigsWithMixedIds() = runTest { + val configWithId = PushNotificationConfig( + id = "config-1", + url = "https://webhook1.example.com" + ) + val configWithoutId = PushNotificationConfig( + id = null, + url = "https://webhook2.example.com" + ) + + storage.save("task-1", configWithId) + storage.save("task-1", configWithoutId) + + val retrieved = storage.getAll("task-1") + assertEquals(setOf(configWithId, configWithoutId), retrieved.toSet()) + } + + @Test + fun testOverwriteExistingConfig() = runTest { + val originalConfig = PushNotificationConfig( + id = "config-1", + url = "https://webhook1.example.com" + ) + val updatedConfig = PushNotificationConfig( + id = "config-1", + url = "https://webhook-updated.example.com" + ) + + storage.save("task-1", originalConfig) + storage.save("task-1", updatedConfig) + + val retrieved = storage.getAll("task-1") + assertEquals(setOf(updatedConfig), retrieved.toSet()) + } + + @Test + fun testDeleteAllConfigsForTask() = runTest { + val config = PushNotificationConfig( + id = "config-1", + url = "https://webhook.example.com" + ) + + storage.save("task-1", config) + storage.delete("task-1", null) + + val remaining = storage.getAll("task-1") + assertTrue(remaining.isEmpty()) + } + + @Test + fun testDeleteSpecificConfig() = runTest { + val config1 = PushNotificationConfig( + id = "config-1", + url = "https://webhook1.example.com" + ) + val config2 = PushNotificationConfig( + id = "config-2", + url = "https://webhook2.example.com" + ) + + storage.save("task-1", config1) + storage.save("task-1", config2) + storage.delete("task-1", "config-1") + + val remaining = storage.getAll("task-1") + assertEquals(setOf(config2), remaining.toSet()) + } + + @Test + fun testGetAllForNonExistentTask() = runTest { + val configs = storage.getAll("non-existent-task") + assertTrue(configs.isEmpty()) + } +} diff --git a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorageTest.kt b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorageTest.kt index 5cdf815282..93279e406e 100644 --- a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorageTest.kt +++ b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorageTest.kt @@ -12,6 +12,9 @@ import ai.koog.a2a.model.TextPart import ai.koog.a2a.server.exceptions.TaskOperationException import kotlinx.coroutines.test.runTest import kotlinx.datetime.Instant +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive +import kotlinx.serialization.json.buildJsonObject import kotlin.test.BeforeTest import kotlin.test.Test import kotlin.test.assertEquals @@ -78,11 +81,19 @@ class InMemoryTaskStorageTest { @Test fun testTaskStatusUpdateEvent() = runTest { - // Create and store initial task - val task = createTask(id = "task-1", contextId = "context-1") + // Create and store initial task with metadata + val initialMetadata = buildJsonObject { + put("initialKey", JsonPrimitive("initialValue")) + put("sharedKey", JsonPrimitive("originalValue")) + } + val task = createTask(id = "task-1", contextId = "context-1", metadata = initialMetadata) storage.update(task) - // Create a status update event + // Create a status update event with additional metadata + val updateMetadata = buildJsonObject { + put("newKey", JsonPrimitive("newValue")) + put("sharedKey", JsonPrimitive("updatedValue")) + } val newMessage = createUserMessage("status-msg", "context-1", "Task completed successfully") val newStatus = TaskStatus( state = TaskState.Completed, @@ -93,15 +104,24 @@ class InMemoryTaskStorageTest { taskId = "task-1", contextId = "context-1", status = newStatus, + metadata = updateMetadata, final = true ) // Update task status storage.update(statusUpdateEvent) - // Verify the status was updated + // Verify the status was updated and metadata was merged val retrieved = storage.get("task-1") assertEquals(newStatus, retrieved?.status) + + // Verify metadata merging: original + new with updates overriding + val expectedMetadata = buildJsonObject { + put("initialKey", JsonPrimitive("initialValue")) // preserved from original + put("sharedKey", JsonPrimitive("updatedValue")) // updated from event + put("newKey", JsonPrimitive("newValue")) // added from event + } + assertEquals(expectedMetadata, retrieved?.metadata) } @Test @@ -239,7 +259,8 @@ class InMemoryTaskStorageTest { id: String, contextId: String, history: List? = null, - artifacts: List? = null + artifacts: List? = null, + metadata: JsonObject? = null ) = Task( id = id, contextId = contextId, @@ -248,6 +269,7 @@ class InMemoryTaskStorageTest { timestamp = Instant.parse("2023-01-01T10:00:00Z") ), history = history, - artifacts = artifacts + artifacts = artifacts, + metadata = metadata ) } diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts index 2bcaef7fc4..592b576661 100644 --- a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts @@ -19,18 +19,12 @@ kotlin { commonMain { dependencies { api(project(":a2a:a2a-transport:a2a-transport-core-jsonrpc")) - api(libs.kotlinx.serialization.json) api(libs.kotlinx.coroutines.core) - } - } - - jvmMain { - dependencies { + api(libs.kotlinx.serialization.json) api(libs.ktor.server.core) - implementation(libs.ktor.server.sse) - implementation(libs.ktor.server.content.negotiation) implementation(libs.ktor.serialization.kotlinx.json) - implementation(libs.ktor.server.netty) + implementation(libs.ktor.server.content.negotiation) + implementation(libs.ktor.server.sse) } } diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt similarity index 94% rename from a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt rename to a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt index 00004fff26..9b9ceb6eb3 100644 --- a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt @@ -15,9 +15,10 @@ import io.ktor.serialization.kotlinx.json.json import io.ktor.server.application.ApplicationCall import io.ktor.server.application.install import io.ktor.server.application.pluginOrNull +import io.ktor.server.engine.ApplicationEngine +import io.ktor.server.engine.ApplicationEngineFactory import io.ktor.server.engine.EmbeddedServer import io.ktor.server.engine.embeddedServer -import io.ktor.server.netty.Netty import io.ktor.server.plugins.contentnegotiation.ContentNegotiation import io.ktor.server.request.receiveText import io.ktor.server.response.respond @@ -48,7 +49,7 @@ import kotlinx.serialization.serializer * requestHandler = A2AServer(...) * ) * - * transport.start(port = 8080, path = "/my-agent", agentCard = AgentCard(...), agentCardPath = "/my-agent-card.json") + * transport.start(Netty, 8080, "/my-agent", agentCard = AgentCard(...), agentCardPath = "/my-agent-card.json") * transport.stop() * ``` * @@ -100,6 +101,7 @@ public class HttpJSONRPCServerTransport( * If you need to integrate A2A request handling logic into existing Ktor application, * use [transportRoutes] method to mount the transport routes into existing [Route] configuration block. * + * @param engineFactory An application engine factory. * @param port A port on which the server will listen. * @param path A JSON-RPC endpoint path to handle incoming requests. * @param agentCard An optional [AgentCard] that will be served at the specified [agentCardPath]. @@ -110,7 +112,8 @@ public class HttpJSONRPCServerTransport( * * @see [transportRoutes] */ - public suspend fun start( + public suspend fun start( + engineFactory: ApplicationEngineFactory, port: Int, path: String, agentCard: AgentCard? = null, @@ -118,7 +121,7 @@ public class HttpJSONRPCServerTransport( ): Unit = serverMutex.withLock { check(server == null) { "Server is already configured and running. Stop it before starting a new one." } - server = embeddedServer(Netty, port) { + server = embeddedServer(engineFactory, port) { install(SSE) routing { From 69259af12fb6107399b2b0ad053e14431613a3f9 Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Sat, 20 Sep 2025 22:21:44 +0200 Subject: [PATCH 28/52] [a2a] Add comprehensive documentation for server components --- a2a/CLAUDE.md | 10 + .../ai/koog/a2a/transport/ServerTransport.kt | 44 ++- .../kotlin/ai/koog/a2a/server/A2AServer.kt | 263 +++++++++++++++++- .../ai/koog/a2a/server/session/Session.kt | 25 ++ .../http/HttpJSONRPCServerTransport.kt | 4 +- 5 files changed, 339 insertions(+), 7 deletions(-) diff --git a/a2a/CLAUDE.md b/a2a/CLAUDE.md index 154c7311c0..2f564e43d0 100644 --- a/a2a/CLAUDE.md +++ b/a2a/CLAUDE.md @@ -94,6 +94,16 @@ The A2A (Agent-to-Agent) module is a **meta-module** within the larger Koog proj - Propagate errors properly in async contexts - Include contextual information in exception messages +#### KDoc Documentation +- **Placement**: KDoc directly above declarations (classes, functions, properties) +- **Constructor properties**: Document using `@param` tags in class KDoc +- **Public class properties**: Document using `@property` tags in class KDoc +- **Cross-references**: Use `[ClassName]`, `[ClassName.propertyName]` syntax for linking components +- **Examples**: Include practical code examples for complex APIs +- **Validation rules**: Document constraints with bullet points and clear explanations +- **Exception documentation**: Use `@throws` with specific conditions +- **Required**: All public APIs must have KDoc (enforced by `explicitApi()`) + ## Testing Requirements ### Mandatory Test Execution diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt index 4de35fa21a..416b0cd80f 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ServerTransport.kt @@ -140,6 +140,39 @@ public interface RequestHandler { /** * Represents the server context of a call. * + * This context has [state] associated with it, which is essentially an untyped map. It can be used to store arbitrary + * user-defined data. This is useful for extending the base logic with business-dependent logic, e.g., storing user + * information to authorize particular requests. This untyped [state] map has typed accessors for more convenient access, + * so it is recommended to use them when reading from state: [getFromState], [getFromStateOrNull]. + * + * **Note**: Make sure the types of [StateKey] and the value match when populating [state], otherwise [getFromState] + * and [getFromStateOrNull] will throw [IllegalStateException]. + * + * Example usage: + * ```kotlin + * // User-defined data class + * data class User(val id: String) + * + * // Collection of user-defined state keys + * object StateKeys { + * val USER_KEY = StateKey("42") + * } + * + * // On the handler side - copying supplied context and populating state + * override suspend fun onSendMessage( + * request: Request, + * ctx: ServerCallContext + * ): Response { + * val user = ctx.headers.getValue("user-id").let { User(it) } + * val newCtx = ctx.copy(state = ctx.state + (StateKeys.USER_KEY to user)) + * + * super.onSendMessage(request, newCtx) + * } + * + * // On the business logic side - retrieving user data from context + * val user = ctx.getFromState(StateKeys.USER_KEY) + * ``` + * * @property headers Headers associated with the call. * @property state State associated with the call, allows storing arbitrary values. To get typed value from the state, * use [getFromState] or [getFromStateOrNull] with appropriate [StateKey]. @@ -156,10 +189,9 @@ public class ServerCallContext( * * @param key The state key for which the associated value needs to be retrieved. */ - public fun getFromStateOrNull(key: StateKey): T? { + public inline fun getFromStateOrNull(key: StateKey): T? { return state[key]?.let { - @Suppress("UNCHECKED_CAST") - it as T + it as? T ?: throw IllegalStateException("State value for key $key is not of expected type ${T::class}") } } @@ -171,8 +203,8 @@ public class ServerCallContext( * @param key The state key for which the associated value needs to be retrieved. * @throws NoSuchElementException if the [key] is not found in the state. */ - public fun getFromState(key: StateKey): T { - return getFromStateOrNull(key) ?: throw NoSuchElementException("State key $key not found") + public inline fun getFromState(key: StateKey): T { + return getFromStateOrNull(key) ?: throw NoSuchElementException("State key $key not found or null") } /** @@ -187,6 +219,8 @@ public class ServerCallContext( /** * Helper class to be used with [ServerCallContext.state] to store and retrieve values associated with a key in a typed * manner. + * + * @see ServerCallContext */ public class StateKey<@Suppress("unused") T>(public val name: String) { override fun toString(): String = "${super.toString()}(name=$name)" diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt index 5c04be2341..a59adacbe0 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt @@ -54,7 +54,263 @@ import kotlin.uuid.ExperimentalUuidApi import kotlin.uuid.Uuid /** - * A2A server responsible for handling requests from A2A clients. + * Default implementation of A2A server responsible for handling requests from A2A clients according to the + * [A2A protocol specification](https://a2a-protocol.org/latest/specification/). + * + * This class provides a complete implementation of all A2A protocol methods including message sending, task management, + * and push notifications. However, it **does not** provide any authorization, authentication, or custom validation + * logic. For production use, you should extend this class and add your own security and business logic. + * + * The A2AServer orchestrates the interaction between transport layer, agent executor, and storage components: + * - Receives requests from [RequestHandler] interface methods + * - Delegates actual agent logic to [AgentExecutor] + * - Delegates event processing and persisting to [SessionEventProcessor] + * - Delegates session management to [SessionManager] + * - Handles push notifications using [PushNotificationSender] + * + * ## Production usage with authorization + * + * For production deployments, extend this class to add authorization and custom validation. You can leverage + * [ServerCallContext.state] to pass user-defined data through the request pipeline to the [AgentExecutor]: + * + * ```kotlin + * // Define your user data and state keys + * data class AuthenticatedUser(val id: String, val permissions: Set) + * + * object AuthStateKeys { + * val USER = StateKey("authenticated_user") + * } + * + * // Extend A2AServer with authorization + * class AuthorizedA2AServer( + * agentExecutor: AgentExecutor, + * agentCard: AgentCard, + * agentCardExtended: AgentCard? = null, + * taskStorage: TaskStorage = InMemoryTaskStorage(), + * messageStorage: MessageStorage = InMemoryMessageStorage(), + * private val authService: AuthService, // Your auth service + * ) : A2AServer( + * agentExecutor = agentExecutor, + * agentCard = agentCard, + * agentCardExtended = agentCardExtended, + * taskStorage = taskStorage, + * messageStorage = messageStorage, + * ) { + * // Helper method for common auth pattern + * private suspend fun authenticateAndAuthorize( + * ctx: ServerCallContext, + * requiredPermission: String + * ): AuthenticatedUser { + * val token = ctx.headers["Authorization"]?.firstOrNull() + * ?: throw A2AInvalidParamsException("Missing authorization token") + * + * val user = authService.authenticate(token) + * ?: throw A2AInvalidParamsException("Invalid token") + * + * if (!user.permissions.contains(requiredPermission)) { + * throw A2AUnsupportedOperationException("Insufficient permissions") + * } + * + * return user + * } + * + * override suspend fun onSendMessage( + * request: Request, + * ctx: ServerCallContext + * ): Response { + * val user = authenticateAndAuthorize(ctx, requiredPermission = "send_message") + * + * // Pass user data to the agent executor via context state + * val enrichedCtx = ctx.copy( + * state = ctx.state + (AuthStateKeys.USER to user) + * ) + * + * // Delegate to parent implementation with enriched context + * return super.onSendMessage(request, enrichedCtx) + * } + * + * override suspend fun onGetTask( + * request: Request, + * ctx: ServerCallContext + * ): Response { + * val user = authenticateAndAuthorize(ctx, requiredPermission = "read_task") + * + * // Optionally validate task ownership + * val task = taskStorage.get(request.data.id, historyLength = 0, includeArtifacts = false) + * if (task?.metadata?.get("owner_id") != user.id) { + * throw A2AUnsupportedOperationException("Access denied to task ${request.data.id}") + * } + * + * val enrichedCtx = ctx.copy( + * state = ctx.state + (AuthStateKeys.USER to user) + * ) + * + * return super.onGetTask(request, enrichedCtx) + * } + * } + * ``` + * + * ## Accessing user data in AgentExecutor + * + * The authenticated user data passed through [ServerCallContext.state] can be accessed in your [AgentExecutor]: + * + * ```kotlin + * class MyAgentExecutor : AgentExecutor { + * override suspend fun execute( + * context: RequestContext, + * eventProcessor: SessionEventProcessor + * ) { + * // Retrieve authenticated user from the context + * val user = context.callContext.getFromState(AuthStateKeys.USER) + * + * // Use user information for personalized agent behavior + * eventProcessor.sendMessage( + * Message( + * role = Role.Agent, + * contextId = context.contextId, + * parts = listOf( + * TextPart("Hello ${user.id}, how can I help you today?") + * ) + * ) + * ) + * } + * + * override suspend fun cancel( + * context: RequestContext, + * session: Session + * ) { + * // Access user data for audit logging + * val user = context.callContext.getFromStateOrNull(AuthStateKeys.USER) + * log.info("Task ${context.taskId} canceled by user ${user?.id}") + * + * // Default cancellation behavior + * super.cancel(context, session) + * } + * } + * ``` + * + * ## Complete server setup example + * + * Here's a complete example of setting up and running an A2A server from scratch: + * + * ```kotlin + * // 1. Create your agent executor with business logic + * val agentExecutor = object : AgentExecutor { + * override suspend fun execute( + * context: RequestContext, + * eventProcessor: SessionEventProcessor + * ) { + * val userMessage = context.params.message + * + * // Process the message and create a task + * val task = Task( + * contextId = context.contextId, + * status = TaskStatus( + * state = TaskState.Working, + * // Mark this message as belonging to the created task + * message = message.copy(taskId = task.id) + * timestamp = Clock.System.now() + * ), + * ) + * + * // Send task creation event + * eventProcessor.sendTaskEvent(task) + * + * // Simulate some work + * delay(1000) + * + * // Mark task as completed + * eventProcessor.sendTaskEvent( + * TaskStatusUpdateEvent( + * taskId = task.id, + * contextId = task.contextId, + * status = TaskStatus( + * state = TaskState.Completed, + * message = Message( + * role = Role.Agent, + * contextId = context.contextId, + * taskId = task.id, + * parts = listOf( + * TextPart("Task completed successfully!") + * ) + * ), + * timestamp = Clock.System.now() + * ), + * final = true + * ) + * ) + * } + * } + * + * // 2. Define your agent card describing capabilities + * val agentCard = AgentCard(...) + * + * // 3. Create the A2AServer instance (or your extended version) + * val a2aServer = A2AServer(...) + * + * // 4. Create HTTP JSON-RPC transport + * val transport = HttpJSONRPCServerTransport( + * requestHandler = a2aServer + * ) + * + * // 5. Start the server + * transport.start( + * engineFactory = Netty, + * port = 8080, + * path = "/a2a", + * wait = true, + * agentCard = agentCard, + * agentCardPath = "/.well-known/a2a/agent-card.json" + * ) + * ``` + * + * ## Integration with existing Ktor application + * + * If you have an existing Ktor application, you can integrate the A2A server as a route: + * + * ```kotlin + * val agentCard = AgentCard(...) + * val a2aServer = A2AServer(...) + * val transport = HttpJSONRPCServerTransport(a2aServer) + * + * embeddedServer(Netty, port = 8080) { + * install(SSE) // Required for streaming support + * + * // To serve AgentCard instance + * install(ContentNegotiation) { + * json(Json) + * } + * + * routing { + * // Your existing routes... + * + * // Mount A2A JSON-RPC server transport + * transport.transportRoutes(this, "/a2a") + * + * // Serve agent card + * get("/a2a/agent-card.json") { + * call.respond(agentCard) + * } + * } + * }.start(wait = true) + * ``` + * + * @param agentExecutor The executor containing the core agent logic + * @param agentCard The agent card describing this agent's capabilities and metadata + * @param agentCardExtended Optional extended agent card for authenticated requests + * @param taskStorage Storage implementation for persisting tasks (defaults to in-memory) + * @param messageStorage Storage implementation for persisting messages (defaults to in-memory) + * @param pushConfigStorage Optional storage for push notification configurations + * @param pushSender Optional push notification sender implementation + * @param coroutineScope Scope for managing all sessions, agent jobs, event processing, etc. + * @param clock Clock instance for timestamp generation (defaults to [Clock.System]) + * + * @see AgentExecutor for implementing agent business logic + * @see TaskStorage for persisting tasks + * @see MessageStorage for persisting messages + * @see PushNotificationConfigStorage for persisting push notification configurations + * @see PushNotificationSender for sending push notifications + * @see ServerCallContext for passing custom state through the request pipeline */ public open class A2AServer( protected val agentExecutor: AgentExecutor, @@ -373,6 +629,11 @@ public open class A2AServer( return pushConfigStorage } + /** + * Cancels [coroutineScope] associated with this server, essentially cancelling all running jobs and sessions. + * + * @param cause Optional cause of the cancellation + */ public fun cancel(cause: CancellationException? = null) { coroutineScope.cancel(cause) } diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt index 7766e5f837..eab861892b 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt @@ -7,6 +7,14 @@ import kotlinx.coroutines.Job import kotlinx.coroutines.flow.SharedFlow import kotlinx.coroutines.launch +/** + * Represents an active agent execution session with lifecycle management. + * + * @property eventProcessor Handles session events and provides event streaming + * @property agentJob The coroutine job executing the agent logic + * @property events A stream of events generated during this session + * @property contextId Unique context identifier for this session + */ public class Session( public val eventProcessor: SessionEventProcessor, public val agentJob: Job @@ -14,20 +22,37 @@ public class Session( public val events: SharedFlow get() = eventProcessor.events public val contextId: String get() = eventProcessor.contextId + /** + * Starts the agent execution job. + */ public fun start() { agentJob.start() } + /* + * Suspends until the agent job completes + */ public suspend fun join() { agentJob.join() } + /** + * Cancels the agent job and closes the event processor + */ public fun close() { agentJob.cancel() eventProcessor.close() } } +/** + * Creates a new [Session] with lazy-started agent execution. + * + * @param coroutineScope The scope for launching the agent coroutine + * @param eventProcessor The session event processor + * @param agentAction The agent logic to execute + * @return A new session instance + */ public fun Session( coroutineScope: CoroutineScope, eventProcessor: SessionEventProcessor, diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt index 9b9ceb6eb3..74fc961a43 100644 --- a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt @@ -104,6 +104,7 @@ public class HttpJSONRPCServerTransport( * @param engineFactory An application engine factory. * @param port A port on which the server will listen. * @param path A JSON-RPC endpoint path to handle incoming requests. + * @param wait If true, the server will block until it is stopped. Defaults to false. * @param agentCard An optional [AgentCard] that will be served at the specified [agentCardPath]. * @param agentCardPath The path at which the [agentCard] will be served, if specified. * Defaults to [A2AConsts.AGENT_CARD_WELL_KNOWN_PATH]. @@ -116,6 +117,7 @@ public class HttpJSONRPCServerTransport( engineFactory: ApplicationEngineFactory, port: Int, path: String, + wait: Boolean = false, agentCard: AgentCard? = null, agentCardPath: String = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH, ): Unit = serverMutex.withLock { @@ -133,7 +135,7 @@ public class HttpJSONRPCServerTransport( } } } - }.startSuspend(wait = false) + }.startSuspend(wait = wait) } /** From 51f85c779d751df712f2f868fd5208ab1c420830 Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Mon, 29 Sep 2025 05:21:50 +0200 Subject: [PATCH 29/52] [a2a] Introduce a2a-test module, create BaseA2AProtocolTest --- a2a/a2a-client/build.gradle.kts | 1 + .../a2a/client/A2AClientIntegrationTest.kt | 420 +---------------- a2a/a2a-test/build.gradle.kts | 34 ++ .../ai/koog/a2a/test/BaseA2AProtocolTest.kt | 436 ++++++++++++++++++ gradle/libs.versions.toml | 1 - koog-agents/build.gradle.kts | 1 + settings.gradle.kts | 1 + 7 files changed, 476 insertions(+), 418 deletions(-) create mode 100644 a2a/a2a-test/build.gradle.kts create mode 100644 a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt diff --git a/a2a/a2a-client/build.gradle.kts b/a2a/a2a-client/build.gradle.kts index 217c0385e3..818b07d2cd 100644 --- a/a2a/a2a-client/build.gradle.kts +++ b/a2a/a2a-client/build.gradle.kts @@ -38,6 +38,7 @@ kotlin { jvmTest { dependencies { implementation(kotlin("test-junit5")) + implementation(project(":a2a:a2a-test")) implementation(project(":a2a:a2a-transport:a2a-transport-client-jsonrpc-http")) implementation(libs.ktor.client.cio) diff --git a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientIntegrationTest.kt b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientIntegrationTest.kt index fe37a3aa2a..214ee2f865 100644 --- a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientIntegrationTest.kt +++ b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientIntegrationTest.kt @@ -1,49 +1,19 @@ package ai.koog.a2a.client -import ai.koog.a2a.exceptions.A2AInternalErrorException -import ai.koog.a2a.model.AgentCapabilities -import ai.koog.a2a.model.AgentCard -import ai.koog.a2a.model.AgentSkill -import ai.koog.a2a.model.Message -import ai.koog.a2a.model.MessageSendConfiguration -import ai.koog.a2a.model.MessageSendParams -import ai.koog.a2a.model.PushNotificationAuthenticationInfo -import ai.koog.a2a.model.PushNotificationConfig -import ai.koog.a2a.model.Role -import ai.koog.a2a.model.Task -import ai.koog.a2a.model.TaskIdParams -import ai.koog.a2a.model.TaskPushNotificationConfig -import ai.koog.a2a.model.TaskPushNotificationConfigParams -import ai.koog.a2a.model.TaskQueryParams -import ai.koog.a2a.model.TaskState -import ai.koog.a2a.model.TaskStatusUpdateEvent -import ai.koog.a2a.model.TextPart -import ai.koog.a2a.model.TransportProtocol -import ai.koog.a2a.transport.Request +import ai.koog.a2a.test.BaseA2AProtocolTest import ai.koog.a2a.transport.client.jsonrpc.http.HttpJSONRPCClientTransport -import io.kotest.assertions.throwables.shouldThrowExactly -import io.kotest.inspectors.shouldForAll -import io.kotest.matchers.collections.shouldHaveSize -import io.kotest.matchers.collections.shouldNotBeEmpty -import io.kotest.matchers.nulls.shouldNotBeNull -import io.kotest.matchers.should -import io.kotest.matchers.shouldBe -import io.kotest.matchers.string.shouldStartWith -import io.kotest.matchers.types.shouldBeInstanceOf import io.ktor.client.HttpClient import io.ktor.client.plugins.logging.LogLevel import io.ktor.client.plugins.logging.Logging -import kotlinx.coroutines.flow.toList import kotlinx.coroutines.test.runTest import org.testcontainers.containers.GenericContainer import org.testcontainers.containers.wait.strategy.Wait import org.testcontainers.junit.jupiter.Container import org.testcontainers.junit.jupiter.Testcontainers import kotlin.test.BeforeTest -import kotlin.test.Test @Testcontainers -class A2AClientIntegrationTest { +class A2AClientIntegrationTest : BaseA2AProtocolTest() { companion object { @Container val testA2AServer: GenericContainer<*> = @@ -68,7 +38,7 @@ class A2AClientIntegrationTest { ) } - private val client by lazy { + override val client by lazy { A2AClient( transport = transport, agentCardResolver = UrlAgentCardResolver( @@ -82,388 +52,4 @@ class A2AClientIntegrationTest { fun initClient() = runTest { client.connect() } - - @Test - fun `test get agent card`() = runTest { - val agentCard = client.getAgentCard() - - // Assert on the full AgentCard structure - val expectedAgentCard = AgentCard( - protocolVersion = "0.3.0", - name = "Hello World Agent", - description = "Just a hello world agent", - url = "http://localhost:9999/", - preferredTransport = TransportProtocol.JSONRPC, - additionalInterfaces = null, - iconUrl = null, - provider = null, - version = "1.0.0", - documentationUrl = null, - capabilities = AgentCapabilities( - streaming = true, - pushNotifications = true, - stateTransitionHistory = null, - extensions = null - ), - securitySchemes = null, - security = null, - defaultInputModes = listOf("text"), - defaultOutputModes = listOf("text"), - skills = listOf( - AgentSkill( - id = "hello_world", - name = "Returns hello world", - description = "just returns hello world", - tags = listOf("hello world"), - examples = listOf("hi", "hello world"), - inputModes = null, - outputModes = null, - security = null - ) - ), - supportsAuthenticatedExtendedCard = true, - signatures = null - ) - - agentCard shouldBe expectedAgentCard - } - - @Test - fun `test get authenticated extended agent card`() = runTest { - val request = Request(data = null) - - val response = client.getAuthenticatedExtendedAgentCard(request) - - // Assert on the extended agent card structure - val expectedExtendedAgentCard = AgentCard( - protocolVersion = "0.3.0", - name = "Hello World Agent - Extended Edition", - description = "The full-featured hello world agent for authenticated users.", - url = "http://localhost:9999/", - preferredTransport = TransportProtocol.JSONRPC, - additionalInterfaces = null, - iconUrl = null, - provider = null, - version = "1.0.1", - documentationUrl = null, - capabilities = AgentCapabilities( - streaming = true, - pushNotifications = true, - stateTransitionHistory = null, - extensions = null - ), - securitySchemes = null, - security = null, - defaultInputModes = listOf("text"), - defaultOutputModes = listOf("text"), - skills = listOf( - AgentSkill( - id = "hello_world", - name = "Returns hello world", - description = "just returns hello world", - tags = listOf("hello world"), - examples = listOf("hi", "hello world"), - inputModes = null, - outputModes = null, - security = null - ), - AgentSkill( - id = "super_hello_world", - name = "Returns a SUPER Hello World", - description = "A more enthusiastic greeting, only for authenticated users.", - tags = listOf("hello world", "super", "extended"), - examples = listOf("super hi", "give me a super hello"), - inputModes = null, - outputModes = null, - security = null - ) - ), - supportsAuthenticatedExtendedCard = true, - signatures = null - ) - - response.data shouldBe expectedExtendedAgentCard - } - - @Test - fun `test send message`() = runTest { - val request = Request( - data = MessageSendParams( - message = Message( - role = Role.User, - parts = listOf( - TextPart("hello world"), - TextPart("How are you doing?"), - ), - contextId = "test-context" - ), - ) - ) - - val response = client.sendMessage(request) - - response should { - it.id shouldBe request.id - - it.data.shouldBeInstanceOf { - it.role shouldBe Role.Agent - it.parts shouldBe listOf(TextPart("Hello World")) - it.contextId shouldBe "test-context" - } - } - } - - @Test - fun `test send message streaming`() = runTest { - val createTaskRequest = Request( - data = MessageSendParams( - message = Message( - role = Role.User, - parts = listOf( - TextPart("do task"), - ), - contextId = "test-context" - ), - ), - ) - - val events = client - .sendMessageStreaming(createTaskRequest) - .toList() - .map { it.data } - - events shouldHaveSize 3 - events[0].shouldBeInstanceOf { - it.contextId shouldBe "test-context" - it.status should { - it.state shouldBe TaskState.Submitted - } - - it.history shouldNotBeNull { - this shouldHaveSize 1 - - this[0] should { - it.role shouldBe Role.User - it.parts shouldBe listOf(TextPart("do task")) - } - } - } - - events[1].shouldBeInstanceOf { - it.contextId shouldBe "test-context" - - it.status should { - it.state shouldBe TaskState.Working - it.message shouldNotBeNull { - role shouldBe Role.Agent - parts shouldBe listOf(TextPart("Working on task")) - } - } - } - - events[2].shouldBeInstanceOf { - it.contextId shouldBe "test-context" - - it.status should { - it.state shouldBe TaskState.Completed - it.message shouldNotBeNull { - role shouldBe Role.Agent - parts shouldBe listOf(TextPart("Task completed")) - } - } - } - } - - @Test - fun `test get task`() = runTest { - val createTaskRequest = Request( - data = MessageSendParams( - message = Message( - role = Role.User, - parts = listOf( - TextPart("do task"), - ), - contextId = "test-context" - ), - ), - ) - - val taskId = (client.sendMessage(createTaskRequest).data as Task).id - - val getTaskRequest = Request( - data = TaskQueryParams( - id = taskId, - historyLength = 1 - ) - ) - - val response = client.getTask(getTaskRequest) - - response.data should { - it.id shouldBe taskId - it.contextId shouldBe "test-context" - it.status should { - it.state shouldBe TaskState.Completed - } - } - } - - @Test - fun `test cancel task`() = runTest { - val createTaskRequest = Request( - data = MessageSendParams( - message = Message( - role = Role.User, - parts = listOf( - TextPart("do cancelable task"), - ), - contextId = "test-context" - ), - ), - ) - - val taskId = (client.sendMessage(createTaskRequest).data as Task).id - - val cancelTaskRequest = Request( - data = TaskIdParams( - id = taskId, - ) - ) - - val response = client.cancelTask(cancelTaskRequest) - - response.data should { - it.id shouldBe taskId - it.contextId shouldBe "test-context" - it.status should { - it.state shouldBe TaskState.Canceled - it.message shouldNotBeNull { - role shouldBe Role.Agent - parts shouldBe listOf(TextPart("Task canceled")) - } - } - } - } - - @Test - fun `test resubscribe task`() = runTest { - val createTaskRequest = Request( - data = MessageSendParams( - message = Message( - role = Role.User, - parts = listOf( - TextPart("do long-running task"), - ), - contextId = "test-context" - ), - configuration = MessageSendConfiguration( - blocking = false - ) - ), - ) - - val taskId = (client.sendMessage(createTaskRequest).data as Task).id - - val resubscribeTaskRequest = Request( - data = TaskIdParams( - id = taskId, - ) - ) - - val events = client - .resubscribeTask(resubscribeTaskRequest) - .toList() - .map { it.data } - - events.shouldNotBeEmpty() - - events.shouldForAll { - it.shouldBeInstanceOf { - it.taskId shouldBe taskId - it.contextId shouldBe "test-context" - - it.status should { - it.state shouldBe TaskState.Working - it.message shouldNotBeNull { - role shouldBe Role.Agent - - parts.shouldForAll { - it.shouldBeInstanceOf { - it.text shouldStartWith "Still working" - } - } - } - } - } - } - } - - @Test - fun `test push notification configs`() = runTest { - val createTaskRequest = Request( - data = MessageSendParams( - message = Message( - role = Role.User, - parts = listOf( - TextPart("do long-running task"), - ), - contextId = "test-context" - ), - ), - ) - - val taskId = (client.sendMessage(createTaskRequest).data as Task).id - - val pushConfig = TaskPushNotificationConfig( - taskId = taskId, - pushNotificationConfig = PushNotificationConfig( - id = "push-id", - url = "https://localhost:3000", - token = "push-token", - authentication = PushNotificationAuthenticationInfo( - schemes = listOf("bearer"), - credentials = "very-secret-credential" - ) - ) - ) - - val request = Request( - data = pushConfig - ) - - val setPushConfigResponse = client.setTaskPushNotificationConfig(request) - setPushConfigResponse.data shouldBe pushConfig - - val getPushConfigRequest = Request( - data = TaskPushNotificationConfigParams( - id = taskId, - pushNotificationConfigId = pushConfig.pushNotificationConfig.id, - ) - ) - - val getPushConfigResponse = client.getTaskPushNotificationConfig(getPushConfigRequest) - getPushConfigResponse.data shouldBe pushConfig - - val listPushConfigRequest = Request( - data = TaskIdParams( - id = taskId, - ) - ) - - val listPushConfigResponse = client.listTaskPushNotificationConfig(listPushConfigRequest) - listPushConfigResponse.data shouldBe listOf(pushConfig) - - val deletePushConfigRequest = Request( - data = TaskPushNotificationConfigParams( - id = taskId, - pushNotificationConfigId = pushConfig.pushNotificationConfig.id, - ) - ) - - client.deleteTaskPushNotificationConfig(deletePushConfigRequest) - - shouldThrowExactly { - client.getTaskPushNotificationConfig(getPushConfigRequest) - } - } } diff --git a/a2a/a2a-test/build.gradle.kts b/a2a/a2a-test/build.gradle.kts new file mode 100644 index 0000000000..56670dc29f --- /dev/null +++ b/a2a/a2a-test/build.gradle.kts @@ -0,0 +1,34 @@ +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + sourceSets { + commonMain { + dependencies { + api(project(":a2a:a2a-core")) + api(project(":a2a:a2a-client")) + api(kotlin("test")) + api(kotlin("test-annotations-common")) + api(libs.kotest.assertions) + api(libs.kotlinx.coroutines.core) + api(libs.kotlinx.coroutines.test) + api(libs.kotlinx.serialization.json) + implementation(libs.kotlinx.coroutines.test) + } + } + + jvmMain { + dependencies { + api(kotlin("test-junit5")) + } + } + + jsMain { + dependencies { + api(kotlin("test-js")) + } + } + } +} diff --git a/a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt b/a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt new file mode 100644 index 0000000000..b9c60da7c9 --- /dev/null +++ b/a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt @@ -0,0 +1,436 @@ +package ai.koog.a2a.test + +import ai.koog.a2a.client.A2AClient +import ai.koog.a2a.exceptions.A2AInternalErrorException +import ai.koog.a2a.model.AgentCapabilities +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.AgentSkill +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendConfiguration +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.PushNotificationAuthenticationInfo +import ai.koog.a2a.model.PushNotificationConfig +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskPushNotificationConfig +import ai.koog.a2a.model.TaskPushNotificationConfigParams +import ai.koog.a2a.model.TaskQueryParams +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.model.TransportProtocol +import ai.koog.a2a.transport.Request +import io.kotest.assertions.throwables.shouldThrowExactly +import io.kotest.inspectors.shouldForAll +import io.kotest.matchers.collections.shouldHaveSize +import io.kotest.matchers.collections.shouldNotBeEmpty +import io.kotest.matchers.nulls.shouldNotBeNull +import io.kotest.matchers.should +import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldStartWith +import io.kotest.matchers.types.shouldBeInstanceOf +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.test.runTest +import kotlin.test.Test + +/** + * Abstract base class containing transport-agnostic A2A protocol compliance tests. + * + * Concrete test classes should inherit from this class and provide the [client] property + * to run the same test suite against different A2A implementations. + * + * @property client The A2A client instance to test against. Should be connected and ready to use. + */ +@Suppress("FunctionName") +abstract class BaseA2AProtocolTest { + + /** + * The A2A client instance to test. Must be connected and ready to use. + */ + protected abstract val client: A2AClient + + @Test + fun `test get agent card`() = runTest { + val agentCard = client.getAgentCard() + + // Assert on the full AgentCard structure + val expectedAgentCard = AgentCard( + protocolVersion = "0.3.0", + name = "Hello World Agent", + description = "Just a hello world agent", + url = "http://localhost:9999/", + preferredTransport = TransportProtocol.JSONRPC, + additionalInterfaces = null, + iconUrl = null, + provider = null, + version = "1.0.0", + documentationUrl = null, + capabilities = AgentCapabilities( + streaming = true, + pushNotifications = true, + stateTransitionHistory = null, + extensions = null + ), + securitySchemes = null, + security = null, + defaultInputModes = listOf("text"), + defaultOutputModes = listOf("text"), + skills = listOf( + AgentSkill( + id = "hello_world", + name = "Returns hello world", + description = "just returns hello world", + tags = listOf("hello world"), + examples = listOf("hi", "hello world"), + inputModes = null, + outputModes = null, + security = null + ) + ), + supportsAuthenticatedExtendedCard = true, + signatures = null + ) + + agentCard shouldBe expectedAgentCard + } + + @Test + fun `test get authenticated extended agent card`() = runTest { + val request = Request(data = null) + + val response = client.getAuthenticatedExtendedAgentCard(request) + + // Assert on the extended agent card structure + val expectedExtendedAgentCard = AgentCard( + protocolVersion = "0.3.0", + name = "Hello World Agent - Extended Edition", + description = "The full-featured hello world agent for authenticated users.", + url = "http://localhost:9999/", + preferredTransport = TransportProtocol.JSONRPC, + additionalInterfaces = null, + iconUrl = null, + provider = null, + version = "1.0.1", + documentationUrl = null, + capabilities = AgentCapabilities( + streaming = true, + pushNotifications = true, + stateTransitionHistory = null, + extensions = null + ), + securitySchemes = null, + security = null, + defaultInputModes = listOf("text"), + defaultOutputModes = listOf("text"), + skills = listOf( + AgentSkill( + id = "hello_world", + name = "Returns hello world", + description = "just returns hello world", + tags = listOf("hello world"), + examples = listOf("hi", "hello world"), + inputModes = null, + outputModes = null, + security = null + ), + AgentSkill( + id = "super_hello_world", + name = "Returns a SUPER Hello World", + description = "A more enthusiastic greeting, only for authenticated users.", + tags = listOf("hello world", "super", "extended"), + examples = listOf("super hi", "give me a super hello"), + inputModes = null, + outputModes = null, + security = null + ) + ), + supportsAuthenticatedExtendedCard = true, + signatures = null + ) + + response.data shouldBe expectedExtendedAgentCard + } + + @Test + fun `test send message`() = runTest { + val request = Request( + data = MessageSendParams( + message = Message( + role = Role.User, + parts = listOf( + TextPart("hello world"), + TextPart("How are you doing?"), + ), + contextId = "test-context" + ), + ) + ) + + val response = client.sendMessage(request) + + response should { + it.id shouldBe request.id + + it.data.shouldBeInstanceOf { + it.role shouldBe Role.Agent + it.parts shouldBe listOf(TextPart("Hello World")) + it.contextId shouldBe "test-context" + } + } + } + + @Test + fun `test send message streaming`() = runTest { + val createTaskRequest = Request( + data = MessageSendParams( + message = Message( + role = Role.User, + parts = listOf( + TextPart("do task"), + ), + contextId = "test-context" + ), + ), + ) + + val events = client + .sendMessageStreaming(createTaskRequest) + .toList() + .map { it.data } + + events shouldHaveSize 3 + events[0].shouldBeInstanceOf { + it.contextId shouldBe "test-context" + it.status should { + it.state shouldBe TaskState.Submitted + } + + it.history shouldNotBeNull { + this shouldHaveSize 1 + + this[0] should { + it.role shouldBe Role.User + it.parts shouldBe listOf(TextPart("do task")) + } + } + } + + events[1].shouldBeInstanceOf { + it.contextId shouldBe "test-context" + + it.status should { + it.state shouldBe TaskState.Working + it.message shouldNotBeNull { + role shouldBe Role.Agent + parts shouldBe listOf(TextPart("Working on task")) + } + } + } + + events[2].shouldBeInstanceOf { + it.contextId shouldBe "test-context" + + it.status should { + it.state shouldBe TaskState.Completed + it.message shouldNotBeNull { + role shouldBe Role.Agent + parts shouldBe listOf(TextPart("Task completed")) + } + } + } + } + + @Test + fun `test get task`() = runTest { + val createTaskRequest = Request( + data = MessageSendParams( + message = Message( + role = Role.User, + parts = listOf( + TextPart("do task"), + ), + contextId = "test-context" + ), + ), + ) + + val taskId = (client.sendMessage(createTaskRequest).data as Task).id + + val getTaskRequest = Request( + data = TaskQueryParams( + id = taskId, + historyLength = 1 + ) + ) + + val response = client.getTask(getTaskRequest) + + response.data should { + it.id shouldBe taskId + it.contextId shouldBe "test-context" + it.status should { + it.state shouldBe TaskState.Completed + } + } + } + + @Test + fun `test cancel task`() = runTest { + val createTaskRequest = Request( + data = MessageSendParams( + message = Message( + role = Role.User, + parts = listOf( + TextPart("do cancelable task"), + ), + contextId = "test-context" + ), + ), + ) + + val taskId = (client.sendMessage(createTaskRequest).data as Task).id + + val cancelTaskRequest = Request( + data = TaskIdParams( + id = taskId, + ) + ) + + val response = client.cancelTask(cancelTaskRequest) + + response.data should { + it.id shouldBe taskId + it.contextId shouldBe "test-context" + it.status should { + it.state shouldBe TaskState.Canceled + it.message shouldNotBeNull { + role shouldBe Role.Agent + parts shouldBe listOf(TextPart("Task canceled")) + } + } + } + } + + @Test + fun `test resubscribe task`() = runTest { + val createTaskRequest = Request( + data = MessageSendParams( + message = Message( + role = Role.User, + parts = listOf( + TextPart("do long-running task"), + ), + contextId = "test-context" + ), + configuration = MessageSendConfiguration( + blocking = false + ) + ), + ) + + val taskId = (client.sendMessage(createTaskRequest).data as Task).id + + val resubscribeTaskRequest = Request( + data = TaskIdParams( + id = taskId, + ) + ) + + val events = client + .resubscribeTask(resubscribeTaskRequest) + .toList() + .map { it.data } + + events.shouldNotBeEmpty() + + events.shouldForAll { + it.shouldBeInstanceOf { + it.taskId shouldBe taskId + it.contextId shouldBe "test-context" + + it.status should { + it.state shouldBe TaskState.Working + it.message shouldNotBeNull { + role shouldBe Role.Agent + + parts.shouldForAll { + it.shouldBeInstanceOf { + it.text shouldStartWith "Still working" + } + } + } + } + } + } + } + + @Test + fun `test push notification configs`() = runTest { + val createTaskRequest = Request( + data = MessageSendParams( + message = Message( + role = Role.User, + parts = listOf( + TextPart("do long-running task"), + ), + contextId = "test-context" + ), + ), + ) + + val taskId = (client.sendMessage(createTaskRequest).data as Task).id + + val pushConfig = TaskPushNotificationConfig( + taskId = taskId, + pushNotificationConfig = PushNotificationConfig( + id = "push-id", + url = "https://localhost:3000", + token = "push-token", + authentication = PushNotificationAuthenticationInfo( + schemes = listOf("bearer"), + credentials = "very-secret-credential" + ) + ) + ) + + val request = Request( + data = pushConfig + ) + + val setPushConfigResponse = client.setTaskPushNotificationConfig(request) + setPushConfigResponse.data shouldBe pushConfig + + val getPushConfigRequest = Request( + data = TaskPushNotificationConfigParams( + id = taskId, + pushNotificationConfigId = pushConfig.pushNotificationConfig.id, + ) + ) + + val getPushConfigResponse = client.getTaskPushNotificationConfig(getPushConfigRequest) + getPushConfigResponse.data shouldBe pushConfig + + val listPushConfigRequest = Request( + data = TaskIdParams( + id = taskId, + ) + ) + + val listPushConfigResponse = client.listTaskPushNotificationConfig(listPushConfigRequest) + listPushConfigResponse.data shouldBe listOf(pushConfig) + + val deletePushConfigRequest = Request( + data = TaskPushNotificationConfigParams( + id = taskId, + pushNotificationConfigId = pushConfig.pushNotificationConfig.id, + ) + ) + + client.deleteTaskPushNotificationConfig(deletePushConfigRequest) + + shouldThrowExactly { + client.getTaskPushNotificationConfig(getPushConfigRequest) + } + } +} diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 59517d9ebf..68c59d1e5d 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -36,7 +36,6 @@ spring-management = "1.1.7" sqlite = "3.46.1.3" testcontainers = "1.19.7" mokksy = "0.5.0-Alpha3" -kotest = "6.0.3" [libraries] jetbrains-annotations = { module = "org.jetbrains:annotations", version.ref = "annotations" } diff --git a/koog-agents/build.gradle.kts b/koog-agents/build.gradle.kts index 08469c82b9..54f589d063 100644 --- a/koog-agents/build.gradle.kts +++ b/koog-agents/build.gradle.kts @@ -11,6 +11,7 @@ val excluded = setOf( ":agents:agents-test", ":agents:agents-ext", ":agents:agents-features:agents-features-sql", // Optional SQL persistence provider + ":a2a:a2a-test", // Testing utilities for A2A protocol compliance ":agents:agents-mcp-server", ":integration-tests", ":test-utils", diff --git a/settings.gradle.kts b/settings.gradle.kts index 69008fce6b..de8bdfe3a6 100755 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -64,6 +64,7 @@ include(":rag:vector-storage") include(":a2a:a2a-core") include(":a2a:a2a-server") include(":a2a:a2a-client") +include(":a2a:a2a-test") include(":a2a:a2a-transport:a2a-transport-core-jsonrpc") include(":a2a:a2a-transport:a2a-transport-server-jsonrpc-http") include(":a2a:a2a-transport:a2a-transport-client-jsonrpc-http") From cef7465eac28147431471907da2f657998a97a10 Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Sun, 21 Sep 2025 22:01:00 +0200 Subject: [PATCH 30/52] [a2a] Fix SessionEventProcessor concurrency and event management --- .gitignore | 1 + a2a/CLAUDE.md | 2 +- a2a/a2a-client/build.gradle.kts | 6 - a2a/a2a-server/build.gradle.kts | 6 - .../kotlin/ai/koog/a2a/server/A2AServer.kt | 81 ++-- .../ai/koog/a2a/server/agent/AgentExecutor.kt | 15 +- .../koog/a2a/server/exceptions/Exceptions.kt | 6 + .../ai/koog/a2a/server/session/IdGenerator.kt | 57 +++ .../koog/a2a/server/session/RequestContext.kt | 22 +- .../ai/koog/a2a/server/session/Session.kt | 16 +- .../server/session/SessionEventProcessor.kt | 106 +++-- .../koog/a2a/server/session/SessionManager.kt | 84 ++-- .../ai/koog/a2a/server/tasks/TaskStorage.kt | 7 + .../session/SessionEventProcessorTest.kt | 441 ++++++++++++++++++ .../a2a/server/session/SessionManagerTest.kt | 257 ++++++++++ .../build.gradle.kts | 6 - .../build.gradle.kts | 6 - .../build.gradle.kts | 6 - .../a2a-transport-core-rest/build.gradle.kts | 6 - .../build.gradle.kts | 6 - .../build.gradle.kts | 6 - 21 files changed, 946 insertions(+), 197 deletions(-) create mode 100644 a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/IdGenerator.kt create mode 100644 a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionEventProcessorTest.kt create mode 100644 a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionManagerTest.kt diff --git a/.gitignore b/.gitignore index da43f86bfa..5053c8f2fe 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ docs/src/main/kotlin/*.kt **/.env .venv .DS_Store +**/kotlin-js-store diff --git a/a2a/CLAUDE.md b/a2a/CLAUDE.md index 2f564e43d0..e5b76f949c 100644 --- a/a2a/CLAUDE.md +++ b/a2a/CLAUDE.md @@ -40,7 +40,7 @@ The A2A (Agent-to-Agent) module is a **meta-module** within the larger Koog proj ## Technologies & Libraries ### Core Dependencies (from gradle/libs.versions.toml) -- **Kotlin Multiplatform**: JVM + JS (IR) support +- **Kotlin Multiplatform** - **kotlinx-serialization**: JSON serialization for protocol messages - **kotlinx-coroutines**: Async/concurrent programming, Flow APIs - **kotlinx-datetime**: Timestamp handling in protocol messages diff --git a/a2a/a2a-client/build.gradle.kts b/a2a/a2a-client/build.gradle.kts index 818b07d2cd..a1d4d9e660 100644 --- a/a2a/a2a-client/build.gradle.kts +++ b/a2a/a2a-client/build.gradle.kts @@ -9,12 +9,6 @@ plugins { } kotlin { - jvm() - - js(IR) { - browser() - } - sourceSets { commonMain { dependencies { diff --git a/a2a/a2a-server/build.gradle.kts b/a2a/a2a-server/build.gradle.kts index 5d156ddd6d..1c94f80da0 100644 --- a/a2a/a2a-server/build.gradle.kts +++ b/a2a/a2a-server/build.gradle.kts @@ -9,12 +9,6 @@ plugins { } kotlin { - jvm() - - js(IR) { - browser() - } - sourceSets { commonMain { dependencies { diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt index a59adacbe0..d460c16401 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt @@ -26,10 +26,12 @@ import ai.koog.a2a.server.messages.InMemoryMessageStorage import ai.koog.a2a.server.messages.MessageStorage import ai.koog.a2a.server.notifications.PushNotificationConfigStorage import ai.koog.a2a.server.notifications.PushNotificationSender +import ai.koog.a2a.server.session.IdGenerator import ai.koog.a2a.server.session.RequestContext import ai.koog.a2a.server.session.Session import ai.koog.a2a.server.session.SessionEventProcessor import ai.koog.a2a.server.session.SessionManager +import ai.koog.a2a.server.session.UuidIdGenerator import ai.koog.a2a.server.tasks.ContextTaskStorage import ai.koog.a2a.server.tasks.InMemoryTaskStorage import ai.koog.a2a.server.tasks.TaskStorage @@ -50,8 +52,6 @@ import kotlinx.coroutines.flow.last import kotlinx.coroutines.flow.map import kotlinx.coroutines.launch import kotlinx.datetime.Clock -import kotlin.uuid.ExperimentalUuidApi -import kotlin.uuid.Uuid /** * Default implementation of A2A server responsible for handling requests from A2A clients according to the @@ -203,18 +203,19 @@ import kotlin.uuid.Uuid * val userMessage = context.params.message * * // Process the message and create a task - * val task = Task( - * contextId = context.contextId, - * status = TaskStatus( - * state = TaskState.Working, - * // Mark this message as belonging to the created task - * message = message.copy(taskId = task.id) - * timestamp = Clock.System.now() - * ), - * ) - * * // Send task creation event - * eventProcessor.sendTaskEvent(task) + * eventProcessor.sendTaskEvent( + * Task( + * id = context.taskId, + * contextId = context.contextId, + * status = TaskStatus( + * state = TaskState.Working, + * // Mark this message as belonging to the created task + * message = message.copy(taskId = context.taskId) + * timestamp = Clock.System.now() + * ), + * ) + * ) * * // Simulate some work * delay(1000) @@ -222,14 +223,14 @@ import kotlin.uuid.Uuid * // Mark task as completed * eventProcessor.sendTaskEvent( * TaskStatusUpdateEvent( - * taskId = task.id, - * contextId = task.contextId, + * taskId = context.taskId, + * contextId = context.contextId, * status = TaskStatus( * state = TaskState.Completed, * message = Message( * role = Role.Agent, * contextId = context.contextId, - * taskId = task.id, + * taskId = context.taskId, * parts = listOf( * TextPart("Task completed successfully!") * ) @@ -302,6 +303,7 @@ import kotlin.uuid.Uuid * @param messageStorage Storage implementation for persisting messages (defaults to in-memory) * @param pushConfigStorage Optional storage for push notification configurations * @param pushSender Optional push notification sender implementation + * @param idGenerator Generator for new task and context IDs (defaults to UUID) * @param coroutineScope Scope for managing all sessions, agent jobs, event processing, etc. * @param clock Clock instance for timestamp generation (defaults to [Clock.System]) * @@ -320,6 +322,7 @@ public open class A2AServer( protected val messageStorage: MessageStorage = InMemoryMessageStorage(), protected val pushConfigStorage: PushNotificationConfigStorage? = null, protected val pushSender: PushNotificationSender? = null, + protected val idGenerator: IdGenerator = UuidIdGenerator, protected val coroutineScope: CoroutineScope = CoroutineScope(SupervisorJob()), protected val clock: Clock = Clock.System, ) : RequestHandler { @@ -357,10 +360,9 @@ public open class A2AServer( ctx: ServerCallContext ): Flow> = channelFlow { val message = request.data.message - val taskId = message.taskId // Check if message links to a task. - val eventProcessor = if (taskId != null) { + val task: Task? = message.taskId?.let { taskId -> // Check if the task is still in progress, no message can be sent. if (sessionManager.sessionForTask(taskId) != null) { throw A2AUnsupportedOperationException("Task '$taskId' is still running, can't send messages to the task that has not yielded control") @@ -374,31 +376,28 @@ public open class A2AServer( throw A2AInvalidParamsException("Message context id '${message.contextId}' doesn't match task context id '${task.contextId}'") } - // Create new event processor for the task. - SessionEventProcessor( - contextId = task.contextId, - taskStorage = taskStorage, - coroutineScope = coroutineScope, - currentTask = task - ) - } else { - // Create new event processor without task specified. - @OptIn(ExperimentalUuidApi::class) - SessionEventProcessor( - contextId = message.contextId ?: Uuid.random().toString(), - taskStorage = taskStorage, - // Use specified context id or generate a new random one. - coroutineScope = coroutineScope, - ) + task } + // Create event processor for the session based on the input data. + val eventProcessor = SessionEventProcessor( + contextId = task?.contextId + ?: message.contextId + ?: idGenerator.generateContextId(message), + taskId = task?.id ?: idGenerator.generateTaskId(message), + taskStorage = taskStorage, + task = null, + ) + // Create request context based on the request information. val requestContext = RequestContext( - contextId = eventProcessor.contextId, callContext = ctx, params = request.data, taskStorage = ContextTaskStorage(eventProcessor.contextId, taskStorage), messageStorage = ContextMessageStorage(eventProcessor.contextId, messageStorage), + contextId = eventProcessor.contextId, + taskId = eventProcessor.taskId, + task = task, ) // Create agent execution session @@ -483,13 +482,13 @@ public open class A2AServer( ctx: ServerCallContext ): Response { val taskParams = request.data + val session = sessionManager.sessionForTask(taskParams.id) + val task = taskStorage.get(taskParams.id, historyLength = 0, includeArtifacts = true) + ?: throw A2ATaskNotFoundException("Task '${taskParams.id}' not found") // Task is not running, check if it exists in the storage. if (session == null) { - val task = taskStorage.get(taskParams.id, historyLength = 0, includeArtifacts = true) - ?: throw A2ATaskNotFoundException("Task '${taskParams.id}' not found") - // Task exists but not running - check if it is already canceled. if (task.status.state == TaskState.Canceled) { return Response(data = task, id = request.id) @@ -515,11 +514,13 @@ public open class A2AServer( } else { // Create request context based on the request information. val requestContext = RequestContext( - contextId = taskParams.id, callContext = ctx, params = request.data, taskStorage = ContextTaskStorage(session.contextId, taskStorage), messageStorage = ContextMessageStorage(session.contextId, messageStorage), + contextId = session.contextId, + taskId = session.taskId, + task = task, ) // Attempt to cancel the agent execution and wait until it's finished. @@ -545,7 +546,7 @@ public open class A2AServer( val taskParams = request.data val session = sessionManager.sessionForTask(taskParams.id) - ?: throw A2AUnsupportedOperationException("Task '${taskParams.id}' is not currently running or does not exist") + ?: throw A2AUnsupportedOperationException("Session for task '${taskParams.id}' is not currently running or task does not exist") emitAll( session.events diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt index 18b66ef4f9..16cd41536e 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt @@ -11,7 +11,6 @@ import ai.koog.a2a.model.TaskState import ai.koog.a2a.server.session.RequestContext import ai.koog.a2a.server.session.Session import ai.koog.a2a.server.session.SessionEventProcessor -import kotlin.jvm.JvmName /** * Implementations of this interface contain the core logic of the agent, @@ -25,6 +24,8 @@ public interface AgentExecutor { * the [eventProcessor]. This method should return once the agent's execution for this request is complete or * yields control (e.g., enters an [TaskState.InputRequired] state). * + * All events must have context id from [RequestContext.contextId] and for task events task id from [RequestContext.taskId]. + * * Can throw an exception if the input is invalid or the agent fails to execute the request. * * @param context The context containing the necessary information and accessors for executing the agent. @@ -94,15 +95,3 @@ public interface AgentExecutor { */ public suspend fun cancel(context: RequestContext, session: Session) {} } - -/** - * Returns the task id from the [MessageSendParams] in the [RequestContext]. - */ -@get:JvmName("getMessageTaskId") -public val RequestContext.taskId: String? get() = params.message.taskId - -/** - * Returns the task id from the [TaskIdParams] in the [RequestContext]. - */ -@get:JvmName("getTaskIdParamsTaskId") -public val RequestContext.taskId: String get() = params.id diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/exceptions/Exceptions.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/exceptions/Exceptions.kt index 8d1ad19e46..ed65db9868 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/exceptions/Exceptions.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/exceptions/Exceptions.kt @@ -1,5 +1,6 @@ package ai.koog.a2a.server.exceptions +import ai.koog.a2a.server.session.Session import ai.koog.a2a.server.session.SessionEventProcessor /** @@ -21,3 +22,8 @@ public class InvalidEventException(message: String, cause: Throwable? = null) : * An exception that is thrown to indicate errors occurring during push notification operations. */ public class PushNotificationException(message: String, cause: Throwable? = null) : Exception(message, cause) + +/** + * An exception that is thrown to indicate that a [Session] has been closed. + */ +public class SessionClosedException(message: String, cause: Throwable? = null) : Exception(message, cause) diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/IdGenerator.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/IdGenerator.kt new file mode 100644 index 0000000000..e773699356 --- /dev/null +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/IdGenerator.kt @@ -0,0 +1,57 @@ +package ai.koog.a2a.server.session + +import ai.koog.a2a.model.Message +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +/** + * Interface for generating unique IDs for new contexts and tasks. + * + * Called by the server to provide unique IDs when messages lack contextId or taskId, + * preventing race conditions during concurrent agent execution. These generated IDs + * are enforced when agents reply in newly created contexts or create tasks based on + * incoming messages without existing task IDs. + */ +public interface IdGenerator { + /** + * Generates a unique context ID based on the given message. + * + * @param message The message for which the context ID is being generated. + * @return A unique string representing the context ID. + */ + public suspend fun generateContextId(message: Message): String + + /** + * Generates a unique task ID based on the given message. + * + * @param message The message for which the task ID is being generated. + * @return A unique string representing the task ID. + */ + public suspend fun generateTaskId(message: Message): String +} + +/** + * Implementation of the [IdGenerator] interface that generates unique identifiers using UUIDs. + * + * This generator ensures that each context ID and task ID is uniquely identified, leveraging UUIDs + * for randomness and collision resistance. IDs are generated only if the relevant existing ID + * (contextId or taskId) is null in the incoming message. + */ +@OptIn(ExperimentalUuidApi::class) +public object UuidIdGenerator : IdGenerator { + override suspend fun generateContextId(message: Message): String { + require(message.contextId == null) { + "Can't generate contextId for message with existing contextId" + } + + return Uuid.random().toString() + } + + override suspend fun generateTaskId(message: Message): String { + require(message.taskId == null) { + "Can't generate taskId for message with existing taskId" + } + + return Uuid.random().toString() + } +} diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/RequestContext.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/RequestContext.kt index 8672c7b128..ae51c604b2 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/RequestContext.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/RequestContext.kt @@ -1,5 +1,6 @@ package ai.koog.a2a.server.session +import ai.koog.a2a.model.Task import ai.koog.a2a.server.messages.ContextMessageStorage import ai.koog.a2a.server.tasks.ContextTaskStorage import ai.koog.a2a.transport.ServerCallContext @@ -8,16 +9,25 @@ import ai.koog.a2a.transport.ServerCallContext * Request context associated with each A2A agent-related request, providing essential information and repositories to * the agent executor. * - * @param contextId Context ID associated with this request. - * @param callContext [ServerCallContext] associated with the request. - * @param params Parameters associated with the request. - * @param taskStorage [ContextTaskStorage] associated with the request. - * @param messageStorage [ContextMessageStorage] associated with the request. + * @property callContext [ServerCallContext] associated with the request. + * @property params Parameters associated with the request. + * @property taskStorage [ContextTaskStorage] associated with the request. + * @property messageStorage [ContextMessageStorage] associated with the request. + * @property contextId The context ID representing either an existing context from the incoming request in [params], + * or a newly generated ID that the agent must use if it decides to reply. + * @property taskId The task ID representing either an existing task from the incoming request in [params], + * or a newly generated ID that the agent must use if it decides to create a new task. + * @property task Optional shallow version of the current task (without message history or artifacts) + * providing lightweight access to general task state information. Present only if the incoming request + * was associated with an existing task. For detailed task information or referenced tasks, + * query [taskStorage] directly. */ public class RequestContext( - public val contextId: String, public val callContext: ServerCallContext, public val params: T, public val taskStorage: ContextTaskStorage, public val messageStorage: ContextMessageStorage, + public val contextId: String, + public val taskId: String, + public val task: Task?, ) diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt index eab861892b..fedb1fa0dc 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt @@ -4,7 +4,8 @@ import ai.koog.a2a.model.Event import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineStart import kotlinx.coroutines.Job -import kotlinx.coroutines.flow.SharedFlow +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.collect import kotlinx.coroutines.launch /** @@ -12,15 +13,17 @@ import kotlinx.coroutines.launch * * @property eventProcessor Handles session events and provides event streaming * @property agentJob The coroutine job executing the agent logic - * @property events A stream of events generated during this session - * @property contextId Unique context identifier for this session + * @property contextId Unique context ID associated with this session, delegates to [SessionEventProcessor.contextId] + * @property taskId Unique task ID associated with this session, delegates to [SessionEventProcessor.contextId] + * @property events A stream of events generated during this session, delegates to [SessionEventProcessor.events] */ public class Session( public val eventProcessor: SessionEventProcessor, public val agentJob: Job ) { - public val events: SharedFlow get() = eventProcessor.events public val contextId: String get() = eventProcessor.contextId + public val taskId: String get() = eventProcessor.taskId + public val events: Flow get() = eventProcessor.events /** * Starts the agent execution job. @@ -30,16 +33,17 @@ public class Session( } /* - * Suspends until the agent job completes + * Suspends until the session, i.e., agent job and event stream, complete. */ public suspend fun join() { agentJob.join() + events.collect() } /** * Cancels the agent job and closes the event processor */ - public fun close() { + public suspend fun close() { agentJob.cancel() eventProcessor.close() } diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt index 4352c8ca57..cdc9c3765f 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt @@ -8,15 +8,19 @@ import ai.koog.a2a.model.TaskEvent import ai.koog.a2a.model.TaskState import ai.koog.a2a.model.TaskStatusUpdateEvent import ai.koog.a2a.server.exceptions.InvalidEventException +import ai.koog.a2a.server.exceptions.SessionClosedException import ai.koog.a2a.server.tasks.TaskStorage -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.channels.Channel -import kotlinx.coroutines.flow.SharedFlow -import kotlinx.coroutines.flow.SharingStarted -import kotlinx.coroutines.flow.receiveAsFlow -import kotlinx.coroutines.flow.shareIn +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.MutableSharedFlow +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.emitAll +import kotlinx.coroutines.flow.filterIsInstance +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.takeWhile import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock +import kotlin.jvm.JvmInline /** * A session processor responsible for handling session events. @@ -31,25 +35,35 @@ import kotlinx.coroutines.sync.withLock * - **Context ID validation**: All events must have the same contextId as the session * - **Single message limit**: Only one [Message] can be sent per session, after which the session becomes terminal * - **Task initialization order**: For new tasks, the first [TaskEvent] must be of type [Task] to create the task - * - **Task ID consistency**: Once a task session is initialized, only [TaskEvent]s with the same taskId are allowed + * - **Task ID consistency**: [TaskEvent] events must have task ids equal to [taskId] provided for this session. * - **Final event enforcement**: After a [TaskStatusUpdateEvent] with `final=true` is sent, no more events are permitted * - **Terminal state blocking**: No events can be sent when the task is already in a terminal state * - **Final flag requirement**: [TaskStatusUpdateEvent]s that set the task to a terminal state must have `final=true` * - * @property contextId The contextId of the session. + * @property contextId The contextId associated with this session, representing either an existing context + * from the incoming request or a newly generated ID that must be used for all events in this session. + * @property taskId The taskId associated with this session, representing either an existing task + * from the incoming request or a newly generated ID that must be used if creating a new task. + * Note: This taskId might not correspond to an actually existing task initially - it serves as the + * identifier that will be validated against all [TaskEvent] in this session. * @param taskStorage The storage for tasks where task events will be saved. - * @param coroutineScope The scope in which the event flow will be shared - * @param currentTask The current task associated with the session, if it is a continuation of a previous task session. + * @param task The initial task associated with the session, if it is a continuation of a previous task session. * - * @property events A shared flow of session events that can be subscribed. The flow will be closed when the session is closed. + * @property events A hot flow of events in this session that can be subscribed to. */ public class SessionEventProcessor( public val contextId: String, + public val taskId: String, private val taskStorage: TaskStorage, - coroutineScope: CoroutineScope, - currentTask: Task? = null, -) : AutoCloseable { + private val task: Task? = null, +) { private companion object { + private const val SESSION_CLOSED = "Session event processor is closed, can't send events" + + private const val INVALID_CONTEXT_ID = "Event contextId must be same as provided contextId" + + private const val INVALID_TASK_ID = "Event taskId must be same as provided taskId" + private const val MESSAGE_SENT = "Message has already been sent in this session. Sending message is a terminal operation and no more events " + "are allowed to be sent, the session must terminate ASAP" @@ -70,10 +84,11 @@ public class SessionEventProcessor( private const val TASK_DOES_NOT_EXIST = "Task associated with the taskId in TaskEvent does not exist yet and the event was not Task. Creating new " + "task should always start with Task event." - - private const val INVALID_CONTEXT_ID = "Event contextId must be same as current contextId" } + /** + * Helper interface to handle different session types. + */ private sealed interface SessionType { object MessageSession : SessionType @@ -84,13 +99,32 @@ public class SessionEventProcessor( ) : SessionType } - private val _events = Channel() - public val events: SharedFlow = _events - .receiveAsFlow() - .shareIn(scope = coroutineScope, started = SharingStarted.Eagerly) + /** + * Helper interface to send actual events or termination signal to cancel events stream on session closure. + */ + private sealed interface FlowEvent { + @JvmInline + value class Data(val data: Event) : FlowEvent + object Cancel : FlowEvent + } + + private val isClosed = MutableStateFlow(false) + + private val _events = MutableSharedFlow() + public val events: Flow + get() = flow { + if (!isClosed.value) { + emitAll( + _events + .takeWhile { !isClosed.value } + .filterIsInstance() + .map { it.data } + ) + } + } private val sessionMutex = Mutex() - private var sessionType: SessionType? = currentTask?.let { + private var sessionType: SessionType? = task?.let { SessionType.TaskSession( taskId = it.id, taskState = it.status.state @@ -101,11 +135,15 @@ public class SessionEventProcessor( * Sends a [Message] to the session event processor. Validates the message against the session context and updates * the session state accordingly. * - * @param message The message to be sent. Contains details such as message content, context ID, and metadata. + * @param message The message to be sent. * @throws [InvalidEventException] for invalid events. * Check [SessionEventProcessor] docs from info about valid events. */ public suspend fun sendMessage(message: Message): Unit = sessionMutex.withLock { + if (isClosed.value) { + throw SessionClosedException(SESSION_CLOSED) + } + if (message.contextId != contextId) { throw InvalidEventException(INVALID_CONTEXT_ID) } @@ -116,7 +154,7 @@ public class SessionEventProcessor( is SessionType.TaskSession -> throw InvalidEventException(TASK_INITIALIZED) null -> { - _events.send(message) + _events.emit(FlowEvent.Data(message)) sessionType = SessionType.MessageSession } } @@ -126,14 +164,23 @@ public class SessionEventProcessor( * Sends a [TaskEvent] to the session event processor. Validates the event against the session context and updates * the session state and [taskStorage] accordingly. * - * @param event The event to be sent. Contains details such as task ID, context ID, and metadata. + * @param event The event to be sent. * @throws [InvalidEventException] for invalid events. * Check [SessionEventProcessor] docs from info about valid events. */ public suspend fun sendTaskEvent(event: TaskEvent): Unit = sessionMutex.withLock { + if (isClosed.value) { + throw SessionClosedException(SESSION_CLOSED) + } + if (event.contextId != contextId) { throw InvalidEventException(INVALID_CONTEXT_ID) } + + if (event.taskId != taskId) { + throw InvalidEventException(INVALID_TASK_ID) + } + /* The first set of checks, to get initial task session type if it is allowed here. */ @@ -143,11 +190,9 @@ public class SessionEventProcessor( is SessionType.TaskSession -> sessionType as SessionType.TaskSession null -> { - val savedTask = taskStorage.get(event.taskId, historyLength = 0, includeArtifacts = false) - SessionType.TaskSession( taskId = event.taskId, - taskState = savedTask?.status?.state, // null - new task + taskState = task?.status?.state, // null - new task finalEventReceived = false ).also { sessionType = it @@ -188,7 +233,7 @@ public class SessionEventProcessor( // Only if all checks passed, attempt to update and emit the event taskStorage.update(event) - _events.send(event) + _events.emit(FlowEvent.Data(event)) when (event) { is TaskStatusUpdateEvent -> taskSessionType.apply { @@ -206,7 +251,8 @@ public class SessionEventProcessor( } } - override fun close() { - _events.close() + public suspend fun close(): Unit = sessionMutex.withLock { + isClosed.value = true + _events.emit(FlowEvent.Cancel) } } diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt index 94aed0e3bc..6e8087391f 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt @@ -1,23 +1,18 @@ package ai.koog.a2a.server.session import ai.koog.a2a.annotations.InternalA2AApi -import ai.koog.a2a.model.Message import ai.koog.a2a.model.TaskEvent import ai.koog.a2a.server.notifications.PushNotificationConfigStorage import ai.koog.a2a.server.notifications.PushNotificationSender import ai.koog.a2a.server.tasks.TaskStorage import ai.koog.a2a.utils.RWLock import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.firstOrNull import kotlinx.coroutines.launch /** * Manages a set of active instances of [Session], sends push notifications if configured after each session completes. - * - * Each session's event stream is monitored for task id associated with this session, if any, i.e., the session is processing a task, - * and if it is a task-related session, it is added to the task sessions map. - * - * Automatically closes and removes the session when it is completed (whether successfully or not). + * Automatically closes and removes the session when agent job is completed (whether successfully or not). * * Additionally, if push notifications are configured, after each task session completes, push notifications are sent with * the current task state. @@ -34,8 +29,10 @@ public class SessionManager( private val pushConfigStorage: PushNotificationConfigStorage? = null, private val pushSender: PushNotificationSender? = null, ) { - private val allSessions = mutableSetOf() - private val taskSessions = mutableMapOf() + /** + * Map of task id to session. All sessions have task id associated with them, even if the task won't be created. + */ + private val sessions = mutableMapOf() private val rwLock = RWLock() /** @@ -46,56 +43,37 @@ public class SessionManager( * * @param session The session to add. */ - public fun addSession(session: Session) { - coroutineScope.launch { - // Check if the first event is task related and add this session to the task sessions map. - when (val firstEvent = session.events.first()) { - is TaskEvent -> { - val taskId = firstEvent.taskId - - rwLock.withWriteLock { - check(taskId !in taskSessions) { - "SessionEventProcessor for taskId '${firstEvent.taskId}' already exists." - } + public suspend fun addSession(session: Session) { + rwLock.withWriteLock { + check(session.taskId !in sessions) { + "SessionEventProcessor for taskId '${session.taskId}' already exists." + } - allSessions += session - taskSessions[firstEvent.taskId] = session - } + sessions[session.taskId] = session + } - // Wait for the session to complete, then close and remove it from collections. - session.join() + // Monitor for agent job completion to send push notifications and remove session from the map. + coroutineScope.launch { + val firstEvent = session.events.firstOrNull() - rwLock.withWriteLock { - session.close() - allSessions -= session - taskSessions -= taskId - } + // Wait for agent job to complete + session.agentJob.join() - // Send push notifications with the current state of the task, after the session completion, if configured. - if (pushSender != null && pushConfigStorage != null) { - val task = taskStorage.get(taskId, historyLength = 0) + // Send push notifications with the current state of the task, after the session completion, if configured. + if (firstEvent is TaskEvent && pushSender != null && pushConfigStorage != null) { + val task = taskStorage.get(session.taskId, historyLength = 0, includeArtifacts = false) - if (task != null) { - pushConfigStorage.getAll(taskId).forEach { config -> - pushSender.send(config, task) - } - } + if (task != null) { + pushConfigStorage.getAll(session.taskId).forEach { config -> + pushSender.send(config, task) } } + } - is Message -> { - rwLock.withWriteLock { - allSessions += session - } - - // Wait for the session to complete, then close and remove it from collection. - session.join() - - rwLock.withWriteLock { - session.close() - allSessions -= session - } - } + // Close the session completely and remove it from the sessions map. + rwLock.withWriteLock { + sessions -= session.taskId + session.close() } } } @@ -104,13 +82,13 @@ public class SessionManager( * Returns the session for the given task id, if any. */ public suspend fun sessionForTask(taskId: String): Session? = rwLock.withReadLock { - taskSessions[taskId] + sessions[taskId] } /** * Returns the number of active sessions. */ public suspend fun activeSessions(): Int = rwLock.withReadLock { - allSessions.size + sessions.size } } diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/TaskStorage.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/TaskStorage.kt index 0f76c9353f..b6ac235bbb 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/TaskStorage.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/TaskStorage.kt @@ -86,6 +86,13 @@ public interface TaskStorage { public suspend fun deleteAll(taskIds: List) } +/** + * A specialized wrapper around [TaskStorage] for providing access to the tasks within a specific context. + * + * @param contextId the unique identifier for the current context + * @param taskStorage the underlying task storage implementation + * @see [TaskStorage] + */ public class ContextTaskStorage( private val contextId: String, private val taskStorage: TaskStorage, diff --git a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionEventProcessorTest.kt b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionEventProcessorTest.kt new file mode 100644 index 0000000000..f6421e4b76 --- /dev/null +++ b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionEventProcessorTest.kt @@ -0,0 +1,441 @@ +package ai.koog.a2a.server.session + +import ai.koog.a2a.model.Artifact +import ai.koog.a2a.model.Event +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskArtifactUpdateEvent +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatus +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.server.exceptions.InvalidEventException +import ai.koog.a2a.server.exceptions.SessionClosedException +import ai.koog.a2a.server.tasks.InMemoryTaskStorage +import kotlinx.coroutines.flow.lastOrNull +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.yield +import kotlinx.datetime.Instant +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertNull +import kotlin.time.Duration.Companion.seconds + +class SessionEventProcessorTest { + private companion object { + private val TEST_TIMEOUT = 5.seconds + } + + private lateinit var taskStorage: InMemoryTaskStorage + private val contextId = "test-context-1" + private val taskId = "task-1" + + @BeforeTest + fun setUp() { + taskStorage = InMemoryTaskStorage() + } + + private fun createMessage( + messageId: String, + contextId: String, + content: String + ) = Message( + messageId = messageId, + role = Role.User, + parts = listOf(TextPart(content)), + contextId = contextId + ) + + private fun createTask( + id: String, + contextId: String, + state: TaskState = TaskState.Submitted + ) = Task( + id = id, + contextId = contextId, + status = TaskStatus( + state = state, + timestamp = Instant.parse("2023-01-01T10:00:00Z") + ) + ) + + private fun createProcessor( + contextId: String, + taskId: String, + task: Task? = null + ) = SessionEventProcessor( + contextId = contextId, + taskId = taskId, + taskStorage = taskStorage, + task = task + ) + + @Test + fun message_testSendMessageWithInvalidContextId() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val message = createMessage("msg-1", "different-context", "Hello") + + assertFailsWith { + processor.sendMessage(message) + } + + processor.close() + } + + @Test + fun message_testSendSecondMessageFails() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val message1 = createMessage("msg-1", contextId, "Hello") + val message2 = createMessage("msg-2", contextId, "World") + + processor.sendMessage(message1) + + assertFailsWith { + processor.sendMessage(message2) + } + + processor.close() + } + + @Test + fun message_testSendTaskEventAfterMessageFails() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val message = createMessage("msg-1", contextId, "Hello") + val task = createTask(taskId, contextId) + + processor.sendMessage(message) + + assertFailsWith { + processor.sendTaskEvent(task) + } + + processor.close() + } + + // Task session tests + + @Test + fun task_testSendTaskEventNewTask() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val task = createTask(taskId, contextId) + + // Start collecting events before sending + val events = mutableListOf() + val eventsJob = launch { + processor.events.collect { + events.add(it) + } + } + + // Let the eventJob job actually start + yield() + + processor.sendTaskEvent(task) + processor.close() + + // Wait for event collection to complete + eventsJob.join() + + // Verify task was stored and collected + val storedTask = taskStorage.get(taskId) + + assertEquals(task, storedTask) + assertEquals(listOf(task), events.toList()) + } + + @Test + fun task_testSendTaskEventWithExistingTask() = runTest(timeout = TEST_TIMEOUT) { + // Store a task first + val existingTask = createTask(taskId, contextId) + taskStorage.update(existingTask) + + // Create processor with existing task + val processor = createProcessor(contextId, taskId, existingTask) + + val statusUpdate = TaskStatusUpdateEvent( + taskId = taskId, + contextId = contextId, + status = TaskStatus(state = TaskState.Working), + final = false + ) + + processor.sendTaskEvent(statusUpdate) + processor.close() + + // Verify event was processed + val updatedTask = taskStorage.get(taskId) + assertEquals(statusUpdate.status, updatedTask?.status) + } + + @Test + fun task_testSendTaskEventWithInvalidContextId() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val task = createTask(taskId, "different-context") + + assertFailsWith { + processor.sendTaskEvent(task) + } + + processor.close() + } + + @Test + fun task_testSendNonTaskEventForNewTaskFails() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val statusUpdate = TaskStatusUpdateEvent( + taskId = taskId, + contextId = contextId, + status = TaskStatus(state = TaskState.Working), + final = false + ) + + assertFailsWith { + processor.sendTaskEvent(statusUpdate) + } + + processor.close() + } + + @Test + fun task_testSendEventAfterFinalEventFails() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val task = createTask(taskId, contextId) + val finalStatusUpdate = TaskStatusUpdateEvent( + taskId = taskId, + contextId = contextId, + status = TaskStatus(state = TaskState.Completed), + final = true + ) + val anotherEvent = TaskArtifactUpdateEvent( + taskId = taskId, + contextId = contextId, + artifact = Artifact( + artifactId = "artifact-1", + parts = listOf(TextPart("content")) + ), + append = false + ) + + processor.sendTaskEvent(task) + processor.sendTaskEvent(finalStatusUpdate) + + assertFailsWith { + processor.sendTaskEvent(anotherEvent) + } + + processor.close() + } + + @Test + fun task_testSendEventWhenTaskInTerminalStateFails() = runTest(timeout = TEST_TIMEOUT) { + // Create task in terminal state + val completedTask = createTask(taskId, contextId, TaskState.Completed) + taskStorage.update(completedTask) + + val processor = createProcessor(contextId, taskId) + + val artifactEvent = TaskArtifactUpdateEvent( + taskId = taskId, + contextId = contextId, + artifact = Artifact( + artifactId = "artifact-1", + parts = listOf(TextPart("content")) + ), + append = false + ) + + assertFailsWith { + processor.sendTaskEvent(artifactEvent) + } + + processor.close() + } + + @Test + fun task_testTerminalStatusUpdateWithoutFinalFlagFails() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val task = createTask(taskId, contextId) + val statusUpdate = TaskStatusUpdateEvent( + taskId = taskId, + contextId = contextId, + status = TaskStatus(state = TaskState.Failed), + final = false // This should be true for terminal state + ) + + processor.sendTaskEvent(task) + + assertFailsWith { + processor.sendTaskEvent(statusUpdate) + } + + processor.close() + } + + @Test + fun task_testSendMessageAfterTaskEventFails() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val task = createTask(taskId, contextId) + val message = createMessage("msg-1", contextId, "Hello") + + processor.sendTaskEvent(task) + + assertFailsWith { + processor.sendMessage(message) + } + + processor.close() + } + + @Test + fun task_testTaskArtifactUpdateEvent() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val task = createTask(taskId, contextId) + val artifact = Artifact( + artifactId = "artifact-1", + parts = listOf(TextPart("content")) + ) + val artifactEvent = TaskArtifactUpdateEvent( + taskId = taskId, + contextId = contextId, + artifact = artifact, + append = false + ) + + processor.sendTaskEvent(task) + processor.sendTaskEvent(artifactEvent) + processor.close() + + // Verify artifact was stored + val storedTask = taskStorage.get("task-1", includeArtifacts = true) + assertEquals(listOf(artifact), storedTask?.artifacts) + } + + @Test + fun task_testCompleteTaskLifecycle() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + + // Create task + val task = createTask(taskId, contextId) + processor.sendTaskEvent(task) + + // Update status to working + val workingUpdate = TaskStatusUpdateEvent( + taskId = taskId, + contextId = contextId, + status = TaskStatus(state = TaskState.Working), + final = false + ) + processor.sendTaskEvent(workingUpdate) + + // Add artifact + val artifact = Artifact( + artifactId = "artifact-1", + parts = listOf(TextPart("result")) + ) + val artifactEvent = TaskArtifactUpdateEvent( + taskId = taskId, + contextId = contextId, + artifact = artifact, + append = false + ) + processor.sendTaskEvent(artifactEvent) + + // Complete task + val completedUpdate = TaskStatusUpdateEvent( + taskId = taskId, + contextId = contextId, + status = TaskStatus(state = TaskState.Completed), + final = true + ) + processor.sendTaskEvent(completedUpdate) + + processor.close() + + // Verify final state + val finalTask = taskStorage.get(taskId, includeArtifacts = true) + assertEquals(TaskState.Completed, finalTask?.status?.state) + assertEquals(listOf(artifact), finalTask?.artifacts) + } + + // Concurrent scenarios + + @Test + fun concurrent_message_testSendMessageBroadcast() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val message = createMessage("msg-1", contextId, "Hello") + + fun collectionJob(events: MutableList) = launch { + processor.events.collect { + events.add(it) + } + } + + // Set up two collectors to test that events are broadcasted properly + val eventsOne = mutableListOf() + val eventsJobOne = collectionJob(eventsOne) + + val eventsTwo = mutableListOf() + val eventsJobTwo = collectionJob(eventsTwo) + + // Let event jobs actually start + yield() + + processor.sendMessage(message) + processor.close() + + eventsJobOne.join() + eventsJobTwo.join() + + assertEquals(listOf(message), eventsOne.toList(), "First collector should collect the message") + assertEquals(listOf(message), eventsTwo.toList(), "Second collector should collect the message") + } + + @Test + fun concurrent_message_testClosedProcessorSendMessageFailsAndEventStreamIsEmpty() = + runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val message1 = createMessage("msg-1", contextId, "Hello") + val message2 = createMessage("msg-2", contextId, "World") + + // Send first message + processor.sendMessage(message1) + + // Close processor and then attempt to send more events + processor.close() + + assertFailsWith("Should not be possible to send events to closed session") { + processor.sendMessage(message2) + } + + assertNull(processor.events.lastOrNull(), "Events stream should be empty after closing") + } + + @Test + fun concurrent_task_testClosedProcessorSendTaskEventFailsAndEventStreamIsEmpty() = runTest(timeout = TEST_TIMEOUT) { + val processor = createProcessor(contextId, taskId) + val task = createTask(taskId, contextId) + + // Send first task event + processor.sendTaskEvent(task) + + // Close processor and then attempt to send more events + processor.close() + + val workingUpdate = TaskStatusUpdateEvent( + taskId = taskId, + contextId = contextId, + status = TaskStatus(state = TaskState.Working), + final = false + ) + + assertFailsWith("Should not be possible to send events to closed session") { + processor.sendTaskEvent(workingUpdate) + } + + assertNull(processor.events.lastOrNull(), "Events stream should be empty after closing") + } +} diff --git a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionManagerTest.kt b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionManagerTest.kt new file mode 100644 index 0000000000..94161bdf29 --- /dev/null +++ b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionManagerTest.kt @@ -0,0 +1,257 @@ +package ai.koog.a2a.server.session + +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.PushNotificationConfig +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatus +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.server.notifications.InMemoryPushNotificationConfigStorage +import ai.koog.a2a.server.notifications.PushNotificationSender +import ai.koog.a2a.server.tasks.InMemoryTaskStorage +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.delay +import kotlinx.coroutines.test.runTest +import kotlinx.datetime.Instant +import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNull + +class SessionManagerTest { + private lateinit var taskStorage: InMemoryTaskStorage + private lateinit var pushConfigStorage: InMemoryPushNotificationConfigStorage + private lateinit var pushSender: MockPushNotificationSender + + private val contextId = "test-context-1" + private val taskId = "task-1" + + private class MockPushNotificationSender : PushNotificationSender { + val sentNotifications = mutableListOf>() + + override suspend fun send(config: PushNotificationConfig, task: Task) { + sentNotifications.add(config to task) + } + } + + @BeforeTest + fun setUp() { + taskStorage = InMemoryTaskStorage() + pushConfigStorage = InMemoryPushNotificationConfigStorage() + pushSender = MockPushNotificationSender() + } + + private fun createMessage( + messageId: String, + contextId: String, + content: String + ) = Message( + messageId = messageId, + role = Role.User, + parts = listOf(TextPart(content)), + contextId = contextId + ) + + private fun createTask( + id: String, + contextId: String, + state: TaskState = TaskState.Submitted + ) = Task( + id = id, + contextId = contextId, + status = TaskStatus( + state = state, + timestamp = Instant.parse("2023-01-01T10:00:00Z") + ) + ) + + private fun createProcessor( + contextId: String, + taskId: String, + task: Task? = null + ) = SessionEventProcessor( + contextId = contextId, + taskId = taskId, + taskStorage = taskStorage, + task = task + ) + + private fun createManager( + coroutineScope: CoroutineScope, + ) = SessionManager( + coroutineScope = coroutineScope, + taskStorage = taskStorage, + pushConfigStorage = pushConfigStorage, + pushSender = pushSender, + ) + + @Test + fun testSessionManagerCreation() = runTest { + val sessionManager = SessionManager( + coroutineScope = this, + taskStorage = taskStorage + ) + + assertEquals(0, sessionManager.activeSessions()) + assertNull(sessionManager.sessionForTask("any-task-id")) + } + + @Test + fun testAddMessageSession() = runTest { + val sessionManager = createManager(this) + val eventProcessor = createProcessor(contextId, taskId) + + val message = createMessage("msg-1", contextId, "Hello") + + val session = Session( + coroutineScope = this, + eventProcessor = eventProcessor + ) { + eventProcessor.sendMessage(message) + } + + // Start session and wait for completion + sessionManager.addSession(session) + session.join() + + // Session should be automatically cleaned up after completion + assertEquals(0, sessionManager.activeSessions()) + } + + @Test + fun testAddTaskSession() = runTest { + val sessionManager = createManager(this) + val eventProcessor = createProcessor(contextId, taskId) + + val session = Session( + coroutineScope = this, + eventProcessor = eventProcessor + ) { + val task = createTask(taskId, contextId) + eventProcessor.sendTaskEvent(task) + + // Simulate work + delay(400) + + val statusUpdate = TaskStatusUpdateEvent( + taskId = taskId, + contextId = contextId, + status = TaskStatus(state = TaskState.Completed), + final = true + ) + eventProcessor.sendTaskEvent(statusUpdate) + } + + sessionManager.addSession(session) + session.start() + + assertEquals(session, sessionManager.sessionForTask(taskId)) + + session.join() + + // Session should be automatically cleaned up after completion + assertEquals(0, sessionManager.activeSessions()) + } + + @Test + fun testMultipleSessions() = runTest { + val sessionManager = createManager(this) + + // Create two task sessions + val eventProcessor1 = createProcessor("context-1", "task-1") + val eventProcessor2 = createProcessor("context-2", "task-2") + + val session1 = Session( + coroutineScope = this, + eventProcessor = eventProcessor1 + ) { + val task = createTask("task-1", "context-1") + eventProcessor1.sendTaskEvent(task) + + // Simulate work + delay(150) + + val statusUpdate = TaskStatusUpdateEvent( + taskId = "task-1", + contextId = "context-1", + status = TaskStatus(state = TaskState.Completed), + final = true + ) + eventProcessor1.sendTaskEvent(statusUpdate) + } + + val session2 = Session( + coroutineScope = this, + eventProcessor = eventProcessor2 + ) { + val task = createTask("task-2", "context-2") + eventProcessor2.sendTaskEvent(task) + + // Simulate work + delay(150) + + val statusUpdate = TaskStatusUpdateEvent( + taskId = "task-2", + contextId = "context-2", + status = TaskStatus(state = TaskState.Completed), + final = true + ) + eventProcessor2.sendTaskEvent(statusUpdate) + } + + sessionManager.addSession(session1) + sessionManager.addSession(session2) + session1.start() + session2.start() + + assertEquals(session1, sessionManager.sessionForTask("task-1")) + assertEquals(session2, sessionManager.sessionForTask("task-2")) + + session1.join() + session2.join() + + // All sessions should be automatically cleaned up + assertEquals(0, sessionManager.activeSessions()) + } + + @Test + fun testSessionWithPushNotifications() = runTest { + val sessionManager = createManager(this) + val eventProcessor = createProcessor(contextId, taskId) + + // Configure push notification + val config = PushNotificationConfig( + id = "config-1", + url = "https://example.com/webhook" + ) + pushConfigStorage.save("task-1", config) + + val task = createTask("task-1", contextId) + + val session = Session( + coroutineScope = this, + eventProcessor = eventProcessor + ) { + eventProcessor.sendTaskEvent(task) + + val statusUpdate = TaskStatusUpdateEvent( + taskId = "task-1", + contextId = contextId, + status = TaskStatus(state = TaskState.Completed), + final = true + ) + eventProcessor.sendTaskEvent(statusUpdate) + } + + sessionManager.addSession(session) + session.join() + + // Verify push notification was sent + assertEquals(1, pushSender.sentNotifications.size) + val (sentConfig, sentTask) = pushSender.sentNotifications[0] + assertEquals(config, sentConfig) + assertEquals(TaskState.Completed, sentTask.status.state) + } +} diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts index dc383847a5..be7918cfff 100644 --- a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts +++ b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts @@ -9,12 +9,6 @@ plugins { } kotlin { - jvm() - - js(IR) { - browser() - } - sourceSets { commonMain { dependencies { diff --git a/a2a/a2a-transport/a2a-transport-client-rest/build.gradle.kts b/a2a/a2a-transport/a2a-transport-client-rest/build.gradle.kts index 36e5aa832a..a124396355 100644 --- a/a2a/a2a-transport/a2a-transport-client-rest/build.gradle.kts +++ b/a2a/a2a-transport/a2a-transport-client-rest/build.gradle.kts @@ -9,12 +9,6 @@ plugins { } kotlin { - jvm() - - js(IR) { - browser() - } - sourceSets { commonMain { dependencies { diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/build.gradle.kts b/a2a/a2a-transport/a2a-transport-core-jsonrpc/build.gradle.kts index c8017f0eee..591006e83b 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/build.gradle.kts +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/build.gradle.kts @@ -9,12 +9,6 @@ plugins { } kotlin { - jvm() - - js(IR) { - browser() - } - sourceSets { commonMain { dependencies { diff --git a/a2a/a2a-transport/a2a-transport-core-rest/build.gradle.kts b/a2a/a2a-transport/a2a-transport-core-rest/build.gradle.kts index 36e5aa832a..a124396355 100644 --- a/a2a/a2a-transport/a2a-transport-core-rest/build.gradle.kts +++ b/a2a/a2a-transport/a2a-transport-core-rest/build.gradle.kts @@ -9,12 +9,6 @@ plugins { } kotlin { - jvm() - - js(IR) { - browser() - } - sourceSets { commonMain { dependencies { diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts index 592b576661..de343e79f1 100644 --- a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts @@ -9,12 +9,6 @@ plugins { } kotlin { - jvm() - - js(IR) { - browser() - } - sourceSets { commonMain { dependencies { diff --git a/a2a/a2a-transport/a2a-transport-server-rest/build.gradle.kts b/a2a/a2a-transport/a2a-transport-server-rest/build.gradle.kts index 36e5aa832a..a124396355 100644 --- a/a2a/a2a-transport/a2a-transport-server-rest/build.gradle.kts +++ b/a2a/a2a-transport/a2a-transport-server-rest/build.gradle.kts @@ -9,12 +9,6 @@ plugins { } kotlin { - jvm() - - js(IR) { - browser() - } - sourceSets { commonMain { dependencies { From dfc5a7948f4c7252edc1b461718cdef2e55c9c80 Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Wed, 24 Sep 2025 14:53:35 +0200 Subject: [PATCH 31/52] [a2a] Refine AgentExecutor interface and add test utilities --- a2a/a2a-server/build.gradle.kts | 1 + .../ai/koog/a2a/server/agent/AgentExecutor.kt | 44 ++++ .../a2a/server/session/SessionManagerTest.kt | 2 +- .../ai/koog/a2a/server/TestAgentExecutor.kt | 219 ++++++++++++++++++ .../http/HttpJSONRPCClientTransport.kt | 5 - .../jsonrpc/JSONRPCServerTransport.kt | 4 +- 6 files changed, 267 insertions(+), 8 deletions(-) create mode 100644 a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/TestAgentExecutor.kt diff --git a/a2a/a2a-server/build.gradle.kts b/a2a/a2a-server/build.gradle.kts index 1c94f80da0..de3c676265 100644 --- a/a2a/a2a-server/build.gradle.kts +++ b/a2a/a2a-server/build.gradle.kts @@ -24,6 +24,7 @@ kotlin { commonTest { dependencies { + implementation(project(":a2a:a2a-test")) implementation(kotlin("test")) implementation(libs.kotlinx.coroutines.test) } diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt index 16cd41536e..0d52f92139 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt @@ -28,6 +28,50 @@ public interface AgentExecutor { * * Can throw an exception if the input is invalid or the agent fails to execute the request. * + * Example implementation: + * ```kotlin + * val userMessage = context.params.message + * + * // Process the message and create a task + * // Send task creation event + * eventProcessor.sendTaskEvent( + * Task( + * id = context.taskId, + * contextId = context.contextId, + * status = TaskStatus( + * state = TaskState.Working, + * // Mark this message as belonging to the created task + * message = message.copy(taskId = context.taskId) + * timestamp = Clock.System.now() + * ), + * ) + * ) + * + * // Simulate some work + * delay(1000) + * + * // Mark task as completed + * eventProcessor.sendTaskEvent( + * TaskStatusUpdateEvent( + * taskId = context.taskId, + * contextId = context.contextId, + * status = TaskStatus( + * state = TaskState.Completed, + * message = Message( + * role = Role.Agent, + * contextId = context.contextId, + * taskId = context.taskId, + * parts = listOf( + * TextPart("Task completed successfully!") + * ) + * ), + * timestamp = Clock.System.now() + * ), + * final = true + * ) + * ) + * ``` + * * @param context The context containing the necessary information and accessors for executing the agent. * @param eventProcessor The event processor to publish events to. * @throws Exception if something goes wrong during execution. Should prefer more specific exceptions when possible, diff --git a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionManagerTest.kt b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionManagerTest.kt index 94161bdf29..81b2c519ed 100644 --- a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionManagerTest.kt +++ b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionManagerTest.kt @@ -25,7 +25,7 @@ class SessionManagerTest { private lateinit var pushConfigStorage: InMemoryPushNotificationConfigStorage private lateinit var pushSender: MockPushNotificationSender - private val contextId = "test-context-1" + private val contextId = "context-1" private val taskId = "task-1" private class MockPushNotificationSender : PushNotificationSender { diff --git a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/TestAgentExecutor.kt b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/TestAgentExecutor.kt new file mode 100644 index 0000000000..242b443783 --- /dev/null +++ b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/TestAgentExecutor.kt @@ -0,0 +1,219 @@ +package ai.koog.a2a.server + +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatus +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.server.agent.AgentExecutor +import ai.koog.a2a.server.session.RequestContext +import ai.koog.a2a.server.session.Session +import ai.koog.a2a.server.session.SessionEventProcessor +import kotlinx.coroutines.delay +import kotlinx.datetime.Clock + +private suspend fun sayHello( + context: RequestContext, + eventProcessor: SessionEventProcessor, +) { + eventProcessor.sendMessage( + Message( + role = Role.Agent, + parts = listOf(TextPart("Hello World")), + contextId = context.contextId, + taskId = context.taskId + ) + ) +} + +private suspend fun doTask( + context: RequestContext, + eventProcessor: SessionEventProcessor, +) { + val task = Task( + id = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = TaskState.Working, + message = Message( + role = Role.Agent, + parts = listOf(TextPart("Task created")), + contextId = context.contextId, + taskId = context.taskId + ), + timestamp = Clock.System.now() + ) + ) + + // Send initial task event + eventProcessor.sendTaskEvent(task) + + // Send task working status update + eventProcessor.sendTaskEvent( + TaskStatusUpdateEvent( + contextId = context.contextId, + taskId = context.taskId, + status = TaskStatus( + state = TaskState.Working, + message = Message( + role = Role.Agent, + parts = listOf(TextPart("Working on task")), + contextId = context.contextId, + taskId = context.taskId + ), + timestamp = Clock.System.now() + ), + final = false + ) + ) + + // Send task completion status update + eventProcessor.sendTaskEvent( + TaskStatusUpdateEvent( + contextId = context.contextId, + taskId = context.taskId, + status = TaskStatus( + state = TaskState.Completed, + message = Message( + role = Role.Agent, + parts = listOf(TextPart("Task completed")), + contextId = context.contextId, + taskId = context.taskId + ), + timestamp = Clock.System.now() + ), + final = true + ) + ) +} + +private suspend fun doCancelableTask( + context: RequestContext, + eventProcessor: SessionEventProcessor, +) { + val task = Task( + id = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = TaskState.Working, + message = Message( + role = Role.Agent, + parts = listOf(TextPart("Cancelable task created")), + contextId = context.contextId, + taskId = context.taskId + ), + timestamp = Clock.System.now() + ) + ) + + eventProcessor.sendTaskEvent(task) +} + +private suspend fun doLongRunningTask( + context: RequestContext, + eventProcessor: SessionEventProcessor, +) { + val task = Task( + id = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = TaskState.Working, + message = Message( + role = Role.Agent, + parts = listOf(TextPart("Long running task started")), + contextId = context.contextId, + taskId = context.taskId + ), + timestamp = Clock.System.now() + ) + ) + + eventProcessor.sendTaskEvent(task) + + // Simulate long-running task + repeat(4) { + delay(200) + + eventProcessor.sendTaskEvent( + TaskStatusUpdateEvent( + taskId = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = TaskState.Working, + message = Message( + role = Role.Agent, + parts = listOf(TextPart("Still working $it")), + contextId = context.contextId, + taskId = context.taskId + ), + timestamp = Clock.System.now() + ), + final = false + ) + ) + } +} + +class TestAgentExecutor : AgentExecutor { + override suspend fun execute(context: RequestContext, eventProcessor: SessionEventProcessor) { + val userMessage = context.params.message + val userInput = userMessage.parts.filterIsInstance() + .joinToString(" ") { it.text } + .lowercase() + + // Test scenarios to test various aspects of A2A + when { + "hello world" in userInput -> { + sayHello(context, eventProcessor) + } + + "do task" in userInput -> { + doTask(context, eventProcessor) + } + + "do cancelable task" in userInput -> { + doCancelableTask(context, eventProcessor) + } + + "do long-running task" in userInput -> { + doLongRunningTask(context, eventProcessor) + } + + else -> { + eventProcessor.sendMessage( + Message( + role = Role.Agent, + parts = listOf(TextPart("Sorry, I don't understand you")), + contextId = context.contextId + ) + ) + } + } + } + + override suspend fun cancel(context: RequestContext, session: Session) { + session.agentJob.cancel() + + session.eventProcessor.sendTaskEvent( + TaskStatusUpdateEvent( + contextId = context.contextId, + taskId = context.taskId, + status = TaskStatus( + state = TaskState.Canceled, + message = Message( + role = Role.Agent, + parts = listOf(TextPart("Task canceled")), + contextId = context.contextId, + taskId = context.taskId + ), + timestamp = Clock.System.now() + ), + final = true + ) + ) + } +} diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransport.kt b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransport.kt index 77856488f7..294de0d628 100644 --- a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransport.kt +++ b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransport.kt @@ -11,12 +11,10 @@ import io.ktor.client.plugins.contentnegotiation.ContentNegotiation import io.ktor.client.plugins.defaultRequest import io.ktor.client.plugins.sse.SSE import io.ktor.client.plugins.sse.sse -import io.ktor.client.request.accept import io.ktor.client.request.headers import io.ktor.client.request.post import io.ktor.client.request.setBody import io.ktor.http.ContentType -import io.ktor.http.HttpMethod import io.ktor.http.contentType import io.ktor.serialization.kotlinx.json.json import kotlinx.coroutines.flow.Flow @@ -73,9 +71,6 @@ public class HttpJSONRPCClientTransport( ): Flow = flow { httpClient.sse( request = { - method = HttpMethod.Post - accept(ContentType.Text.EventStream) - headers { ctx.additionalHeaders.forEach { (key, values) -> appendAll(key, values) diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt index e68efbc003..8eaac54109 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt @@ -68,7 +68,7 @@ public abstract class JSONRPCServerTransport : ServerTransport { .toJSONRPCSuccessResponse() else -> - throw A2AMethodNotFoundException(request.method) + throw A2AMethodNotFoundException("Method not found: ${request.method}") } }.getOrElse { it.toJSONRPCErrorResponse(request.id) } } @@ -90,7 +90,7 @@ public abstract class JSONRPCServerTransport : ServerTransport { requestHandler.onResubscribeTask(request.toRequest(), ctx) else -> - flow { throw A2AMethodNotFoundException(request.method) } + flow { throw A2AMethodNotFoundException("Method not found: ${request.method}") } }.map { it.toJSONRPCSuccessResponse() as JSONRPCResponse } .catch { emit(it.toJSONRPCErrorResponse(request.id)) } } From 03e2241a180487b260c2ca5fc7c912e814e55c9e Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Wed, 24 Sep 2025 19:53:12 +0200 Subject: [PATCH 32/52] [a2a] Add task locking, add A2A server integration tests --- ....kt => A2AClientJsonRpcIntegrationTest.kt} | 33 +++- a2a/a2a-server/build.gradle.kts | 9 + .../kotlin/ai/koog/a2a/server/A2AServer.kt | 180 ++++++++--------- .../ai/koog/a2a/server/agent/AgentExecutor.kt | 58 +++--- .../ai/koog/a2a/server/session/Session.kt | 34 ++-- .../koog/a2a/server/session/SessionManager.kt | 114 +++++++++-- .../a2a/server/tasks/InMemoryTaskStorage.kt | 5 +- .../a2a/server/session/SessionManagerTest.kt | 184 +++++++++++++++-- .../server/tasks/InMemoryTaskStorageTest.kt | 84 +++++++- .../server/A2AServerJsonRpcIntegrationTest.kt | 187 ++++++++++++++++++ .../ai/koog/a2a/server/TestAgentExecutor.kt | 55 +++--- .../ai/koog/a2a/test/BaseA2AProtocolTest.kt | 3 +- .../http/HttpJSONRPCClientTransport.kt | 5 +- .../koog/a2a/transport/jsonrpc/A2AMethod.kt | 9 +- .../jsonrpc/JSONRPCServerTransport.kt | 14 +- .../http/HttpJSONRPCServerTransport.kt | 100 +++++++--- .../http/HttpJSONRPCServerTransportTest.kt | 3 + .../src/agent_executor.py | 69 +++++-- 18 files changed, 882 insertions(+), 264 deletions(-) rename a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/{A2AClientIntegrationTest.kt => A2AClientJsonRpcIntegrationTest.kt} (64%) create mode 100644 a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/A2AServerJsonRpcIntegrationTest.kt diff --git a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientIntegrationTest.kt b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt similarity index 64% rename from a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientIntegrationTest.kt rename to a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt index 214ee2f865..7799a7e47f 100644 --- a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientIntegrationTest.kt +++ b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt @@ -6,14 +6,22 @@ import io.ktor.client.HttpClient import io.ktor.client.plugins.logging.LogLevel import io.ktor.client.plugins.logging.Logging import kotlinx.coroutines.test.runTest +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.TestInstance import org.testcontainers.containers.GenericContainer import org.testcontainers.containers.wait.strategy.Wait import org.testcontainers.junit.jupiter.Container import org.testcontainers.junit.jupiter.Testcontainers -import kotlin.test.BeforeTest +/** + * Integration test class for testing the JSON-RPC HTTP communication in the A2A client context. + * This class ensures the proper functioning and correctness of the A2A protocol over HTTP + * using the JSON-RPC standard. + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) @Testcontainers -class A2AClientIntegrationTest : BaseA2AProtocolTest() { +class A2AClientJsonRpcIntegrationTest : BaseA2AProtocolTest() { companion object { @Container val testA2AServer: GenericContainer<*> = @@ -31,25 +39,30 @@ class A2AClientIntegrationTest : BaseA2AProtocolTest() { @Suppress("HttpUrlsUsage") private val agentUrl by lazy { "http://${testA2AServer.host}:${testA2AServer.getMappedPort(9999)}" } - private val transport by lazy { - HttpJSONRPCClientTransport( + private lateinit var transport: HttpJSONRPCClientTransport + + override lateinit var client: A2AClient + + @BeforeAll + fun setUp() = runTest { + transport = HttpJSONRPCClientTransport( url = agentUrl, baseHttpClient = httpClient ) - } - override val client by lazy { - A2AClient( + client = A2AClient( transport = transport, agentCardResolver = UrlAgentCardResolver( baseUrl = agentUrl, baseHttpClient = httpClient, ), ) - } - @BeforeTest - fun initClient() = runTest { client.connect() } + + @AfterAll + fun tearDown() = runTest { + transport.close() + } } diff --git a/a2a/a2a-server/build.gradle.kts b/a2a/a2a-server/build.gradle.kts index de3c676265..68186fac97 100644 --- a/a2a/a2a-server/build.gradle.kts +++ b/a2a/a2a-server/build.gradle.kts @@ -33,6 +33,15 @@ kotlin { jvmTest { dependencies { implementation(kotlin("test-junit5")) + implementation(project(":a2a:a2a-test")) + implementation(project(":a2a:a2a-client")) + implementation(project(":a2a:a2a-transport:a2a-transport-client-jsonrpc-http")) + implementation(project(":a2a:a2a-transport:a2a-transport-server-jsonrpc-http")) + + implementation(libs.ktor.client.cio) + implementation(libs.ktor.client.logging) + implementation(libs.ktor.server.netty) + runtimeOnly(libs.slf4j.simple) } } diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt index d460c16401..2cd19c9bd0 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt @@ -4,6 +4,7 @@ import ai.koog.a2a.exceptions.A2AAuthenticatedExtendedCardNotConfiguredException import ai.koog.a2a.exceptions.A2AInternalErrorException import ai.koog.a2a.exceptions.A2AInvalidParamsException import ai.koog.a2a.exceptions.A2APushNotificationNotSupportedException +import ai.koog.a2a.exceptions.A2ATaskNotCancelableException import ai.koog.a2a.exceptions.A2ATaskNotFoundException import ai.koog.a2a.exceptions.A2AUnsupportedOperationException import ai.koog.a2a.model.AgentCard @@ -18,20 +19,19 @@ import ai.koog.a2a.model.TaskPushNotificationConfig import ai.koog.a2a.model.TaskPushNotificationConfigParams import ai.koog.a2a.model.TaskQueryParams import ai.koog.a2a.model.TaskState -import ai.koog.a2a.model.TaskStatus -import ai.koog.a2a.model.TaskStatusUpdateEvent import ai.koog.a2a.server.agent.AgentExecutor import ai.koog.a2a.server.messages.ContextMessageStorage import ai.koog.a2a.server.messages.InMemoryMessageStorage import ai.koog.a2a.server.messages.MessageStorage import ai.koog.a2a.server.notifications.PushNotificationConfigStorage import ai.koog.a2a.server.notifications.PushNotificationSender +import ai.koog.a2a.server.session.AgentSession import ai.koog.a2a.server.session.IdGenerator import ai.koog.a2a.server.session.RequestContext -import ai.koog.a2a.server.session.Session import ai.koog.a2a.server.session.SessionEventProcessor import ai.koog.a2a.server.session.SessionManager import ai.koog.a2a.server.session.UuidIdGenerator +import ai.koog.a2a.server.session.withTaskLock import ai.koog.a2a.server.tasks.ContextTaskStorage import ai.koog.a2a.server.tasks.InMemoryTaskStorage import ai.koog.a2a.server.tasks.TaskStorage @@ -261,7 +261,7 @@ import kotlinx.datetime.Clock * path = "/a2a", * wait = true, * agentCard = agentCard, - * agentCardPath = "/.well-known/a2a/agent-card.json" + * agentCardPath = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH * ) * ``` * @@ -299,13 +299,14 @@ import kotlinx.datetime.Clock * @param agentExecutor The executor containing the core agent logic * @param agentCard The agent card describing this agent's capabilities and metadata * @param agentCardExtended Optional extended agent card for authenticated requests - * @param taskStorage Storage implementation for persisting tasks (defaults to in-memory) - * @param messageStorage Storage implementation for persisting messages (defaults to in-memory) - * @param pushConfigStorage Optional storage for push notification configurations - * @param pushSender Optional push notification sender implementation - * @param idGenerator Generator for new task and context IDs (defaults to UUID) + * @param taskStorage Storage implementation for persisting tasks (defaults to [InMemoryTaskStorage]) + * @param messageStorage Storage implementation for persisting messages (defaults to [InMemoryMessageStorage]) + * @param pushConfigStorage Optional storage for push notification configurations (defaults to `null`) + * @param pushSender Optional push notification sender implementation (defaults to `null`) + * @param idGenerator Generator for new task and context IDs (defaults to [UuidIdGenerator]) * @param coroutineScope Scope for managing all sessions, agent jobs, event processing, etc. * @param clock Clock instance for timestamp generation (defaults to [Clock.System]) + * @param sessionManager Manager for managing agent sessions (defaults to [SessionManager]) * * @see AgentExecutor for implementing agent business logic * @see TaskStorage for persisting tasks @@ -326,7 +327,7 @@ public open class A2AServer( protected val coroutineScope: CoroutineScope = CoroutineScope(SupervisorJob()), protected val clock: Clock = Clock.System, ) : RequestHandler { - protected val sessionManager: SessionManager = SessionManager( + protected open val sessionManager: SessionManager = SessionManager( coroutineScope = coroutineScope, taskStorage = taskStorage, pushConfigStorage = pushConfigStorage, @@ -355,54 +356,60 @@ public open class A2AServer( * * @return A stream of events from the agent */ - protected fun onSendMessageCommon( + protected open fun onSendMessageCommon( request: Request, ctx: ServerCallContext ): Flow> = channelFlow { val message = request.data.message + val contextId = message.contextId ?: idGenerator.generateContextId(message) + val taskId = message.taskId ?: idGenerator.generateTaskId(message) + + val session = sessionManager.withTaskLock(taskId) { + // Check if message links to a task. + val task: Task? = message.taskId?.let { taskId -> + // Check if the specified task exists and message context id matches the task context id. + val task = taskStorage.get(taskId, historyLength = 0, includeArtifacts = false) + ?: throw A2ATaskNotFoundException("Task '$taskId' not found") + + if (message.contextId != task.contextId) { + throw A2AInvalidParamsException("Message context id '${message.contextId}' doesn't match task context id '${task.contextId}'") + } - // Check if message links to a task. - val task: Task? = message.taskId?.let { taskId -> - // Check if the task is still in progress, no message can be sent. - if (sessionManager.sessionForTask(taskId) != null) { - throw A2AUnsupportedOperationException("Task '$taskId' is still running, can't send messages to the task that has not yielded control") - } - - // Check if the specified task exists and message context id matches the task context id. - val task = taskStorage.get(taskId, historyLength = 0, includeArtifacts = false) - ?: throw A2ATaskNotFoundException("Task '$taskId' not found") - - if (message.contextId != task.contextId) { - throw A2AInvalidParamsException("Message context id '${message.contextId}' doesn't match task context id '${task.contextId}'") + task } - task - } - - // Create event processor for the session based on the input data. - val eventProcessor = SessionEventProcessor( - contextId = task?.contextId - ?: message.contextId - ?: idGenerator.generateContextId(message), - taskId = task?.id ?: idGenerator.generateTaskId(message), - taskStorage = taskStorage, - task = null, - ) + // Create event processor for the session based on the input data. + val eventProcessor = SessionEventProcessor( + contextId = contextId, + taskId = taskId, + taskStorage = taskStorage, + task = task, + ) - // Create request context based on the request information. - val requestContext = RequestContext( - callContext = ctx, - params = request.data, - taskStorage = ContextTaskStorage(eventProcessor.contextId, taskStorage), - messageStorage = ContextMessageStorage(eventProcessor.contextId, messageStorage), - contextId = eventProcessor.contextId, - taskId = eventProcessor.taskId, - task = task, - ) + // Create request context based on the request information. + val requestContext = RequestContext( + callContext = ctx, + params = request.data, + taskStorage = ContextTaskStorage(eventProcessor.contextId, taskStorage), + messageStorage = ContextMessageStorage(eventProcessor.contextId, messageStorage), + contextId = eventProcessor.contextId, + taskId = eventProcessor.taskId, + task = task, + ) - // Create agent execution session - val session = Session(coroutineScope, eventProcessor) { - agentExecutor.execute(requestContext, eventProcessor) + // Create agent execution session + AgentSession(coroutineScope, eventProcessor) { + agentExecutor.execute(requestContext, eventProcessor) + }.also { + try { + // Add to session manager, it will handle monitoring and closing once the session is completed (successfully or not). + sessionManager.addSession(it) + } catch (_: IllegalArgumentException) { + throw A2AUnsupportedOperationException( + "Task '${request.data.message.taskId}' is already running, can't send messages to the task that hasn't yielded control." + ) + } + } } // Subscribe to events stream and start emitting them. @@ -413,9 +420,6 @@ public open class A2AServer( } } - // Add to session manager, it will handle monitoring and closing once the session is completed (successfully or not). - sessionManager.addSession(session) - // Start the session to execute the agent and wait for it to finish. session.join() } @@ -482,57 +486,57 @@ public open class A2AServer( ctx: ServerCallContext ): Response { val taskParams = request.data + val taskId = taskParams.id - val session = sessionManager.sessionForTask(taskParams.id) - val task = taskStorage.get(taskParams.id, historyLength = 0, includeArtifacts = true) - ?: throw A2ATaskNotFoundException("Task '${taskParams.id}' not found") + sessionManager.withTaskLock(taskId) { + val session = sessionManager.getSession(taskParams.id) - // Task is not running, check if it exists in the storage. - if (session == null) { - // Task exists but not running - check if it is already canceled. - if (task.status.state == TaskState.Canceled) { - return Response(data = task, id = request.id) - } + val task = taskStorage.get(taskParams.id, historyLength = 0, includeArtifacts = true) + ?: throw A2ATaskNotFoundException("Task '${taskParams.id}' not found") + + // Task is not running, check if it exists in the storage. + if (session == null) { + // Task exists but not running - check if it is already canceled. + if (task.status.state == TaskState.Canceled) { + return Response(data = task, id = request.id) + } - // If the task is not canceled and in the terminal state, throw. - if (task.status.state.terminal) { - throw A2AUnsupportedOperationException("Task '${taskParams.id}' is already in terminal state ${task.status.state}") + // If the task is not canceled and in the terminal state, throw. + if (task.status.state.terminal) { + throw A2ATaskNotCancelableException("Task '${taskParams.id}' is already in terminal state ${task.status.state}") + } } - // Proceed to mark the task as canceled. - taskStorage.update( - TaskStatusUpdateEvent( - taskId = task.id, - contextId = task.contextId, - status = TaskStatus( - state = TaskState.Canceled, - timestamp = clock.now() - ), - final = true - ) + val eventProcessor = session?.eventProcessor ?: SessionEventProcessor( + contextId = task.contextId, + taskId = task.id, + taskStorage = taskStorage, + task = task, ) - } else { + // Create request context based on the request information. val requestContext = RequestContext( callContext = ctx, params = request.data, - taskStorage = ContextTaskStorage(session.contextId, taskStorage), - messageStorage = ContextMessageStorage(session.contextId, messageStorage), - contextId = session.contextId, - taskId = session.taskId, + taskStorage = ContextTaskStorage(eventProcessor.contextId, taskStorage), + messageStorage = ContextMessageStorage(eventProcessor.contextId, messageStorage), + contextId = eventProcessor.contextId, + taskId = eventProcessor.taskId, task = task, ) // Attempt to cancel the agent execution and wait until it's finished. - agentExecutor.cancel(requestContext, session) - - // If cancel finished without exception, assume the cancellation was successful and close the session explicitly. - session.close() + agentExecutor.cancel(requestContext, eventProcessor, session?.agentJob) } // Return the final task state. return Response( data = taskStorage.get(taskParams.id, historyLength = 0, includeArtifacts = true) + ?.also { + if (it.status.state != TaskState.Canceled) { + throw A2ATaskNotCancelableException("Task '${taskParams.id}' was not canceled successfully, current state is ${it.status.state}") + } + } ?: throw A2ATaskNotFoundException("Task '${taskParams.id}' not found"), id = request.id, ) @@ -545,7 +549,7 @@ public open class A2AServer( checkStreamingSupport() val taskParams = request.data - val session = sessionManager.sessionForTask(taskParams.id) + val session = sessionManager.getSession(taskParams.id) ?: throw A2AUnsupportedOperationException("Session for task '${taskParams.id}' is not currently running or task does not exist") emitAll( @@ -612,13 +616,13 @@ public open class A2AServer( return Response(data = null, id = request.id) } - protected fun checkStreamingSupport() { + protected open fun checkStreamingSupport() { if (agentCard.capabilities.streaming != true) { throw A2AUnsupportedOperationException("Streaming is not supported by the server") } } - protected fun storageIfPushNotificationSupported(): PushNotificationConfigStorage { + protected open fun storageIfPushNotificationSupported(): PushNotificationConfigStorage { if (agentCard.capabilities.pushNotifications != true) { throw A2APushNotificationNotSupportedException("Push notifications are not supported by the server") } @@ -635,7 +639,7 @@ public open class A2AServer( * * @param cause Optional cause of the cancellation */ - public fun cancel(cause: CancellationException? = null) { + public open fun cancel(cause: CancellationException? = null) { coroutineScope.cancel(cause) } } diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt index 0d52f92139..e163178c75 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt @@ -9,8 +9,8 @@ import ai.koog.a2a.model.TaskEvent import ai.koog.a2a.model.TaskIdParams import ai.koog.a2a.model.TaskState import ai.koog.a2a.server.session.RequestContext -import ai.koog.a2a.server.session.Session import ai.koog.a2a.server.session.SessionEventProcessor +import kotlinx.coroutines.Job /** * Implementations of this interface contain the core logic of the agent, @@ -22,7 +22,7 @@ public interface AgentExecutor { * * The agent should read necessary information from the [context] and publish [TaskEvent] or [Message] events to * the [eventProcessor]. This method should return once the agent's execution for this request is complete or - * yields control (e.g., enters an [TaskState.InputRequired] state). + * yields control (e.g., enters a [TaskState.InputRequired] state). * * All events must have context id from [RequestContext.contextId] and for task events task id from [RequestContext.taskId]. * @@ -41,7 +41,7 @@ public interface AgentExecutor { * status = TaskStatus( * state = TaskState.Working, * // Mark this message as belonging to the created task - * message = message.copy(taskId = context.taskId) + * message = userMessage.copy(taskId = context.taskId) * timestamp = Clock.System.now() * ), * ) @@ -81,37 +81,26 @@ public interface AgentExecutor { public suspend fun execute(context: RequestContext, eventProcessor: SessionEventProcessor) /** - * Request to cancel an ongoing task in the running [session]. + * Request to cancel a task. * - * The executor should attempt to stop the task identified by the task id in the [context] or throw an exception if - * cancellation is not supported or not possible, e.g. [A2ATaskNotCancelableException]. + * Must throw an exception if the cancellation fails or is impossible. The executor should attempt to stop the task + * identified by the task id in the [context] or throw an exception if cancellation is not supported or not possible, + * e.g., [A2ATaskNotCancelableException]. * - * If this method finishes normally, it will be considered successful cancellation and the [session] will be explicitly closed. - * This means the agent execution job (the code running in the [execute]) will be canceled, and - * [SessionEventProcessor] associated with this session will be closed. + * Can also publish [TaskEvent]s to the [eventProcessor] to update the task state. Must ensure that the final + * task state will be [TaskState.Canceled], otherwise the task will not be considered canceled, and the requester will + * get [A2ATaskNotCancelableException]. * - * Implementations can call [Session.close] explicitly themselves if they want to stop the agent execution first and - * then perform some cleanup afterwards, e.g., closing connection to external resources. + * **IMPORTANT**: This should execute quickly as it runs synchronously with the request. * - * Must throw an exception if the cancellation fails or is impossible. + * Default implementation throws [A2ATaskNotCancelableException], meaning cancellation is not supported by default. * - * Default implementation does nothing, meaning cancellations will always be successful and the [session] will be closed - * immediately. - * - * Example simple implementation: - * ```kotlin - * // Explicitly close the session to stop the agent execution job and event processor - * session.close() - * // Log the fact that the task was canceled - * log.info("Task '${context.taskId}' canceled") - * ``` - * - * Example more advanced implementation: + * Example implementation: * ```kotlin - * // Cancel only the agent execution job to terminate the agent run, but keep event processor running. - * session.agentJob.cancel() + * // Cancel agent execution job, if the agent is currently running, to terminate it. + * agentJob?.cancel() * // Send task cancellation event with custom message to event processor - * session.eventProcessor.sendTaskEvent( + * eventProcessor.sendTaskEvent( * TaskStatusUpdateEvent( * taskId = context.taskId, * contextId = context.contextId, @@ -129,13 +118,20 @@ public interface AgentExecutor { * final = true, * ) * ) - * // Close the session completely - * session.close() * ``` * + * @param context The context containing the necessary information and accessors for executing the agent. + * @param eventProcessor The event processor to publish events to. + * @param agentJob Optional [Job] executing the agent logic, if the agent is currently running. * @throws Exception if something goes wrong during execution or the cancellation is impossible. Should prefer more - * specific exceptions when available, e.g., [A2ATaskNotCancelableException], [A2AUnsupportedOperationException], etc. + * specific exceptions if possible, e.g., [A2ATaskNotCancelableException], [A2AUnsupportedOperationException], etc. * See full list of available A2A exceptions in [ai.koog.a2a.exceptions]. */ - public suspend fun cancel(context: RequestContext, session: Session) {} + public suspend fun cancel( + context: RequestContext, + eventProcessor: SessionEventProcessor, + agentJob: Job?, + ) { + throw A2ATaskNotCancelableException("Cancellation is not supported") + } } diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt index fedb1fa0dc..f8ed81dcfc 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt @@ -4,6 +4,7 @@ import ai.koog.a2a.model.Event import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineStart import kotlinx.coroutines.Job +import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.collect import kotlinx.coroutines.launch @@ -11,11 +12,11 @@ import kotlinx.coroutines.launch /** * Represents an active agent execution session with lifecycle management. * - * @property eventProcessor Handles session events and provides event streaming - * @property agentJob The coroutine job executing the agent logic - * @property contextId Unique context ID associated with this session, delegates to [SessionEventProcessor.contextId] - * @property taskId Unique task ID associated with this session, delegates to [SessionEventProcessor.contextId] - * @property events A stream of events generated during this session, delegates to [SessionEventProcessor.events] + * @property eventProcessor The session event processor + * @property agentJob The job executing the agent logic + * @property contextId Unique context ID associated with this session + * @property taskId Unique task ID associated with this session + * @property events A stream of events generated during this session */ public class Session( public val eventProcessor: SessionEventProcessor, @@ -26,7 +27,7 @@ public class Session( public val events: Flow get() = eventProcessor.events /** - * Starts the agent execution job. + * Starts the [agentJob], if it hasn't already been started. */ public fun start() { agentJob.start() @@ -41,27 +42,34 @@ public class Session( } /** - * Cancels the agent job and closes the event processor + * Cancels the agent job, waiting for it to complete, and then closes event processor. */ - public suspend fun close() { - agentJob.cancel() + public suspend fun cancel() { + agentJob.cancelAndJoin() eventProcessor.close() } } /** - * Creates a new [Session] with lazy-started agent execution. + * Factory function that creates a new [Session] with lazy-started [agentAction]. * * @param coroutineScope The scope for launching the agent coroutine * @param eventProcessor The session event processor * @param agentAction The agent logic to execute * @return A new session instance */ -public fun Session( +@Suppress("ktlint:standard:function-naming", "FunctionName") +public fun AgentSession( coroutineScope: CoroutineScope, eventProcessor: SessionEventProcessor, agentAction: suspend CoroutineScope.() -> Unit ): Session { - val agentJob = coroutineScope.launch(start = CoroutineStart.LAZY, block = agentAction) - return Session(eventProcessor, agentJob) + val agentJob = coroutineScope.launch(start = CoroutineStart.LAZY) { + agentAction() + } + + return Session( + eventProcessor = eventProcessor, + agentJob = agentJob + ) } diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt index 6e8087391f..64750a4361 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt @@ -9,6 +9,11 @@ import ai.koog.a2a.utils.RWLock import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.flow.firstOrNull import kotlinx.coroutines.launch +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract /** * Manages a set of active instances of [Session], sends push notifications if configured after each session completes. @@ -17,6 +22,8 @@ import kotlinx.coroutines.launch * Additionally, if push notifications are configured, after each task session completes, push notifications are sent with * the current task state. * + * Provides the ability to lock a task id. + * * @param coroutineScope The scope in which the monitoring jobs will be launched. * @param taskStorage The storage for tasks. * @param pushConfigStorage The storage for push notification configurations. @@ -29,22 +36,26 @@ public class SessionManager( private val pushConfigStorage: PushNotificationConfigStorage? = null, private val pushSender: PushNotificationSender? = null, ) { + /** * Map of task id to session. All sessions have task id associated with them, even if the task won't be created. */ private val sessions = mutableMapOf() - private val rwLock = RWLock() + private val sessionsRwLock = RWLock() + + private val taskMutexes = mutableMapOf() + private val taskMutexesLock = Mutex() /** * Adds a session to a set of active sessions. - * If the first event in the session events stream is of type [TaskEvent], the session is added to the task sessions map too. - * * Handles cleanup by closing and removing the session when it is completed (whether successfully or not). + * Sends push notifications if configured after each session completes. * * @param session The session to add. + * @throws IllegalArgumentException if a session for the same task id already exists. */ public suspend fun addSession(session: Session) { - rwLock.withWriteLock { + sessionsRwLock.withWriteLock { check(session.taskId !in sessions) { "SessionEventProcessor for taskId '${session.taskId}' already exists." } @@ -56,9 +67,20 @@ public class SessionManager( coroutineScope.launch { val firstEvent = session.events.firstOrNull() - // Wait for agent job to complete + // Wait for the agent job to finish session.agentJob.join() + /* + Check and wait if the task lock is free (e.g., there's a cancellation request for this task running now and still publishing some events). + Then remove it from the sessions map. + */ + withTaskLock(session.taskId) { + sessionsRwLock.withWriteLock { + session.cancel() + sessions -= session.taskId + } + } + // Send push notifications with the current state of the task, after the session completion, if configured. if (firstEvent is TaskEvent && pushSender != null && pushConfigStorage != null) { val task = taskStorage.get(session.taskId, historyLength = 0, includeArtifacts = false) @@ -69,26 +91,90 @@ public class SessionManager( } } } - - // Close the session completely and remove it from the sessions map. - rwLock.withWriteLock { - sessions -= session.taskId - session.close() - } } } /** - * Returns the session for the given task id, if any. + * Returns the session for the given task id, if it exists. */ - public suspend fun sessionForTask(taskId: String): Session? = rwLock.withReadLock { + public suspend fun getSession(taskId: String): Session? = sessionsRwLock.withReadLock { sessions[taskId] } /** * Returns the number of active sessions. */ - public suspend fun activeSessions(): Int = rwLock.withReadLock { + public suspend fun activeSessions(): Int = sessionsRwLock.withReadLock { sessions.size } + + /** + * Acquires a lock for the specified task ID. + * Useful for maintaining concurrency safety in task-related operations. + * + * @param taskId The unique identifier of the task to be locked. + */ + public suspend fun taskLock(taskId: String) { + val mutex = taskMutexesLock.withLock { + taskMutexes.getOrPut(taskId) { Mutex() } + } + mutex.lock() + } + + /** + * Releases the lock for the specified task ID. + * Useful for maintaining concurrency safety in task-related operations. + * + * @param taskId The unique identifier of the task to be unlocked. + * @throws IllegalStateException if the lock for the task cannot be released. + */ + public suspend fun taskUnlock(taskId: String) { + val mutex = taskMutexesLock.withLock { + taskMutexes[taskId] + } ?: throw IllegalStateException("Task '$taskId' was never locked") + + if (!mutex.isLocked) { + throw IllegalStateException("Task '$taskId' is not currently locked") + } + + mutex.unlock() + + // Clean up unused mutexes + taskMutexesLock.withLock { + if (!mutex.isLocked && taskMutexes[taskId] === mutex) { + taskMutexes.remove(taskId) + } + } + } + + /** + * Returns true if the task ID is locked, false otherwise. + */ + public suspend fun isTaskLocked(taskId: String): Boolean { + return taskMutexesLock.withLock { + taskMutexes[taskId]?.isLocked == true + } + } +} + +/** + * Executes the given block of code while holding a lock for the specified task ID. + * Useful for maintaining concurrency safety in task-related operations. + * + * @param taskId The ID of the task to be locked. + * @param action The block of code to be executed. + * @return The result of [action] + */ +@OptIn(ExperimentalContracts::class) +public suspend inline fun SessionManager.withTaskLock(taskId: String, action: suspend () -> T): T { + contract { + callsInPlace(action, InvocationKind.EXACTLY_ONCE) + } + + taskLock(taskId) + return try { + action() + } finally { + taskUnlock(taskId) + } } diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorage.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorage.kt index 52a4ced6cb..cfa9a8d8dc 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorage.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorage.kt @@ -95,9 +95,12 @@ public class InMemoryTaskStorage : TaskStorage { val updatedTask = existingTask.copy( status = event.status, + history = existingTask.status.message + ?.let { existingTask.history.orEmpty() + it } + ?: existingTask.history, metadata = existingTask.metadata ?.let { JsonObject(it + event.metadata.orEmpty()) } - ?: event.metadata + ?: event.metadata, ) tasks[event.taskId] = updatedTask diff --git a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionManagerTest.kt b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionManagerTest.kt index 81b2c519ed..4f04ca3f40 100644 --- a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionManagerTest.kt +++ b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionManagerTest.kt @@ -13,14 +13,25 @@ import ai.koog.a2a.server.notifications.PushNotificationSender import ai.koog.a2a.server.tasks.InMemoryTaskStorage import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.delay +import kotlinx.coroutines.joinAll +import kotlinx.coroutines.launch import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.yield import kotlinx.datetime.Instant import kotlin.test.BeforeTest import kotlin.test.Test import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse import kotlin.test.assertNull +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.seconds class SessionManagerTest { + private companion object Companion { + private val TEST_TIMEOUT = 5.seconds + } + private lateinit var taskStorage: InMemoryTaskStorage private lateinit var pushConfigStorage: InMemoryPushNotificationConfigStorage private lateinit var pushSender: MockPushNotificationSender @@ -63,7 +74,7 @@ class SessionManagerTest { contextId = contextId, status = TaskStatus( state = state, - timestamp = Instant.parse("2023-01-01T10:00:00Z") + timestamp = Instant.Companion.parse("2023-01-01T10:00:00Z") ) ) @@ -88,24 +99,24 @@ class SessionManagerTest { ) @Test - fun testSessionManagerCreation() = runTest { + fun testSessionManagerCreation() = runTest(timeout = TEST_TIMEOUT) { val sessionManager = SessionManager( coroutineScope = this, taskStorage = taskStorage ) assertEquals(0, sessionManager.activeSessions()) - assertNull(sessionManager.sessionForTask("any-task-id")) + assertNull(sessionManager.getSession("any-task-id")) } @Test - fun testAddMessageSession() = runTest { + fun testAddMessageSession() = runTest(timeout = TEST_TIMEOUT) { val sessionManager = createManager(this) val eventProcessor = createProcessor(contextId, taskId) val message = createMessage("msg-1", contextId, "Hello") - val session = Session( + val session = AgentSession( coroutineScope = this, eventProcessor = eventProcessor ) { @@ -116,16 +127,19 @@ class SessionManagerTest { sessionManager.addSession(session) session.join() + // Let the session manager process it + yield() + // Session should be automatically cleaned up after completion assertEquals(0, sessionManager.activeSessions()) } @Test - fun testAddTaskSession() = runTest { + fun testAddTaskSession() = runTest(timeout = TEST_TIMEOUT) { val sessionManager = createManager(this) val eventProcessor = createProcessor(contextId, taskId) - val session = Session( + val session = AgentSession( coroutineScope = this, eventProcessor = eventProcessor ) { @@ -147,23 +161,29 @@ class SessionManagerTest { sessionManager.addSession(session) session.start() - assertEquals(session, sessionManager.sessionForTask(taskId)) + // Let the session manager process it + yield() + + assertEquals(session, sessionManager.getSession(taskId)) session.join() + // Let the session manager process it + yield() + // Session should be automatically cleaned up after completion assertEquals(0, sessionManager.activeSessions()) } @Test - fun testMultipleSessions() = runTest { + fun testMultipleSessions() = runTest(timeout = TEST_TIMEOUT) { val sessionManager = createManager(this) // Create two task sessions val eventProcessor1 = createProcessor("context-1", "task-1") val eventProcessor2 = createProcessor("context-2", "task-2") - val session1 = Session( + val session1 = AgentSession( coroutineScope = this, eventProcessor = eventProcessor1 ) { @@ -182,7 +202,7 @@ class SessionManagerTest { eventProcessor1.sendTaskEvent(statusUpdate) } - val session2 = Session( + val session2 = AgentSession( coroutineScope = this, eventProcessor = eventProcessor2 ) { @@ -206,18 +226,24 @@ class SessionManagerTest { session1.start() session2.start() - assertEquals(session1, sessionManager.sessionForTask("task-1")) - assertEquals(session2, sessionManager.sessionForTask("task-2")) + // Let the session manager process it + yield() + + assertEquals(session1, sessionManager.getSession("task-1")) + assertEquals(session2, sessionManager.getSession("task-2")) session1.join() session2.join() + // Let the session manager process it + yield() + // All sessions should be automatically cleaned up assertEquals(0, sessionManager.activeSessions()) } @Test - fun testSessionWithPushNotifications() = runTest { + fun testSessionWithPushNotifications() = runTest(timeout = TEST_TIMEOUT) { val sessionManager = createManager(this) val eventProcessor = createProcessor(contextId, taskId) @@ -230,7 +256,7 @@ class SessionManagerTest { val task = createTask("task-1", contextId) - val session = Session( + val session = AgentSession( coroutineScope = this, eventProcessor = eventProcessor ) { @@ -248,10 +274,138 @@ class SessionManagerTest { sessionManager.addSession(session) session.join() + // Let the session manager process it + yield() + // Verify push notification was sent assertEquals(1, pushSender.sentNotifications.size) val (sentConfig, sentTask) = pushSender.sentNotifications[0] assertEquals(config, sentConfig) assertEquals(TaskState.Completed, sentTask.status.state) } + + @Test + fun testTaskLockMultipleTasks() = runTest { + val sessionManager = createManager(this) + + val taskId1 = "test-task-1" + val taskId2 = "test-task-2" + + // Lock both tasks + sessionManager.taskLock(taskId1) + sessionManager.taskLock(taskId2) + + assertTrue(sessionManager.isTaskLocked(taskId1)) + assertTrue(sessionManager.isTaskLocked(taskId2)) + + // Unlock first task + sessionManager.taskUnlock(taskId1) + assertFalse(sessionManager.isTaskLocked(taskId1)) + assertTrue(sessionManager.isTaskLocked(taskId2)) + + // Unlock second task + sessionManager.taskUnlock(taskId2) + assertFalse(sessionManager.isTaskLocked(taskId2)) + } + + @Test + fun testConcurrentTaskLocking() = runTest { + val sessionManager = createManager(this) + val taskId = "concurrent-task" + val results = mutableListOf() + + // First coroutine locks the task + val job1 = launch { + sessionManager.taskLock(taskId) + results.add("job1-locked") + delay(100) // Hold the lock for some time + results.add("job1-working") + sessionManager.taskUnlock(taskId) + results.add("job1-unlocked") + } + + // Second coroutine tries to lock the same task + val job2 = launch { + delay(50) // Start after job1 has locked + results.add("job2-attempting-lock") + sessionManager.taskLock(taskId) // Should wait for job1 to unlock + results.add("job2-locked") + sessionManager.taskUnlock(taskId) + results.add("job2-unlocked") + } + + joinAll(job1, job2) + + // Verify the order of execution + assertEquals( + listOf( + "job1-locked", + "job2-attempting-lock", + "job1-working", + "job1-unlocked", + "job2-locked", + "job2-unlocked" + ), + results + ) + } + + @Test + fun testUnlockNeverLockedTaskThrowsException() = runTest { + val sessionManager = createManager(this) + val taskId = "never-locked-task" + + val exception = assertFailsWith { + sessionManager.taskUnlock(taskId) + } + + assertEquals("Task '$taskId' was never locked", exception.message) + } + + @Test + fun testUnlockAlreadyUnlockedTaskThrowsException() = runTest { + val sessionManager = createManager(this) + val taskId = "already-unlocked-task" + + // Lock and unlock the task + sessionManager.taskLock(taskId) + sessionManager.taskUnlock(taskId) + + // Try to unlock again + val exception = assertFailsWith { + sessionManager.taskUnlock(taskId) + } + + assertEquals("Task '$taskId' was never locked", exception.message) + } + + @Test + fun testSameLockMultipleTimes() = runTest { + val sessionManager = createManager(this) + val taskId = "same-lock-task" + + // First lock + sessionManager.taskLock(taskId) + assertTrue(sessionManager.isTaskLocked(taskId)) + + // Trying to lock the same task again should suspend indefinitely + // We'll test this with a timeout + val job = launch { + sessionManager.taskLock(taskId) // This should suspend + } + + delay(100) // Give some time for the second lock attempt + assertTrue(job.isActive) // Job should still be waiting + + // Unlock the first lock + sessionManager.taskUnlock(taskId) + + // Now the second lock should proceed + job.join() + assertTrue(sessionManager.isTaskLocked(taskId)) + + // Unlock the second lock + sessionManager.taskUnlock(taskId) + assertFalse(sessionManager.isTaskLocked(taskId)) + } } diff --git a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorageTest.kt b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorageTest.kt index 93279e406e..258ca53eda 100644 --- a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorageTest.kt +++ b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/tasks/InMemoryTaskStorageTest.kt @@ -244,6 +244,84 @@ class InMemoryTaskStorageTest { } } + @Test + fun testTaskStatusHistoryPreservation() = runTest { + // Initial task with no message in status + val initialTask = createTask( + id = "task-1", + contextId = "context-1" + ) + storage.update(initialTask) + + // Verify initial task - should have the initial status message but no history yet + val initialTaskFromStorage = storage.get("task-1", historyLength = null) + assertNotNull(initialTaskFromStorage) + assertNull(initialTaskFromStorage.history) + + // Update status with a new message - history is still empty + val firstUpdateMessage = createUserMessage("update-msg-1", "context-1", "Making progress") + val firstUpdateStatus = TaskStatus( + state = TaskState.Working, + message = firstUpdateMessage, + timestamp = Instant.parse("2023-01-01T11:00:00Z") + ) + val firstUpdate = TaskStatusUpdateEvent( + taskId = "task-1", + contextId = "context-1", + status = firstUpdateStatus, + final = false + ) + + storage.update(firstUpdate) + + val afterFirstUpdate = storage.get("task-1", historyLength = null) // Get full history + assertNotNull(afterFirstUpdate) + assertEquals(firstUpdateStatus, afterFirstUpdate.status) + assertNull(afterFirstUpdate.history) + + // Second status update - this should add the previous status message to history + val secondUpdateMessage = createUserMessage("update-msg-2", "context-1", "Almost done") + val secondUpdateStatus = TaskStatus( + state = TaskState.Working, + message = secondUpdateMessage, + timestamp = Instant.parse("2023-01-01T12:00:00Z") + ) + val secondUpdate = TaskStatusUpdateEvent( + taskId = "task-1", + contextId = "context-1", + status = secondUpdateStatus, + final = false + ) + storage.update(secondUpdate) + + // Verify second update + val afterSecondUpdate = storage.get("task-1", historyLength = null) + assertNotNull(afterSecondUpdate) + assertEquals(secondUpdateStatus, afterSecondUpdate.status) + assertEquals(listOf(firstUpdateMessage), afterSecondUpdate.history) + + // Final status update + val completionMessage = createUserMessage("completion-msg", "context-1", "Task completed successfully") + val completionStatus = TaskStatus( + state = TaskState.Completed, + message = completionMessage, + timestamp = Instant.parse("2023-01-01T13:00:00Z") + ) + val completionUpdate = TaskStatusUpdateEvent( + taskId = "task-1", + contextId = "context-1", + status = completionStatus, + final = true + ) + storage.update(completionUpdate) + + // Verify final update + val finalTask = storage.get("task-1", historyLength = null) + assertNotNull(finalTask) + assertEquals(completionStatus, finalTask.status) + assertEquals(listOf(firstUpdateMessage, secondUpdateMessage), finalTask.history) + } + private fun createUserMessage( messageId: String, contextId: String, @@ -258,16 +336,14 @@ class InMemoryTaskStorageTest { private fun createTask( id: String, contextId: String, + status: TaskStatus = TaskStatus(state = TaskState.Submitted), history: List? = null, artifacts: List? = null, metadata: JsonObject? = null ) = Task( id = id, contextId = contextId, - status = TaskStatus( - state = TaskState.Submitted, - timestamp = Instant.parse("2023-01-01T10:00:00Z") - ), + status = status, history = history, artifacts = artifacts, metadata = metadata diff --git a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/A2AServerJsonRpcIntegrationTest.kt b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/A2AServerJsonRpcIntegrationTest.kt new file mode 100644 index 0000000000..aaad141b91 --- /dev/null +++ b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/A2AServerJsonRpcIntegrationTest.kt @@ -0,0 +1,187 @@ +package ai.koog.a2a.server + +import ai.koog.a2a.client.A2AClient +import ai.koog.a2a.client.UrlAgentCardResolver +import ai.koog.a2a.consts.A2AConsts +import ai.koog.a2a.model.AgentCapabilities +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.AgentSkill +import ai.koog.a2a.model.TransportProtocol +import ai.koog.a2a.server.notifications.InMemoryPushNotificationConfigStorage +import ai.koog.a2a.test.BaseA2AProtocolTest +import ai.koog.a2a.transport.client.jsonrpc.http.HttpJSONRPCClientTransport +import ai.koog.a2a.transport.server.jsonrpc.http.HttpJSONRPCServerTransport +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.logging.LogLevel +import io.ktor.client.plugins.logging.Logging +import io.ktor.server.netty.Netty +import kotlinx.coroutines.runBlocking +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.TestInstance +import kotlin.test.BeforeTest + +/** + * Integration test class for testing the JSON-RPC HTTP communication in the A2A server context. + * This class ensures the proper functioning and correctness of the A2A protocol over HTTP + * using the JSON-RPC standard. + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +class A2AServerJsonRpcIntegrationTest : BaseA2AProtocolTest() { + + companion object { + private const val TEST_PORT = 9999 + private const val TEST_PATH = "/a2a" + private const val SERVER_URL = "http://localhost:$TEST_PORT$TEST_PATH" + } + + private lateinit var serverTransport: HttpJSONRPCServerTransport + private lateinit var clientTransport: HttpJSONRPCClientTransport + private lateinit var httpClient: HttpClient + + override lateinit var client: A2AClient + + @BeforeAll + fun setup(): Unit = runBlocking { + // Create agent cards + val agentCard = createAgentCard() + val agentCardExtended = createExtendedAgentCard() + + // Create test agent executor + val testAgentExecutor = TestAgentExecutor() + + // Create A2A server + val a2aServer = A2AServer( + agentExecutor = testAgentExecutor, + agentCard = agentCard, + agentCardExtended = agentCardExtended, + pushConfigStorage = InMemoryPushNotificationConfigStorage() + ) + + // Create server transport + serverTransport = HttpJSONRPCServerTransport(a2aServer) + + // Start server + serverTransport.start( + engineFactory = Netty, + port = TEST_PORT, + path = TEST_PATH, + wait = false, + agentCard = agentCard, + agentCardPath = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH, + ) + + // Create client transport + httpClient = HttpClient(CIO) { + install(Logging) { + level = LogLevel.ALL + } + } + + clientTransport = HttpJSONRPCClientTransport(SERVER_URL, httpClient) + + client = A2AClient( + transport = clientTransport, + agentCardResolver = UrlAgentCardResolver( + baseUrl = SERVER_URL, + path = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH + ) + ) + } + + @BeforeTest + fun initClient(): Unit = runBlocking { + client.connect() + } + + @AfterAll + fun tearDown(): Unit = runBlocking { + clientTransport.close() + serverTransport.stop() + } + + private fun createAgentCard(): AgentCard = AgentCard( + protocolVersion = "0.3.0", + name = "Hello World Agent", + description = "Just a hello world agent", + url = "http://localhost:9999/", + preferredTransport = TransportProtocol.JSONRPC, + additionalInterfaces = null, + iconUrl = null, + provider = null, + version = "1.0.0", + documentationUrl = null, + capabilities = AgentCapabilities( + streaming = true, + pushNotifications = true, + stateTransitionHistory = null, + extensions = null + ), + securitySchemes = null, + security = null, + defaultInputModes = listOf("text"), + defaultOutputModes = listOf("text"), + skills = listOf( + AgentSkill( + id = "hello_world", + name = "Returns hello world", + description = "just returns hello world", + tags = listOf("hello world"), + examples = listOf("hi", "hello world"), + inputModes = null, + outputModes = null, + security = null + ) + ), + supportsAuthenticatedExtendedCard = true, + signatures = null + ) + + private fun createExtendedAgentCard(): AgentCard = AgentCard( + protocolVersion = "0.3.0", + name = "Hello World Agent - Extended Edition", + description = "The full-featured hello world agent for authenticated users.", + url = "http://localhost:9999/", + preferredTransport = TransportProtocol.JSONRPC, + additionalInterfaces = null, + iconUrl = null, + provider = null, + version = "1.0.1", + documentationUrl = null, + capabilities = AgentCapabilities( + streaming = true, + pushNotifications = true, + stateTransitionHistory = null, + extensions = null + ), + securitySchemes = null, + security = null, + defaultInputModes = listOf("text"), + defaultOutputModes = listOf("text"), + skills = listOf( + AgentSkill( + id = "hello_world", + name = "Returns hello world", + description = "just returns hello world", + tags = listOf("hello world"), + examples = listOf("hi", "hello world"), + inputModes = null, + outputModes = null, + security = null + ), + AgentSkill( + id = "super_hello_world", + name = "Returns a SUPER Hello World", + description = "A more enthusiastic greeting, only for authenticated users.", + tags = listOf("hello world", "super", "extended"), + examples = listOf("super hi", "give me a super hello"), + inputModes = null, + outputModes = null, + security = null + ) + ), + supportsAuthenticatedExtendedCard = true, + signatures = null + ) +} diff --git a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/TestAgentExecutor.kt b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/TestAgentExecutor.kt index 242b443783..97843c2748 100644 --- a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/TestAgentExecutor.kt +++ b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/TestAgentExecutor.kt @@ -11,8 +11,8 @@ import ai.koog.a2a.model.TaskStatusUpdateEvent import ai.koog.a2a.model.TextPart import ai.koog.a2a.server.agent.AgentExecutor import ai.koog.a2a.server.session.RequestContext -import ai.koog.a2a.server.session.Session import ai.koog.a2a.server.session.SessionEventProcessor +import kotlinx.coroutines.Job import kotlinx.coroutines.delay import kotlinx.datetime.Clock @@ -38,15 +38,10 @@ private suspend fun doTask( id = context.taskId, contextId = context.contextId, status = TaskStatus( - state = TaskState.Working, - message = Message( - role = Role.Agent, - parts = listOf(TextPart("Task created")), - contextId = context.contextId, - taskId = context.taskId - ), + state = TaskState.Submitted, timestamp = Clock.System.now() - ) + ), + history = listOf(context.params.message) ) // Send initial task event @@ -99,15 +94,10 @@ private suspend fun doCancelableTask( id = context.taskId, contextId = context.contextId, status = TaskStatus( - state = TaskState.Working, - message = Message( - role = Role.Agent, - parts = listOf(TextPart("Cancelable task created")), - contextId = context.contextId, - taskId = context.taskId - ), + state = TaskState.Submitted, timestamp = Clock.System.now() - ) + ), + history = listOf(context.params.message) ) eventProcessor.sendTaskEvent(task) @@ -121,15 +111,10 @@ private suspend fun doLongRunningTask( id = context.taskId, contextId = context.contextId, status = TaskStatus( - state = TaskState.Working, - message = Message( - role = Role.Agent, - parts = listOf(TextPart("Long running task started")), - contextId = context.contextId, - taskId = context.taskId - ), + state = TaskState.Submitted, timestamp = Clock.System.now() - ) + ), + history = listOf(context.params.message) ) eventProcessor.sendTaskEvent(task) @@ -166,20 +151,20 @@ class TestAgentExecutor : AgentExecutor { .lowercase() // Test scenarios to test various aspects of A2A - when { - "hello world" in userInput -> { + when (userInput) { + "hello world" -> { sayHello(context, eventProcessor) } - "do task" in userInput -> { + "do task" -> { doTask(context, eventProcessor) } - "do cancelable task" in userInput -> { + "do cancelable task" -> { doCancelableTask(context, eventProcessor) } - "do long-running task" in userInput -> { + "do long-running task" -> { doLongRunningTask(context, eventProcessor) } @@ -195,10 +180,14 @@ class TestAgentExecutor : AgentExecutor { } } - override suspend fun cancel(context: RequestContext, session: Session) { - session.agentJob.cancel() + override suspend fun cancel( + context: RequestContext, + eventProcessor: SessionEventProcessor, + agentJob: Job? + ) { + agentJob?.cancel() - session.eventProcessor.sendTaskEvent( + eventProcessor.sendTaskEvent( TaskStatusUpdateEvent( contextId = context.contextId, taskId = context.taskId, diff --git a/a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt b/a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt index b9c60da7c9..9e3883470d 100644 --- a/a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt +++ b/a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt @@ -48,7 +48,7 @@ abstract class BaseA2AProtocolTest { /** * The A2A client instance to test. Must be connected and ready to use. */ - protected abstract val client: A2AClient + protected abstract var client: A2AClient @Test fun `test get agent card`() = runTest { @@ -160,7 +160,6 @@ abstract class BaseA2AProtocolTest { role = Role.User, parts = listOf( TextPart("hello world"), - TextPart("How are you doing?"), ), contextId = "test-context" ), diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransport.kt b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransport.kt index 294de0d628..f009adb00a 100644 --- a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransport.kt +++ b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransport.kt @@ -15,6 +15,7 @@ import io.ktor.client.request.headers import io.ktor.client.request.post import io.ktor.client.request.setBody import io.ktor.http.ContentType +import io.ktor.http.HttpMethod import io.ktor.http.contentType import io.ktor.serialization.kotlinx.json.json import kotlinx.coroutines.flow.Flow @@ -31,7 +32,7 @@ import kotlinx.coroutines.flow.flow */ public class HttpJSONRPCClientTransport( url: String, - baseHttpClient: HttpClient + baseHttpClient: HttpClient = HttpClient() ) : JSONRPCClientTransport() { private val httpClient: HttpClient = baseHttpClient.config { defaultRequest { @@ -71,6 +72,8 @@ public class HttpJSONRPCClientTransport( ): Flow = flow { httpClient.sse( request = { + method = HttpMethod.Post + headers { ctx.additionalHeaders.forEach { (key, values) -> appendAll(key, values) diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/A2AMethod.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/A2AMethod.kt index 26f9e3ef28..c7b02861b5 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/A2AMethod.kt +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/A2AMethod.kt @@ -3,13 +3,16 @@ package ai.koog.a2a.transport.jsonrpc /** * A2A JSON-RPC methods. */ -public enum class A2AMethod(public val value: String) { +public enum class A2AMethod( + public val value: String, + public val streaming: Boolean = false +) { GetAuthenticatedExtendedAgentCard("agent/getAuthenticatedExtendedCard"), SendMessage("message/send"), - SendMessageStreaming("message/stream"), + SendMessageStreaming("message/stream", streaming = true), GetTask("tasks/get"), CancelTask("tasks/cancel"), - ResubscribeTask("tasks/resubscribe"), + ResubscribeTask("tasks/resubscribe", streaming = true), SetTaskPushNotificationConfig("tasks/pushNotificationConfig/set"), GetTaskPushNotificationConfig("tasks/pushNotificationConfig/get"), ListTaskPushNotificationConfig("tasks/pushNotificationConfig/list"), diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt index 8eaac54109..c3b1ab6b83 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt @@ -30,6 +30,16 @@ import kotlinx.serialization.json.encodeToJsonElement * Handles receiving JSON-RPC requests, processing them, and sending responses. */ public abstract class JSONRPCServerTransport : ServerTransport { + /** + * Parses [A2AMethod] from the given [JSONRPCRequest]. + * + * @throws A2AMethodNotFoundException if method is not found. + */ + protected fun parseA2AMethod(request: JSONRPCRequest): A2AMethod { + return A2AMethod.entries.find { it.value == request.method } + ?: throw A2AMethodNotFoundException("Method not found: ${request.method}") + } + /** * Handles a JSON-RPC request and returns the corresponding response * Handles exceptions, mapping all non [A2AException]s to [A2AInternalErrorException], and then converting them to [JSONRPCErrorResponse]. @@ -68,7 +78,7 @@ public abstract class JSONRPCServerTransport : ServerTransport { .toJSONRPCSuccessResponse() else -> - throw A2AMethodNotFoundException("Method not found: ${request.method}") + throw A2AMethodNotFoundException("Non-streaming method not found: ${request.method}") } }.getOrElse { it.toJSONRPCErrorResponse(request.id) } } @@ -90,7 +100,7 @@ public abstract class JSONRPCServerTransport : ServerTransport { requestHandler.onResubscribeTask(request.toRequest(), ctx) else -> - flow { throw A2AMethodNotFoundException("Method not found: ${request.method}") } + flow { throw A2AMethodNotFoundException("Streaming method not found: ${request.method}") } }.map { it.toJSONRPCSuccessResponse() as JSONRPCResponse } .catch { emit(it.toJSONRPCErrorResponse(request.id)) } } diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt index 74fc961a43..9db317abf9 100644 --- a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt @@ -11,32 +11,34 @@ import ai.koog.a2a.transport.jsonrpc.JSONRPCServerTransport import ai.koog.a2a.transport.jsonrpc.model.JSONRPCJson import ai.koog.a2a.transport.jsonrpc.model.JSONRPCRequest import ai.koog.a2a.utils.runCatchingCancellable +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders import io.ktor.serialization.kotlinx.json.json import io.ktor.server.application.ApplicationCall import io.ktor.server.application.install -import io.ktor.server.application.pluginOrNull import io.ktor.server.engine.ApplicationEngine import io.ktor.server.engine.ApplicationEngineFactory import io.ktor.server.engine.EmbeddedServer import io.ktor.server.engine.embeddedServer import io.ktor.server.plugins.contentnegotiation.ContentNegotiation import io.ktor.server.request.receiveText +import io.ktor.server.response.header import io.ktor.server.response.respond import io.ktor.server.routing.Route -import io.ktor.server.routing.application +import io.ktor.server.routing.RoutingContext import io.ktor.server.routing.get import io.ktor.server.routing.post import io.ktor.server.routing.route import io.ktor.server.routing.routing import io.ktor.server.sse.SSE -import io.ktor.server.sse.send -import io.ktor.server.sse.sse +import io.ktor.server.sse.SSEServerContent +import io.ktor.server.sse.ServerSSESession +import io.ktor.sse.ServerSentEvent import io.ktor.util.toMap import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import kotlinx.serialization.SerializationException import kotlinx.serialization.json.decodeFromJsonElement -import kotlinx.serialization.serializer /** * Implements A2A JSON-RPC server transport over HTTP using Ktor server @@ -81,6 +83,7 @@ import kotlinx.serialization.serializer * * @property requestHandler The handler responsible for processing A2A requests received by the transport. */ +@OptIn(InternalA2AApi::class) public class HttpJSONRPCServerTransport( override val requestHandler: RequestHandler, ) : JSONRPCServerTransport() { @@ -127,6 +130,10 @@ public class HttpJSONRPCServerTransport( install(SSE) routing { + install(ContentNegotiation) { + json(JSONRPCJson) + } + transportRoutes(this, path) if (agentCard != null) { @@ -180,47 +187,82 @@ public class HttpJSONRPCServerTransport( * @param route The base route to which the transport routes should be mounted. * @param path JSON-RPC endpoint path that will be mounted under the base [route]. */ - @OptIn(InternalA2AApi::class) public fun transportRoutes(route: Route, path: String): Route = route.route(path) { - if (application.pluginOrNull(SSE) == null) { - throw IllegalStateException("SSE plugin must be installed in the application to add these routes.") - } + plugin(SSE) install(ContentNegotiation) { json(JSONRPCJson) } - // Regular JSON-RPC requests + // Handle incoming JSON-RPC requests, both regular and streaming post { - val response = runCatchingCancellable { - onRequest( - request = call.receiveJSONRPCRequest(), - ctx = call.toServerCallContext() - ) - }.getOrElse { it.toJSONRPCErrorResponse() } + runCatchingCancellable { + val request: JSONRPCRequest = call.receiveJSONRPCRequest() + val ctx: ServerCallContext = call.toServerCallContext() - call.respond(response) - } + runCatchingCancellable { + val a2aMethod = parseA2AMethod(request) - // Streaming JSON-RPC requests - sse( - serialize = { typeInfo, it -> - val kType = typeInfo.kotlinType ?: throw IllegalArgumentException("Null KType for value: $it") - val serializer = JSONRPCJson.serializersModule.serializer(kType) - JSONRPCJson.encodeToString(serializer, it) + if (a2aMethod.streaming) { + handleRequestStreaming(request, ctx) + } else { + handleRequest(request, ctx) + } + }.getOrElse { + call.respond(it.toJSONRPCErrorResponse(request.id)) + } + }.getOrElse { + call.respond(it.toJSONRPCErrorResponse()) } - ) { + } + } + + /** + * Handling A2A requests to regular methods. + */ + private suspend fun RoutingContext.handleRequest(request: JSONRPCRequest, ctx: ServerCallContext) { + val response = runCatchingCancellable { + onRequest( + request = request, + ctx = ctx + ) + }.getOrElse { it.toJSONRPCErrorResponse() } + + call.respond(response) + } + + /** + * Handling A2A requests to streaming methods. + */ + private suspend fun RoutingContext.handleRequestStreaming(request: JSONRPCRequest, ctx: ServerCallContext) { + val handle: suspend ServerSSESession.() -> Unit = { runCatchingCancellable { onRequestStreaming( - request = call.receiveJSONRPCRequest(), - ctx = call.toServerCallContext() + request = request, + ctx = ctx ).collect { response -> - send(response) + send( + ServerSentEvent(JSONRPCJson.encodeToString(response)) + ) } }.getOrElse { - send(it.toJSONRPCErrorResponse()) + send( + ServerSentEvent( + JSONRPCJson.encodeToString(it.toJSONRPCErrorResponse()) + ) + ) } } + + // Reply with SSE (implementation copied from SSE plugin code) + call.response.apply { + header(HttpHeaders.ContentType, ContentType.Text.EventStream.toString()) + header(HttpHeaders.CacheControl, "no-store") + header(HttpHeaders.Connection, "keep-alive") + header("X-Accel-Buffering", "no") + } + + call.respond(SSEServerContent(call, handle)) } /** diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt index 4b492bbc3a..fe3dd62cff 100644 --- a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt @@ -33,6 +33,7 @@ import io.ktor.client.request.post import io.ktor.client.request.setBody import io.ktor.client.statement.bodyAsText import io.ktor.http.ContentType +import io.ktor.http.HttpMethod import io.ktor.http.HttpStatusCode import io.ktor.http.contentType import io.ktor.server.sse.SSE @@ -299,6 +300,8 @@ class HttpJSONRPCServerTransportTest { client.sse( urlString = "/a2a", request = { + this.method = HttpMethod.Post + contentType(ContentType.Application.Json) setBody(json.encodeToString(jsonRpcRequest)) }, diff --git a/a2a/test-python-a2a-server/src/agent_executor.py b/a2a/test-python-a2a-server/src/agent_executor.py index 7ebb0f3f13..0a3ad4d5d1 100644 --- a/a2a/test-python-a2a-server/src/agent_executor.py +++ b/a2a/test-python-a2a-server/src/agent_executor.py @@ -13,12 +13,20 @@ new_agent_text_message, new_task ) +from datetime import datetime, timezone + + +def get_current_timestamp(): + """Get current timestamp in ISO 8601 format (UTC)""" + return datetime.now(timezone.utc).isoformat() async def say_hello( event_queue: EventQueue, - message: Message + context: RequestContext, ) -> None: + message = context.message + await event_queue.enqueue_event( new_agent_text_message( text="Hello World", @@ -30,9 +38,20 @@ async def say_hello( async def do_task( event_queue: EventQueue, - message: Message + context: RequestContext, ) -> None: - task = new_task(message) + message = context.message + + # noinspection PyTypeChecker + task = Task( + id=message.task_id, + context_id=message.context_id, + status=TaskStatus( + state=TaskState.submitted, + timestamp=get_current_timestamp() + ), + history=[message] + ) # noinspection PyTypeChecker events = [ @@ -73,24 +92,38 @@ async def do_task( async def do_cancelable_task( event_queue: EventQueue, - message: Message, + context: RequestContext, ): - await event_queue.enqueue_event( - new_task(message), + message = context.message + + # noinspection PyTypeChecker + task = Task( + id=message.task_id, + context_id=message.context_id, + status=TaskStatus( + state=TaskState.submitted, + timestamp=get_current_timestamp() + ), + history=[message] ) + await event_queue.enqueue_event(task) async def do_long_running_task( event_queue: EventQueue, - message: Message + context: RequestContext, ): + message = context.message + + # noinspection PyTypeChecker task = Task( id=message.task_id, context_id=message.context_id, status=TaskStatus( - state=TaskState.working, - message=message - ) + state=TaskState.submitted, + timestamp=get_current_timestamp() + ), + history=[message] ) await event_queue.enqueue_event(task) @@ -128,17 +161,17 @@ async def execute( user_input = context.get_user_input() # Test scenarios to test various aspects of A2A - if "hello world" in user_input: - await say_hello(event_queue, context.message) + if user_input == "hello world": + await say_hello(event_queue, context) - elif "do task" in user_input: - await do_task(event_queue, context.message) + elif user_input == "do task": + await do_task(event_queue, context) - elif "do cancelable task" in user_input: - await do_cancelable_task(event_queue, context.message) + elif user_input == "do cancelable task": + await do_cancelable_task(event_queue, context) - elif "do long-running task" in user_input: - await do_long_running_task(event_queue, context.message) + elif user_input == "do long-running task": + await do_long_running_task(event_queue, context) else: await event_queue.enqueue_event( From 0c95155bbe41c7847c62ac5ec1435082fefb8628 Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Mon, 29 Sep 2025 05:24:16 +0200 Subject: [PATCH 33/52] [a2a] Implement full TCK compliance with all tests passing --- .github/workflows/a2a-tck-test.yml | 112 +++++ a2a/CLAUDE.md | 2 +- a2a/a2a-client/build.gradle.kts | 2 +- .../client/A2AClientJsonRpcIntegrationTest.kt | 3 + a2a/a2a-core/build.gradle.kts | 1 + .../kotlin/ai/koog/a2a/dsl/A2ADsl.kt | 7 - .../ai/koog/a2a/exceptions/Exceptions.kt | 74 ++-- .../kotlin/ai/koog/a2a/model/Artifact.kt | 7 +- .../kotlin/ai/koog/a2a/model/Message.kt | 6 +- .../kotlin/ai/koog/a2a/model/Serialization.kt | 23 +- .../kotlin/ai/koog/a2a/model/Task.kt | 6 +- .../ai/koog/a2a/transport/Serialization.kt | 7 +- .../kotlin/ai/koog/a2a/utils/KeyedMutex.kt | 146 ++++++ .../ai/koog/a2a/utils/KeyedMutexTest.kt | 416 ++++++++++++++++++ a2a/a2a-server/build.gradle.kts | 3 +- .../kotlin/ai/koog/a2a/server/A2AServer.kt | 226 +++++----- .../ai/koog/a2a/server/agent/AgentExecutor.kt | 8 +- .../koog/a2a/server/exceptions/Exceptions.kt | 11 +- .../ai/koog/a2a/server/session/Session.kt | 47 +- .../server/session/SessionEventProcessor.kt | 231 +++------- .../koog/a2a/server/session/SessionManager.kt | 113 +---- .../session/SessionEventProcessorTest.kt | 92 +--- .../a2a/server/session/SessionManagerTest.kt | 163 +------ .../server/A2AServerJsonRpcIntegrationTest.kt | 206 ++++++++- .../ai/koog/a2a/server/TestAgentExecutor.kt | 17 +- .../src/jvmTest/resources/logback.xml | 11 + a2a/a2a-test/build.gradle.kts | 2 +- .../ai/koog/a2a/test/BaseA2AProtocolTest.kt | 27 +- .../build.gradle.kts | 2 +- .../http/HttpJSONRPCClientTransport.kt | 14 +- .../http/HttpJSONRPCClientTransportTest.kt | 15 +- .../build.gradle.kts | 3 + .../jsonrpc/JSONRPCClientTransport.kt | 21 +- .../jsonrpc/JSONRPCServerTransport.kt | 89 +++- .../a2a/transport/jsonrpc/model/Messages.kt | 13 +- .../jsonrpc/model/JsonRpcSerializationTest.kt | 14 +- .../http/HttpJSONRPCServerTransport.kt | 28 +- .../http/HttpJSONRPCServerTransportTest.kt | 10 +- a2a/test-tck/.gitignore | 2 + a2a/test-tck/README.md | 33 ++ .../a2a-test-server-tck/build.gradle.kts | 22 + .../main/kotlin/ai/koog/a2a/test/tck/Main.kt | 114 +++++ .../ai/koog/a2a/test/tck/TckAgentExecutor.kt | 229 ++++++++++ .../src/main/resources/logback.xml | 11 + a2a/test-tck/run_sut.sh | 11 + a2a/test-tck/run_tck.sh | 17 + a2a/test-tck/setup_tck.sh | 26 ++ koog-agents/build.gradle.kts | 3 +- settings.gradle.kts | 1 + 49 files changed, 1871 insertions(+), 776 deletions(-) create mode 100644 .github/workflows/a2a-tck-test.yml delete mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/dsl/A2ADsl.kt create mode 100644 a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/KeyedMutex.kt create mode 100644 a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/utils/KeyedMutexTest.kt create mode 100644 a2a/a2a-server/src/jvmTest/resources/logback.xml create mode 100644 a2a/test-tck/.gitignore create mode 100644 a2a/test-tck/README.md create mode 100644 a2a/test-tck/a2a-test-server-tck/build.gradle.kts create mode 100644 a2a/test-tck/a2a-test-server-tck/src/main/kotlin/ai/koog/a2a/test/tck/Main.kt create mode 100644 a2a/test-tck/a2a-test-server-tck/src/main/kotlin/ai/koog/a2a/test/tck/TckAgentExecutor.kt create mode 100644 a2a/test-tck/a2a-test-server-tck/src/main/resources/logback.xml create mode 100755 a2a/test-tck/run_sut.sh create mode 100755 a2a/test-tck/run_tck.sh create mode 100755 a2a/test-tck/setup_tck.sh diff --git a/.github/workflows/a2a-tck-test.yml b/.github/workflows/a2a-tck-test.yml new file mode 100644 index 0000000000..6ebd15ad0c --- /dev/null +++ b/.github/workflows/a2a-tck-test.yml @@ -0,0 +1,112 @@ +name: A2A TCK Test + +on: + push: + paths: + - 'a2a/**' + pull_request: + paths: + - 'a2a/**' + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' && github.ref != 'refs/heads/develop' }} + +env: + JAVA_VERSION: 17 + JAVA_DISTRIBUTION: 'corretto' + +jobs: + a2a-tck-test: + runs-on: ubuntu-latest + timeout-minutes: 10 + + permissions: + contents: read + + steps: + - name: Configure Git + run: | + git config --global core.autocrlf input + + - uses: actions/checkout@v5 + + - name: Set up JDK ${{ env.JAVA_VERSION }} + uses: actions/setup-java@v5 + with: + java-version: ${{ env.JAVA_VERSION }} + distribution: ${{ env.JAVA_DISTRIBUTION }} + + - name: Setup Gradle + uses: gradle/actions/setup-gradle@v4 + + - name: Set up Python 3.12 + uses: actions/setup-python@v6 + with: + python-version: '3.12' + + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + version: "0.8.15" + enable-cache: true + + - name: Setup A2A TCK + working-directory: ./a2a/test-tck + run: ./setup_tck.sh + + - name: Build test server + run: ./gradlew :a2a:test-tck:a2a-test-server-tck:assemble + + - name: Start test server in background + working-directory: ./a2a/test-tck + run: | + ./run_sut.sh > server.log 2>&1 & + SERVER_PID=$! + echo "SERVER_PID=$SERVER_PID" >> $GITHUB_ENV + + # Wait for server to start (max 60 seconds) + timeout 60 bash -c ' + while ! grep -q "Responding at http://0.0.0.0:9999" server.log 2>/dev/null; do + sleep 1 + done + ' + + echo "Server started successfully" + + - name: Run TCK tests - Mandatory + working-directory: ./a2a/test-tck + run: ./run_tck.sh --sut-url http://localhost:9999/a2a --category mandatory --report + + - name: Run TCK tests - Capabilities + working-directory: ./a2a/test-tck + run: ./run_tck.sh --sut-url http://localhost:9999/a2a --category capabilities --report + + - name: Run TCK tests - Transport Equivalence + working-directory: ./a2a/test-tck + run: ./run_tck.sh --sut-url http://localhost:9999/a2a --category transport-equivalence --report + + - name: Run TCK tests - Quality + working-directory: ./a2a/test-tck + run: ./run_tck.sh --sut-url http://localhost:9999/a2a --category quality --report + + - name: Run TCK tests - Features + working-directory: ./a2a/test-tck + run: ./run_tck.sh --sut-url http://localhost:9999/a2a --category features --report + + - name: Stop test server + if: always() + run: | + if [ ! -z "$SERVER_PID" ]; then + kill $SERVER_PID || true + wait $SERVER_PID 2>/dev/null || true + fi + + - name: Upload TCK reports + if: always() + uses: actions/upload-artifact@v4 + with: + name: a2a-tck-reports + path: a2a/test-tck/a2a-tck/reports/*.html + if-no-files-found: warn diff --git a/a2a/CLAUDE.md b/a2a/CLAUDE.md index e5b76f949c..0f4d2ca415 100644 --- a/a2a/CLAUDE.md +++ b/a2a/CLAUDE.md @@ -52,7 +52,7 @@ The A2A (Agent-to-Agent) module is a **meta-module** within the larger Koog proj - **kotest-assertions**: Rich assertion library for complex objects - **kotlinx-coroutines-test**: Testing coroutines with runTest - **testcontainers**: Docker-based integration testing -- **slf4j-simple**: Runtime logging for tests +- **logback-classic**: Runtime logging for tests ### Platform Support - **JVM**: Full server and client support diff --git a/a2a/a2a-client/build.gradle.kts b/a2a/a2a-client/build.gradle.kts index a1d4d9e660..4c0a68f204 100644 --- a/a2a/a2a-client/build.gradle.kts +++ b/a2a/a2a-client/build.gradle.kts @@ -38,7 +38,7 @@ kotlin { implementation(libs.ktor.client.cio) implementation(libs.ktor.client.logging) implementation(libs.testcontainers.junit) - runtimeOnly(libs.slf4j.simple) + runtimeOnly(libs.logback.classic) } } diff --git a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt index 7799a7e47f..cb95bc4618 100644 --- a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt +++ b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt @@ -13,6 +13,7 @@ import org.testcontainers.containers.GenericContainer import org.testcontainers.containers.wait.strategy.Wait import org.testcontainers.junit.jupiter.Container import org.testcontainers.junit.jupiter.Testcontainers +import kotlin.time.Duration.Companion.seconds /** * Integration test class for testing the JSON-RPC HTTP communication in the A2A client context. @@ -30,6 +31,8 @@ class A2AClientJsonRpcIntegrationTest : BaseA2AProtocolTest() { .waitingFor(Wait.forListeningPort()) } + override val testTimeout = 10.seconds + private val httpClient = HttpClient { install(Logging) { level = LogLevel.BODY diff --git a/a2a/a2a-core/build.gradle.kts b/a2a/a2a-core/build.gradle.kts index 110474284b..7e99e5353e 100644 --- a/a2a/a2a-core/build.gradle.kts +++ b/a2a/a2a-core/build.gradle.kts @@ -27,6 +27,7 @@ kotlin { commonTest { dependencies { implementation(kotlin("test")) + implementation(libs.kotlinx.coroutines.test) } } diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/dsl/A2ADsl.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/dsl/A2ADsl.kt deleted file mode 100644 index 4570d8824c..0000000000 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/dsl/A2ADsl.kt +++ /dev/null @@ -1,7 +0,0 @@ -package ai.koog.a2a.dsl - -/** - * A2A DSL marker - */ -@DslMarker -public annotation class A2ADsl diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt index 798ed3766a..7620ba56cf 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/exceptions/Exceptions.kt @@ -1,5 +1,7 @@ package ai.koog.a2a.exceptions +import ai.koog.a2a.transport.RequestId + /** * Object containing all A2A error codes. */ @@ -24,7 +26,8 @@ public object A2AErrorCodes { */ public sealed class A2AException( public override val message: String, - public val errorCode: Int + public val errorCode: Int, + public val requestId: RequestId? = null, ) : Exception(message) /** @@ -32,35 +35,40 @@ public sealed class A2AException( */ public class A2AParseException( message: String = "Invalid JSON payload", -) : A2AException(message, errorCode = A2AErrorCodes.PARSE_ERROR) + requestId: RequestId? = null, +) : A2AException(message, A2AErrorCodes.PARSE_ERROR, requestId) /** * The JSON payload was valid JSON, but not a valid JSON-RPC Request object. */ public class A2AInvalidRequestException( message: String = "Invalid JSON-RPC Request", -) : A2AException(message, errorCode = A2AErrorCodes.INVALID_REQUEST) + requestId: RequestId? = null, +) : A2AException(message, A2AErrorCodes.INVALID_REQUEST, requestId) /** * The requested A2A RPC method does not exist or is not supported. */ public class A2AMethodNotFoundException( message: String = "Method not found", -) : A2AException(message, errorCode = A2AErrorCodes.METHOD_NOT_FOUND) + requestId: RequestId? = null, +) : A2AException(message, A2AErrorCodes.METHOD_NOT_FOUND, requestId) /** * The params provided for the method are invalid. */ public class A2AInvalidParamsException( message: String = "Invalid method parameters", -) : A2AException(message, errorCode = A2AErrorCodes.INVALID_PARAMS) + requestId: RequestId? = null, +) : A2AException(message, A2AErrorCodes.INVALID_PARAMS, requestId) /** * An unexpected error occurred on the server during processing. */ public class A2AInternalErrorException( message: String = "Internal server error", -) : A2AException(message, errorCode = A2AErrorCodes.INTERNAL_ERROR) + requestId: RequestId? = null, +) : A2AException(message, A2AErrorCodes.INTERNAL_ERROR, requestId) /** * Reserved for implementation-defined server exceptions. A2A-specific exceptions use this range. @@ -68,7 +76,8 @@ public class A2AInternalErrorException( public sealed class A2AServerException( message: String, errorCode: Int, -) : A2AException(message, errorCode) { + requestId: RequestId? = null, +) : A2AException(message, errorCode, requestId) { init { require(errorCode in -32099..-32000) { "Server error code must be in -32099..-32000" } } @@ -80,7 +89,8 @@ public sealed class A2AServerException( */ public class A2ATaskNotFoundException( message: String = "Task not found", -) : A2AServerException(message, errorCode = A2AErrorCodes.TASK_NOT_FOUND) + requestId: RequestId? = null, +) : A2AServerException(message, A2AErrorCodes.TASK_NOT_FOUND, requestId) /** * An attempt was made to cancel a task that is not in a cancelable state. @@ -88,7 +98,8 @@ public class A2ATaskNotFoundException( */ public class A2ATaskNotCancelableException( message: String = "Task cannot be canceled", -) : A2AServerException(message, errorCode = A2AErrorCodes.TASK_NOT_CANCELABLE) + requestId: RequestId? = null, +) : A2AServerException(message, A2AErrorCodes.TASK_NOT_CANCELABLE, requestId) /** * Client attempted to use push notification features but the server agent does not support them. @@ -96,7 +107,8 @@ public class A2ATaskNotCancelableException( */ public class A2APushNotificationNotSupportedException( message: String = "Push Notification is not supported", -) : A2AServerException(message, errorCode = A2AErrorCodes.PUSH_NOTIFICATION_NOT_SUPPORTED) + requestId: RequestId? = null, +) : A2AServerException(message, A2AErrorCodes.PUSH_NOTIFICATION_NOT_SUPPORTED, requestId) /** * The requested operation or a specific aspect of it is not supported by this server agent implementation. @@ -104,7 +116,8 @@ public class A2APushNotificationNotSupportedException( */ public class A2AUnsupportedOperationException( message: String = "This operation is not supported", -) : A2AServerException(message, errorCode = A2AErrorCodes.UNSUPPORTED_OPERATION) + requestId: RequestId? = null, +) : A2AServerException(message, A2AErrorCodes.UNSUPPORTED_OPERATION, requestId) /** * A Media Type provided in the request's message.parts or implied for an artifact is not supported @@ -112,21 +125,24 @@ public class A2AUnsupportedOperationException( */ public class A2AContentTypeNotSupportedException( message: String = "Incompatible content types", -) : A2AServerException(message, errorCode = A2AErrorCodes.CONTENT_TYPE_NOT_SUPPORTED) + requestId: RequestId? = null, +) : A2AServerException(message, A2AErrorCodes.CONTENT_TYPE_NOT_SUPPORTED, requestId) /** * Agent generated an invalid response for the requested method. */ public class A2AInvalidAgentResponseException( message: String = "Invalid agent response type", -) : A2AServerException(message, errorCode = A2AErrorCodes.INVALID_AGENT_RESPONSE) + requestId: RequestId? = null, +) : A2AServerException(message, A2AErrorCodes.INVALID_AGENT_RESPONSE, requestId) /** * The agent does not have an Authenticated Extended Card configured. */ public class A2AAuthenticatedExtendedCardNotConfiguredException( message: String = "Authenticated Extended Card not configured", -) : A2AServerException(message, errorCode = A2AErrorCodes.AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED) + requestId: RequestId? = null, +) : A2AServerException(message, A2AErrorCodes.AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED, requestId) /** * Server returned some unknown error code. @@ -134,7 +150,8 @@ public class A2AAuthenticatedExtendedCardNotConfiguredException( public class A2AUnknownException( message: String, errorCode: Int, -) : A2AException(message, errorCode) + requestId: RequestId? = null, +) : A2AException(message, errorCode, requestId) /** * Create appropriate [A2AException] based on the provided errorCode. @@ -142,20 +159,21 @@ public class A2AUnknownException( public fun createA2AException( message: String, errorCode: Int, + requestId: RequestId?, ): A2AException { return when (errorCode) { - A2AErrorCodes.PARSE_ERROR -> A2AParseException(message) - A2AErrorCodes.INVALID_REQUEST -> A2AInvalidRequestException(message) - A2AErrorCodes.METHOD_NOT_FOUND -> A2AMethodNotFoundException(message) - A2AErrorCodes.INVALID_PARAMS -> A2AInvalidParamsException(message) - A2AErrorCodes.INTERNAL_ERROR -> A2AInternalErrorException(message) - A2AErrorCodes.TASK_NOT_FOUND -> A2ATaskNotFoundException(message) - A2AErrorCodes.TASK_NOT_CANCELABLE -> A2ATaskNotCancelableException(message) - A2AErrorCodes.PUSH_NOTIFICATION_NOT_SUPPORTED -> A2APushNotificationNotSupportedException(message) - A2AErrorCodes.UNSUPPORTED_OPERATION -> A2AUnsupportedOperationException(message) - A2AErrorCodes.CONTENT_TYPE_NOT_SUPPORTED -> A2AContentTypeNotSupportedException(message) - A2AErrorCodes.INVALID_AGENT_RESPONSE -> A2AInvalidAgentResponseException(message) - A2AErrorCodes.AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED -> A2AAuthenticatedExtendedCardNotConfiguredException(message) - else -> A2AUnknownException(message, errorCode) + A2AErrorCodes.PARSE_ERROR -> A2AParseException(message, requestId) + A2AErrorCodes.INVALID_REQUEST -> A2AInvalidRequestException(message, requestId) + A2AErrorCodes.METHOD_NOT_FOUND -> A2AMethodNotFoundException(message, requestId) + A2AErrorCodes.INVALID_PARAMS -> A2AInvalidParamsException(message, requestId) + A2AErrorCodes.INTERNAL_ERROR -> A2AInternalErrorException(message, requestId) + A2AErrorCodes.TASK_NOT_FOUND -> A2ATaskNotFoundException(message, requestId) + A2AErrorCodes.TASK_NOT_CANCELABLE -> A2ATaskNotCancelableException(message, requestId) + A2AErrorCodes.PUSH_NOTIFICATION_NOT_SUPPORTED -> A2APushNotificationNotSupportedException(message, requestId) + A2AErrorCodes.UNSUPPORTED_OPERATION -> A2AUnsupportedOperationException(message, requestId) + A2AErrorCodes.CONTENT_TYPE_NOT_SUPPORTED -> A2AContentTypeNotSupportedException(message, requestId) + A2AErrorCodes.INVALID_AGENT_RESPONSE -> A2AInvalidAgentResponseException(message, requestId) + A2AErrorCodes.AUTHENTICATED_EXTENDED_CARD_NOT_CONFIGURED -> A2AAuthenticatedExtendedCardNotConfiguredException(message, requestId) + else -> A2AUnknownException(message, errorCode, requestId) } } diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Artifact.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Artifact.kt index d00c893212..74237903e9 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Artifact.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Artifact.kt @@ -1,10 +1,7 @@ package ai.koog.a2a.model -import kotlinx.serialization.EncodeDefault import kotlinx.serialization.Serializable import kotlinx.serialization.json.JsonObject -import kotlin.uuid.ExperimentalUuidApi -import kotlin.uuid.Uuid /** * Represents a file, data structure, or other resource generated by an agent during a task. @@ -16,11 +13,9 @@ import kotlin.uuid.Uuid * @property extensions Optional URIs of extensions that are relevant to this artifact. * @property metadata Optional metadata for extensions. The key is an extension-specific identifier. */ -@OptIn(ExperimentalUuidApi::class) @Serializable public data class Artifact( - @EncodeDefault - public val artifactId: String = Uuid.random().toString(), + public val artifactId: String, public val name: String? = null, public val description: String? = null, public val parts: List, diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Message.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Message.kt index 4832145eb7..b479a67b64 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Message.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Message.kt @@ -4,8 +4,6 @@ import kotlinx.serialization.EncodeDefault import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import kotlinx.serialization.json.JsonObject -import kotlin.uuid.ExperimentalUuidApi -import kotlin.uuid.Uuid /** * Message role. @@ -32,11 +30,9 @@ public enum class Role { * @property contextId The context ID for this message, used to group related interactions. * @property metadata Optional metadata for extensions. The key is an extension-specific identifier. */ -@OptIn(ExperimentalUuidApi::class) @Serializable public data class Message( - @EncodeDefault - public val messageId: String = Uuid.random().toString(), + public val messageId: String, public val role: Role, public val parts: List, public val extensions: List? = null, diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt index bdf6bde1a4..78c72f911a 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Serialization.kt @@ -1,6 +1,7 @@ package ai.koog.a2a.model import kotlinx.serialization.DeserializationStrategy +import kotlinx.serialization.SerializationException import kotlinx.serialization.json.JsonContentPolymorphicSerializer import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.jsonObject @@ -9,7 +10,7 @@ import kotlinx.serialization.json.jsonPrimitive internal object SecuritySchemeSerializer : JsonContentPolymorphicSerializer(SecurityScheme::class) { override fun selectDeserializer(element: JsonElement): DeserializationStrategy { val jsonObject = element.jsonObject - val type = jsonObject["type"]?.jsonPrimitive?.content ?: error("Missing 'type' field in SecurityScheme") + val type = jsonObject["type"]?.jsonPrimitive?.content ?: throw SerializationException("Missing 'type' field in SecurityScheme") return when (type) { "apiKey" -> APIKeySecurityScheme.serializer() @@ -17,7 +18,7 @@ internal object SecuritySchemeSerializer : JsonContentPolymorphicSerializer OAuth2SecurityScheme.serializer() "openIdConnect" -> OpenIdConnectSecurityScheme.serializer() "mutualTLS" -> MutualTLSSecurityScheme.serializer() - else -> error("Unknown SecurityScheme type: $type") + else -> throw SerializationException("Unknown SecurityScheme type: $type") } } } @@ -25,13 +26,13 @@ internal object SecuritySchemeSerializer : JsonContentPolymorphicSerializer(Part::class) { override fun selectDeserializer(element: JsonElement): DeserializationStrategy { val jsonObject = element.jsonObject - val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: error("Missing 'kind' field in Part") + val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: throw SerializationException("Missing 'kind' field in Part") return when (kind) { "text" -> TextPart.serializer() "file" -> FilePart.serializer() "data" -> DataPart.serializer() - else -> error("Unknown Part kind: $kind") + else -> throw SerializationException("Unknown Part kind: $kind") } } } @@ -43,7 +44,7 @@ internal object FileSerializer : JsonContentPolymorphicSerializer(File::cl return when { "bytes" in jsonObject -> FileWithBytes.serializer() "uri" in jsonObject -> FileWithUri.serializer() - else -> error("Unknown File type") + else -> throw SerializationException("Unknown File type") } } } @@ -51,14 +52,14 @@ internal object FileSerializer : JsonContentPolymorphicSerializer(File::cl internal object EventSerializer : JsonContentPolymorphicSerializer(Event::class) { override fun selectDeserializer(element: JsonElement): DeserializationStrategy { val jsonObject = element.jsonObject - val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: error("Missing 'kind' field in Event") + val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: throw SerializationException("Missing 'kind' field in Event") return when (kind) { "status-update" -> TaskStatusUpdateEvent.serializer() "artifact-update" -> TaskArtifactUpdateEvent.serializer() "task" -> Task.serializer() "message" -> Message.serializer() - else -> error("Unknown kind: $kind") + else -> throw SerializationException("Unknown kind: $kind") } } } @@ -66,12 +67,12 @@ internal object EventSerializer : JsonContentPolymorphicSerializer(Event: internal object CommunicationEventSerializer : JsonContentPolymorphicSerializer(CommunicationEvent::class) { override fun selectDeserializer(element: JsonElement): DeserializationStrategy { val jsonObject = element.jsonObject - val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: error("Missing 'kind' field in CommunicationEvent") + val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: throw SerializationException("Missing 'kind' field in CommunicationEvent") return when (kind) { "task" -> Task.serializer() "message" -> Message.serializer() - else -> error("Unknown kind: $kind") + else -> throw SerializationException("Unknown kind: $kind") } } } @@ -79,13 +80,13 @@ internal object CommunicationEventSerializer : JsonContentPolymorphicSerializer< internal object TaskEventSerializer : JsonContentPolymorphicSerializer(TaskEvent::class) { override fun selectDeserializer(element: JsonElement): DeserializationStrategy { val jsonObject = element.jsonObject - val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: error("Missing 'kind' field in TaskEvent") + val kind = jsonObject["kind"]?.jsonPrimitive?.content ?: throw SerializationException("Missing 'kind' field in TaskEvent") return when (kind) { "task" -> Task.serializer() "status-update" -> TaskStatusUpdateEvent.serializer() "artifact-update" -> TaskArtifactUpdateEvent.serializer() - else -> error("Unknown kind: $kind") + else -> throw SerializationException("Unknown kind: $kind") } } } diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt index 3a861c24c7..aa007b68a7 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/Task.kt @@ -6,8 +6,6 @@ import kotlinx.serialization.EncodeDefault import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable import kotlinx.serialization.json.JsonObject -import kotlin.uuid.ExperimentalUuidApi -import kotlin.uuid.Uuid /** * Represents a single, stateful operation or conversation between a client and an agent. @@ -19,11 +17,9 @@ import kotlin.uuid.Uuid * @property artifacts A collection of artifacts generated by the agent during the execution of the task. * @property metadata Optional metadata for extensions. The key is an extension-specific identifier. */ -@OptIn(ExperimentalUuidApi::class) @Serializable public data class Task( - @EncodeDefault - public val id: String = Uuid.random().toString(), + public val id: String, override val contextId: String, public val status: TaskStatus, public val history: List? = null, diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Serialization.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Serialization.kt index ba8f1c7c0b..b55cec4339 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Serialization.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/Serialization.kt @@ -1,6 +1,7 @@ package ai.koog.a2a.transport import kotlinx.serialization.KSerializer +import kotlinx.serialization.SerializationException import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.descriptors.buildClassSerialDescriptor import kotlinx.serialization.encoding.Decoder @@ -21,15 +22,15 @@ internal object RequestIdSerializer : KSerializer { is JsonPrimitive -> when { element.isString -> RequestId.StringId(element.content) element.longOrNull != null -> RequestId.NumberId(element.long) - else -> error("Invalid RequestId type") + else -> throw SerializationException("Invalid RequestId type") } - else -> error("Invalid RequestId format") + else -> throw SerializationException("Invalid RequestId format") } } override fun serialize(encoder: Encoder, value: RequestId) { - val jsonEncoder = encoder as? JsonEncoder ?: error("Can only serialize JSON") + val jsonEncoder = encoder as? JsonEncoder ?: throw SerializationException("Can only serialize JSON") when (value) { is RequestId.StringId -> jsonEncoder.encodeString(value.value) is RequestId.NumberId -> jsonEncoder.encodeLong(value.value) diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/KeyedMutex.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/KeyedMutex.kt new file mode 100644 index 0000000000..73e13f7001 --- /dev/null +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/utils/KeyedMutex.kt @@ -0,0 +1,146 @@ +package ai.koog.a2a.utils + +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract + +/** + * A keyed, suspend-friendly mutex for Kotlin Multiplatform. + * + * Guarantees mutual exclusion per key, without blocking threads. + * API mirrors kotlinx.coroutines Mutex: + * - lock(key, owner) + * - tryLock(key, owner) + * - unlock(key, owner) + * - withLock(key, owner) { ... } + * + * Internals: + * - A per-key entry holds a Mutex and a reference count (holders/waiters). + * - We increment refs before suspending for lock(key) so entries aren’t removed while waiting. + * - Cleanup removes the entry only when refs == 0 and the mutex is not locked. + */ +public class KeyedMutex { + private class Entry( + val mutex: Mutex = Mutex(), + var refs: Int = 0 + ) + + /** + * Protects access to [entries] and [Entry.refs] updates. + */ + private val mapMutex = Mutex() + private val entries = mutableMapOf() + + /** + * Suspends until the lock for [key] is acquired. + * Not re-entrant for the same key within the same coroutine. + */ + public suspend fun lock(key: K, owner: Any? = null) { + val entry = mapMutex.withLock { + val e = entries.getOrPut(key) { Entry() } + e.refs += 1 + e + } + + try { + entry.mutex.lock(owner) + } catch (t: Throwable) { + // If lock failed/cancelled before acquiring, roll back the ref and maybe cleanup. + mapMutex.withLock { + entry.refs -= 1 + if (entry.refs == 0 && !entry.mutex.isLocked && entries[key] == entry) { + entries.remove(key) + } + } + throw t + } + } + + /** + * Attempts to acquire the lock for [key] without suspension. + * Returns true if lock was acquired. + */ + public suspend fun tryLock(key: K, owner: Any? = null): Boolean { + return mapMutex.withLock { + val existing = entries[key] + if (existing != null) { + if (existing.mutex.tryLock(owner)) { + existing.refs += 1 + true + } else { + false + } + } else { + // Avoid inserting if we cannot lock; create a fresh entry and try immediately. + val e = Entry() + val locked = e.mutex.tryLock(owner) + if (locked) { + e.refs = 1 + entries[key] = e + true + } else { + // Unlikely with a fresh Mutex, but don't insert on failure. + false + } + } + } + } + + /** + * Releases the lock for [key]. + * Must be called exactly once per successful lock/tryLock. + * @throws IllegalStateException on misuse (mirrors Mutex behavior). + */ + public suspend fun unlock(key: K, owner: Any? = null) { + val entry = mapMutex.withLock { + entries[key] ?: throw IllegalStateException("Unlock requested for key without active entry") + } + + // Perform the actual unlock; may throw if not locked or wrong owner. + entry.mutex.unlock(owner) + + // Decrement refs and cleanup if safe. + mapMutex.withLock { + entry.refs -= 1 + if (entry.refs == 0 && !entry.mutex.isLocked && entries[key] == entry) { + entries.remove(key) + } + } + } + + /** + * Optional: observe whether a key appears locked. + * For diagnostics/metrics only (not a synchronization primitive). + */ + public suspend fun isLocked(key: K): Boolean = + mapMutex.withLock { entries[key]?.mutex?.isLocked == true } + + /** + * Checks whether this key is locked by the specified owner. + */ + public suspend fun holdsLock(key: K, owner: Any): Boolean = + mapMutex.withLock { entries[key]?.mutex?.holdsLock(owner) == true } +} + +/** + * Convenience function mirroring [Mutex.withLock] + */ +@OptIn(ExperimentalContracts::class) +public suspend inline fun KeyedMutex.withLock( + key: K, + owner: Any? = null, + action: suspend () -> T +): T { + contract { + callsInPlace(action, InvocationKind.EXACTLY_ONCE) + } + + lock(key, owner) + try { + return action() + } finally { + unlock(key, owner) + } +} diff --git a/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/utils/KeyedMutexTest.kt b/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/utils/KeyedMutexTest.kt new file mode 100644 index 0000000000..0c23d6a8c6 --- /dev/null +++ b/a2a/a2a-core/src/commonTest/kotlin/ai/koog/a2a/utils/KeyedMutexTest.kt @@ -0,0 +1,416 @@ +package ai.koog.a2a.utils + +import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll +import kotlinx.coroutines.delay +import kotlinx.coroutines.sync.withLock +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.yield +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class KeyedMutexTest { + + @Test + fun basicLockUnlock() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + + assertFalse(mutex.isLocked(key), "Key should not be locked initially") + + mutex.lock(key) + assertTrue(mutex.isLocked(key), "Key should be locked after lock()") + + mutex.unlock(key) + assertFalse(mutex.isLocked(key), "Key should be unlocked after unlock()") + } + + @Test + fun basicTryLock() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + + assertTrue(mutex.tryLock(key), "tryLock should succeed on unlocked key") + assertTrue(mutex.isLocked(key), "Key should be locked after tryLock") + assertFalse(mutex.tryLock(key), "tryLock should fail on locked key") + + mutex.unlock(key) + assertFalse(mutex.isLocked(key), "Key should be unlocked after unlock") + } + + @Test + fun withLockConvenienceFunction() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + var executed = false + + val result = mutex.withLock(key) { + assertTrue(mutex.isLocked(key), "Key should be locked inside withLock block") + executed = true + "result" + } + + assertEquals("result", result, "withLock should return block result") + assertTrue(executed, "Block should have been executed") + assertFalse(mutex.isLocked(key), "Key should be unlocked after withLock") + } + + @Test + fun differentKeysCanBeLocked() = runTest { + val mutex = KeyedMutex() + val key1 = "key1" + val key2 = "key2" + + mutex.lock(key1) + mutex.lock(key2) + + assertTrue(mutex.isLocked(key1), "Key1 should be locked") + assertTrue(mutex.isLocked(key2), "Key2 should be locked") + + mutex.unlock(key1) + assertFalse(mutex.isLocked(key1), "Key1 should be unlocked") + assertTrue(mutex.isLocked(key2), "Key2 should still be locked") + + mutex.unlock(key2) + assertFalse(mutex.isLocked(key2), "Key2 should be unlocked") + } + + @Test + fun concurrentAccessSameKey() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + val results = mutableListOf() + val accessOrder = mutableListOf() + + val jobs = (1..3).map { i -> + this.async { + mutex.withLock(key) { + accessOrder.add(i) + delay(10) // Simulate some work + results.add(i) + } + } + } + + jobs.awaitAll() + + assertEquals(3, results.size, "All coroutines should complete") + assertEquals(accessOrder, results, "Access order should match completion order due to mutex") + assertFalse(mutex.isLocked(key), "Key should be unlocked after all operations") + } + + @Test + fun parallelAccessDifferentKeys() = runTest { + val mutex = KeyedMutex() + val results = mutableListOf() + val completionOrder = mutableListOf() + + // Use different keys with different work durations to verify parallelism + val keyWorkMap = mapOf( + "fast" to 10L, + "medium" to 15L, + "slow" to 25L + ) + + val jobs = keyWorkMap.map { (key, workDuration) -> + this.async { + mutex.withLock(key) { + results.add("$key-started") + delay(workDuration) + results.add("$key-finished") + completionOrder.add(key) + } + } + } + + jobs.awaitAll() + + assertEquals(6, results.size, "All operations should complete") + assertEquals(3, completionOrder.size, "All keys should complete") + + // If running in parallel, fast operations should finish before slow ones + // even if they started later. Check that "fast" completes before "slow" + val fastIndex = completionOrder.indexOf("fast") + val slowIndex = completionOrder.indexOf("slow") + assertTrue( + fastIndex < slowIndex, + "Fast operation should complete before slow operation if running in parallel. " + + "Completion order: $completionOrder" + ) + } + + @Test + fun unlockWithoutLockThrowsException() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + + assertFailsWith("Should throw when unlocking non-existent key") { + mutex.unlock(key) + } + } + + @Test + fun unlockAfterAlreadyUnlockedThrowsException() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + + mutex.lock(key) + mutex.unlock(key) + + assertFailsWith("Should throw when unlocking already unlocked key") { + mutex.unlock(key) + } + } + + @Test + fun ownerParameterEnforcement() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + val owner1 = "owner1" + val owner2 = "owner2" + + mutex.lock(key, owner1) + + assertFailsWith("Should throw when unlocking with wrong owner") { + mutex.unlock(key, owner2) + } + + // Should succeed with correct owner + mutex.unlock(key, owner1) + assertFalse(mutex.isLocked(key), "Key should be unlocked") + } + + @Test + fun tryLockWithOwnerParameterEnforcement() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + val owner1 = "owner1" + val owner2 = "owner2" + + assertTrue(mutex.tryLock(key, owner1), "First tryLock should succeed") + assertFalse(mutex.tryLock(key, owner2), "Second tryLock with different owner should fail") + + mutex.unlock(key, owner1) + assertTrue(mutex.tryLock(key, owner2), "tryLock should succeed after unlock") + mutex.unlock(key, owner2) + } + + @Test + fun exceptionDuringLockRollsBackRefCount() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + + // First, acquire the lock to create an entry + mutex.lock(key) + assertTrue(mutex.isLocked(key), "Key should be locked") + + // Start another coroutine that will try to lock and fail + val job = this.async { + try { + // This should wait since the key is already locked + mutex.lock(key, "owner") + // If we somehow get here, unlock to clean up + mutex.unlock(key, "owner") + } catch (e: Exception) { + // Expected to be cancelled + } + } + + yield() // Let the other coroutine start waiting + + // Cancel the waiting coroutine to simulate an exception during lock + job.cancel() + + // Unlock the original lock + mutex.unlock(key) + + // The entry should be cleaned up properly despite the cancelled waiter + assertFalse(mutex.isLocked(key), "Key should be unlocked and cleaned up") + + // Should be able to lock again normally + assertTrue(mutex.tryLock(key), "Should be able to lock after cleanup") + mutex.unlock(key) + } + + @Test + fun memoryLeakPrevention() = runTest { + val mutex = KeyedMutex() + val keys = (1..100).map { "key$it" } + + // Lock and unlock many keys + keys.forEach { key -> + mutex.withLock(key) { + // Do nothing, just acquire and release + } + } + + // All keys should be unlocked and entries cleaned up + keys.forEach { key -> + assertFalse(mutex.isLocked(key), "Key $key should be unlocked") + } + + // Test with concurrent access to same key + repeat(50) { + val key = "shared-key" + val jobs = (1..10).map { + async { + mutex.withLock(key) { + delay(1) + } + } + } + jobs.awaitAll() + assertFalse(mutex.isLocked(key), "Shared key should be unlocked after round $it") + } + } + + @Test + fun highConcurrencyStressTest() = runTest { + val mutex = KeyedMutex() + val sharedCounter = mutableMapOf() + val keys = listOf("key1", "key2", "key3") + + // Initialize counters + keys.forEach { key -> sharedCounter[key] = 0 } + + val jobs = (1..300).map { i -> + this.async { + val key = keys[i % keys.size] + mutex.withLock(key) { + val current = sharedCounter[key]!! + yield() // Force context switch to test race conditions + sharedCounter[key] = current + 1 + } + } + } + + jobs.awaitAll() + + // Verify that each counter was incremented exactly the right number of times + keys.forEach { key -> + val expected = 300 / keys.size + assertEquals(expected, sharedCounter[key], "Counter for $key should be exactly $expected") + assertFalse(mutex.isLocked(key), "Key $key should be unlocked") + } + } + + @Test + fun isLockedObservationOnly() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + + // isLocked should work for non-existent keys + assertFalse(mutex.isLocked(key), "Non-existent key should not be locked") + + mutex.lock(key) + assertTrue(mutex.isLocked(key), "Locked key should return true") + + // isLocked should be usable from concurrent contexts + val job = this.async { + mutex.isLocked(key) + } + + assertTrue(job.await(), "isLocked should work from concurrent context") + mutex.unlock(key) + } + + @Test + fun holdsLockOwnershipCheck() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + val owner1 = "owner1" + val owner2 = "owner2" + + // holdsLock should return false for non-existent keys + assertFalse(mutex.holdsLock(key, owner1), "Non-existent key should not be held by any owner") + + // Lock with owner1 + mutex.lock(key, owner1) + assertTrue(mutex.holdsLock(key, owner1), "owner1 should hold the lock") + assertFalse(mutex.holdsLock(key, owner2), "owner2 should not hold the lock") + + // After unlock, no one should hold the lock + mutex.unlock(key, owner1) + assertFalse(mutex.holdsLock(key, owner1), "owner1 should not hold the lock after unlock") + assertFalse(mutex.holdsLock(key, owner2), "owner2 should not hold the lock after unlock") + } + + @Test + fun holdsLockWithTryLock() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + val owner1 = "owner1" + val owner2 = "owner2" + + // Acquire lock with tryLock + assertTrue(mutex.tryLock(key, owner1), "tryLock should succeed") + assertTrue(mutex.holdsLock(key, owner1), "owner1 should hold the lock after tryLock") + assertFalse(mutex.holdsLock(key, owner2), "owner2 should not hold the lock") + + // Second tryLock with different owner should fail + assertFalse(mutex.tryLock(key, owner2), "tryLock with different owner should fail") + assertTrue(mutex.holdsLock(key, owner1), "owner1 should still hold the lock") + assertFalse(mutex.holdsLock(key, owner2), "owner2 should still not hold the lock") + + mutex.unlock(key, owner1) + assertFalse(mutex.holdsLock(key, owner1), "owner1 should not hold lock after unlock") + } + + @Test + fun holdsLockConcurrentAccess() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + val owner = "test-owner" + + mutex.lock(key, owner) + + // holdsLock should be usable from concurrent contexts + val job = this.async { + mutex.holdsLock(key, owner) + } + + assertTrue(job.await(), "holdsLock should work from concurrent context") + mutex.unlock(key, owner) + } + + @Test + fun reentrantBehaviorShouldDeadlock() = runTest { + val mutex = KeyedMutex() + val key = "test-key" + var innerReached = false + + // This test verifies that the mutex is NOT re-entrant + // According to the documentation, it's "Not re-entrant for the same key within the same coroutine" + + val job = this.async { + mutex.withLock(key) { + try { + // This should deadlock since we're trying to lock the same key again + // in the same coroutine. We use a timeout to detect this. + mutex.withLock(key) { + innerReached = true + } + } catch (e: Exception) { + // May throw if timeout or cancellation occurs + } + } + } + + // Give it a short time to potentially complete + delay(100) + + // The job should still be running (deadlocked) + assertFalse(job.isCompleted, "Re-entrant lock should deadlock") + assertFalse(innerReached, "Inner block should not be reached") + + // Clean up by cancelling + job.cancel() + + // Verify the key gets cleaned up properly after cancellation + delay(10) + assertFalse(mutex.isLocked(key), "Key should be cleaned up after cancellation") + } +} diff --git a/a2a/a2a-server/build.gradle.kts b/a2a/a2a-server/build.gradle.kts index 68186fac97..5479619ece 100644 --- a/a2a/a2a-server/build.gradle.kts +++ b/a2a/a2a-server/build.gradle.kts @@ -27,6 +27,7 @@ kotlin { implementation(project(":a2a:a2a-test")) implementation(kotlin("test")) implementation(libs.kotlinx.coroutines.test) + implementation(libs.kotest.assertions) } } @@ -41,7 +42,7 @@ kotlin { implementation(libs.ktor.client.cio) implementation(libs.ktor.client.logging) implementation(libs.ktor.server.netty) - runtimeOnly(libs.slf4j.simple) + runtimeOnly(libs.logback.classic) } } diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt index 2cd19c9bd0..f045bcad05 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt @@ -1,7 +1,6 @@ package ai.koog.a2a.server import ai.koog.a2a.exceptions.A2AAuthenticatedExtendedCardNotConfiguredException -import ai.koog.a2a.exceptions.A2AInternalErrorException import ai.koog.a2a.exceptions.A2AInvalidParamsException import ai.koog.a2a.exceptions.A2APushNotificationNotSupportedException import ai.koog.a2a.exceptions.A2ATaskNotCancelableException @@ -25,13 +24,12 @@ import ai.koog.a2a.server.messages.InMemoryMessageStorage import ai.koog.a2a.server.messages.MessageStorage import ai.koog.a2a.server.notifications.PushNotificationConfigStorage import ai.koog.a2a.server.notifications.PushNotificationSender -import ai.koog.a2a.server.session.AgentSession import ai.koog.a2a.server.session.IdGenerator +import ai.koog.a2a.server.session.LazySession import ai.koog.a2a.server.session.RequestContext import ai.koog.a2a.server.session.SessionEventProcessor import ai.koog.a2a.server.session.SessionManager import ai.koog.a2a.server.session.UuidIdGenerator -import ai.koog.a2a.server.session.withTaskLock import ai.koog.a2a.server.tasks.ContextTaskStorage import ai.koog.a2a.server.tasks.InMemoryTaskStorage import ai.koog.a2a.server.tasks.TaskStorage @@ -39,19 +37,19 @@ import ai.koog.a2a.transport.Request import ai.koog.a2a.transport.RequestHandler import ai.koog.a2a.transport.Response import ai.koog.a2a.transport.ServerCallContext +import ai.koog.a2a.utils.KeyedMutex +import ai.koog.a2a.utils.withLock import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancel import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.channelFlow -import kotlinx.coroutines.flow.emitAll import kotlinx.coroutines.flow.first import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.last import kotlinx.coroutines.flow.map import kotlinx.coroutines.launch -import kotlinx.datetime.Clock /** * Default implementation of A2A server responsible for handling requests from A2A clients according to the @@ -177,14 +175,13 @@ import kotlinx.datetime.Clock * * override suspend fun cancel( * context: RequestContext, - * session: Session + * eventProcessor: SessionEventProcessor, + * agentJob: Deferred?, * ) { + * agentJob?.cancelAndJoin() * // Access user data for audit logging * val user = context.callContext.getFromStateOrNull(AuthStateKeys.USER) * log.info("Task ${context.taskId} canceled by user ${user?.id}") - * - * // Default cancellation behavior - * super.cancel(context, session) * } * } * ``` @@ -305,8 +302,6 @@ import kotlinx.datetime.Clock * @param pushSender Optional push notification sender implementation (defaults to `null`) * @param idGenerator Generator for new task and context IDs (defaults to [UuidIdGenerator]) * @param coroutineScope Scope for managing all sessions, agent jobs, event processing, etc. - * @param clock Clock instance for timestamp generation (defaults to [Clock.System]) - * @param sessionManager Manager for managing agent sessions (defaults to [SessionManager]) * * @see AgentExecutor for implementing agent business logic * @see TaskStorage for persisting tasks @@ -325,10 +320,21 @@ public open class A2AServer( protected val pushSender: PushNotificationSender? = null, protected val idGenerator: IdGenerator = UuidIdGenerator, protected val coroutineScope: CoroutineScope = CoroutineScope(SupervisorJob()), - protected val clock: Clock = Clock.System, ) : RequestHandler { + /** + * Mutex for locking specific tasks by their IDs. + */ + protected val tasksMutex: KeyedMutex = KeyedMutex() + + /** + * Special cancellation key for additional set of task cancellation locks. + */ + protected fun cancelKey(taskId: String): String = "cancel:$taskId" + protected open val sessionManager: SessionManager = SessionManager( coroutineScope = coroutineScope, + cancelKey = ::cancelKey, + tasksMutex = tasksMutex, taskStorage = taskStorage, pushConfigStorage = pushConfigStorage, pushSender = pushSender, @@ -361,29 +367,33 @@ public open class A2AServer( ctx: ServerCallContext ): Flow> = channelFlow { val message = request.data.message - val contextId = message.contextId ?: idGenerator.generateContextId(message) + + if (message.parts.isEmpty()) { + throw A2AInvalidParamsException("Empty message parts are not supported") + } + val taskId = message.taskId ?: idGenerator.generateTaskId(message) - val session = sessionManager.withTaskLock(taskId) { + val session = tasksMutex.withLock(taskId) { + // If there's a currently running session for the same task, wait for it to finish. + sessionManager.getSession(taskId)?.join() + // Check if message links to a task. val task: Task? = message.taskId?.let { taskId -> - // Check if the specified task exists and message context id matches the task context id. + // Check if the specified task exists val task = taskStorage.get(taskId, historyLength = 0, includeArtifacts = false) ?: throw A2ATaskNotFoundException("Task '$taskId' not found") - if (message.contextId != task.contextId) { - throw A2AInvalidParamsException("Message context id '${message.contextId}' doesn't match task context id '${task.contextId}'") - } - task } // Create event processor for the session based on the input data. val eventProcessor = SessionEventProcessor( - contextId = contextId, + contextId = task?.contextId + ?: message.contextId + ?: idGenerator.generateContextId(message), taskId = taskId, taskStorage = taskStorage, - task = task, ) // Create request context based on the request information. @@ -397,18 +407,13 @@ public open class A2AServer( task = task, ) - // Create agent execution session - AgentSession(coroutineScope, eventProcessor) { + LazySession( + coroutineScope = coroutineScope, + eventProcessor = eventProcessor, + ) { agentExecutor.execute(requestContext, eventProcessor) }.also { - try { - // Add to session manager, it will handle monitoring and closing once the session is completed (successfully or not). - sessionManager.addSession(it) - } catch (_: IllegalArgumentException) { - throw A2AUnsupportedOperationException( - "Task '${request.data.message.taskId}' is already running, can't send messages to the task that hasn't yielded control." - ) - } + sessionManager.addSession(it) } } @@ -421,7 +426,8 @@ public open class A2AServer( } // Start the session to execute the agent and wait for it to finish. - session.join() + // Using await here to propagate any exceptions thrown by the agent execution. + session.agentJob.await() } override suspend fun onSendMessage( @@ -432,31 +438,24 @@ public open class A2AServer( // Reusing streaming logic here, because it's essentially the same, only we need some particular event from the stream val eventStream = onSendMessageCommon(request, ctx) - return if (messageConfiguration?.blocking == true) { + val event = if (messageConfiguration?.blocking == true) { // If blocking is requested, attempt to wait for the last event, until the current turn of the agent execution is finished. - val lastEventResponse = eventStream.last() - - when (val eventData = lastEventResponse.data) { - is Message -> Response(data = eventData, id = lastEventResponse.id) - is TaskEvent -> - taskStorage - .get( - eventData.taskId, - historyLength = messageConfiguration.historyLength, - includeArtifacts = true - ) - ?.let { Response(data = it, id = lastEventResponse.id) } - ?: throw A2ATaskNotFoundException("Task '${eventData.taskId}' not found after the agent execution") - } + eventStream.last() } else { - // Else read the first event from the stream, check that it's a proper communication event and return it. - val firstEventResponse = eventStream.first() + eventStream.first() + } - when (val eventData = firstEventResponse.data) { - is Message -> Response(data = eventData, id = firstEventResponse.id) - is Task -> Response(data = eventData, id = firstEventResponse.id) - else -> throw A2AInternalErrorException("Got unexpected event type from the agent '${eventData::class.simpleName}'") - } + return when (val eventData = event.data) { + is Message -> Response(data = eventData, id = event.id) + is TaskEvent -> + taskStorage + .get( + eventData.taskId, + historyLength = messageConfiguration?.historyLength, + includeArtifacts = true + ) + ?.let { Response(data = it, id = event.id) } + ?: throw A2ATaskNotFoundException("Task '${eventData.taskId}' not found after the agent execution") } } @@ -465,7 +464,7 @@ public open class A2AServer( ctx: ServerCallContext ): Flow> = flow { checkStreamingSupport() - emitAll(onSendMessageCommon(request, ctx)) + onSendMessageCommon(request, ctx).collect(this) } override suspend fun onGetTask( @@ -488,58 +487,69 @@ public open class A2AServer( val taskParams = request.data val taskId = taskParams.id - sessionManager.withTaskLock(taskId) { - val session = sessionManager.getSession(taskParams.id) - - val task = taskStorage.get(taskParams.id, historyLength = 0, includeArtifacts = true) - ?: throw A2ATaskNotFoundException("Task '${taskParams.id}' not found") - - // Task is not running, check if it exists in the storage. - if (session == null) { - // Task exists but not running - check if it is already canceled. - if (task.status.state == TaskState.Canceled) { - return Response(data = task, id = request.id) - } - - // If the task is not canceled and in the terminal state, throw. - if (task.status.state.terminal) { + /* + Cancellation uses two lock levels. The first is the standard task lock. + If it’s already held by another request, ignore it because cancellation takes priority. + If it’s not held, acquire it to block new requests while the cancellation is in progress. + */ + val lockAcquired = tasksMutex.tryLock(taskId) + + return try { + /* + The second lock is a per-task cancellation lock. + It’s always taken during cancellation to serialize cancel operations and allow them to proceed even if the + regular task lock is held. It prevents overlapping cancels and delays session teardown so the event processor + isn’t closed immediately after the agent job is canceled. This allows the cancel handler to emit additional + cancellation events through the same processor and session, ensuring that existing subscribers receive all events. + */ + tasksMutex.withLock(cancelKey(taskId)) { + val session = sessionManager.getSession(taskParams.id) + + val task = taskStorage.get(taskParams.id, historyLength = 0, includeArtifacts = true) + ?: throw A2ATaskNotFoundException("Task '${taskParams.id}' not found") + + // Task is not running, check if it's already in a terminal state. + if (session == null && task.status.state.terminal) { throw A2ATaskNotCancelableException("Task '${taskParams.id}' is already in terminal state ${task.status.state}") } - } - val eventProcessor = session?.eventProcessor ?: SessionEventProcessor( - contextId = task.contextId, - taskId = task.id, - taskStorage = taskStorage, - task = task, - ) - - // Create request context based on the request information. - val requestContext = RequestContext( - callContext = ctx, - params = request.data, - taskStorage = ContextTaskStorage(eventProcessor.contextId, taskStorage), - messageStorage = ContextMessageStorage(eventProcessor.contextId, messageStorage), - contextId = eventProcessor.contextId, - taskId = eventProcessor.taskId, - task = task, - ) - - // Attempt to cancel the agent execution and wait until it's finished. - agentExecutor.cancel(requestContext, eventProcessor, session?.agentJob) + val eventProcessor = session?.eventProcessor ?: SessionEventProcessor( + contextId = task.contextId, + taskId = task.id, + taskStorage = taskStorage, + ) + + // Create request context based on the request information. + val requestContext = RequestContext( + callContext = ctx, + params = request.data, + taskStorage = ContextTaskStorage(eventProcessor.contextId, taskStorage), + messageStorage = ContextMessageStorage(eventProcessor.contextId, messageStorage), + contextId = eventProcessor.contextId, + taskId = eventProcessor.taskId, + task = task, + ) + + // Attempt to cancel the agent execution and wait until it's finished. + agentExecutor.cancel(requestContext, eventProcessor, session?.agentJob) + + // Return the final task state. + Response( + data = taskStorage.get(taskParams.id, historyLength = 0, includeArtifacts = true) + ?.also { + if (it.status.state != TaskState.Canceled) { + throw A2ATaskNotCancelableException("Task '${taskParams.id}' was not canceled successfully, current state is ${it.status.state}") + } + } + ?: throw A2ATaskNotFoundException("Task '${taskParams.id}' not found"), + id = request.id, + ) + } + } finally { + if (lockAcquired) { + tasksMutex.unlock(taskId) + } } - - // Return the final task state. - return Response( - data = taskStorage.get(taskParams.id, historyLength = 0, includeArtifacts = true) - ?.also { - if (it.status.state != TaskState.Canceled) { - throw A2ATaskNotCancelableException("Task '${taskParams.id}' was not canceled successfully, current state is ${it.status.state}") - } - } - ?: throw A2ATaskNotFoundException("Task '${taskParams.id}' not found"), - id = request.id, - ) } override fun onResubscribeTask( @@ -549,13 +559,11 @@ public open class A2AServer( checkStreamingSupport() val taskParams = request.data - val session = sessionManager.getSession(taskParams.id) - ?: throw A2AUnsupportedOperationException("Session for task '${taskParams.id}' is not currently running or task does not exist") + val session = sessionManager.getSession(taskParams.id) ?: return@flow - emitAll( - session.events - .map { event -> Response(data = event, id = request.id) } - ) + session.events + .map { event -> Response(data = event, id = request.id) } + .collect(this) } override suspend fun onSetTaskPushNotificationConfig( diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt index e163178c75..bd6e9c4a69 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/agent/AgentExecutor.kt @@ -10,7 +10,7 @@ import ai.koog.a2a.model.TaskIdParams import ai.koog.a2a.model.TaskState import ai.koog.a2a.server.session.RequestContext import ai.koog.a2a.server.session.SessionEventProcessor -import kotlinx.coroutines.Job +import kotlinx.coroutines.Deferred /** * Implementations of this interface contain the core logic of the agent, @@ -98,7 +98,7 @@ public interface AgentExecutor { * Example implementation: * ```kotlin * // Cancel agent execution job, if the agent is currently running, to terminate it. - * agentJob?.cancel() + * agentJob?.cancelAndJoin() * // Send task cancellation event with custom message to event processor * eventProcessor.sendTaskEvent( * TaskStatusUpdateEvent( @@ -122,7 +122,7 @@ public interface AgentExecutor { * * @param context The context containing the necessary information and accessors for executing the agent. * @param eventProcessor The event processor to publish events to. - * @param agentJob Optional [Job] executing the agent logic, if the agent is currently running. + * @param agentJob Optional job executing the agent logic, if the agent is currently running. * @throws Exception if something goes wrong during execution or the cancellation is impossible. Should prefer more * specific exceptions if possible, e.g., [A2ATaskNotCancelableException], [A2AUnsupportedOperationException], etc. * See full list of available A2A exceptions in [ai.koog.a2a.exceptions]. @@ -130,7 +130,7 @@ public interface AgentExecutor { public suspend fun cancel( context: RequestContext, eventProcessor: SessionEventProcessor, - agentJob: Job?, + agentJob: Deferred?, ) { throw A2ATaskNotCancelableException("Cancellation is not supported") } diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/exceptions/Exceptions.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/exceptions/Exceptions.kt index ed65db9868..62d54a0a6a 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/exceptions/Exceptions.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/exceptions/Exceptions.kt @@ -1,8 +1,5 @@ package ai.koog.a2a.server.exceptions -import ai.koog.a2a.server.session.Session -import ai.koog.a2a.server.session.SessionEventProcessor - /** * Indicates an error with task-related operations. */ @@ -14,16 +11,16 @@ public class TaskOperationException(message: String, cause: Throwable? = null) : public class MessageOperationException(message: String, cause: Throwable? = null) : Exception(message, cause) /** - * Indicates a failure in sending an event through the [SessionEventProcessor] because of invalid event. + * Indicates a failure in sending an event because it was invalid. */ public class InvalidEventException(message: String, cause: Throwable? = null) : Exception(message, cause) /** - * An exception that is thrown to indicate errors occurring during push notification operations. + * Indicates errors occurring during push notification operations. */ public class PushNotificationException(message: String, cause: Throwable? = null) : Exception(message, cause) /** - * An exception that is thrown to indicate that a [Session] has been closed. + * Indicates a session is not in the active state. */ -public class SessionClosedException(message: String, cause: Throwable? = null) : Exception(message, cause) +public class SessionNotActiveException(message: String, cause: Throwable? = null) : Exception(message, cause) diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt index f8ed81dcfc..305f18cf07 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt @@ -3,24 +3,22 @@ package ai.koog.a2a.server.session import ai.koog.a2a.model.Event import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.CoroutineStart -import kotlinx.coroutines.Job +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.async import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.collect -import kotlinx.coroutines.launch /** - * Represents an active agent execution session with lifecycle management. + * Represents a session with lifecycle management. * * @property eventProcessor The session event processor - * @property agentJob The job executing the agent logic - * @property contextId Unique context ID associated with this session - * @property taskId Unique task ID associated with this session + * @property agentJob The execution process associated with this session's execution * @property events A stream of events generated during this session */ public class Session( public val eventProcessor: SessionEventProcessor, - public val agentJob: Job + public val agentJob: Deferred ) { public val contextId: String get() = eventProcessor.contextId public val taskId: String get() = eventProcessor.taskId @@ -34,42 +32,47 @@ public class Session( } /* - * Suspends until the session, i.e., agent job and event stream, complete. + * Suspends until the session, i.e., event stream and agent job, complete. + * Waits for the event stream to finish first, to avoid triggering the agent job prematurely. + * Assumes that by the time event stream is finished, agent job will already be completed or canceled. */ public suspend fun join() { - agentJob.join() events.collect() + agentJob.join() + } + + /** + * [start] and then [join] the session. + */ + public suspend fun startAndJoin() { + start() + join() } /** - * Cancels the agent job, waiting for it to complete, and then closes event processor. + * Cancels the execution process, waiting for it to complete, and then closes event processor. */ - public suspend fun cancel() { + public suspend fun cancelAndJoin() { agentJob.cancelAndJoin() eventProcessor.close() } } /** - * Factory function that creates a new [Session] with lazy-started [agentAction]. + * Creates an instance of [Session] with lazily started [Session.agentJob] * - * @param coroutineScope The scope for launching the agent coroutine + * @param coroutineScope The coroutine scope to use for running the [block] * @param eventProcessor The session event processor - * @param agentAction The agent logic to execute - * @return A new session instance + * @param block The block to be executed */ @Suppress("ktlint:standard:function-naming", "FunctionName") -public fun AgentSession( +public fun LazySession( coroutineScope: CoroutineScope, eventProcessor: SessionEventProcessor, - agentAction: suspend CoroutineScope.() -> Unit + block: suspend CoroutineScope.() -> Unit ): Session { - val agentJob = coroutineScope.launch(start = CoroutineStart.LAZY) { - agentAction() - } - return Session( eventProcessor = eventProcessor, - agentJob = agentJob + agentJob = coroutineScope.async(start = CoroutineStart.LAZY, block = block) ) } diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt index cdc9c3765f..ce9148ec38 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt @@ -3,59 +3,53 @@ package ai.koog.a2a.server.session import ai.koog.a2a.model.Event import ai.koog.a2a.model.Message import ai.koog.a2a.model.Task -import ai.koog.a2a.model.TaskArtifactUpdateEvent import ai.koog.a2a.model.TaskEvent -import ai.koog.a2a.model.TaskState import ai.koog.a2a.model.TaskStatusUpdateEvent import ai.koog.a2a.server.exceptions.InvalidEventException -import ai.koog.a2a.server.exceptions.SessionClosedException +import ai.koog.a2a.server.exceptions.SessionNotActiveException import ai.koog.a2a.server.tasks.TaskStorage import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.MutableSharedFlow -import kotlinx.coroutines.flow.MutableStateFlow -import kotlinx.coroutines.flow.emitAll +import kotlinx.coroutines.flow.emptyFlow import kotlinx.coroutines.flow.filterIsInstance import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map import kotlinx.coroutines.flow.takeWhile import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock +import kotlin.concurrent.atomics.AtomicBoolean +import kotlin.concurrent.atomics.ExperimentalAtomicApi import kotlin.jvm.JvmInline /** * A session processor responsible for handling session events. - * It validates the events, emits them to the subscribers via [events] and updates session state. - * All valid [TaskEvent] events that are sent using [sendTaskEvent] will also be saved to the [taskStorage] provided. + * It validates the events, writes them to [taskStorage] and emits them to the subscribers via [events]. * - * Validation logic attempts to verify that the number, type and order of events comply to what is expected from a proper - * A2A server implementation. + * Validation logic attempts to perform basic verification that events follow what is expected from a proper A2A server implementation. * These are the main rules: * - * - **Session type exclusivity**: A session can only handle either [Message] events or [TaskEvent] events, never both - * - **Context ID validation**: All events must have the same contextId as the session - * - **Single message limit**: Only one [Message] can be sent per session, after which the session becomes terminal - * - **Task initialization order**: For new tasks, the first [TaskEvent] must be of type [Task] to create the task - * - **Task ID consistency**: [TaskEvent] events must have task ids equal to [taskId] provided for this session. - * - **Final event enforcement**: After a [TaskStatusUpdateEvent] with `final=true` is sent, no more events are permitted - * - **Terminal state blocking**: No events can be sent when the task is already in a terminal state - * - **Final flag requirement**: [TaskStatusUpdateEvent]s that set the task to a terminal state must have `final=true` + * - **Session type exclusivity**: A session can only handle either [Message] event or [TaskEvent] events, never both. + * - **Context ID validation**: All events must have the same context id as provided [contextId]. + * - **Single message limit**: Only one [Message] can be sent per session, after which the processor closes. + * - **Task ID consistency**: [TaskEvent] events must have task ids equal to provided [taskId]. + * - **Final event enforcement**: After a [TaskStatusUpdateEvent] with `final=true` is sent, processor closes. + * - **Terminal state closure**: When the event with the terminal state is sent, processor closes. * + * @param taskStorage The storage where task events will be saved. * @property contextId The contextId associated with this session, representing either an existing context * from the incoming request or a newly generated ID that must be used for all events in this session. * @property taskId The taskId associated with this session, representing either an existing task * from the incoming request or a newly generated ID that must be used if creating a new task. * Note: This taskId might not correspond to an actually existing task initially - it serves as the * identifier that will be validated against all [TaskEvent] in this session. - * @param taskStorage The storage for tasks where task events will be saved. - * @param task The initial task associated with the session, if it is a continuation of a previous task session. - * + * @property isOpen Whether the session is open. * @property events A hot flow of events in this session that can be subscribed to. */ +@OptIn(ExperimentalAtomicApi::class) public class SessionEventProcessor( public val contextId: String, public val taskId: String, private val taskStorage: TaskStorage, - private val task: Task? = null, ) { private companion object { private const val SESSION_CLOSED = "Session event processor is closed, can't send events" @@ -64,195 +58,106 @@ public class SessionEventProcessor( private const val INVALID_TASK_ID = "Event taskId must be same as provided taskId" - private const val MESSAGE_SENT = - "Message has already been sent in this session. Sending message is a terminal operation and no more events " + - "are allowed to be sent, the session must terminate ASAP" - - private const val TASK_INITIALIZED = - "Task has already been initialized in this sessions, only TaskEvent's with the same taskId can be sent from now on." - - private const val TASK_EVENT_FINAL_SENT = - "Final TaskEvent has already been sent in this session. Sending final event is a terminal operation " + - "and no more events are allowed to be sent, the session must terminate ASAP" - - private const val TASK_EVENT_TERMINAL_STATE = - "TaskEvent's cannot be sent when the task transitioned to the terminal state." - - private const val TASK_EVENT_FINAL_REQUIRED = - "TaskEvent final parameter is required to be set to 'true' when setting task state to the terminal state" - - private const val TASK_DOES_NOT_EXIST = - "Task associated with the taskId in TaskEvent does not exist yet and the event was not Task. Creating new " + - "task should always start with Task event." + private const val TASK_EVENT_SENT = + "Task has already been initialized in this session, only TaskEvent's with the same taskId can be sent from now on" } + private val _isOpen: AtomicBoolean = AtomicBoolean(true) + public val isOpen: Boolean get() = _isOpen.load() + /** - * Helper interface to handle different session types. + * Tracks whether a task event was sent in this session, meaning we have to reject [Message] events now. */ - private sealed interface SessionType { - object MessageSession : SessionType + private var isTaskEventSent: Boolean = false - class TaskSession( - val taskId: String, - var taskState: TaskState? = null, - var finalEventReceived: Boolean = false, - ) : SessionType - } + private val sessionMutex = Mutex() /** - * Helper interface to send actual events or termination signal to cancel events stream on session closure. + * Helper interface to send actual events or termination signal to close current event stream subscribers on session closure. */ private sealed interface FlowEvent { @JvmInline value class Data(val data: Event) : FlowEvent - object Cancel : FlowEvent + object Close : FlowEvent } - private val isClosed = MutableStateFlow(false) - private val _events = MutableSharedFlow() public val events: Flow get() = flow { - if (!isClosed.value) { - emitAll( - _events - .takeWhile { !isClosed.value } - .filterIsInstance() - .map { it.data } - ) - } + if (isOpen) { + _events + .takeWhile { it !is FlowEvent.Close } + .filterIsInstance() + .map { it.data } + } else { + emptyFlow() + }.collect(this) } - private val sessionMutex = Mutex() - private var sessionType: SessionType? = task?.let { - SessionType.TaskSession( - taskId = it.id, - taskState = it.status.state - ) - } - /** * Sends a [Message] to the session event processor. Validates the message against the session context and updates * the session state accordingly. * * @param message The message to be sent. * @throws [InvalidEventException] for invalid events. - * Check [SessionEventProcessor] docs from info about valid events. + * @see SessionEventProcessor */ public suspend fun sendMessage(message: Message): Unit = sessionMutex.withLock { - if (isClosed.value) { - throw SessionClosedException(SESSION_CLOSED) - } - - if (message.contextId != contextId) { - throw InvalidEventException(INVALID_CONTEXT_ID) - } - - when (sessionType) { - is SessionType.MessageSession -> throw InvalidEventException(MESSAGE_SENT) - - is SessionType.TaskSession -> throw InvalidEventException(TASK_INITIALIZED) + if (_isOpen.load()) { + if (isTaskEventSent) { + throw InvalidEventException(TASK_EVENT_SENT) + } - null -> { - _events.emit(FlowEvent.Data(message)) - sessionType = SessionType.MessageSession + if (message.contextId != this.contextId) { + throw InvalidEventException(INVALID_CONTEXT_ID) } + + _events.emit(FlowEvent.Data(message)) + _isOpen.store(false) + } else { + throw SessionNotActiveException(SESSION_CLOSED) } } /** - * Sends a [TaskEvent] to the session event processor. Validates the event against the session context and updates - * the session state and [taskStorage] accordingly. + * Sends a [TaskEvent] to the session event processor. + * Validates the event against the session context and updates [taskStorage]. * * @param event The event to be sent. * @throws [InvalidEventException] for invalid events. - * Check [SessionEventProcessor] docs from info about valid events. + * @see SessionEventProcessor */ public suspend fun sendTaskEvent(event: TaskEvent): Unit = sessionMutex.withLock { - if (isClosed.value) { - throw SessionClosedException(SESSION_CLOSED) - } - - if (event.contextId != contextId) { - throw InvalidEventException(INVALID_CONTEXT_ID) - } - - if (event.taskId != taskId) { - throw InvalidEventException(INVALID_TASK_ID) - } - - /* - The first set of checks, to get initial task session type if it is allowed here. - */ - val taskSessionType: SessionType.TaskSession = when (sessionType) { - is SessionType.MessageSession -> throw InvalidEventException(MESSAGE_SENT) - - is SessionType.TaskSession -> sessionType as SessionType.TaskSession + if (_isOpen.load()) { + isTaskEventSent = true - null -> { - SessionType.TaskSession( - taskId = event.taskId, - taskState = task?.status?.state, // null - new task - finalEventReceived = false - ).also { - sessionType = it - } + if (event.contextId != this.contextId) { + throw InvalidEventException(INVALID_CONTEXT_ID) } - } - - /* - The second set of checks to check various aspects of the current task and session state and guide the user to emit - only allowed events. - */ - when { - /** - * If the task does not exist yet, the first [TaskEvent] should be only of type Task, to create the task itself - */ - taskSessionType.taskState == null && event !is Task -> - throw InvalidEventException(TASK_DOES_NOT_EXIST) - /** - * If there was already a [TaskStatusUpdateEvent] with [TaskStatusUpdateEvent.final] set to true, no more events are expected - */ - taskSessionType.finalEventReceived -> - throw InvalidEventException(TASK_EVENT_FINAL_SENT) - - /** - * If the task is already in a terminal state, no more events are expected - */ - taskSessionType.taskState?.terminal == true -> - throw InvalidEventException(TASK_EVENT_TERMINAL_STATE) - - /** - * If the event is a [TaskStatusUpdateEvent] attempting to set a task to a terminal state, - * then [TaskStatusUpdateEvent.final] must be set to true - */ - event is TaskStatusUpdateEvent && event.status.state.terminal && !event.final -> - throw InvalidEventException(TASK_EVENT_FINAL_REQUIRED) - } - - // Only if all checks passed, attempt to update and emit the event - taskStorage.update(event) - _events.emit(FlowEvent.Data(event)) - - when (event) { - is TaskStatusUpdateEvent -> taskSessionType.apply { - taskState = event.status.state - finalEventReceived = event.final + if (event.taskId != this.taskId) { + throw InvalidEventException(INVALID_TASK_ID) } - is Task -> taskSessionType.apply { - taskState = event.status.state - } + taskStorage.update(event) + _events.emit(FlowEvent.Data(event)) + + val isFinalEvent = (event is TaskStatusUpdateEvent && (event.status.state.terminal || event.final)) || + (event is Task && event.status.state.terminal) - is TaskArtifactUpdateEvent -> { - // do nothing, condition is left here for clarity + if (isFinalEvent) { + _isOpen.store(false) } + } else { + throw SessionNotActiveException(SESSION_CLOSED) } } + /** + * Closes the session event processor, also closing event stream. + */ public suspend fun close(): Unit = sessionMutex.withLock { - isClosed.value = true - _events.emit(FlowEvent.Cancel) + _isOpen.store(false) + _events.emit(FlowEvent.Close) } } diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt index 64750a4361..77b2427f35 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt @@ -5,18 +5,15 @@ import ai.koog.a2a.model.TaskEvent import ai.koog.a2a.server.notifications.PushNotificationConfigStorage import ai.koog.a2a.server.notifications.PushNotificationSender import ai.koog.a2a.server.tasks.TaskStorage +import ai.koog.a2a.utils.KeyedMutex import ai.koog.a2a.utils.RWLock +import ai.koog.a2a.utils.withLock import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.flow.firstOrNull import kotlinx.coroutines.launch -import kotlinx.coroutines.sync.Mutex -import kotlinx.coroutines.sync.withLock -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.InvocationKind -import kotlin.contracts.contract /** - * Manages a set of active instances of [Session], sends push notifications if configured after each session completes. + * Manages a set of active instances of [LazySession], sends push notifications if configured after each session completes. * Automatically closes and removes the session when agent job is completed (whether successfully or not). * * Additionally, if push notifications are configured, after each task session completes, push notifications are sent with @@ -25,6 +22,7 @@ import kotlin.contracts.contract * Provides the ability to lock a task id. * * @param coroutineScope The scope in which the monitoring jobs will be launched. + * @param tasksMutex The mutex for locking specific task ids. * @param taskStorage The storage for tasks. * @param pushConfigStorage The storage for push notification configurations. * @param pushSender The push notification sender. @@ -32,6 +30,8 @@ import kotlin.contracts.contract @OptIn(InternalA2AApi::class) public class SessionManager( private val coroutineScope: CoroutineScope, + private val tasksMutex: KeyedMutex, + private val cancelKey: (String) -> String, private val taskStorage: TaskStorage, private val pushConfigStorage: PushNotificationConfigStorage? = null, private val pushSender: PushNotificationSender? = null, @@ -43,9 +43,6 @@ public class SessionManager( private val sessions = mutableMapOf() private val sessionsRwLock = RWLock() - private val taskMutexes = mutableMapOf() - private val taskMutexesLock = Mutex() - /** * Adds a session to a set of active sessions. * Handles cleanup by closing and removing the session when it is completed (whether successfully or not). @@ -57,7 +54,7 @@ public class SessionManager( public suspend fun addSession(session: Session) { sessionsRwLock.withWriteLock { check(session.taskId !in sessions) { - "SessionEventProcessor for taskId '${session.taskId}' already exists." + "Session for taskId '${session.taskId}' already runs." } sessions[session.taskId] = session @@ -71,23 +68,29 @@ public class SessionManager( session.agentJob.join() /* - Check and wait if the task lock is free (e.g., there's a cancellation request for this task running now and still publishing some events). - Then remove it from the sessions map. + Check and wait if there's a cancellation request for this task running now and still publishing some events. + Then remove it from the session map. */ - withTaskLock(session.taskId) { + tasksMutex.withLock(cancelKey(session.taskId)) { sessionsRwLock.withWriteLock { - session.cancel() sessions -= session.taskId + session.cancelAndJoin() } } // Send push notifications with the current state of the task, after the session completion, if configured. - if (firstEvent is TaskEvent && pushSender != null && pushConfigStorage != null) { - val task = taskStorage.get(session.taskId, historyLength = 0, includeArtifacts = false) - - if (task != null) { - pushConfigStorage.getAll(session.taskId).forEach { config -> - pushSender.send(config, task) + coroutineScope.launch { + if (firstEvent is TaskEvent && pushSender != null && pushConfigStorage != null) { + val task = taskStorage.get(session.taskId, historyLength = 0, includeArtifacts = false) + + if (task != null) { + pushConfigStorage.getAll(session.taskId).forEach { config -> + try { + pushSender.send(config, task) + } catch (e: Exception) { + // TODO log error + } + } } } } @@ -107,74 +110,4 @@ public class SessionManager( public suspend fun activeSessions(): Int = sessionsRwLock.withReadLock { sessions.size } - - /** - * Acquires a lock for the specified task ID. - * Useful for maintaining concurrency safety in task-related operations. - * - * @param taskId The unique identifier of the task to be locked. - */ - public suspend fun taskLock(taskId: String) { - val mutex = taskMutexesLock.withLock { - taskMutexes.getOrPut(taskId) { Mutex() } - } - mutex.lock() - } - - /** - * Releases the lock for the specified task ID. - * Useful for maintaining concurrency safety in task-related operations. - * - * @param taskId The unique identifier of the task to be unlocked. - * @throws IllegalStateException if the lock for the task cannot be released. - */ - public suspend fun taskUnlock(taskId: String) { - val mutex = taskMutexesLock.withLock { - taskMutexes[taskId] - } ?: throw IllegalStateException("Task '$taskId' was never locked") - - if (!mutex.isLocked) { - throw IllegalStateException("Task '$taskId' is not currently locked") - } - - mutex.unlock() - - // Clean up unused mutexes - taskMutexesLock.withLock { - if (!mutex.isLocked && taskMutexes[taskId] === mutex) { - taskMutexes.remove(taskId) - } - } - } - - /** - * Returns true if the task ID is locked, false otherwise. - */ - public suspend fun isTaskLocked(taskId: String): Boolean { - return taskMutexesLock.withLock { - taskMutexes[taskId]?.isLocked == true - } - } -} - -/** - * Executes the given block of code while holding a lock for the specified task ID. - * Useful for maintaining concurrency safety in task-related operations. - * - * @param taskId The ID of the task to be locked. - * @param action The block of code to be executed. - * @return The result of [action] - */ -@OptIn(ExperimentalContracts::class) -public suspend inline fun SessionManager.withTaskLock(taskId: String, action: suspend () -> T): T { - contract { - callsInPlace(action, InvocationKind.EXACTLY_ONCE) - } - - taskLock(taskId) - return try { - action() - } finally { - taskUnlock(taskId) - } } diff --git a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionEventProcessorTest.kt b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionEventProcessorTest.kt index f6421e4b76..462b158f65 100644 --- a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionEventProcessorTest.kt +++ b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionEventProcessorTest.kt @@ -11,7 +11,7 @@ import ai.koog.a2a.model.TaskStatus import ai.koog.a2a.model.TaskStatusUpdateEvent import ai.koog.a2a.model.TextPart import ai.koog.a2a.server.exceptions.InvalidEventException -import ai.koog.a2a.server.exceptions.SessionClosedException +import ai.koog.a2a.server.exceptions.SessionNotActiveException import ai.koog.a2a.server.tasks.InMemoryTaskStorage import kotlinx.coroutines.flow.lastOrNull import kotlinx.coroutines.launch @@ -22,6 +22,7 @@ import kotlin.test.BeforeTest import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFailsWith +import kotlin.test.assertFalse import kotlin.test.assertNull import kotlin.time.Duration.Companion.seconds @@ -66,12 +67,10 @@ class SessionEventProcessorTest { private fun createProcessor( contextId: String, taskId: String, - task: Task? = null - ) = SessionEventProcessor( + ): SessionEventProcessor = SessionEventProcessor( contextId = contextId, taskId = taskId, taskStorage = taskStorage, - task = task ) @Test @@ -94,11 +93,10 @@ class SessionEventProcessorTest { processor.sendMessage(message1) - assertFailsWith { + assertFailsWith { processor.sendMessage(message2) } - - processor.close() + assertFalse(processor.isOpen) } @Test @@ -109,11 +107,10 @@ class SessionEventProcessorTest { processor.sendMessage(message) - assertFailsWith { + assertFailsWith { processor.sendTaskEvent(task) } - - processor.close() + assertFalse(processor.isOpen) } // Task session tests @@ -154,7 +151,7 @@ class SessionEventProcessorTest { taskStorage.update(existingTask) // Create processor with existing task - val processor = createProcessor(contextId, taskId, existingTask) + val processor = createProcessor(contextId, taskId) val statusUpdate = TaskStatusUpdateEvent( taskId = taskId, @@ -183,23 +180,6 @@ class SessionEventProcessorTest { processor.close() } - @Test - fun task_testSendNonTaskEventForNewTaskFails() = runTest(timeout = TEST_TIMEOUT) { - val processor = createProcessor(contextId, taskId) - val statusUpdate = TaskStatusUpdateEvent( - taskId = taskId, - contextId = contextId, - status = TaskStatus(state = TaskState.Working), - final = false - ) - - assertFailsWith { - processor.sendTaskEvent(statusUpdate) - } - - processor.close() - } - @Test fun task_testSendEventAfterFinalEventFails() = runTest(timeout = TEST_TIMEOUT) { val processor = createProcessor(contextId, taskId) @@ -223,56 +203,10 @@ class SessionEventProcessorTest { processor.sendTaskEvent(task) processor.sendTaskEvent(finalStatusUpdate) - assertFailsWith { + assertFailsWith { processor.sendTaskEvent(anotherEvent) } - - processor.close() - } - - @Test - fun task_testSendEventWhenTaskInTerminalStateFails() = runTest(timeout = TEST_TIMEOUT) { - // Create task in terminal state - val completedTask = createTask(taskId, contextId, TaskState.Completed) - taskStorage.update(completedTask) - - val processor = createProcessor(contextId, taskId) - - val artifactEvent = TaskArtifactUpdateEvent( - taskId = taskId, - contextId = contextId, - artifact = Artifact( - artifactId = "artifact-1", - parts = listOf(TextPart("content")) - ), - append = false - ) - - assertFailsWith { - processor.sendTaskEvent(artifactEvent) - } - - processor.close() - } - - @Test - fun task_testTerminalStatusUpdateWithoutFinalFlagFails() = runTest(timeout = TEST_TIMEOUT) { - val processor = createProcessor(contextId, taskId) - val task = createTask(taskId, contextId) - val statusUpdate = TaskStatusUpdateEvent( - taskId = taskId, - contextId = contextId, - status = TaskStatus(state = TaskState.Failed), - final = false // This should be true for terminal state - ) - - processor.sendTaskEvent(task) - - assertFailsWith { - processor.sendTaskEvent(statusUpdate) - } - - processor.close() + assertFalse(processor.isOpen) } @Test @@ -407,11 +341,12 @@ class SessionEventProcessorTest { // Close processor and then attempt to send more events processor.close() - assertFailsWith("Should not be possible to send events to closed session") { + assertFailsWith("Should not be possible to send events to closed session") { processor.sendMessage(message2) } assertNull(processor.events.lastOrNull(), "Events stream should be empty after closing") + assertFalse(processor.isOpen) } @Test @@ -432,10 +367,11 @@ class SessionEventProcessorTest { final = false ) - assertFailsWith("Should not be possible to send events to closed session") { + assertFailsWith("Should not be possible to send events to closed session") { processor.sendTaskEvent(workingUpdate) } assertNull(processor.events.lastOrNull(), "Events stream should be empty after closing") + assertFalse(processor.isOpen) } } diff --git a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionManagerTest.kt b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionManagerTest.kt index 4f04ca3f40..326ea114ae 100644 --- a/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionManagerTest.kt +++ b/a2a/a2a-server/src/commonTest/kotlin/ai/koog/a2a/server/session/SessionManagerTest.kt @@ -11,20 +11,19 @@ import ai.koog.a2a.model.TextPart import ai.koog.a2a.server.notifications.InMemoryPushNotificationConfigStorage import ai.koog.a2a.server.notifications.PushNotificationSender import ai.koog.a2a.server.tasks.InMemoryTaskStorage +import ai.koog.a2a.utils.KeyedMutex import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.currentCoroutineContext import kotlinx.coroutines.delay +import kotlinx.coroutines.job import kotlinx.coroutines.joinAll -import kotlinx.coroutines.launch import kotlinx.coroutines.test.runTest import kotlinx.coroutines.yield import kotlinx.datetime.Instant import kotlin.test.BeforeTest import kotlin.test.Test import kotlin.test.assertEquals -import kotlin.test.assertFailsWith -import kotlin.test.assertFalse import kotlin.test.assertNull -import kotlin.test.assertTrue import kotlin.time.Duration.Companion.seconds class SessionManagerTest { @@ -74,25 +73,25 @@ class SessionManagerTest { contextId = contextId, status = TaskStatus( state = state, - timestamp = Instant.Companion.parse("2023-01-01T10:00:00Z") + timestamp = Instant.parse("2023-01-01T10:00:00Z") ) ) private fun createProcessor( contextId: String, taskId: String, - task: Task? = null ) = SessionEventProcessor( contextId = contextId, taskId = taskId, taskStorage = taskStorage, - task = task ) private fun createManager( coroutineScope: CoroutineScope, ) = SessionManager( coroutineScope = coroutineScope, + cancelKey = { "cancel:$it" }, + tasksMutex = KeyedMutex(), taskStorage = taskStorage, pushConfigStorage = pushConfigStorage, pushSender = pushSender, @@ -100,10 +99,7 @@ class SessionManagerTest { @Test fun testSessionManagerCreation() = runTest(timeout = TEST_TIMEOUT) { - val sessionManager = SessionManager( - coroutineScope = this, - taskStorage = taskStorage - ) + val sessionManager = createManager(this) assertEquals(0, sessionManager.activeSessions()) assertNull(sessionManager.getSession("any-task-id")) @@ -116,7 +112,7 @@ class SessionManagerTest { val message = createMessage("msg-1", contextId, "Hello") - val session = AgentSession( + val session = LazySession( coroutineScope = this, eventProcessor = eventProcessor ) { @@ -125,7 +121,7 @@ class SessionManagerTest { // Start session and wait for completion sessionManager.addSession(session) - session.join() + session.startAndJoin() // Let the session manager process it yield() @@ -139,7 +135,7 @@ class SessionManagerTest { val sessionManager = createManager(this) val eventProcessor = createProcessor(contextId, taskId) - val session = AgentSession( + val session = LazySession( coroutineScope = this, eventProcessor = eventProcessor ) { @@ -166,7 +162,7 @@ class SessionManagerTest { assertEquals(session, sessionManager.getSession(taskId)) - session.join() + session.startAndJoin() // Let the session manager process it yield() @@ -183,7 +179,7 @@ class SessionManagerTest { val eventProcessor1 = createProcessor("context-1", "task-1") val eventProcessor2 = createProcessor("context-2", "task-2") - val session1 = AgentSession( + val session1 = LazySession( coroutineScope = this, eventProcessor = eventProcessor1 ) { @@ -202,7 +198,7 @@ class SessionManagerTest { eventProcessor1.sendTaskEvent(statusUpdate) } - val session2 = AgentSession( + val session2 = LazySession( coroutineScope = this, eventProcessor = eventProcessor2 ) { @@ -256,7 +252,7 @@ class SessionManagerTest { val task = createTask("task-1", contextId) - val session = AgentSession( + val session = LazySession( coroutineScope = this, eventProcessor = eventProcessor ) { @@ -272,10 +268,10 @@ class SessionManagerTest { } sessionManager.addSession(session) - session.join() + session.startAndJoin() - // Let the session manager process it - yield() + // Let all coroutines finish, so that the push notifications background job is definitely completed. + currentCoroutineContext().job.children.toList().joinAll() // Verify push notification was sent assertEquals(1, pushSender.sentNotifications.size) @@ -283,129 +279,4 @@ class SessionManagerTest { assertEquals(config, sentConfig) assertEquals(TaskState.Completed, sentTask.status.state) } - - @Test - fun testTaskLockMultipleTasks() = runTest { - val sessionManager = createManager(this) - - val taskId1 = "test-task-1" - val taskId2 = "test-task-2" - - // Lock both tasks - sessionManager.taskLock(taskId1) - sessionManager.taskLock(taskId2) - - assertTrue(sessionManager.isTaskLocked(taskId1)) - assertTrue(sessionManager.isTaskLocked(taskId2)) - - // Unlock first task - sessionManager.taskUnlock(taskId1) - assertFalse(sessionManager.isTaskLocked(taskId1)) - assertTrue(sessionManager.isTaskLocked(taskId2)) - - // Unlock second task - sessionManager.taskUnlock(taskId2) - assertFalse(sessionManager.isTaskLocked(taskId2)) - } - - @Test - fun testConcurrentTaskLocking() = runTest { - val sessionManager = createManager(this) - val taskId = "concurrent-task" - val results = mutableListOf() - - // First coroutine locks the task - val job1 = launch { - sessionManager.taskLock(taskId) - results.add("job1-locked") - delay(100) // Hold the lock for some time - results.add("job1-working") - sessionManager.taskUnlock(taskId) - results.add("job1-unlocked") - } - - // Second coroutine tries to lock the same task - val job2 = launch { - delay(50) // Start after job1 has locked - results.add("job2-attempting-lock") - sessionManager.taskLock(taskId) // Should wait for job1 to unlock - results.add("job2-locked") - sessionManager.taskUnlock(taskId) - results.add("job2-unlocked") - } - - joinAll(job1, job2) - - // Verify the order of execution - assertEquals( - listOf( - "job1-locked", - "job2-attempting-lock", - "job1-working", - "job1-unlocked", - "job2-locked", - "job2-unlocked" - ), - results - ) - } - - @Test - fun testUnlockNeverLockedTaskThrowsException() = runTest { - val sessionManager = createManager(this) - val taskId = "never-locked-task" - - val exception = assertFailsWith { - sessionManager.taskUnlock(taskId) - } - - assertEquals("Task '$taskId' was never locked", exception.message) - } - - @Test - fun testUnlockAlreadyUnlockedTaskThrowsException() = runTest { - val sessionManager = createManager(this) - val taskId = "already-unlocked-task" - - // Lock and unlock the task - sessionManager.taskLock(taskId) - sessionManager.taskUnlock(taskId) - - // Try to unlock again - val exception = assertFailsWith { - sessionManager.taskUnlock(taskId) - } - - assertEquals("Task '$taskId' was never locked", exception.message) - } - - @Test - fun testSameLockMultipleTimes() = runTest { - val sessionManager = createManager(this) - val taskId = "same-lock-task" - - // First lock - sessionManager.taskLock(taskId) - assertTrue(sessionManager.isTaskLocked(taskId)) - - // Trying to lock the same task again should suspend indefinitely - // We'll test this with a timeout - val job = launch { - sessionManager.taskLock(taskId) // This should suspend - } - - delay(100) // Give some time for the second lock attempt - assertTrue(job.isActive) // Job should still be waiting - - // Unlock the first lock - sessionManager.taskUnlock(taskId) - - // Now the second lock should proceed - job.join() - assertTrue(sessionManager.isTaskLocked(taskId)) - - // Unlock the second lock - sessionManager.taskUnlock(taskId) - assertFalse(sessionManager.isTaskLocked(taskId)) - } } diff --git a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/A2AServerJsonRpcIntegrationTest.kt b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/A2AServerJsonRpcIntegrationTest.kt index aaad141b91..4a93328e12 100644 --- a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/A2AServerJsonRpcIntegrationTest.kt +++ b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/A2AServerJsonRpcIntegrationTest.kt @@ -6,35 +6,64 @@ import ai.koog.a2a.consts.A2AConsts import ai.koog.a2a.model.AgentCapabilities import ai.koog.a2a.model.AgentCard import ai.koog.a2a.model.AgentSkill +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendConfiguration +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.model.TextPart import ai.koog.a2a.model.TransportProtocol import ai.koog.a2a.server.notifications.InMemoryPushNotificationConfigStorage import ai.koog.a2a.test.BaseA2AProtocolTest +import ai.koog.a2a.transport.Request import ai.koog.a2a.transport.client.jsonrpc.http.HttpJSONRPCClientTransport import ai.koog.a2a.transport.server.jsonrpc.http.HttpJSONRPCServerTransport +import io.kotest.inspectors.shouldForAll +import io.kotest.inspectors.shouldForAtLeastOne +import io.kotest.matchers.nulls.shouldNotBeNull +import io.kotest.matchers.should +import io.kotest.matchers.shouldBe +import io.kotest.matchers.string.shouldStartWith +import io.kotest.matchers.types.shouldBeInstanceOf import io.ktor.client.HttpClient import io.ktor.client.engine.cio.CIO import io.ktor.client.plugins.logging.LogLevel import io.ktor.client.plugins.logging.Logging import io.ktor.server.netty.Netty +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.delay +import kotlinx.coroutines.flow.toList +import kotlinx.coroutines.joinAll +import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext import org.junit.jupiter.api.AfterAll import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.TestInstance +import java.net.ServerSocket import kotlin.test.BeforeTest +import kotlin.test.Test +import kotlin.time.Duration.Companion.seconds +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid /** * Integration test class for testing the JSON-RPC HTTP communication in the A2A server context. * This class ensures the proper functioning and correctness of the A2A protocol over HTTP * using the JSON-RPC standard. */ +@OptIn(ExperimentalUuidApi::class) @TestInstance(TestInstance.Lifecycle.PER_CLASS) class A2AServerJsonRpcIntegrationTest : BaseA2AProtocolTest() { + override val testTimeout = 10.seconds - companion object { - private const val TEST_PORT = 9999 - private const val TEST_PATH = "/a2a" - private const val SERVER_URL = "http://localhost:$TEST_PORT$TEST_PATH" - } + private var testPort: Int? = null + private val testPath = "/a2a" + private lateinit var serverUrl: String private lateinit var serverTransport: HttpJSONRPCServerTransport private lateinit var clientTransport: HttpJSONRPCClientTransport @@ -44,6 +73,10 @@ class A2AServerJsonRpcIntegrationTest : BaseA2AProtocolTest() { @BeforeAll fun setup(): Unit = runBlocking { + // Discover and take any free port + testPort = ServerSocket(0).use { it.localPort } + serverUrl = "http://localhost:$testPort$testPath" + // Create agent cards val agentCard = createAgentCard() val agentCardExtended = createExtendedAgentCard() @@ -65,8 +98,8 @@ class A2AServerJsonRpcIntegrationTest : BaseA2AProtocolTest() { // Start server serverTransport.start( engineFactory = Netty, - port = TEST_PORT, - path = TEST_PATH, + port = testPort!!, + path = testPath, wait = false, agentCard = agentCard, agentCardPath = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH, @@ -79,12 +112,12 @@ class A2AServerJsonRpcIntegrationTest : BaseA2AProtocolTest() { } } - clientTransport = HttpJSONRPCClientTransport(SERVER_URL, httpClient) + clientTransport = HttpJSONRPCClientTransport(serverUrl, httpClient) client = A2AClient( transport = clientTransport, agentCardResolver = UrlAgentCardResolver( - baseUrl = SERVER_URL, + baseUrl = serverUrl, path = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH ) ) @@ -184,4 +217,159 @@ class A2AServerJsonRpcIntegrationTest : BaseA2AProtocolTest() { supportsAuthenticatedExtendedCard = true, signatures = null ) + + /** + * Extended test that wouldn't work with Python A2A SDK server, because their implementation has some problems. + * It doesn't send events emitted in the `cancel` method in AgentExecutor to the subscribers of message/stream or tasks/resubscribe. + * But our server implementation should handle it properly. + */ + @Test + fun `test cancel task cancellation events received`() = runTest(timeout = testTimeout) { + // Need real time for this test + withContext(Dispatchers.Default) { + val createTaskRequest = Request( + data = MessageSendParams( + message = Message( + messageId = Uuid.random().toString(), + role = Role.User, + parts = listOf( + TextPart("do long-running task"), + ), + contextId = "test-context", + ), + ), + ) + + val taskId = (client.sendMessage(createTaskRequest).data as Task).id + + joinAll( + launch { + val resubscribeTaskRequest = Request( + data = TaskIdParams( + id = taskId, + ) + ) + + val events = client + .resubscribeTask(resubscribeTaskRequest) + .toList() + .map { it.data } + + // All the same task and context + events.shouldForAll { + it.shouldBeInstanceOf { + it.taskId shouldBe taskId + it.contextId shouldBe "test-context" + } + } + + // Has events from `execute` - task is working + events.shouldForAtLeastOne { + it.shouldBeInstanceOf { + it.status.state shouldBe TaskState.Working + it.status.message shouldNotBeNull { + role shouldBe Role.Agent + + parts.shouldForAll { + it.shouldBeInstanceOf { + it.text shouldStartWith "Still working" + } + } + } + } + } + + // Has events from `cancel` - task is canceled + events.shouldForAtLeastOne { + it.shouldBeInstanceOf { + it.status.state shouldBe TaskState.Canceled + it.status.message shouldNotBeNull { + role shouldBe Role.Agent + parts shouldBe listOf(TextPart("Task canceled")) + } + } + } + }, + launch { + // Let the task run for a while + delay(400) + + val cancelTaskRequest = Request( + data = TaskIdParams( + id = taskId, + ) + ) + + val response = client.cancelTask(cancelTaskRequest) + response.data should { + it.id shouldBe taskId + it.contextId shouldBe "test-context" + it.status should { + it.state shouldBe TaskState.Canceled + it.message shouldNotBeNull { + role shouldBe Role.Agent + parts shouldBe listOf(TextPart("Task canceled")) + } + } + } + } + ) + } + } + + /** + * Another test that doesn't work with Python A2A SDK server because of its implementation problems. + * It's taken from TCK. Follow-up messages to the running task should be supported. + * In case the task is still running, request should wait for a chance to be processed when the task is done. + */ + @Test + fun `test task send follow-up message`() = runTest(timeout = testTimeout) { + fun createRequest( + taskId: String?, + blocking: Boolean, + ) = Request( + data = MessageSendParams( + message = Message( + messageId = Uuid.random().toString(), + role = Role.User, + parts = listOf( + TextPart("do long-running task"), + ), + taskId = taskId, + contextId = "test-context" + ), + configuration = MessageSendConfiguration( + blocking = blocking + ) + ) + ) + + // Create a long-running task and return without waiting + val initialRequest = createRequest(taskId = null, blocking = false) + val initialResponse = client.sendMessage(initialRequest) + + val taskId = initialResponse.data.shouldBeInstanceOf().taskId + + // Immediately send a follow-up message to the same task and wait for the response + val followupRequest = createRequest(taskId = taskId, blocking = true) + val followupResponse = client.sendMessage(followupRequest) + + followupResponse.data.shouldBeInstanceOf { + it.taskId shouldBe taskId + it.contextId shouldBe "test-context" + + it.status should { + it.state shouldBe TaskState.Working + it.message shouldNotBeNull { + role shouldBe Role.Agent + + parts.shouldForAll { + it.shouldBeInstanceOf { + it.text shouldStartWith "Still working" + } + } + } + } + } + } } diff --git a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/TestAgentExecutor.kt b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/TestAgentExecutor.kt index 97843c2748..398d2afe53 100644 --- a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/TestAgentExecutor.kt +++ b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/TestAgentExecutor.kt @@ -1,3 +1,5 @@ +@file:OptIn(ExperimentalUuidApi::class) + package ai.koog.a2a.server import ai.koog.a2a.model.Message @@ -12,9 +14,12 @@ import ai.koog.a2a.model.TextPart import ai.koog.a2a.server.agent.AgentExecutor import ai.koog.a2a.server.session.RequestContext import ai.koog.a2a.server.session.SessionEventProcessor -import kotlinx.coroutines.Job +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.delay import kotlinx.datetime.Clock +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid private suspend fun sayHello( context: RequestContext, @@ -22,6 +27,7 @@ private suspend fun sayHello( ) { eventProcessor.sendMessage( Message( + messageId = Uuid.random().toString(), role = Role.Agent, parts = listOf(TextPart("Hello World")), contextId = context.contextId, @@ -55,6 +61,7 @@ private suspend fun doTask( status = TaskStatus( state = TaskState.Working, message = Message( + messageId = Uuid.random().toString(), role = Role.Agent, parts = listOf(TextPart("Working on task")), contextId = context.contextId, @@ -74,6 +81,7 @@ private suspend fun doTask( status = TaskStatus( state = TaskState.Completed, message = Message( + messageId = Uuid.random().toString(), role = Role.Agent, parts = listOf(TextPart("Task completed")), contextId = context.contextId, @@ -130,6 +138,7 @@ private suspend fun doLongRunningTask( status = TaskStatus( state = TaskState.Working, message = Message( + messageId = Uuid.random().toString(), role = Role.Agent, parts = listOf(TextPart("Still working $it")), contextId = context.contextId, @@ -171,6 +180,7 @@ class TestAgentExecutor : AgentExecutor { else -> { eventProcessor.sendMessage( Message( + messageId = Uuid.random().toString(), role = Role.Agent, parts = listOf(TextPart("Sorry, I don't understand you")), contextId = context.contextId @@ -183,9 +193,9 @@ class TestAgentExecutor : AgentExecutor { override suspend fun cancel( context: RequestContext, eventProcessor: SessionEventProcessor, - agentJob: Job? + agentJob: Deferred? ) { - agentJob?.cancel() + agentJob?.cancelAndJoin() eventProcessor.sendTaskEvent( TaskStatusUpdateEvent( @@ -194,6 +204,7 @@ class TestAgentExecutor : AgentExecutor { status = TaskStatus( state = TaskState.Canceled, message = Message( + messageId = Uuid.random().toString(), role = Role.Agent, parts = listOf(TextPart("Task canceled")), contextId = context.contextId, diff --git a/a2a/a2a-server/src/jvmTest/resources/logback.xml b/a2a/a2a-server/src/jvmTest/resources/logback.xml new file mode 100644 index 0000000000..24a99c370b --- /dev/null +++ b/a2a/a2a-server/src/jvmTest/resources/logback.xml @@ -0,0 +1,11 @@ + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + diff --git a/a2a/a2a-test/build.gradle.kts b/a2a/a2a-test/build.gradle.kts index 56670dc29f..69bfd6d21c 100644 --- a/a2a/a2a-test/build.gradle.kts +++ b/a2a/a2a-test/build.gradle.kts @@ -11,10 +11,10 @@ kotlin { api(project(":a2a:a2a-client")) api(kotlin("test")) api(kotlin("test-annotations-common")) - api(libs.kotest.assertions) api(libs.kotlinx.coroutines.core) api(libs.kotlinx.coroutines.test) api(libs.kotlinx.serialization.json) + implementation(libs.kotest.assertions) implementation(libs.kotlinx.coroutines.test) } } diff --git a/a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt b/a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt index 9e3883470d..d20014b792 100644 --- a/a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt +++ b/a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt @@ -33,6 +33,9 @@ import io.kotest.matchers.types.shouldBeInstanceOf import kotlinx.coroutines.flow.toList import kotlinx.coroutines.test.runTest import kotlin.test.Test +import kotlin.time.Duration +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid /** * Abstract base class containing transport-agnostic A2A protocol compliance tests. @@ -42,8 +45,10 @@ import kotlin.test.Test * * @property client The A2A client instance to test against. Should be connected and ready to use. */ +@OptIn(ExperimentalUuidApi::class) @Suppress("FunctionName") abstract class BaseA2AProtocolTest { + protected abstract val testTimeout: Duration /** * The A2A client instance to test. Must be connected and ready to use. @@ -51,7 +56,7 @@ abstract class BaseA2AProtocolTest { protected abstract var client: A2AClient @Test - fun `test get agent card`() = runTest { + fun `test get agent card`() = runTest(timeout = testTimeout) { val agentCard = client.getAgentCard() // Assert on the full AgentCard structure @@ -96,7 +101,7 @@ abstract class BaseA2AProtocolTest { } @Test - fun `test get authenticated extended agent card`() = runTest { + fun `test get authenticated extended agent card`() = runTest(timeout = testTimeout) { val request = Request(data = null) val response = client.getAuthenticatedExtendedAgentCard(request) @@ -153,10 +158,11 @@ abstract class BaseA2AProtocolTest { } @Test - fun `test send message`() = runTest { + fun `test send message`() = runTest(timeout = testTimeout) { val request = Request( data = MessageSendParams( message = Message( + messageId = Uuid.random().toString(), role = Role.User, parts = listOf( TextPart("hello world"), @@ -180,10 +186,11 @@ abstract class BaseA2AProtocolTest { } @Test - fun `test send message streaming`() = runTest { + fun `test send message streaming`() = runTest(timeout = testTimeout) { val createTaskRequest = Request( data = MessageSendParams( message = Message( + messageId = Uuid.random().toString(), role = Role.User, parts = listOf( TextPart("do task"), @@ -241,10 +248,11 @@ abstract class BaseA2AProtocolTest { } @Test - fun `test get task`() = runTest { + fun `test get task`() = runTest(timeout = testTimeout) { val createTaskRequest = Request( data = MessageSendParams( message = Message( + messageId = Uuid.random().toString(), role = Role.User, parts = listOf( TextPart("do task"), @@ -275,10 +283,11 @@ abstract class BaseA2AProtocolTest { } @Test - fun `test cancel task`() = runTest { + fun `test cancel task`() = runTest(timeout = testTimeout) { val createTaskRequest = Request( data = MessageSendParams( message = Message( + messageId = Uuid.random().toString(), role = Role.User, parts = listOf( TextPart("do cancelable task"), @@ -312,10 +321,11 @@ abstract class BaseA2AProtocolTest { } @Test - fun `test resubscribe task`() = runTest { + fun `test resubscribe task`() = runTest(timeout = testTimeout) { val createTaskRequest = Request( data = MessageSendParams( message = Message( + messageId = Uuid.random().toString(), role = Role.User, parts = listOf( TextPart("do long-running task"), @@ -365,10 +375,11 @@ abstract class BaseA2AProtocolTest { } @Test - fun `test push notification configs`() = runTest { + fun `test push notification configs`() = runTest(timeout = testTimeout) { val createTaskRequest = Request( data = MessageSendParams( message = Message( + messageId = Uuid.random().toString(), role = Role.User, parts = listOf( TextPart("do long-running task"), diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts index be7918cfff..0dfa2d09db 100644 --- a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts +++ b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/build.gradle.kts @@ -35,7 +35,7 @@ kotlin { implementation(kotlin("test-junit5")) implementation(libs.mokksy.a2a) implementation(libs.ktor.client.cio) - runtimeOnly(libs.slf4j.simple) + runtimeOnly(libs.logback.classic) } } diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransport.kt b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransport.kt index f009adb00a..11ca3edc76 100644 --- a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransport.kt +++ b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransport.kt @@ -20,6 +20,7 @@ import io.ktor.http.contentType import io.ktor.serialization.kotlinx.json.json import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.map /** * Implementation of a JSON-RPC client transport using HTTP as the underlying communication protocol. @@ -83,13 +84,12 @@ public class HttpJSONRPCClientTransport( setBody(request) } ) { - incoming.collect { event -> - requireNotNull(event.data) { "SSE data must not be null" } - .let { data -> - val response = JSONRPCJson.decodeFromString(data) - emit(response) - } - } + incoming + .map { event -> + requireNotNull(event.data) { "SSE data must not be null" } + .let { data -> JSONRPCJson.decodeFromString(data) } + } + .collect(this@flow) } } diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportTest.kt b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportTest.kt index f274142806..9a033d8c3a 100644 --- a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportTest.kt +++ b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/src/commonTest/kotlin/ai/koog/a2a/transport/client/jsonrpc/http/HttpJSONRPCClientTransportTest.kt @@ -28,6 +28,7 @@ import ai.koog.a2a.transport.jsonrpc.model.JSONRPCErrorResponse import ai.koog.a2a.transport.jsonrpc.model.JSONRPCJson import ai.koog.a2a.transport.jsonrpc.model.JSONRPCRequest import ai.koog.a2a.transport.jsonrpc.model.JSONRPCSuccessResponse +import ai.koog.a2a.transport.jsonrpc.model.JSONRPC_VERSION import io.ktor.client.HttpClient import io.ktor.client.engine.mock.MockEngine import io.ktor.client.engine.mock.respond @@ -43,7 +44,10 @@ import kotlin.test.Ignore import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.fail +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid +@OptIn(ExperimentalUuidApi::class) class HttpJSONRPCClientTransportTest { private val json = JSONRPCJson @@ -67,7 +71,8 @@ class HttpJSONRPCClientTransportTest { val jsonRpcResponse = JSONRPCSuccessResponse( id = expectedResponse.id, - result = json.encodeToJsonElement(expectedResponse.data) + result = json.encodeToJsonElement(expectedResponse.data), + jsonrpc = JSONRPC_VERSION, ) respond( @@ -130,6 +135,7 @@ class HttpJSONRPCClientTransportTest { val id = RequestId.StringId("test-2") val testMessage = Message( + messageId = Uuid.random().toString(), role = Role.User, parts = listOf(TextPart("Hello, agent!")), taskId = "task-123" @@ -190,12 +196,14 @@ class HttpJSONRPCClientTransportTest { status = TaskStatus( state = TaskState.Working, message = Message( + messageId = Uuid.random().toString(), role = Role.Agent, parts = listOf(TextPart("Working on your request...")) ) ), history = listOf( Message( + messageId = Uuid.random().toString(), role = Role.User, parts = listOf(TextPart("Hello, agent!")), taskId = "task-123" @@ -231,6 +239,7 @@ class HttpJSONRPCClientTransportTest { status = TaskStatus( state = TaskState.Canceled, message = Message( + messageId = Uuid.random().toString(), role = Role.Agent, parts = listOf(TextPart("Task has been canceled.")) ) @@ -390,6 +399,7 @@ class HttpJSONRPCClientTransportTest { val id = RequestId.StringId("test-error-1") val testMessage = Message( + messageId = Uuid.random().toString(), role = Role.User, parts = listOf(TextPart("Hello, agent!")), taskId = "invalid-task-id" @@ -421,7 +431,8 @@ class HttpJSONRPCClientTransportTest { code = A2AErrorCodes.INVALID_PARAMS, message = "Invalid method parameters", data = json.encodeToJsonElement("The message parameters are invalid") - ) + ), + jsonrpc = JSONRPC_VERSION, ) respond( diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/build.gradle.kts b/a2a/a2a-transport/a2a-transport-core-jsonrpc/build.gradle.kts index 591006e83b..dbc55803c5 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/build.gradle.kts +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/build.gradle.kts @@ -16,6 +16,8 @@ kotlin { api(libs.kotlinx.serialization.json) api(libs.kotlinx.coroutines.core) + + implementation(libs.oshai.kotlin.logging) } } @@ -28,6 +30,7 @@ kotlin { jvmTest { dependencies { implementation(kotlin("test-junit5")) + runtimeOnly(libs.logback.classic) } } diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt index 9d9f330d37..13ffc238d0 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCClientTransport.kt @@ -14,6 +14,7 @@ import ai.koog.a2a.model.TaskQueryParams import ai.koog.a2a.transport.ClientCallContext import ai.koog.a2a.transport.ClientTransport import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.RequestId import ai.koog.a2a.transport.Response import ai.koog.a2a.transport.jsonrpc.model.JSONRPCError import ai.koog.a2a.transport.jsonrpc.model.JSONRPCErrorResponse @@ -21,8 +22,10 @@ import ai.koog.a2a.transport.jsonrpc.model.JSONRPCJson import ai.koog.a2a.transport.jsonrpc.model.JSONRPCRequest import ai.koog.a2a.transport.jsonrpc.model.JSONRPCResponse import ai.koog.a2a.transport.jsonrpc.model.JSONRPCSuccessResponse +import ai.koog.a2a.transport.jsonrpc.model.JSONRPC_VERSION import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.onCompletion import kotlinx.serialization.json.decodeFromJsonElement import kotlinx.serialization.json.encodeToJsonElement @@ -54,7 +57,8 @@ public abstract class JSONRPCClientTransport : ClientTransport { return JSONRPCRequest( id = id, method = method.value, - params = JSONRPCJson.encodeToJsonElement(data) + params = JSONRPCJson.encodeToJsonElement(data), + jsonrpc = JSONRPC_VERSION, ) } @@ -71,13 +75,13 @@ public abstract class JSONRPCClientTransport : ClientTransport { ) is JSONRPCErrorResponse -> { - throw error.toA2AException() + throw error.toA2AException(id) } } } - protected fun JSONRPCError.toA2AException(): A2AException { - return createA2AException(message, code) + protected fun JSONRPCError.toA2AException(id: RequestId?): A2AException { + return createA2AException(message, code, id) } /** @@ -105,7 +109,14 @@ public abstract class JSONRPCClientTransport : ClientTransport { val jsonrpcRequest = request.toJSONRPCRequest(method) val jsonrpcResponse = requestStreaming(jsonrpcRequest, ctx) - return jsonrpcResponse.map { it.toResponse() } + return jsonrpcResponse + .map { it.toResponse() } + .onCompletion { thr -> + // Do not let wrap A2A exceptions, propagate them directly + if (thr?.cause is A2AException) { + throw thr.cause!! + } + } } override suspend fun getAuthenticatedExtendedAgentCard( diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt index c3b1ab6b83..ae015b6026 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/JSONRPCServerTransport.kt @@ -4,7 +4,9 @@ import ai.koog.a2a.annotations.InternalA2AApi import ai.koog.a2a.exceptions.A2AException import ai.koog.a2a.exceptions.A2AInternalErrorException import ai.koog.a2a.exceptions.A2AInvalidParamsException +import ai.koog.a2a.exceptions.A2AInvalidRequestException import ai.koog.a2a.exceptions.A2AMethodNotFoundException +import ai.koog.a2a.exceptions.A2AParseException import ai.koog.a2a.transport.Request import ai.koog.a2a.transport.RequestId import ai.koog.a2a.transport.Response @@ -16,28 +18,88 @@ import ai.koog.a2a.transport.jsonrpc.model.JSONRPCJson import ai.koog.a2a.transport.jsonrpc.model.JSONRPCRequest import ai.koog.a2a.transport.jsonrpc.model.JSONRPCResponse import ai.koog.a2a.transport.jsonrpc.model.JSONRPCSuccessResponse +import ai.koog.a2a.transport.jsonrpc.model.JSONRPC_VERSION import ai.koog.a2a.utils.runCatchingCancellable +import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.coroutines.CancellationException import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map import kotlinx.serialization.SerializationException +import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.decodeFromJsonElement import kotlinx.serialization.json.encodeToJsonElement +import kotlinx.serialization.json.jsonPrimitive /** * Abstract transport implementation for JSON-RPC-based server communication. * Handles receiving JSON-RPC requests, processing them, and sending responses. */ public abstract class JSONRPCServerTransport : ServerTransport { + private companion object { + private val logger = KotlinLogging.logger {} + } + /** - * Parses [A2AMethod] from the given [JSONRPCRequest]. - * - * @throws A2AMethodNotFoundException if method is not found. + * Manually parse [raw] string to build a [JSONRPCRequest] while throwing exceptions that A2A TCK excepts, according + * to A2A specification. */ - protected fun parseA2AMethod(request: JSONRPCRequest): A2AMethod { - return A2AMethod.entries.find { it.value == request.method } - ?: throw A2AMethodNotFoundException("Method not found: ${request.method}") + protected fun parseJSONRPCRequest(raw: String): Pair { + val jsonBody = try { + JSONRPCJson.decodeFromString(raw) + } catch (e: SerializationException) { + throw A2AParseException("Cannot parse request body to JSON:\n${e.message}") + } + + // According to A2A TCK, need to parse id early to reply with provided id in error messages + val id = jsonBody["id"]?.let { + try { + JSONRPCJson.decodeFromJsonElement(it) + } catch (e: SerializationException) { + throw A2AInvalidRequestException("Cannot parse request id to JSON-RPC id:\n${e.message}") + } + } + + val a2aMethod = (jsonBody["method"] as? JsonPrimitive) + ?.content + ?.let { + A2AMethod.entries.find { m -> m.value == it } + ?: throw A2AMethodNotFoundException("Method not found: $it", id) + } + ?: throw A2AInvalidRequestException("No method parameter", id) + + val params = jsonBody["params"] + ?.let { + try { + JSONRPCJson + .decodeFromJsonElement(it) + .also { + // According to A2A TCK, empty parameter names are not allowed + if (it.keys.any { it.isEmpty() }) { + throw A2AInvalidParamsException("Empty parameter names are not allowed", id) + } + } + } catch (e: SerializationException) { + throw A2AInvalidParamsException("Cannot parse request params to JSON:\n${e.message}", id) + } + } + + val jsonrpc = jsonBody["jsonrpc"] + ?.jsonPrimitive?.content + ?.takeIf { it == JSONRPC_VERSION } + ?: throw A2AInvalidRequestException("Unsupported JSON-RPC version", id) + + val jsonrpcBody = JSONRPCRequest( + id = id ?: throw A2AInvalidRequestException("No id parameter"), + method = a2aMethod.value, + params = params ?: JsonNull, + jsonrpc = jsonrpc, + ) + + return jsonrpcBody to a2aMethod } /** @@ -129,7 +191,8 @@ public abstract class JSONRPCServerTransport : ServerTransport { protected inline fun Response.toJSONRPCSuccessResponse(): JSONRPCSuccessResponse { return JSONRPCSuccessResponse( id = id, - result = JSONRPCJson.encodeToJsonElement(data) + result = JSONRPCJson.encodeToJsonElement(data), + jsonrpc = JSONRPC_VERSION, ) } @@ -139,13 +202,19 @@ public abstract class JSONRPCServerTransport : ServerTransport { protected fun Throwable.toJSONRPCErrorResponse(requestId: RequestId? = null): JSONRPCErrorResponse { val a2aException: A2AException = when (this) { is A2AException -> this - is Exception -> A2AInternalErrorException("Internal error: ${this.message}") + is CancellationException -> throw this + is Exception -> { + logger.warn(this) { "Non-A2A exception was detected when responding to request [requestId=$requestId]" } + A2AInternalErrorException("Internal error: ${this.message}") + } + else -> throw this // Non-exception throwable shouldn't be handled, rethrowing it } return JSONRPCErrorResponse( - id = requestId, - error = a2aException.toJSONRPCError() + id = requestId ?: a2aException.requestId, // if there's no requestId, use the one from the exception + error = a2aException.toJSONRPCError(), + jsonrpc = JSONRPC_VERSION, ) } diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt index 8238c2c565..0a87cfbc8a 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonMain/kotlin/ai/koog/a2a/transport/jsonrpc/model/Messages.kt @@ -3,7 +3,6 @@ package ai.koog.a2a.transport.jsonrpc.model import ai.koog.a2a.transport.RequestId -import kotlinx.serialization.EncodeDefault import kotlinx.serialization.Serializable import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.JsonNull @@ -26,24 +25,21 @@ public data class JSONRPCRequest( public val id: RequestId, val method: String, val params: JsonElement = JsonNull, - @EncodeDefault - override val jsonrpc: String = JSONRPC_VERSION, + override val jsonrpc: String, ) : JSONRPCMessage @Serializable public data class JSONRPCNotification( val method: String, val params: JsonElement = JsonNull, - @EncodeDefault - override val jsonrpc: String = JSONRPC_VERSION, + override val jsonrpc: String, ) : JSONRPCMessage @Serializable public data class JSONRPCSuccessResponse( public val id: RequestId, public val result: JsonElement = JsonNull, - @EncodeDefault - override val jsonrpc: String = JSONRPC_VERSION, + override val jsonrpc: String, ) : JSONRPCResponse @Serializable @@ -57,6 +53,5 @@ public data class JSONRPCError( public data class JSONRPCErrorResponse( public val id: RequestId?, public val error: JSONRPCError, - @EncodeDefault - override val jsonrpc: String = JSONRPC_VERSION, + override val jsonrpc: String, ) : JSONRPCResponse diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonTest/kotlin/ai/koog/a2a/transport/jsonrpc/model/JsonRpcSerializationTest.kt b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonTest/kotlin/ai/koog/a2a/transport/jsonrpc/model/JsonRpcSerializationTest.kt index 268432d441..87b084e48d 100644 --- a/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonTest/kotlin/ai/koog/a2a/transport/jsonrpc/model/JsonRpcSerializationTest.kt +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/src/commonTest/kotlin/ai/koog/a2a/transport/jsonrpc/model/JsonRpcSerializationTest.kt @@ -24,6 +24,7 @@ class JsonRpcSerializationTest { val request: JSONRPCMessage = JSONRPCRequest( id = RequestId.NumberId(42), method = "add", + jsonrpc = JSONRPC_VERSION, ) //language=JSON @@ -40,7 +41,8 @@ class JsonRpcSerializationTest { fun testJSONRPCNotification() { val request: JSONRPCMessage = JSONRPCNotification( method = "update", - params = JsonPrimitive("notification-params") + params = JsonPrimitive("notification-params"), + jsonrpc = JSONRPC_VERSION, ) //language=JSON @@ -57,6 +59,7 @@ class JsonRpcSerializationTest { fun testJSONRPCNotificationWithoutParams() { val request: JSONRPCMessage = JSONRPCNotification( method = "notify", + jsonrpc = JSONRPC_VERSION, ) //language=JSON @@ -73,7 +76,8 @@ class JsonRpcSerializationTest { fun testJSONRPCSuccessResponse() { val response: JSONRPCMessage = JSONRPCSuccessResponse( id = RequestId.NumberId(99), - result = JsonPrimitive(100) + result = JsonPrimitive(100), + jsonrpc = JSONRPC_VERSION, ) //language=JSON @@ -90,7 +94,8 @@ class JsonRpcSerializationTest { fun testJSONRPCErrorResponse() { val response: JSONRPCMessage = JSONRPCErrorResponse( id = RequestId.NumberId(123), - error = JSONRPCError(code = -32602, message = "Invalid params") + error = JSONRPCError(code = -32602, message = "Invalid params"), + jsonrpc = JSONRPC_VERSION, ) //language=JSON @@ -107,7 +112,8 @@ class JsonRpcSerializationTest { fun testJSONRPCErrorResponseWithoutId() { val response: JSONRPCMessage = JSONRPCErrorResponse( id = null, - error = JSONRPCError(code = -32700, message = "Parse error") + error = JSONRPCError(code = -32700, message = "Parse error"), + jsonrpc = JSONRPC_VERSION, ) //language=JSON diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt index 9db317abf9..176f0c65bc 100644 --- a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt @@ -2,8 +2,6 @@ package ai.koog.a2a.transport.server.jsonrpc.http import ai.koog.a2a.annotations.InternalA2AApi import ai.koog.a2a.consts.A2AConsts -import ai.koog.a2a.exceptions.A2AInvalidRequestException -import ai.koog.a2a.exceptions.A2AParseException import ai.koog.a2a.model.AgentCard import ai.koog.a2a.transport.RequestHandler import ai.koog.a2a.transport.ServerCallContext @@ -37,8 +35,6 @@ import io.ktor.sse.ServerSentEvent import io.ktor.util.toMap import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock -import kotlinx.serialization.SerializationException -import kotlinx.serialization.json.decodeFromJsonElement /** * Implements A2A JSON-RPC server transport over HTTP using Ktor server @@ -197,12 +193,10 @@ public class HttpJSONRPCServerTransport( // Handle incoming JSON-RPC requests, both regular and streaming post { runCatchingCancellable { - val request: JSONRPCRequest = call.receiveJSONRPCRequest() - val ctx: ServerCallContext = call.toServerCallContext() + val (request, a2aMethod) = parseJSONRPCRequest(call.receiveText()) + val ctx = call.toServerCallContext() runCatchingCancellable { - val a2aMethod = parseA2AMethod(request) - if (a2aMethod.streaming) { handleRequestStreaming(request, ctx) } else { @@ -265,24 +259,6 @@ public class HttpJSONRPCServerTransport( call.respond(SSEServerContent(call, handle)) } - /** - * Converts raw request body to [JSONRPCRequest], following A2A specification for error handling. - */ - private suspend fun ApplicationCall.receiveJSONRPCRequest(): JSONRPCRequest { - val jsonBody = try { - val rawBody = receiveText() - JSONRPCJson.parseToJsonElement(rawBody) - } catch (e: SerializationException) { - throw A2AParseException("Cannot parse request body to JSON:\n${e.message}") - } - - return try { - JSONRPCJson.decodeFromJsonElement(jsonBody) - } catch (e: SerializationException) { - throw A2AInvalidRequestException("Cannot parse request params to JSON-RPC request:\n${e.message}") - } - } - private fun ApplicationCall.toServerCallContext(): ServerCallContext { return ServerCallContext( headers = request.headers.toMap() diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt index fe3dd62cff..b41739fbff 100644 --- a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/jvmTest/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransportTest.kt @@ -28,6 +28,7 @@ import ai.koog.a2a.transport.jsonrpc.model.JSONRPCErrorResponse import ai.koog.a2a.transport.jsonrpc.model.JSONRPCJson import ai.koog.a2a.transport.jsonrpc.model.JSONRPCRequest import ai.koog.a2a.transport.jsonrpc.model.JSONRPCSuccessResponse +import ai.koog.a2a.transport.jsonrpc.model.JSONRPC_VERSION import io.ktor.client.plugins.sse.sse import io.ktor.client.request.post import io.ktor.client.request.setBody @@ -251,7 +252,8 @@ class HttpJSONRPCServerTransportTest { val jsonRpcRequest = JSONRPCRequest( id = request.id, method = method.value, - params = json.encodeToJsonElement(request.data) + params = json.encodeToJsonElement(request.data), + jsonrpc = JSONRPC_VERSION, ) val response = client.post("/a2a") { @@ -293,7 +295,8 @@ class HttpJSONRPCServerTransportTest { val jsonRpcRequest = JSONRPCRequest( id = request.id, method = method.value, - params = json.encodeToJsonElement(request.data) + params = json.encodeToJsonElement(request.data), + jsonrpc = JSONRPC_VERSION, ) val jsonrpcResponses = buildList { @@ -596,7 +599,8 @@ class HttpJSONRPCServerTransportTest { val jsonRpcRequest = JSONRPCRequest( id = requestId, method = "unknown.method", - params = JsonNull + params = JsonNull, + jsonrpc = JSONRPC_VERSION, ) val response = client.post("/a2a") { diff --git a/a2a/test-tck/.gitignore b/a2a/test-tck/.gitignore new file mode 100644 index 0000000000..b44f8ddb8c --- /dev/null +++ b/a2a/test-tck/.gitignore @@ -0,0 +1,2 @@ +# A2A Testing Kit, should be cloned locally by using setup_tck.sh +/a2a-tck diff --git a/a2a/test-tck/README.md b/a2a/test-tck/README.md new file mode 100644 index 0000000000..b0b863c4e8 --- /dev/null +++ b/a2a/test-tck/README.md @@ -0,0 +1,33 @@ +# A2A Testing Kit Integration + +This directory contains tooling to validate the A2A Kotlin SDK against the +official [A2A protocol specification](https://a2a-protocol.org/latest/specification/) using the A2A Testing Kit (TCK). + +## Contents + +- **`a2a-test-server-tck/`**: Sample A2A server implementation built with Koog SDK for TCK validation +- **`a2a-tck/`**: Official A2A Testing Kit repository (gitignored, should be cloned by `setup_tck.sh`) +- **`setup_tck.sh`**: Clone and setup the A2A Testing Kit +- **`run_sut.sh`**: Run the Kotlin test server (System Under Test) +- **`run_tck.sh`**: Execute TCK tests against the running server + +## Quick start + +1. **Setup the Testing Kit:** + ```bash + ./setup_tck.sh + ``` + +2. **Run the test server:** + ```bash + ./run_sut.sh + ``` + +3. **In another terminal, run the TCK tests:** + ```bash + ./run_tck.sh --sut-url http://localhost:9999/a2a --category all --report + ``` + +## More information + +For more information, see the [A2A Testing Kit repo](https://a2a-protocol.org/latest/tck/). diff --git a/a2a/test-tck/a2a-test-server-tck/build.gradle.kts b/a2a/test-tck/a2a-test-server-tck/build.gradle.kts new file mode 100644 index 0000000000..bd974a53e9 --- /dev/null +++ b/a2a/test-tck/a2a-test-server-tck/build.gradle.kts @@ -0,0 +1,22 @@ +plugins { + id("ai.kotlin.jvm") + alias(libs.plugins.kotlin.serialization) + application +} + +application { + mainClass = "ai.koog.a2a.test.tck.MainKt" +} + +dependencies { + implementation(project(":a2a:a2a-server")) + implementation(project(":a2a:a2a-client")) + implementation(project(":a2a:a2a-transport:a2a-transport-client-jsonrpc-http")) + implementation(project(":a2a:a2a-transport:a2a-transport-server-jsonrpc-http")) + + implementation(libs.ktor.client.cio) + implementation(libs.ktor.client.logging) + implementation(libs.ktor.server.netty) + implementation(libs.oshai.kotlin.logging) + runtimeOnly(libs.logback.classic) +} diff --git a/a2a/test-tck/a2a-test-server-tck/src/main/kotlin/ai/koog/a2a/test/tck/Main.kt b/a2a/test-tck/a2a-test-server-tck/src/main/kotlin/ai/koog/a2a/test/tck/Main.kt new file mode 100644 index 0000000000..df4cbb42d5 --- /dev/null +++ b/a2a/test-tck/a2a-test-server-tck/src/main/kotlin/ai/koog/a2a/test/tck/Main.kt @@ -0,0 +1,114 @@ +package ai.koog.a2a.test.tck + +import ai.koog.a2a.consts.A2AConsts +import ai.koog.a2a.model.AgentCapabilities +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.AgentInterface +import ai.koog.a2a.model.AgentSkill +import ai.koog.a2a.model.AuthorizationCodeOAuthFlow +import ai.koog.a2a.model.HTTPAuthSecurityScheme +import ai.koog.a2a.model.OAuth2SecurityScheme +import ai.koog.a2a.model.OAuthFlows +import ai.koog.a2a.model.TransportProtocol +import ai.koog.a2a.server.A2AServer +import ai.koog.a2a.server.notifications.InMemoryPushNotificationConfigStorage +import ai.koog.a2a.transport.server.jsonrpc.http.HttpJSONRPCServerTransport +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.server.netty.Netty + +private val logger = KotlinLogging.logger {} + +suspend fun main() { + logger.info { "Starting TCK A2A Agent on http://localhost:9999" } + + // Define security schemes + val httpBearerScheme = HTTPAuthSecurityScheme( + scheme = "bearer", + description = "HTTP Bearer token authentication" + ) + + val oauth2Scheme = OAuth2SecurityScheme( + flows = OAuthFlows( + authorizationCode = AuthorizationCodeOAuthFlow( + authorizationUrl = "https://auth.example.com/oauth/authorize", + tokenUrl = "https://auth.example.com/oauth/token", + scopes = mapOf( + "read" to "Read access", + "write" to "Write access" + ) + ) + ), + description = "OAuth 2.0 authentication" + ) + + val securitySchemes = mapOf( + "bearerAuth" to httpBearerScheme, + "oauth2" to oauth2Scheme + ) + + // Create agent card with capabilities and security + val agentCard = AgentCard( + protocolVersion = "0.3.0", + name = "TCK A2A Agent", + description = "A complete A2A agent implementation designed specifically for testing with the A2A Technology Compatibility Kit (TCK)", + version = "1.0.0", + url = "http://localhost:9999/a2a", + preferredTransport = TransportProtocol.JSONRPC, + additionalInterfaces = listOf( + AgentInterface( + url = "http://localhost:9999/a2a", + transport = TransportProtocol.JSONRPC, + ) + ), + capabilities = AgentCapabilities( + streaming = true, + ), + defaultInputModes = listOf("text"), + defaultOutputModes = listOf("text"), + skills = listOf( + AgentSkill( + id = "tck_agent", + name = "TCK Agent", + description = "A complete A2A agent implementation designed for TCK testing", + examples = listOf("hi", "hello world", "how are you", "goodbye"), + tags = listOf("tck", "testing", "core", "complete") + ) + ), + securitySchemes = securitySchemes, + security = listOf( + mapOf("bearerAuth" to emptyList()), + mapOf("oauth2" to listOf("read", "write")) + ), + supportsAuthenticatedExtendedCard = false + ) + + // Create extended agent card (same as basic for testing purposes) + val agentCardExtended = agentCard.copy( + name = "TCK A2A Agent - Extended Edition", + description = "The full-featured A2A agent for authenticated users." + ) + + // Create agent executor + val agentExecutor = TckAgentExecutor() + + // Create A2A server + val a2aServer = A2AServer( + agentExecutor = agentExecutor, + agentCard = agentCard, + agentCardExtended = agentCardExtended, + pushConfigStorage = InMemoryPushNotificationConfigStorage() + ) + + // Create and start server transport + val serverTransport = HttpJSONRPCServerTransport(a2aServer) + + logger.info { "Authentication tests will document SDK gaps with expected failures" } + serverTransport.start( + engineFactory = Netty, + port = 9999, + path = "/a2a", + wait = true, // Block until server stops + agentCard = agentCard, + agentCardPath = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH + ) +} diff --git a/a2a/test-tck/a2a-test-server-tck/src/main/kotlin/ai/koog/a2a/test/tck/TckAgentExecutor.kt b/a2a/test-tck/a2a-test-server-tck/src/main/kotlin/ai/koog/a2a/test/tck/TckAgentExecutor.kt new file mode 100644 index 0000000000..9dd6398680 --- /dev/null +++ b/a2a/test-tck/a2a-test-server-tck/src/main/kotlin/ai/koog/a2a/test/tck/TckAgentExecutor.kt @@ -0,0 +1,229 @@ +package ai.koog.a2a.test.tck + +import ai.koog.a2a.model.Artifact +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskArtifactUpdateEvent +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatus +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.server.agent.AgentExecutor +import ai.koog.a2a.server.session.RequestContext +import ai.koog.a2a.server.session.SessionEventProcessor +import io.ktor.utils.io.CancellationException +import kotlinx.coroutines.Deferred +import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.delay +import kotlinx.datetime.Clock +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +@OptIn(ExperimentalUuidApi::class) +class TckAgentExecutor : AgentExecutor { + override suspend fun execute( + context: RequestContext, + eventProcessor: SessionEventProcessor + ) { + val userMessage = context.params.message + val userInput = userMessage.parts.filterIsInstance() + .joinToString(" ") { it.text } + .lowercase() + + if (userInput.isBlank()) { + eventProcessor.sendMessage( + Message( + messageId = Uuid.random().toString(), + role = Role.Agent, + parts = listOf(TextPart("Hello! Please provide a message for me to respond to.")), + contextId = context.contextId + ) + ) + return + } + + val taskId = context.taskId + + try { + if (context.task == null) { + processNewTask(context, eventProcessor, userMessage, userInput) + } else { + processExistingTask(context, eventProcessor, userMessage, userInput) + } + } catch (e: CancellationException) { + // Propagate cancellation exception + throw e + } catch (e: Exception) { + // Handle errors by marking task as failed + val errorMessage = "Error processing request: ${e.message}" + val failedStatus = TaskStatusUpdateEvent( + contextId = context.contextId, + taskId = taskId, + status = TaskStatus( + state = TaskState.Failed, + message = Message( + messageId = Uuid.random().toString(), + role = Role.Agent, + parts = listOf(TextPart(errorMessage)), + contextId = context.contextId, + taskId = taskId + ), + timestamp = Clock.System.now() + ), + final = true + ) + eventProcessor.sendTaskEvent(failedStatus) + } + } + + override suspend fun cancel( + context: RequestContext, + eventProcessor: SessionEventProcessor, + agentJob: Deferred? + ) { + val taskId = context.taskId + // Cancel the coroutine job if provided + agentJob?.cancelAndJoin() + + // Send cancellation event + val canceledStatus = TaskStatusUpdateEvent( + contextId = context.contextId, + taskId = taskId, + status = TaskStatus( + state = TaskState.Canceled, + timestamp = Clock.System.now() + ), + final = true + ) + eventProcessor.sendTaskEvent(canceledStatus) + } + + private suspend fun processNewTask( + context: RequestContext, + eventProcessor: SessionEventProcessor, + userMessage: Message, + userInput: String, + ) { + val task = Task( + id = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = TaskState.Submitted, + timestamp = Clock.System.now() + ), + history = listOf(userMessage) + ) + + // Send initial task event + eventProcessor.sendTaskEvent(task) + + // Short delay to allow tests to see submitted state + delay(200) + + // Update to working state + val workingStatus = TaskStatusUpdateEvent( + contextId = context.contextId, + taskId = task.id, + status = TaskStatus( + state = TaskState.Working, + timestamp = Clock.System.now() + ), + final = false + ) + eventProcessor.sendTaskEvent(workingStatus) + + // Short delay for working state + delay(200) + + // Process the request + val result = processUserInput(userInput) + delay(200) // Brief processing delay + + // Create artifact with result + val artifact = Artifact( + artifactId = "response", + parts = listOf(TextPart(result)), + description = "Agent response to user message." + ) + + val artifactEvent = TaskArtifactUpdateEvent( + taskId = task.id, + contextId = context.contextId, + artifact = artifact, + append = false + ) + eventProcessor.sendTaskEvent(artifactEvent) + + // Mark task as completed + val completedStatus = TaskStatusUpdateEvent( + contextId = context.contextId, + taskId = task.id, + status = TaskStatus( + state = TaskState.InputRequired, + message = Message( + messageId = Uuid.random().toString(), + role = Role.Agent, + parts = listOf(TextPart(result)), + contextId = context.contextId, + taskId = task.id + ), + timestamp = Clock.System.now() + ), + final = true + ) + eventProcessor.sendTaskEvent(completedStatus) + } + + private suspend fun processExistingTask( + context: RequestContext, + eventProcessor: SessionEventProcessor, + userMessage: Message, + userInput: String, + ) { + eventProcessor.sendTaskEvent( + TaskStatusUpdateEvent( + taskId = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = TaskState.Working, + timestamp = Clock.System.now(), + message = userMessage + ), + final = false + ) + ) + + delay(100) + + eventProcessor.sendTaskEvent( + TaskStatusUpdateEvent( + taskId = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = TaskState.Completed, + timestamp = Clock.System.now(), + message = Message( + messageId = Uuid.random().toString(), + role = Role.Agent, + parts = listOf(TextPart("Task completed successfully!")), + contextId = context.contextId, + taskId = context.taskId + ) + ), + final = true + ) + ) + } + + private fun processUserInput(input: String): String { + return when { + "hello" in input || "hi" in input -> "Hello World! Nice to meet you!" + "how are you" in input -> "I'm doing great! Thanks for asking. How can I help you today?" + "goodbye" in input || "bye" in input -> "Goodbye! Have a wonderful day!" + else -> "Hello World! You said: '$input'. Thanks for your message!" + } + } +} diff --git a/a2a/test-tck/a2a-test-server-tck/src/main/resources/logback.xml b/a2a/test-tck/a2a-test-server-tck/src/main/resources/logback.xml new file mode 100644 index 0000000000..24a99c370b --- /dev/null +++ b/a2a/test-tck/a2a-test-server-tck/src/main/resources/logback.xml @@ -0,0 +1,11 @@ + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + diff --git a/a2a/test-tck/run_sut.sh b/a2a/test-tck/run_sut.sh new file mode 100755 index 0000000000..86701ec414 --- /dev/null +++ b/a2a/test-tck/run_sut.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Always operate relative to this script's directory +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ROOT_DIR="$(cd "${SCRIPT_DIR}/../.." && pwd)" + +cd "${ROOT_DIR}" + +# Run the SUT via Gradle; delegate any extra CLI args to Gradle +exec ./gradlew :a2a:test-tck:a2a-test-server-tck:run "$@" diff --git a/a2a/test-tck/run_tck.sh b/a2a/test-tck/run_tck.sh new file mode 100755 index 0000000000..e860921c9f --- /dev/null +++ b/a2a/test-tck/run_tck.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Always operate relative to this script's directory +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_DIR="${SCRIPT_DIR}/a2a-tck" + +if [ ! -d "${REPO_DIR}" ]; then + echo "[run_tck] Expected directory not found: ${REPO_DIR}" + echo "[run_tck] Please run setup_tck.sh first to clone and prepare A2A testing kit project." + exit 1 +fi + +cd "${REPO_DIR}" + +# Delegate all CLI parameters to the underlying script +exec uv run ./run_tck.py "$@" diff --git a/a2a/test-tck/setup_tck.sh b/a2a/test-tck/setup_tck.sh new file mode 100755 index 0000000000..611b1c30e2 --- /dev/null +++ b/a2a/test-tck/setup_tck.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash + +set -euo pipefail + +# Resolve the directory where this script resides to always operate relative to it +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_DIR="${SCRIPT_DIR}/a2a-tck" +REPO_URL="https://github.com/a2aproject/a2a-tck.git" + +# 1) Check if a2a-tck already exists +if [ -d "${REPO_DIR}" ]; then + echo "[setup_tck] 'a2a-tck' directory already exists at: ${REPO_DIR}" + echo "[setup_tck] Will skip cloning but still run 'uv sync'." +else + echo "[setup_tck] 'a2a-tck' directory not found in ${SCRIPT_DIR}." + echo "[setup_tck] Cloning repository into: ${REPO_DIR}" + git clone "${REPO_URL}" "${REPO_DIR}" --depth=1 +fi + +# 2) Always run uv sync in the repo directory +echo "[setup_tck] Running 'uv sync' in ${REPO_DIR}..." +( + cd "${REPO_DIR}" + uv sync --all-packages --all-extras +) +echo "[setup_tck] Done." diff --git a/koog-agents/build.gradle.kts b/koog-agents/build.gradle.kts index 54f589d063..e060385f96 100644 --- a/koog-agents/build.gradle.kts +++ b/koog-agents/build.gradle.kts @@ -11,7 +11,6 @@ val excluded = setOf( ":agents:agents-test", ":agents:agents-ext", ":agents:agents-features:agents-features-sql", // Optional SQL persistence provider - ":a2a:a2a-test", // Testing utilities for A2A protocol compliance ":agents:agents-mcp-server", ":integration-tests", ":test-utils", @@ -28,6 +27,8 @@ val excluded = setOf( ":a2a:a2a-transport:a2a-transport-core-rest", ":a2a:a2a-transport:a2a-transport-server-rest", ":a2a:a2a-transport:a2a-transport-client-rest", + ":a2a:a2a-test", + ":a2a:test-tck:a2a-test-server-tck", project.path, // the current project should not depend on itself ) diff --git a/settings.gradle.kts b/settings.gradle.kts index de8bdfe3a6..248c910e30 100755 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -71,6 +71,7 @@ include(":a2a:a2a-transport:a2a-transport-client-jsonrpc-http") include(":a2a:a2a-transport:a2a-transport-core-rest") include(":a2a:a2a-transport:a2a-transport-server-rest") include(":a2a:a2a-transport:a2a-transport-client-rest") +include(":a2a:test-tck:a2a-test-server-tck") include(":koog-spring-boot-starter") From 10d5b1b85f129d3a4ce5f0a6f9e8fcdc4704226a Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Mon, 29 Sep 2025 22:56:07 +0200 Subject: [PATCH 34/52] [a2a] Add CORS configuration --- .../a2a-transport-server-jsonrpc-http/build.gradle.kts | 1 + .../server/jsonrpc/http/HttpJSONRPCServerTransport.kt | 6 ++++++ gradle/libs.versions.toml | 1 + 3 files changed, 8 insertions(+) diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts index de343e79f1..4033523c28 100644 --- a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/build.gradle.kts @@ -19,6 +19,7 @@ kotlin { implementation(libs.ktor.serialization.kotlinx.json) implementation(libs.ktor.server.content.negotiation) implementation(libs.ktor.server.sse) + implementation(libs.ktor.server.cors) } } diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt index 176f0c65bc..c32bf736c9 100644 --- a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/src/commonMain/kotlin/ai/koog/a2a/transport/server/jsonrpc/http/HttpJSONRPCServerTransport.kt @@ -19,6 +19,7 @@ import io.ktor.server.engine.ApplicationEngineFactory import io.ktor.server.engine.EmbeddedServer import io.ktor.server.engine.embeddedServer import io.ktor.server.plugins.contentnegotiation.ContentNegotiation +import io.ktor.server.plugins.cors.routing.CORS import io.ktor.server.request.receiveText import io.ktor.server.response.header import io.ktor.server.response.respond @@ -130,6 +131,11 @@ public class HttpJSONRPCServerTransport( json(JSONRPCJson) } + install(CORS) { + anyHost() + allowNonSimpleContentTypes = true + } + transportRoutes(this, path) if (agentCard != null) { diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 68c59d1e5d..ca669a3f40 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -71,6 +71,7 @@ ktor-server-cio = { module = "io.ktor:ktor-server-cio", version.ref = "ktor3" } ktor-server-netty = { module = "io.ktor:ktor-server-netty-jvm", version.ref = "ktor3" } ktor-server-sse = { module = "io.ktor:ktor-server-sse", version.ref = "ktor3" } ktor-server-content-negotiation = { module = "io.ktor:ktor-server-content-negotiation", version.ref = "ktor3" } +ktor-server-cors = { module = "io.ktor:ktor-server-cors", version.ref = "ktor3" } ktor-server-test-host = { module = "io.ktor:ktor-server-test-host", version.ref = "ktor3" } lettuce-core = { module = "io.lettuce:lettuce-core", version.ref = "lettuce" } logback-classic = { module = "ch.qos.logback:logback-classic", version.ref = "logback" } From b1d5e2b84a4b215edda69253e98d8d908bc4566f Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Tue, 30 Sep 2025 00:29:31 +0200 Subject: [PATCH 35/52] [a2a] Add A2A server and A2A client agent features --- .../ai/koog/a2a/transport/ClientTransport.kt | 4 +- .../agents-features-a2a-client/Module.md | 12 + .../build.gradle.kts | 40 +++ .../a2a/client/feature/A2AAgentClient.kt | 111 +++++++ .../a2a/client/feature/A2AAgentClientNodes.kt | 293 ++++++++++++++++++ .../agents-features-a2a-core/Module.md | 7 + .../agents-features-a2a-core/build.gradle.kts | 39 +++ .../ai/koog/agents/a2a/core/Converters.kt | 161 ++++++++++ .../ai/koog/agents/a2a/core/ConvertersTest.kt | 196 ++++++++++++ .../agents-features-a2a-server/Module.md | 17 + .../build.gradle.kts | 40 +++ .../a2a/server/feature/A2AAgentServer.kt | 112 +++++++ .../a2a/server/feature/A2AAgentServerNodes.kt | 195 ++++++++++++ koog-agents/build.gradle.kts | 4 + settings.gradle.kts | 3 + 15 files changed, 1233 insertions(+), 1 deletion(-) create mode 100644 agents/agents-features/agents-features-a2a-client/Module.md create mode 100644 agents/agents-features/agents-features-a2a-client/build.gradle.kts create mode 100644 agents/agents-features/agents-features-a2a-client/src/commonMain/kotlin/ai/koog/agents/a2a/client/feature/A2AAgentClient.kt create mode 100644 agents/agents-features/agents-features-a2a-client/src/commonMain/kotlin/ai/koog/agents/a2a/client/feature/A2AAgentClientNodes.kt create mode 100644 agents/agents-features/agents-features-a2a-core/Module.md create mode 100644 agents/agents-features/agents-features-a2a-core/build.gradle.kts create mode 100644 agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/Converters.kt create mode 100644 agents/agents-features/agents-features-a2a-core/src/commonTest/kotlin/ai/koog/agents/a2a/core/ConvertersTest.kt create mode 100644 agents/agents-features/agents-features-a2a-server/Module.md create mode 100644 agents/agents-features/agents-features-a2a-server/build.gradle.kts create mode 100644 agents/agents-features/agents-features-a2a-server/src/commonMain/kotlin/ai/koog/agents/a2a/server/feature/A2AAgentServer.kt create mode 100644 agents/agents-features/agents-features-a2a-server/src/commonMain/kotlin/ai/koog/agents/a2a/server/feature/A2AAgentServerNodes.kt diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt index 6d1b5d1a89..ca1821c373 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/transport/ClientTransport.kt @@ -11,6 +11,7 @@ import ai.koog.a2a.model.TaskPushNotificationConfig import ai.koog.a2a.model.TaskPushNotificationConfigParams import ai.koog.a2a.model.TaskQueryParams import kotlinx.coroutines.flow.Flow +import kotlinx.serialization.Serializable import kotlinx.serialization.SerializationException /** @@ -131,7 +132,8 @@ public interface ClientTransport : AutoCloseable { * * @property additionalHeaders Additional call-specific headers associated with the call. */ -public class ClientCallContext( +@Serializable +public data class ClientCallContext( public val additionalHeaders: Map> = emptyMap(), ) { @Suppress("MissingKDocForPublicAPI") diff --git a/agents/agents-features/agents-features-a2a-client/Module.md b/agents/agents-features/agents-features-a2a-client/Module.md new file mode 100644 index 0000000000..985dc62807 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-client/Module.md @@ -0,0 +1,12 @@ +# Module agents-features-a2a-client + +Agent feature for enabling A2A client capabilities within Koog agents. + +## Overview + +This module provides an agent feature that enables Koog agents to act as A2A clients, allowing them to communicate with remote A2A-enabled agents. When installed, agents gain access to a registry of A2A clients and pre-built nodes for sending messages, managing tasks, and subscribing to events from other A2A agents. + +## Key Components + +- **`A2AAgentClient`**: Feature providing access to registered A2A clients from agent nodes +- **Agent Nodes**: Pre-built nodes for common operations (messaging, task management, push notifications) diff --git a/agents/agents-features/agents-features-a2a-client/build.gradle.kts b/agents/agents-features/agents-features-a2a-client/build.gradle.kts new file mode 100644 index 0000000000..2af1909021 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-client/build.gradle.kts @@ -0,0 +1,40 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + sourceSets { + commonMain { + dependencies { + api(project(":agents:agents-core")) + api(project(":a2a:a2a-client")) + api(project(":agents:agents-features:agents-features-a2a-core")) + + api(libs.kotlinx.serialization.json) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + implementation(libs.kotlinx.coroutines.test) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/agents/agents-features/agents-features-a2a-client/src/commonMain/kotlin/ai/koog/agents/a2a/client/feature/A2AAgentClient.kt b/agents/agents-features/agents-features-a2a-client/src/commonMain/kotlin/ai/koog/agents/a2a/client/feature/A2AAgentClient.kt new file mode 100644 index 0000000000..d2af12c29c --- /dev/null +++ b/agents/agents-features/agents-features-a2a-client/src/commonMain/kotlin/ai/koog/agents/a2a/client/feature/A2AAgentClient.kt @@ -0,0 +1,111 @@ +package ai.koog.agents.a2a.client.feature + +import ai.koog.a2a.client.A2AClient +import ai.koog.agents.core.agent.context.AIAgentContext +import ai.koog.agents.core.agent.entity.AIAgentStorageKey +import ai.koog.agents.core.agent.entity.createStorageKey +import ai.koog.agents.core.feature.AIAgentGraphFeature +import ai.koog.agents.core.feature.AIAgentGraphPipeline +import ai.koog.agents.core.feature.AIAgentNonGraphFeature +import ai.koog.agents.core.feature.AIAgentNonGraphPipeline +import ai.koog.agents.core.feature.config.FeatureConfig +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract + +/** + * Agent feature that enables A2A client mode by providing access to registered A2A clients + * from within agent strategies. + * + * This feature allows agents to communicate with other A2A-enabled agents by making + * [A2AClient] instances available to agent nodes. When installed, it provides access to + * a registry of A2A clients, allowing agent nodes to: + * - Send messages to remote A2A agents + * - Retrieve agent cards and capabilities + * - Manage tasks on remote agents + * - Subscribe to task events and streaming responses + * - Configure push notifications + * + * The feature provides convenience nodes for common A2A client + * operations like sending messages, retrieving tasks, and managing subscriptions. + * + * @property a2aClients Map of A2A clients keyed by agent ID + * + * @see ai.koog.a2a.client.A2AClient + */ +public class A2AAgentClient( + public val a2aClients: Map +) { + /** + * Configuration for the [A2AAgentClient] feature. + */ + public class Config : FeatureConfig() { + /** + * Map of [A2AClient] instances keyed by agent ID for accessing remote A2A agents. + */ + public val a2aClients: Map = mapOf() + } + + public companion object Feature : + AIAgentGraphFeature, + AIAgentNonGraphFeature { + + override val key: AIAgentStorageKey = + createStorageKey("agents-features-a2a-client") + + override fun createInitialConfig(): Config = Config() + + override fun install( + config: Config, + pipeline: AIAgentGraphPipeline + ) { + pipeline.interceptContextAgentFeature(this) { _ -> + A2AAgentClient(config.a2aClients) + } + } + + override fun install( + config: Config, + pipeline: AIAgentNonGraphPipeline + ) { + pipeline.interceptContextAgentFeature(this) { + A2AAgentClient(config.a2aClients) + } + } + } +} + +/** + * Retrieves the [A2AAgentClient] feature from the agent context. + * + * @return The installed A2AAgentClient feature + * @throws IllegalStateException if the feature is not installed + */ +public fun AIAgentContext.a2aAgentClient(): A2AAgentClient = featureOrThrow(A2AAgentClient.Feature) + +/** + * Executes an action with the [A2AAgentClient] feature as the receiver. + * This is a convenience function that retrieves the feature and provides it as the receiver for the action block. + * + * @param action The action to execute with A2AAgentClient as receiver + * @return The result of the action + * @throws IllegalStateException if the feature is not installed + */ +@OptIn(ExperimentalContracts::class) +public inline fun AIAgentContext.withA2AAgentClient(action: A2AAgentClient.() -> T): T { + contract { + callsInPlace(action, InvocationKind.AT_MOST_ONCE) + } + + return a2aAgentClient().action() +} + +/** + * Retrieves an A2A client by agent ID or throws if not found. + * + * @param agentId The identifier of the A2A agent to retrieve + * @return The A2AClient instance for the specified agent ID + * @throws NoSuchElementException if no client is registered with the given agent ID + */ +public fun A2AAgentClient.a2aClientOrThrow(agentId: String): A2AClient = + a2aClients[agentId] ?: throw NoSuchElementException("A2A agent with id $agentId not found in the current agent context. Make sure to register it in the A2AAgentClient feature.") diff --git a/agents/agents-features/agents-features-a2a-client/src/commonMain/kotlin/ai/koog/agents/a2a/client/feature/A2AAgentClientNodes.kt b/agents/agents-features/agents-features-a2a-client/src/commonMain/kotlin/ai/koog/agents/a2a/client/feature/A2AAgentClientNodes.kt new file mode 100644 index 0000000000..c40a0bf065 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-client/src/commonMain/kotlin/ai/koog/agents/a2a/client/feature/A2AAgentClientNodes.kt @@ -0,0 +1,293 @@ +package ai.koog.agents.a2a.client.feature + +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.CommunicationEvent +import ai.koog.a2a.model.Event +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskIdParams +import ai.koog.a2a.model.TaskPushNotificationConfig +import ai.koog.a2a.model.TaskPushNotificationConfigParams +import ai.koog.a2a.model.TaskQueryParams +import ai.koog.a2a.transport.ClientCallContext +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.Response +import ai.koog.agents.core.dsl.builder.AIAgentBuilderDslMarker +import ai.koog.agents.core.dsl.builder.AIAgentNodeDelegate +import ai.koog.agents.core.dsl.builder.AIAgentSubgraphBuilderBase +import kotlinx.coroutines.flow.Flow +import kotlinx.serialization.Serializable + +/** + * Request parameters for A2A client operations that require call context and typed parameters. + * + * @property agentId The identifier of the A2A agent to send the request to, with which it is registered in [A2AAgentClient]. + * @property callContext The [io.ktor.server.application.CallContext] + * @property params The typed parameters for the specific A2A operation + */ +@Serializable +public data class A2AClientRequest( + val agentId: String, + val callContext: ClientCallContext, + val params: T +) + +/** + * Information about a registered A2A agent client. + * + * @property agentId The identifier with which the agent is registered in [A2AAgentClient] + * @property agentCard The cached agent card for this agent + */ +public data class A2AClientAgentInfo( + val agentId: String, + val agentCard: AgentCard, +) + +/** + * Creates a node that retrieves information about all A2A agents registered in [A2AAgentClient]. + * + * @param name Optional node name for debugging and tracing + * @return A node that returns a list of all registered agents with their cached agent cards + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientGetAllAgents( + name: String? = null, +): AIAgentNodeDelegate> = + node(name) { + withA2AAgentClient { + a2aClients.map { (agentId, a2aClient) -> + A2AClientAgentInfo( + agentId = agentId, + agentCard = a2aClient.cachedAgentCard() + ) + } + } + } + +/** + * Creates a node that retrieves an agent card from an A2A server. + * Input is an agent id with which [ai.koog.a2a.client.A2AClient] is registered in [A2AAgentClient]. + * + * @see ai.koog.a2a.client.A2AClient.getAgentCard + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientGetAgentCard( + name: String? = null, +): AIAgentNodeDelegate = + node(name) { agentId -> + withA2AAgentClient { + a2aClientOrThrow(agentId).getAgentCard() + } + } + +/** + * Creates a node that retrieves the cached agent card without making a network call. + * Input is an agent id with which [ai.koog.a2a.client.A2AClient] is registered in [A2AAgentClient]. + * + * @see ai.koog.a2a.client.A2AClient.cachedAgentCard + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientCachedAgentCard( + name: String? = null, +): AIAgentNodeDelegate = + node(name) { agentId -> + withA2AAgentClient { + a2aClientOrThrow(agentId).cachedAgentCard() + } + } + +/** + * Creates a node that retrieves an authenticated extended agent card. + * + * @see ai.koog.a2a.client.A2AClient.getAuthenticatedExtendedAgentCard + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientGetAuthenticatedExtendedAgentCard( + name: String? = null, +): AIAgentNodeDelegate, Response> = + node(name) { request -> + withA2AAgentClient { + a2aClientOrThrow(request.agentId) + .getAuthenticatedExtendedAgentCard( + request = Request(data = request.params), + ctx = request.callContext, + ) + } + } + +/** + * Creates a node that sends a message to an A2A agent. + * + * @see ai.koog.a2a.client.A2AClient.sendMessage + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientSendMessage( + name: String? = null, +): AIAgentNodeDelegate, CommunicationEvent> = + node(name) { request -> + withA2AAgentClient { + a2aClientOrThrow(request.agentId) + .sendMessage( + request = Request(data = request.params), + ctx = request.callContext, + ) + .data + } + } + +/** + * Creates a node that sends a message to an A2A agent with streaming response. + * + * @see ai.koog.a2a.client.A2AClient.sendMessageStreaming + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientSendMessageStreaming( + name: String? = null, +): AIAgentNodeDelegate, Flow>> = + node(name) { request -> + withA2AAgentClient { + a2aClientOrThrow(request.agentId) + .sendMessageStreaming( + request = Request(data = request.params), + ctx = request.callContext, + ) + } + } + +/** + * Creates a node that retrieves a task by ID from an A2A agent. + * + * @see ai.koog.a2a.client.A2AClient.getTask + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientGetTask( + name: String? = null, +): AIAgentNodeDelegate, Task> = + node(name) { request -> + withA2AAgentClient { + a2aClientOrThrow(request.agentId) + .getTask( + request = Request(data = request.params), + ctx = request.callContext, + ) + .data + } + } + +/** + * Creates a node that cancels a task on an A2A agent. + * + * @see ai.koog.a2a.client.A2AClient.cancelTask + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientCancelTask( + name: String? = null, +): AIAgentNodeDelegate, Task> = + node(name) { request -> + withA2AAgentClient { + a2aClientOrThrow(request.agentId) + .cancelTask( + request = Request(data = request.params), + ctx = request.callContext, + ) + .data + } + } + +/** + * Creates a node that resubscribes to task events from an A2A agent. + * + * @see ai.koog.a2a.client.A2AClient.resubscribeTask + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientResubscribeTask( + name: String? = null, +): AIAgentNodeDelegate, Flow>> = + node(name) { request -> + withA2AAgentClient { + a2aClientOrThrow(request.agentId) + .resubscribeTask( + request = Request(data = request.params), + ctx = request.callContext, + ) + } + } + +/** + * Creates a node that sets push notification configuration for a task. + * + * @see ai.koog.a2a.client.A2AClient.setTaskPushNotificationConfig + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientSetTaskPushNotificationConfig( + name: String? = null, +): AIAgentNodeDelegate, TaskPushNotificationConfig> = + node(name) { request -> + withA2AAgentClient { + a2aClientOrThrow(request.agentId) + .setTaskPushNotificationConfig( + request = Request(data = request.params), + ctx = request.callContext, + ) + .data + } + } + +/** + * Creates a node that retrieves push notification configuration for a task. + * + * @see ai.koog.a2a.client.A2AClient.getTaskPushNotificationConfig + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientGetTaskPushNotificationConfig( + name: String? = null, +): AIAgentNodeDelegate, TaskPushNotificationConfig> = + node(name) { request -> + withA2AAgentClient { + a2aClientOrThrow(request.agentId) + .getTaskPushNotificationConfig( + request = Request(data = request.params), + ctx = request.callContext, + ) + .data + } + } + +/** + * Creates a node that lists all push notification configurations for a task. + * + * @see ai.koog.a2a.client.A2AClient.listTaskPushNotificationConfig + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientListTaskPushNotificationConfig( + name: String? = null, +): AIAgentNodeDelegate, List> = + node(name) { request -> + withA2AAgentClient { + a2aClientOrThrow(request.agentId) + .listTaskPushNotificationConfig( + request = Request(data = request.params), + ctx = request.callContext, + ) + .data + } + } + +/** + * Creates a node that deletes push notification configuration for a task. + * + * @see ai.koog.a2a.client.A2AClient.deleteTaskPushNotificationConfig + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AClientDeleteTaskPushNotificationConfig( + name: String? = null, +): AIAgentNodeDelegate, Unit> = + node(name) { request -> + withA2AAgentClient { + a2aClientOrThrow(request.agentId) + .deleteTaskPushNotificationConfig( + request = Request(data = request.params), + ctx = request.callContext, + ) + } + } diff --git a/agents/agents-features/agents-features-a2a-core/Module.md b/agents/agents-features/agents-features-a2a-core/Module.md new file mode 100644 index 0000000000..b40aaa5b5b --- /dev/null +++ b/agents/agents-features/agents-features-a2a-core/Module.md @@ -0,0 +1,7 @@ +# Module agents-features-a2a-core + +Core utilities and type converters for A2A (Agent-to-Agent) integration with Koog agents. + +## Overview + +This module provides type converters that bridge between A2A's message format and Koog's internal message representation, enabling seamless communication between A2A-enabled agents and Koog's agent system. diff --git a/agents/agents-features/agents-features-a2a-core/build.gradle.kts b/agents/agents-features/agents-features-a2a-core/build.gradle.kts new file mode 100644 index 0000000000..a2edcb5c55 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-core/build.gradle.kts @@ -0,0 +1,39 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + sourceSets { + commonMain { + dependencies { + api(project(":prompt:prompt-model")) + api(project(":a2a:a2a-core")) + + api(libs.kotlinx.serialization.json) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + implementation(libs.kotlinx.coroutines.test) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/Converters.kt b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/Converters.kt new file mode 100644 index 0000000000..a3f8eef5c1 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/Converters.kt @@ -0,0 +1,161 @@ +package ai.koog.agents.a2a.core + +import ai.koog.a2a.model.DataPart +import ai.koog.a2a.model.FilePart +import ai.koog.a2a.model.FileWithBytes +import ai.koog.a2a.model.FileWithUri +import ai.koog.a2a.model.Part +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.TextPart +import ai.koog.prompt.message.Attachment +import ai.koog.prompt.message.AttachmentContent +import ai.koog.prompt.message.Message +import ai.koog.prompt.message.RequestMetaInfo +import ai.koog.prompt.message.ResponseMetaInfo +import kotlinx.datetime.Clock +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonObject +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +/** + * Alias to A2A message type, to avoid clashing with Koog's message type. + * @see [ai.koog.a2a.model.Message] + */ +public typealias A2AMessage = ai.koog.a2a.model.Message + +private val json = Json { + prettyPrint = true +} + +/** + * Converts [A2AMessage] to Koog's [Message]. + * + * @param clock The clock to use for the timestamp. Defaults to [Clock.System]. + */ +public fun A2AMessage.toKoogMessage( + clock: Clock = Clock.System, +): Message { + val content = StringBuilder() + val attachments = mutableListOf() + + // Put ids information in the text content, since Koog doesn't have special fields for them. + contextId?.let { + content.appendLine("Context ID: $it") + } + + taskId?.let { + content.appendLine("Task ID: $it") + } + + referenceTaskIds?.forEach { + content.appendLine("Reference Task ID: $it") + } + + parts.forEach { part -> + when (part) { + is TextPart -> content.appendLine(part.text) + // Koog doesn't support structured data as a separate type, just append it to the content. + is DataPart -> content.appendLine(json.encodeToString(part.data)) + is FilePart -> { + val file = part.file + + val attachment = Attachment.File( + // do not have that information separately in A2A + format = "", + // if no mime type is provided, assume it's arbitrary binary data + mimeType = file.mimeType ?: "application/octet-stream", + fileName = file.name, + content = when (file) { + is FileWithBytes -> AttachmentContent.Binary.Base64(file.bytes) + is FileWithUri -> AttachmentContent.URL(file.uri) + } + ) + + attachments.add(attachment) + } + } + } + + return when (role) { + Role.User -> Message.User( + content = content.toString(), + metaInfo = RequestMetaInfo( + timestamp = clock.now() + ), + attachments = attachments.toList(), + ) + + Role.Agent -> Message.Assistant( + content = content.toString(), + metaInfo = ResponseMetaInfo( + timestamp = clock.now() + ), + attachments = attachments, + ) + } +} + +/** + * Converts Koog's [Message] to [A2AMessage]. + * + * @see ai.koog.a2a.model.Message + */ +@OptIn(ExperimentalUuidApi::class) +public fun Message.toA2AMessage( + messageId: String = Uuid.random().toString(), + contextId: String? = null, + taskId: String? = null, + referenceTaskIds: List? = null, + metadata: JsonObject? = null, + extensions: List? = null, +): A2AMessage { + val role = when (this) { + is Message.User -> Role.User + is Message.Assistant -> Role.Agent + else -> throw IllegalArgumentException("A2A can't handle this Koog message type: $this") + } + + val parts = mutableListOf() + + // Add content + parts.add(TextPart(content)) + + // Add attachments + attachments.forEach { attachment -> + val file = when (val content = attachment.content) { + // Plain text files are not supported, convert them to binary files. + is AttachmentContent.PlainText -> FileWithBytes( + bytes = AttachmentContent.Binary.Bytes(content.text.encodeToByteArray()) + .asBase64(), + name = attachment.fileName, + mimeType = attachment.mimeType, + ) + + is AttachmentContent.Binary -> FileWithBytes( + bytes = content.asBase64(), + name = attachment.fileName, + mimeType = attachment.mimeType, + ) + + is AttachmentContent.URL -> FileWithUri( + uri = content.url, + name = attachment.fileName, + mimeType = attachment.mimeType, + ) + } + + parts.add(FilePart(file)) + } + + return A2AMessage( + messageId = messageId, + role = role, + parts = parts, + extensions = extensions, + taskId = taskId, + referenceTaskIds = referenceTaskIds, + contextId = contextId, + metadata = metadata + ) +} diff --git a/agents/agents-features/agents-features-a2a-core/src/commonTest/kotlin/ai/koog/agents/a2a/core/ConvertersTest.kt b/agents/agents-features/agents-features-a2a-core/src/commonTest/kotlin/ai/koog/agents/a2a/core/ConvertersTest.kt new file mode 100644 index 0000000000..9dd6c59a69 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-core/src/commonTest/kotlin/ai/koog/agents/a2a/core/ConvertersTest.kt @@ -0,0 +1,196 @@ +package ai.koog.agents.a2a.core + +import ai.koog.a2a.model.DataPart +import ai.koog.a2a.model.FilePart +import ai.koog.a2a.model.FileWithBytes +import ai.koog.a2a.model.FileWithUri +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.TextPart +import ai.koog.prompt.message.Attachment +import ai.koog.prompt.message.AttachmentContent +import ai.koog.prompt.message.Message +import ai.koog.prompt.message.RequestMetaInfo +import ai.koog.prompt.message.ResponseMetaInfo +import kotlinx.datetime.Clock +import kotlinx.datetime.Instant +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.put +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFailsWith + +class ConvertersTest { + + private val fixedInstant: Instant = Instant.parse("2024-01-01T00:00:00Z") + private val fixedClock: Clock = object : Clock { + override fun now(): Instant = fixedInstant + } + + private val prettyJson = Json { prettyPrint = true } + + @Test + fun testA2AtoKoog_User_withTextDataAndFiles_fullObjectEquality() { + val json = buildJsonObject { put("k", "v") } + val bytesBase64 = "YmFzZTY0" // arbitrary base64 string + + val a2a = A2AMessage( + messageId = "m1", + role = Role.User, + parts = listOf( + TextPart("Hello"), + DataPart(json), + FilePart(FileWithBytes(bytes = bytesBase64, name = "file.bin", mimeType = null)), + FilePart(FileWithUri(uri = "https://example.com/doc.txt", name = "doc.txt", mimeType = "text/plain")), + ), + contextId = "ctx-123", + taskId = "task-1", + referenceTaskIds = listOf("ref-1", "ref-2"), + extensions = listOf("ext:a"), + ) + + val actual: Message = a2a.toKoogMessage(clock = fixedClock) + + val expectedContent = buildString { + appendLine("Context ID: ctx-123") + appendLine("Task ID: task-1") + appendLine("Reference Task ID: ref-1") + appendLine("Reference Task ID: ref-2") + appendLine("Hello") + appendLine(prettyJson.encodeToString(json)) + } + val expectedAttachments = listOf( + Attachment.File( + format = "", + mimeType = "application/octet-stream", + fileName = "file.bin", + content = AttachmentContent.Binary.Base64(bytesBase64) + ), + Attachment.File( + format = "", + mimeType = "text/plain", + fileName = "doc.txt", + content = AttachmentContent.URL("https://example.com/doc.txt") + ) + ) + val expected: Message = Message.User( + content = expectedContent, + metaInfo = RequestMetaInfo(timestamp = fixedInstant), + attachments = expectedAttachments + ) + + assertEquals(expected, actual) + } + + @Test + fun testA2AtoKoog_Agent_fullObjectEquality() { + val a2a = A2AMessage( + messageId = "m2", + role = Role.Agent, + parts = listOf(TextPart("Agent says hi")), + ) + + val actual = a2a.toKoogMessage(clock = fixedClock) + + val expected = Message.Assistant( + content = buildString { appendLine("Agent says hi") }, + metaInfo = ResponseMetaInfo(timestamp = fixedInstant), + attachments = emptyList() + ) + + assertEquals(expected, actual) + } + + @Test + fun testKoogToA2A_User_withPlainTextBinaryAndUrlAttachments_fullObjectEquality() { + val plain = Attachment.File( + content = AttachmentContent.PlainText("abc"), + format = "txt", + mimeType = "text/plain", + fileName = "note.txt", + ) + val bytes = byteArrayOf(1, 2, 3) + val bin = Attachment.File( + content = AttachmentContent.Binary.Bytes(bytes), + format = "bin", + mimeType = "application/octet-stream", + fileName = "bytes.bin", + ) + val url = Attachment.File( + content = AttachmentContent.URL("https://example.com/a.png"), + format = "png", + mimeType = "image/png", + fileName = "a.png", + ) + + val koog: Message = Message.User( + content = "Hi", + metaInfo = RequestMetaInfo(timestamp = fixedInstant), + attachments = listOf(plain, bin, url) + ) + + val actual = koog.toA2AMessage( + messageId = "mid", + contextId = "ctx", + taskId = "task", + referenceTaskIds = listOf("r1"), + ) + + val expectedPlainBase64 = AttachmentContent.Binary.Bytes("abc".encodeToByteArray()).asBase64() + val expectedBinBase64 = AttachmentContent.Binary.Bytes(bytes).asBase64() + val expected = A2AMessage( + messageId = "mid", + role = Role.User, + parts = listOf( + TextPart("Hi"), + FilePart(FileWithBytes(bytes = expectedPlainBase64, name = "note.txt", mimeType = "text/plain")), + FilePart( + FileWithBytes( + bytes = expectedBinBase64, + name = "bytes.bin", + mimeType = "application/octet-stream" + ) + ), + FilePart(FileWithUri(uri = "https://example.com/a.png", name = "a.png", mimeType = "image/png")), + ), + extensions = null, + taskId = "task", + referenceTaskIds = listOf("r1"), + contextId = "ctx", + metadata = null, + ) + + assertEquals(expected, actual) + } + + @Test + fun testKoogToA2A_Assistant_fullObjectEquality() { + val koog: Message = Message.Assistant( + content = "Answer", + metaInfo = ResponseMetaInfo(timestamp = fixedInstant), + ) + val actual = koog.toA2AMessage(messageId = "m3") + val expected = A2AMessage( + messageId = "m3", + role = Role.Agent, + parts = listOf(TextPart("Answer")), + extensions = null, + taskId = null, + referenceTaskIds = null, + contextId = null, + metadata = null, + ) + assertEquals(expected, actual) + } + + @Test + fun testKoogToA2A_unsupportedKoogMessageThrows() { + val sys: Message = Message.System( + content = "system", + metaInfo = RequestMetaInfo(timestamp = fixedInstant) + ) + assertFailsWith { + sys.toA2AMessage(messageId = "m4") + } + } +} diff --git a/agents/agents-features/agents-features-a2a-server/Module.md b/agents/agents-features/agents-features-a2a-server/Module.md new file mode 100644 index 0000000000..5e97a68301 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-server/Module.md @@ -0,0 +1,17 @@ +# Module agents-features-a2a-server + +Agent feature for enabling A2A server capabilities within Koog agents. + +## Overview + +This module provides an agent feature that enables Koog agents to act as A2A servers, allowing them to receive and process requests from A2A clients. When installed, agents can access the A2A request context and event processor, enabling them to handle incoming messages, manage tasks, interact with storage, and send events back to clients. + +## Key Components + +- **`A2AAgentServer`**: Feature providing access to A2A request context and event processor +- **Agent Nodes**: Pre-built nodes for messaging, task events, and storage operations + +## Related Modules + +- `agents-features-a2a-client`: Client-side A2A agent feature +- `agents-features-a2a-core`: Core A2A utilities and converters diff --git a/agents/agents-features/agents-features-a2a-server/build.gradle.kts b/agents/agents-features/agents-features-a2a-server/build.gradle.kts new file mode 100644 index 0000000000..28fa55136e --- /dev/null +++ b/agents/agents-features/agents-features-a2a-server/build.gradle.kts @@ -0,0 +1,40 @@ +import ai.koog.gradle.publish.maven.Publishing.publishToMaven + +group = rootProject.group +version = rootProject.version + +plugins { + id("ai.kotlin.multiplatform") + alias(libs.plugins.kotlin.serialization) +} + +kotlin { + sourceSets { + commonMain { + dependencies { + api(project(":agents:agents-core")) + api(project(":a2a:a2a-server")) + api(project(":agents:agents-features:agents-features-a2a-core")) + + api(libs.kotlinx.serialization.json) + } + } + + commonTest { + dependencies { + implementation(kotlin("test")) + implementation(libs.kotlinx.coroutines.test) + } + } + + jvmTest { + dependencies { + implementation(kotlin("test-junit5")) + } + } + } + + explicitApi() +} + +publishToMaven() diff --git a/agents/agents-features/agents-features-a2a-server/src/commonMain/kotlin/ai/koog/agents/a2a/server/feature/A2AAgentServer.kt b/agents/agents-features/agents-features-a2a-server/src/commonMain/kotlin/ai/koog/agents/a2a/server/feature/A2AAgentServer.kt new file mode 100644 index 0000000000..ad1a757887 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-server/src/commonMain/kotlin/ai/koog/agents/a2a/server/feature/A2AAgentServer.kt @@ -0,0 +1,112 @@ +package ai.koog.agents.a2a.server.feature + +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.server.session.RequestContext +import ai.koog.a2a.server.session.SessionEventProcessor +import ai.koog.agents.core.agent.context.AIAgentContext +import ai.koog.agents.core.agent.entity.AIAgentStorageKey +import ai.koog.agents.core.agent.entity.createStorageKey +import ai.koog.agents.core.feature.AIAgentGraphFeature +import ai.koog.agents.core.feature.AIAgentGraphPipeline +import ai.koog.agents.core.feature.AIAgentNonGraphFeature +import ai.koog.agents.core.feature.AIAgentNonGraphPipeline +import ai.koog.agents.core.feature.config.FeatureConfig +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract + +/** + * Agent feature that enables A2A server mode by providing access to the request context and event processor + * from within agent strategies. + * + * This feature is designed to be used within agents that are hosted as A2A servers via [ai.koog.a2a.server.agent.AgentExecutor]. + * When installed, it makes the [context] and [eventProcessor] available to agent nodes, allowing them to: + * - Access incoming A2A messages and task information + * - Send outgoing messages and task events back to the client + * - Interact with message and task storage + * - Manage the A2A session lifecycle + * + * The feature provides convenience nodes for common A2A operations like + * sending messages, updating task status, and accessing storage. + * + * @property context The A2A [RequestContext] from [ai.koog.a2a.server.agent.AgentExecutor.execute] + * @property eventProcessor The A2A [SessionEventProcessor] from [ai.koog.a2a.server.agent.AgentExecutor.execute] + * + * @see ai.koog.a2a.server.agent.AgentExecutor + * @see ai.koog.a2a.server.session.RequestContext + * @see ai.koog.a2a.server.session.SessionEventProcessor + */ +public class A2AAgentServer( + public val context: RequestContext, + public val eventProcessor: SessionEventProcessor +) { + /** + * Configuration for the [A2AAgentServer] feature. + */ + public class Config : FeatureConfig() { + /** + * The A2A [RequestContext] from [ai.koog.a2a.server.agent.AgentExecutor.execute] + * @see RequestContext + */ + public lateinit var context: RequestContext + + /** + * The A2A [SessionEventProcessor] from [ai.koog.a2a.server.agent.AgentExecutor.execute] + * @see SessionEventProcessor + */ + public lateinit var eventProcessor: SessionEventProcessor + } + + public companion object Feature : + AIAgentGraphFeature, + AIAgentNonGraphFeature { + + override val key: AIAgentStorageKey = + createStorageKey("agents-features-a2a-server") + + override fun createInitialConfig(): Config = Config() + + override fun install( + config: Config, + pipeline: AIAgentGraphPipeline + ) { + pipeline.interceptContextAgentFeature(this) { _ -> + A2AAgentServer(config.context, config.eventProcessor) + } + } + + override fun install( + config: Config, + pipeline: AIAgentNonGraphPipeline + ) { + pipeline.interceptContextAgentFeature(this) { + A2AAgentServer(config.context, config.eventProcessor) + } + } + } +} + +/** + * Retrieves the [A2AAgentServer] feature from the agent context. + * + * @return The installed A2AAgentExecutor feature + * @throws IllegalStateException if the feature is not installed + */ +public fun AIAgentContext.a2aAgentServer(): A2AAgentServer = featureOrThrow(A2AAgentServer.Feature) + +/** + * Executes an action with the [A2AAgentServer] feature as the receiver. + * This is a convenience function that retrieves the feature and provides it as the receiver for the action block. + * + * @param action The action to execute with A2AAgentExecutor as receiver + * @return The result of the action + * @throws IllegalStateException if the feature is not installed + */ +@OptIn(ExperimentalContracts::class) +public inline fun AIAgentContext.withA2AAgentServer(action: A2AAgentServer.() -> T): T { + contract { + callsInPlace(action, InvocationKind.AT_MOST_ONCE) + } + + return a2aAgentServer().action() +} diff --git a/agents/agents-features/agents-features-a2a-server/src/commonMain/kotlin/ai/koog/agents/a2a/server/feature/A2AAgentServerNodes.kt b/agents/agents-features/agents-features-a2a-server/src/commonMain/kotlin/ai/koog/agents/a2a/server/feature/A2AAgentServerNodes.kt new file mode 100644 index 0000000000..de2a5e5086 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-server/src/commonMain/kotlin/ai/koog/agents/a2a/server/feature/A2AAgentServerNodes.kt @@ -0,0 +1,195 @@ +package ai.koog.agents.a2a.server.feature + +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskEvent +import ai.koog.agents.a2a.core.A2AMessage +import ai.koog.agents.core.dsl.builder.AIAgentBuilderDslMarker +import ai.koog.agents.core.dsl.builder.AIAgentNodeDelegate +import ai.koog.agents.core.dsl.builder.AIAgentSubgraphBuilderBase +import kotlinx.serialization.Serializable + +/** + * Creates a node that sends an A2A message back to the client. + * + * @param name Optional node name for debugging and tracing + * @param saveToStorage If true, also saves the message to storage before sending + * @return A node that sends the message and passes it through unchanged + * @see ai.koog.a2a.server.session.SessionEventProcessor.sendMessage + * @see ai.koog.a2a.server.messages.MessageStorage.save + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2ARespondMessage( + name: String? = null, + saveToStorage: Boolean = false, +): AIAgentNodeDelegate = + node(name) { message -> + withA2AAgentServer { + eventProcessor.sendMessage(message) + if (saveToStorage) { + context.messageStorage.save(message) + } + } + + message + } + +/** + * Creates a node that sends a task event (status update, creation, etc.) to the client. + * + * @param name Optional node name for debugging and tracing + * @return A node that sends the task event and passes it through unchanged + * @see ai.koog.a2a.server.session.SessionEventProcessor.sendTaskEvent + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2ARespondTaskEvent( + name: String? = null, +): AIAgentNodeDelegate = + node(name) { event -> + withA2AAgentServer { + eventProcessor.sendTaskEvent(event) + } + + event + } + +/** + * Creates a node that saves a message to storage without sending it to the client. + * + * @param name Optional node name for debugging and tracing + * @return A node that saves the message and passes it through unchanged + * @see ai.koog.a2a.server.messages.MessageStorage.save + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AMessageStorageSave( + name: String? = null, +): AIAgentNodeDelegate = + node(name) { event -> + withA2AAgentServer { + context.messageStorage.save(event) + } + + event + } + +/** + * Creates a node that loads all messages from storage for the current context. + * + * @param name Optional node name for debugging and tracing + * @return A node that returns the list of all stored messages + * @see ai.koog.a2a.server.messages.MessageStorage.getByContext + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AMessageStorageLoad( + name: String? = null, +): AIAgentNodeDelegate> = + node(name) { + withA2AAgentServer { + context.messageStorage.getAll() + } + } + +/** + * Creates a node that replaces all messages in storage for the current context. + * + * @param name Optional node name for debugging and tracing + * @return A node that replaces all stored messages with the provided list + * @see ai.koog.a2a.server.messages.MessageStorage.replaceByContext + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AMessageStorageReplace( + name: String? = null, +): AIAgentNodeDelegate, Unit> = + node(name) { + withA2AAgentServer { + context.messageStorage.replaceAll(it) + } + } + +/** + * Creates a node that loads all messages from storage for the current context. + * + * This is an alias for [nodeA2AMessageStorageLoad]. + * + * @param name Optional node name for debugging and tracing + * @return A node that returns the list of all stored messages + * @see ai.koog.a2a.server.messages.MessageStorage.getByContext + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AMessagesContextLoad( + name: String? = null, +): AIAgentNodeDelegate> = + node(name) { + withA2AAgentServer { + context.messageStorage.getAll() + } + } + +/** + * Parameters for retrieving a single task from storage. + * + * @property taskId The unique task identifier + * @property historyLength Maximum number of messages to include in conversation history. Set to `null` for all messages + * @property includeArtifacts Whether to include task artifacts in the response + */ +@Serializable +public data class A2ATaskGetParams( + val taskId: String, + val historyLength: Int? = 0, + val includeArtifacts: Boolean = false, +) + +/** + * Creates a node that retrieves a single task by ID from storage. + * + * @param name Optional node name for debugging and tracing + * @return A node that returns the task or null if not found + * @see ai.koog.a2a.server.tasks.TaskStorage.get + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2ATaskGet( + name: String? = null, +): AIAgentNodeDelegate = + node(name) { params -> + withA2AAgentServer { + context.taskStorage.get( + taskId = params.taskId, + historyLength = params.historyLength, + includeArtifacts = params.includeArtifacts + ) + } + } + +/** + * Parameters for retrieving multiple tasks from storage. + * + * @property taskIds List of task identifiers to retrieve + * @property historyLength Maximum number of messages to include in conversation history. Set to `null` for all messages + * @property includeArtifacts Whether to include task artifacts in the response + */ +@Serializable +public data class A2ATaskGetAllParams( + val taskIds: List, + val historyLength: Int? = 0, + val includeArtifacts: Boolean = false, +) + +/** + * Creates a node that retrieves multiple tasks by their IDs from storage. + * + * @param name Optional node name for debugging and tracing + * @return A node that returns the list of found tasks (may be fewer than requested) + * @see ai.koog.a2a.server.tasks.TaskStorage.getAll + */ +@AIAgentBuilderDslMarker +public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2ATaskGetAll( + name: String? = null, +): AIAgentNodeDelegate> = + node(name) { params -> + withA2AAgentServer { + context.taskStorage.getAll( + taskIds = params.taskIds, + historyLength = params.historyLength, + includeArtifacts = params.includeArtifacts + ) + } + } diff --git a/koog-agents/build.gradle.kts b/koog-agents/build.gradle.kts index e060385f96..fb861d712a 100644 --- a/koog-agents/build.gradle.kts +++ b/koog-agents/build.gradle.kts @@ -30,6 +30,10 @@ val excluded = setOf( ":a2a:a2a-test", ":a2a:test-tck:a2a-test-server-tck", + ":agents:agents-features:agents-features-a2a-core", + ":agents:agents-features:agents-features-a2a-server", + ":agents:agents-features:agents-features-a2a-client", + project.path, // the current project should not depend on itself ) diff --git a/settings.gradle.kts b/settings.gradle.kts index 248c910e30..563ed90936 100755 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -17,6 +17,9 @@ include(":agents:agents-features:agents-features-sql") include(":agents:agents-features:agents-features-trace") include(":agents:agents-features:agents-features-tokenizer") include(":agents:agents-features:agents-features-snapshot") +include(":agents:agents-features:agents-features-a2a-core") +include(":agents:agents-features:agents-features-a2a-server") +include(":agents:agents-features:agents-features-a2a-client") include(":agents:agents-mcp") include(":agents:agents-mcp-server") From 489b7d34c1197a7f5d8c01458eaf6c238799f427 Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Tue, 30 Sep 2025 15:02:49 +0200 Subject: [PATCH 36/52] [a2a] Fix compilation issues and build configuration --- a2a/a2a-client/build.gradle.kts | 27 ++- .../client/A2AClientJsonRpcIntegrationTest.kt | 7 +- .../server/A2AServerJsonRpcIntegrationTest.kt | 12 +- .../ai/koog/a2a/test/BaseA2AProtocolTest.kt | 0 .../agents-features-a2a-core/build.gradle.kts | 1 + .../ai/koog/agents/a2a/core/Converters.kt | 161 ------------------ .../koog/agents/a2a/core/MessageConverters.kt | 110 ++++++++++++ .../ai/koog/agents/a2a/core/PartConverters.kt | 114 +++++++++++++ ...ertersTest.kt => MessageConvertersTest.kt} | 36 ++-- 9 files changed, 287 insertions(+), 181 deletions(-) rename a2a/a2a-test/src/{commonMain => jvmMain}/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt (100%) delete mode 100644 agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/Converters.kt create mode 100644 agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageConverters.kt create mode 100644 agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/PartConverters.kt rename agents/agents-features/agents-features-a2a-core/src/commonTest/kotlin/ai/koog/agents/a2a/core/{ConvertersTest.kt => MessageConvertersTest.kt} (87%) diff --git a/a2a/a2a-client/build.gradle.kts b/a2a/a2a-client/build.gradle.kts index 4c0a68f204..efc7cd1a31 100644 --- a/a2a/a2a-client/build.gradle.kts +++ b/a2a/a2a-client/build.gradle.kts @@ -1,4 +1,7 @@ import ai.koog.gradle.publish.maven.Publishing.publishToMaven +import org.gradle.internal.os.OperatingSystem +import org.gradle.kotlin.dsl.support.serviceOf +import java.io.ByteArrayOutputStream group = rootProject.group version = rootProject.version @@ -59,8 +62,30 @@ tasks.register("dockerBuildTestPythonA2AServer") { description = "Build Python A2A test server image" workingDir = file("../test-python-a2a-server") commandLine = listOf("docker", "build", "-t", "test-python-a2a-server", ".") -} + onlyIf { + // do not attempt to check for docker on windows + if (OperatingSystem.current().isWindows) { + return@onlyIf false + } + + try { + val buffer = ByteArrayOutputStream() + + serviceOf().exec { + commandLine = listOf("docker", "--version") + standardOutput = buffer + errorOutput = buffer + } + + true + } catch (_: Exception) { + logger.warn("Docker not available. Skipping task 'dockerBuildTestPythonA2AServer'") + + false + } + } +} tasks.named("jvmTest") { dependsOn("dockerBuildTestPythonA2AServer") } diff --git a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt index cb95bc4618..c0a850a882 100644 --- a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt +++ b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt @@ -9,11 +9,13 @@ import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.AfterAll import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.TestInstance +import org.junit.jupiter.api.condition.EnabledOnOs +import org.junit.jupiter.api.condition.OS import org.testcontainers.containers.GenericContainer import org.testcontainers.containers.wait.strategy.Wait import org.testcontainers.junit.jupiter.Container import org.testcontainers.junit.jupiter.Testcontainers -import kotlin.time.Duration.Companion.seconds +import kotlin.time.Duration.Companion.minutes /** * Integration test class for testing the JSON-RPC HTTP communication in the A2A client context. @@ -22,6 +24,7 @@ import kotlin.time.Duration.Companion.seconds */ @TestInstance(TestInstance.Lifecycle.PER_CLASS) @Testcontainers +@EnabledOnOs(OS.LINUX) class A2AClientJsonRpcIntegrationTest : BaseA2AProtocolTest() { companion object { @Container @@ -31,7 +34,7 @@ class A2AClientJsonRpcIntegrationTest : BaseA2AProtocolTest() { .waitingFor(Wait.forListeningPort()) } - override val testTimeout = 10.seconds + override val testTimeout = 2.minutes private val httpClient = HttpClient { install(Logging) { diff --git a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/A2AServerJsonRpcIntegrationTest.kt b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/A2AServerJsonRpcIntegrationTest.kt index 4a93328e12..ebd57917ec 100644 --- a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/A2AServerJsonRpcIntegrationTest.kt +++ b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/A2AServerJsonRpcIntegrationTest.kt @@ -30,6 +30,7 @@ import io.kotest.matchers.string.shouldStartWith import io.kotest.matchers.types.shouldBeInstanceOf import io.ktor.client.HttpClient import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.HttpTimeout import io.ktor.client.plugins.logging.LogLevel import io.ktor.client.plugins.logging.Logging import io.ktor.server.netty.Netty @@ -47,7 +48,7 @@ import org.junit.jupiter.api.TestInstance import java.net.ServerSocket import kotlin.test.BeforeTest import kotlin.test.Test -import kotlin.time.Duration.Companion.seconds +import kotlin.time.Duration.Companion.minutes import kotlin.uuid.ExperimentalUuidApi import kotlin.uuid.Uuid @@ -59,7 +60,7 @@ import kotlin.uuid.Uuid @OptIn(ExperimentalUuidApi::class) @TestInstance(TestInstance.Lifecycle.PER_CLASS) class A2AServerJsonRpcIntegrationTest : BaseA2AProtocolTest() { - override val testTimeout = 10.seconds + override val testTimeout = 2.minutes private var testPort: Int? = null private val testPath = "/a2a" @@ -110,6 +111,10 @@ class A2AServerJsonRpcIntegrationTest : BaseA2AProtocolTest() { install(Logging) { level = LogLevel.ALL } + + install(HttpTimeout) { + requestTimeoutMillis = testTimeout.inWholeMilliseconds + } } clientTransport = HttpJSONRPCClientTransport(serverUrl, httpClient) @@ -118,7 +123,8 @@ class A2AServerJsonRpcIntegrationTest : BaseA2AProtocolTest() { transport = clientTransport, agentCardResolver = UrlAgentCardResolver( baseUrl = serverUrl, - path = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH + path = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH, + baseHttpClient = httpClient, ) ) } diff --git a/a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt b/a2a/a2a-test/src/jvmMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt similarity index 100% rename from a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt rename to a2a/a2a-test/src/jvmMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt diff --git a/agents/agents-features/agents-features-a2a-core/build.gradle.kts b/agents/agents-features/agents-features-a2a-core/build.gradle.kts index a2edcb5c55..4de791de25 100644 --- a/agents/agents-features/agents-features-a2a-core/build.gradle.kts +++ b/agents/agents-features/agents-features-a2a-core/build.gradle.kts @@ -13,6 +13,7 @@ kotlin { commonMain { dependencies { api(project(":prompt:prompt-model")) + api(project(":prompt:prompt-xml")) api(project(":a2a:a2a-core")) api(libs.kotlinx.serialization.json) diff --git a/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/Converters.kt b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/Converters.kt deleted file mode 100644 index a3f8eef5c1..0000000000 --- a/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/Converters.kt +++ /dev/null @@ -1,161 +0,0 @@ -package ai.koog.agents.a2a.core - -import ai.koog.a2a.model.DataPart -import ai.koog.a2a.model.FilePart -import ai.koog.a2a.model.FileWithBytes -import ai.koog.a2a.model.FileWithUri -import ai.koog.a2a.model.Part -import ai.koog.a2a.model.Role -import ai.koog.a2a.model.TextPart -import ai.koog.prompt.message.Attachment -import ai.koog.prompt.message.AttachmentContent -import ai.koog.prompt.message.Message -import ai.koog.prompt.message.RequestMetaInfo -import ai.koog.prompt.message.ResponseMetaInfo -import kotlinx.datetime.Clock -import kotlinx.serialization.json.Json -import kotlinx.serialization.json.JsonObject -import kotlin.uuid.ExperimentalUuidApi -import kotlin.uuid.Uuid - -/** - * Alias to A2A message type, to avoid clashing with Koog's message type. - * @see [ai.koog.a2a.model.Message] - */ -public typealias A2AMessage = ai.koog.a2a.model.Message - -private val json = Json { - prettyPrint = true -} - -/** - * Converts [A2AMessage] to Koog's [Message]. - * - * @param clock The clock to use for the timestamp. Defaults to [Clock.System]. - */ -public fun A2AMessage.toKoogMessage( - clock: Clock = Clock.System, -): Message { - val content = StringBuilder() - val attachments = mutableListOf() - - // Put ids information in the text content, since Koog doesn't have special fields for them. - contextId?.let { - content.appendLine("Context ID: $it") - } - - taskId?.let { - content.appendLine("Task ID: $it") - } - - referenceTaskIds?.forEach { - content.appendLine("Reference Task ID: $it") - } - - parts.forEach { part -> - when (part) { - is TextPart -> content.appendLine(part.text) - // Koog doesn't support structured data as a separate type, just append it to the content. - is DataPart -> content.appendLine(json.encodeToString(part.data)) - is FilePart -> { - val file = part.file - - val attachment = Attachment.File( - // do not have that information separately in A2A - format = "", - // if no mime type is provided, assume it's arbitrary binary data - mimeType = file.mimeType ?: "application/octet-stream", - fileName = file.name, - content = when (file) { - is FileWithBytes -> AttachmentContent.Binary.Base64(file.bytes) - is FileWithUri -> AttachmentContent.URL(file.uri) - } - ) - - attachments.add(attachment) - } - } - } - - return when (role) { - Role.User -> Message.User( - content = content.toString(), - metaInfo = RequestMetaInfo( - timestamp = clock.now() - ), - attachments = attachments.toList(), - ) - - Role.Agent -> Message.Assistant( - content = content.toString(), - metaInfo = ResponseMetaInfo( - timestamp = clock.now() - ), - attachments = attachments, - ) - } -} - -/** - * Converts Koog's [Message] to [A2AMessage]. - * - * @see ai.koog.a2a.model.Message - */ -@OptIn(ExperimentalUuidApi::class) -public fun Message.toA2AMessage( - messageId: String = Uuid.random().toString(), - contextId: String? = null, - taskId: String? = null, - referenceTaskIds: List? = null, - metadata: JsonObject? = null, - extensions: List? = null, -): A2AMessage { - val role = when (this) { - is Message.User -> Role.User - is Message.Assistant -> Role.Agent - else -> throw IllegalArgumentException("A2A can't handle this Koog message type: $this") - } - - val parts = mutableListOf() - - // Add content - parts.add(TextPart(content)) - - // Add attachments - attachments.forEach { attachment -> - val file = when (val content = attachment.content) { - // Plain text files are not supported, convert them to binary files. - is AttachmentContent.PlainText -> FileWithBytes( - bytes = AttachmentContent.Binary.Bytes(content.text.encodeToByteArray()) - .asBase64(), - name = attachment.fileName, - mimeType = attachment.mimeType, - ) - - is AttachmentContent.Binary -> FileWithBytes( - bytes = content.asBase64(), - name = attachment.fileName, - mimeType = attachment.mimeType, - ) - - is AttachmentContent.URL -> FileWithUri( - uri = content.url, - name = attachment.fileName, - mimeType = attachment.mimeType, - ) - } - - parts.add(FilePart(file)) - } - - return A2AMessage( - messageId = messageId, - role = role, - parts = parts, - extensions = extensions, - taskId = taskId, - referenceTaskIds = referenceTaskIds, - contextId = contextId, - metadata = metadata - ) -} diff --git a/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageConverters.kt b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageConverters.kt new file mode 100644 index 0000000000..89260eafe0 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageConverters.kt @@ -0,0 +1,110 @@ +package ai.koog.agents.a2a.core + +import ai.koog.a2a.model.Role +import ai.koog.prompt.message.Message +import ai.koog.prompt.message.RequestMetaInfo +import ai.koog.prompt.message.ResponseMetaInfo +import ai.koog.prompt.xml.xml +import kotlinx.datetime.Clock +import kotlinx.serialization.json.JsonObject +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +/** + * Alias to A2A message type, to avoid clashing with Koog's message type. + * @see [ai.koog.a2a.model.Message] + */ +public typealias A2AMessage = ai.koog.a2a.model.Message + +/** + * Converts [A2AMessage] to Koog's [Message]. + * + * @param clock The clock to use for the timestamp. Defaults to [Clock.System]. + */ +public fun A2AMessage.toKoogMessage( + clock: Clock = Clock.System, +): Message { + // Convert to the actual message content and attachments. + val (messageContent, attachments) = parts.map { it.toKoogPart() }.toContentWithAttachments() + + val content = xml { + tag("message_content") { + +messageContent + } + + // Put ids information in the text content, since Koog doesn't have special fields for them. + tag("a2a_message_metadata") { + contextId?.let { + tag("context_id") { +it } + } + + taskId?.let { + tag("task_id") { +it } + } + + referenceTaskIds + ?.takeIf { it.isNotEmpty() } + ?.let { referenceIds -> + tag("reference_task_ids") { + referenceIds.forEach { + tag("id") { +it } + } + } + } + } + } + + return when (role) { + Role.User -> Message.User( + content = content.toString(), + metaInfo = RequestMetaInfo( + timestamp = clock.now() + ), + attachments = attachments.toList(), + ) + + Role.Agent -> Message.Assistant( + content = content.toString(), + metaInfo = ResponseMetaInfo( + timestamp = clock.now() + ), + attachments = attachments, + ) + } +} + +/** + * Converts Koog's [Message] to [A2AMessage]. + * + * @see ai.koog.a2a.model.Message + */ +@OptIn(ExperimentalUuidApi::class) +public fun Message.toA2AMessage( + messageId: String = Uuid.random().toString(), + contextId: String? = null, + taskId: String? = null, + referenceTaskIds: List? = null, + metadata: JsonObject? = null, + extensions: List? = null, +): A2AMessage { + val role = when (this) { + is Message.User -> Role.User + is Message.Assistant -> Role.Agent + else -> throw IllegalArgumentException("A2A can't handle this Koog message type: $this") + } + + // Add parts + val parts = (listOf(KoogContentPart(content)) + attachments.map { KoogAttachmentPart(it) }) + .map { it.toA2APart() } + + return A2AMessage( + messageId = messageId, + role = role, + parts = parts, + extensions = extensions, + taskId = taskId, + referenceTaskIds = referenceTaskIds, + contextId = contextId, + metadata = metadata + ) +} diff --git a/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/PartConverters.kt b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/PartConverters.kt new file mode 100644 index 0000000000..6401e14fe6 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/PartConverters.kt @@ -0,0 +1,114 @@ +package ai.koog.agents.a2a.core + +import ai.koog.a2a.model.DataPart +import ai.koog.a2a.model.FilePart +import ai.koog.a2a.model.FileWithBytes +import ai.koog.a2a.model.FileWithUri +import ai.koog.a2a.model.Part +import ai.koog.a2a.model.TextPart +import ai.koog.prompt.message.Attachment +import ai.koog.prompt.message.AttachmentContent +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json + +/** + * Koog doesn't have proper support for message parts yet, but A2A operates with parts. + * This is a helper mapping structure, to map A2A parts either to part of the textual content or to the attachment. + */ +@Serializable +public sealed interface KoogPart + +/** + * Text content part representing part of the [ai.koog.prompt.message.Message.content] + */ +@Serializable +public data class KoogContentPart(val content: String) : KoogPart + +/** + * Attachment part representing part of the [ai.koog.prompt.message.Message.Assistant.attachments] or + * [ai.koog.prompt.message.Message.User.attachments] + */ +@Serializable +public data class KoogAttachmentPart(val attachment: Attachment) : KoogPart + +private val json = Json { + prettyPrint = true +} + +/** + * Converts A2A [Part] to Koog [KoogPart]. + */ +public fun Part.toKoogPart(): KoogPart = when (this) { + is TextPart -> KoogContentPart(this.text) + // Koog doesn't support structured data as a separate type, treat it as a content part. + + is DataPart -> KoogContentPart(json.encodeToString(this.data)) + + is FilePart -> { + val file = this.file // to enable smart cast + + val attachment = Attachment.File( + // do not have that information separately in A2A + format = "", + // if no mime type is provided, assume it's arbitrary binary data + mimeType = file.mimeType ?: "application/octet-stream", + fileName = file.name, + content = when (file) { + is FileWithBytes -> AttachmentContent.Binary.Base64(file.bytes) + is FileWithUri -> AttachmentContent.URL(file.uri) + } + ) + + KoogAttachmentPart(attachment) + } +} + +/** + * Converts Koog [KoogPart] to A2A [Part]. + */ +public fun KoogPart.toA2APart(): Part = when (this) { + is KoogContentPart -> TextPart(this.content) + + is KoogAttachmentPart -> { + val file = when (val content = attachment.content) { + // Plain text files are not supported, convert them to binary files. + is AttachmentContent.PlainText -> FileWithBytes( + bytes = AttachmentContent.Binary.Bytes(content.text.encodeToByteArray()) + .asBase64(), + name = attachment.fileName, + mimeType = attachment.mimeType, + ) + + is AttachmentContent.Binary -> FileWithBytes( + bytes = content.asBase64(), + name = attachment.fileName, + mimeType = attachment.mimeType, + ) + + is AttachmentContent.URL -> FileWithUri( + uri = content.url, + name = attachment.fileName, + mimeType = attachment.mimeType, + ) + } + + FilePart(file) + } +} + +/** + * Helper method to convert an iterable of [KoogPart] to the pair of the text content and attachments. + */ +public fun Iterable.toContentWithAttachments(): Pair> { + val content = StringBuilder() + val attachments = mutableListOf() + + forEach { part -> + when (part) { + is KoogContentPart -> content.appendLine(part.content) + is KoogAttachmentPart -> attachments.add(part.attachment) + } + } + + return content.toString().trim() to attachments +} diff --git a/agents/agents-features/agents-features-a2a-core/src/commonTest/kotlin/ai/koog/agents/a2a/core/ConvertersTest.kt b/agents/agents-features/agents-features-a2a-core/src/commonTest/kotlin/ai/koog/agents/a2a/core/MessageConvertersTest.kt similarity index 87% rename from agents/agents-features/agents-features-a2a-core/src/commonTest/kotlin/ai/koog/agents/a2a/core/ConvertersTest.kt rename to agents/agents-features/agents-features-a2a-core/src/commonTest/kotlin/ai/koog/agents/a2a/core/MessageConvertersTest.kt index 9dd6c59a69..36d9823543 100644 --- a/agents/agents-features/agents-features-a2a-core/src/commonTest/kotlin/ai/koog/agents/a2a/core/ConvertersTest.kt +++ b/agents/agents-features/agents-features-a2a-core/src/commonTest/kotlin/ai/koog/agents/a2a/core/MessageConvertersTest.kt @@ -11,6 +11,7 @@ import ai.koog.prompt.message.AttachmentContent import ai.koog.prompt.message.Message import ai.koog.prompt.message.RequestMetaInfo import ai.koog.prompt.message.ResponseMetaInfo +import ai.koog.prompt.xml.xml import kotlinx.datetime.Clock import kotlinx.datetime.Instant import kotlinx.serialization.json.Json @@ -20,8 +21,7 @@ import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFailsWith -class ConvertersTest { - +class MessageConvertersTest { private val fixedInstant: Instant = Instant.parse("2024-01-01T00:00:00Z") private val fixedClock: Clock = object : Clock { override fun now(): Instant = fixedInstant @@ -30,7 +30,7 @@ class ConvertersTest { private val prettyJson = Json { prettyPrint = true } @Test - fun testA2AtoKoog_User_withTextDataAndFiles_fullObjectEquality() { + fun testA2AtoKoog_User_withTextDataAndFiles() { val json = buildJsonObject { put("k", "v") } val bytesBase64 = "YmFzZTY0" // arbitrary base64 string @@ -51,13 +51,18 @@ class ConvertersTest { val actual: Message = a2a.toKoogMessage(clock = fixedClock) - val expectedContent = buildString { - appendLine("Context ID: ctx-123") - appendLine("Task ID: task-1") - appendLine("Reference Task ID: ref-1") - appendLine("Reference Task ID: ref-2") - appendLine("Hello") - appendLine(prettyJson.encodeToString(json)) + val expectedContent = xml { + tag("message_content") { + +("Hello\n" + prettyJson.encodeToString(json)) + } + tag("a2a_message_metadata") { + tag("context_id") { +"ctx-123" } + tag("task_id") { +"task-1" } + tag("reference_task_ids") { + tag("id") { +"ref-1" } + tag("id") { +"ref-2" } + } + } } val expectedAttachments = listOf( Attachment.File( @@ -83,7 +88,7 @@ class ConvertersTest { } @Test - fun testA2AtoKoog_Agent_fullObjectEquality() { + fun testA2AtoKoog_Agent() { val a2a = A2AMessage( messageId = "m2", role = Role.Agent, @@ -93,7 +98,10 @@ class ConvertersTest { val actual = a2a.toKoogMessage(clock = fixedClock) val expected = Message.Assistant( - content = buildString { appendLine("Agent says hi") }, + content = xml { + tag("message_content") { +"Agent says hi" } + tag("a2a_message_metadata") {} + }, metaInfo = ResponseMetaInfo(timestamp = fixedInstant), attachments = emptyList() ) @@ -102,7 +110,7 @@ class ConvertersTest { } @Test - fun testKoogToA2A_User_withPlainTextBinaryAndUrlAttachments_fullObjectEquality() { + fun testKoogToA2A_User_withPlainTextBinaryAndUrlAttachments() { val plain = Attachment.File( content = AttachmentContent.PlainText("abc"), format = "txt", @@ -164,7 +172,7 @@ class ConvertersTest { } @Test - fun testKoogToA2A_Assistant_fullObjectEquality() { + fun testKoogToA2A_Assistant() { val koog: Message = Message.Assistant( content = "Answer", metaInfo = ResponseMetaInfo(timestamp = fixedInstant), From b6371beda1d1aad6cd925f864f524030e9b90061 Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Tue, 30 Sep 2025 19:30:48 +0200 Subject: [PATCH 37/52] [a2a] Add module documentation and metadata field support --- a2a/a2a-client/Module.md | 4 + a2a/a2a-core/Module.md | 4 + a2a/a2a-server/Module.md | 4 + a2a/a2a-test/Module.md | 4 + .../Module.md | 3 + .../a2a-transport-client-rest/Module.md | 3 + .../a2a-transport-core-jsonrpc/Module.md | 3 + .../a2a-transport-core-rest/Module.md | 3 + .../Module.md | 3 + .../a2a-transport-server-rest/Module.md | 3 + a2a/test-tck/a2a-test-server-tck/Module.md | 3 + .../agents-features-a2a-core/build.gradle.kts | 1 - .../agents/a2a/core/MessageA2AMetadata.kt | 49 ++++++++++++ .../koog/agents/a2a/core/MessageConverters.kt | 76 ++++++++---------- .../ai/koog/agents/a2a/core/PartConverters.kt | 7 +- .../ai/koog/agents/a2a/core/Serialization.kt | 7 ++ .../agents/a2a/core/MessageConvertersTest.kt | 79 +++++++++++++------ .../kotlin/ai/koog/prompt/message/Message.kt | 25 +++++- 18 files changed, 200 insertions(+), 81 deletions(-) create mode 100644 a2a/a2a-client/Module.md create mode 100644 a2a/a2a-core/Module.md create mode 100644 a2a/a2a-server/Module.md create mode 100644 a2a/a2a-test/Module.md create mode 100644 a2a/a2a-transport/a2a-transport-client-jsonrpc-http/Module.md create mode 100644 a2a/a2a-transport/a2a-transport-client-rest/Module.md create mode 100644 a2a/a2a-transport/a2a-transport-core-jsonrpc/Module.md create mode 100644 a2a/a2a-transport/a2a-transport-core-rest/Module.md create mode 100644 a2a/a2a-transport/a2a-transport-server-jsonrpc-http/Module.md create mode 100644 a2a/a2a-transport/a2a-transport-server-rest/Module.md create mode 100644 a2a/test-tck/a2a-test-server-tck/Module.md create mode 100644 agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageA2AMetadata.kt create mode 100644 agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/Serialization.kt diff --git a/a2a/a2a-client/Module.md b/a2a/a2a-client/Module.md new file mode 100644 index 0000000000..f123824d72 --- /dev/null +++ b/a2a/a2a-client/Module.md @@ -0,0 +1,4 @@ +# Module a2a-client + +High-level client library for communicating with A2A protocol servers. Provides the A2AClient wrapper with capability validation, +request/response handling, and AgentCard resolution. Built on top of a2a-core abstractions with Ktor HTTP client support. diff --git a/a2a/a2a-core/Module.md b/a2a/a2a-core/Module.md new file mode 100644 index 0000000000..f574bc331b --- /dev/null +++ b/a2a/a2a-core/Module.md @@ -0,0 +1,4 @@ +# Module a2a-core + +Core abstractions and data models for the A2A (Agent-to-Agent) protocol implementation. This module provides the foundational +types and interfaces for agent communication, including protocol message structures, transport layer abstractions, and domain-specific exceptions. diff --git a/a2a/a2a-server/Module.md b/a2a/a2a-server/Module.md new file mode 100644 index 0000000000..8c7536be5b --- /dev/null +++ b/a2a/a2a-server/Module.md @@ -0,0 +1,4 @@ +# Module a2a-server + +Server-side implementation for hosting A2A protocol agents. Provides the A2AServer class, AgentExecutor interface for implementing agent logic, +session management, and storage abstractions with for tasks, sessions, and push subscriptions. diff --git a/a2a/a2a-test/Module.md b/a2a/a2a-test/Module.md new file mode 100644 index 0000000000..c1b5c0d71e --- /dev/null +++ b/a2a/a2a-test/Module.md @@ -0,0 +1,4 @@ +# Module a2a-test + +Testing utilities and helpers for A2A protocol implementations. Provides common test fixtures, utilities, and integration +test support for both client and server components. diff --git a/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/Module.md b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/Module.md new file mode 100644 index 0000000000..adaaf5c980 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-client-jsonrpc-http/Module.md @@ -0,0 +1,3 @@ +# Module a2a-transport-client-jsonrpc-http + +HTTP-based JSON-RPC client transport implementation for A2A protocol. Implements the ClientTransport interface using Ktor HTTP client to communicate with JSON-RPC servers over HTTP. diff --git a/a2a/a2a-transport/a2a-transport-client-rest/Module.md b/a2a/a2a-transport/a2a-transport-client-rest/Module.md new file mode 100644 index 0000000000..9f93b94c62 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-client-rest/Module.md @@ -0,0 +1,3 @@ +# Module a2a-transport-client-rest + +HTTP+JSON/REST client transport implementation for A2A protocol. Implements the ClientTransport interface using Ktor HTTP client to communicate with REST API servers. diff --git a/a2a/a2a-transport/a2a-transport-core-jsonrpc/Module.md b/a2a/a2a-transport/a2a-transport-core-jsonrpc/Module.md new file mode 100644 index 0000000000..a9edb95769 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-core-jsonrpc/Module.md @@ -0,0 +1,3 @@ +# Module a2a-transport-core-jsonrpc + +Core JSON-RPC protocol implementation for A2A communication. Provides base classes and utilities for implementing JSON-RPC based transport layers, including request/response handling and error mapping. diff --git a/a2a/a2a-transport/a2a-transport-core-rest/Module.md b/a2a/a2a-transport/a2a-transport-core-rest/Module.md new file mode 100644 index 0000000000..f81a55a540 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-core-rest/Module.md @@ -0,0 +1,3 @@ +# Module a2a-transport-core-rest + +Core HTTP+JSON/REST protocol implementation for A2A communication. Provides base classes and utilities for implementing REST-based transport layers, including request routing and response serialization. diff --git a/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/Module.md b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/Module.md new file mode 100644 index 0000000000..39c8ed9260 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-server-jsonrpc-http/Module.md @@ -0,0 +1,3 @@ +# Module a2a-transport-server-jsonrpc-http + +HTTP-based JSON-RPC server transport implementation for A2A protocol. Implements the ServerTransport interface using Ktor HTTP server to expose JSON-RPC endpoints for agent communication. diff --git a/a2a/a2a-transport/a2a-transport-server-rest/Module.md b/a2a/a2a-transport/a2a-transport-server-rest/Module.md new file mode 100644 index 0000000000..51df759724 --- /dev/null +++ b/a2a/a2a-transport/a2a-transport-server-rest/Module.md @@ -0,0 +1,3 @@ +# Module a2a-transport-server-rest + +HTTP+JSON/REST server transport implementation for A2A protocol. Implements the ServerTransport interface using Ktor HTTP server to expose RESTful endpoints for agent communication. diff --git a/a2a/test-tck/a2a-test-server-tck/Module.md b/a2a/test-tck/a2a-test-server-tck/Module.md new file mode 100644 index 0000000000..dc73dc300d --- /dev/null +++ b/a2a/test-tck/a2a-test-server-tck/Module.md @@ -0,0 +1,3 @@ +# Module a2a-test-server-tck + +Technology Compatibility Kit (TCK) for A2A server implementations. Provides a comprehensive test suite to verify compliance with the A2A protocol specification, ensuring interoperability across different implementations. diff --git a/agents/agents-features/agents-features-a2a-core/build.gradle.kts b/agents/agents-features/agents-features-a2a-core/build.gradle.kts index 4de791de25..a2edcb5c55 100644 --- a/agents/agents-features/agents-features-a2a-core/build.gradle.kts +++ b/agents/agents-features/agents-features-a2a-core/build.gradle.kts @@ -13,7 +13,6 @@ kotlin { commonMain { dependencies { api(project(":prompt:prompt-model")) - api(project(":prompt:prompt-xml")) api(project(":a2a:a2a-core")) api(libs.kotlinx.serialization.json) diff --git a/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageA2AMetadata.kt b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageA2AMetadata.kt new file mode 100644 index 0000000000..e3220aaa69 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageA2AMetadata.kt @@ -0,0 +1,49 @@ +package ai.koog.agents.a2a.core + +import ai.koog.prompt.message.Message +import ai.koog.prompt.message.MessageMetaInfo +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.decodeFromJsonElement +import kotlinx.serialization.json.encodeToJsonElement + +/** + * Key used to store [MessageA2AMetadata] in [ai.koog.prompt.message.MessageMetaInfo.metadata] + */ +public const val MESSAGE_A2A_METADATA_KEY: String = "a2a_metadata" + +/** + * Represents additional A2A-related message metadata stored in [ai.koog.prompt.message.MessageMetaInfo.metadata] + * For more info on each field, check [ai.koog.a2a.model.Message] + */ +@Serializable +public data class MessageA2AMetadata( + val messageId: String, + val contextId: String?, + val taskId: String?, + val referenceTaskIds: List?, + val metadata: JsonObject?, + val extensions: List?, +) + +/** + * Retrieves [MessageA2AMetadata] from [MessageMetaInfo.metadata], if [MESSAGE_A2A_METADATA_KEY] is present. + */ +public fun MessageMetaInfo.getA2AMetadata(): MessageA2AMetadata? { + return metadata + ?.get(MESSAGE_A2A_METADATA_KEY) + ?.let { a2aMetadata -> + A2AFeatureJson.decodeFromJsonElement(a2aMetadata) + } +} + +/** + * Updates provided [JsonObject], overwriting [MESSAGE_A2A_METADATA_KEY] with [a2aMetadata]. + */ +public fun JsonObject.withA2AMetadata(a2aMetadata: MessageA2AMetadata): JsonObject { + return toMutableMap() + .apply { + put(MESSAGE_A2A_METADATA_KEY, A2AFeatureJson.encodeToJsonElement(a2aMetadata)) + } + .let { JsonObject(it) } +} diff --git a/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageConverters.kt b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageConverters.kt index 89260eafe0..acbf239d28 100644 --- a/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageConverters.kt +++ b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageConverters.kt @@ -4,7 +4,6 @@ import ai.koog.a2a.model.Role import ai.koog.prompt.message.Message import ai.koog.prompt.message.RequestMetaInfo import ai.koog.prompt.message.ResponseMetaInfo -import ai.koog.prompt.xml.xml import kotlinx.datetime.Clock import kotlinx.serialization.json.JsonObject import kotlin.uuid.ExperimentalUuidApi @@ -18,6 +17,8 @@ public typealias A2AMessage = ai.koog.a2a.model.Message /** * Converts [A2AMessage] to Koog's [Message]. + * Returned message will contain [MessageA2AMetadata] at [MESSAGE_A2A_METADATA_KEY] in [ai.koog.prompt.message.MessageMetaInfo.metadata], + * which can be retrieved with helper method [getA2AMetadata]. * * @param clock The clock to use for the timestamp. Defaults to [Clock.System]. */ @@ -25,48 +26,35 @@ public fun A2AMessage.toKoogMessage( clock: Clock = Clock.System, ): Message { // Convert to the actual message content and attachments. - val (messageContent, attachments) = parts.map { it.toKoogPart() }.toContentWithAttachments() + val (content, attachments) = parts.map { it.toKoogPart() }.toContentWithAttachments() - val content = xml { - tag("message_content") { - +messageContent - } - - // Put ids information in the text content, since Koog doesn't have special fields for them. - tag("a2a_message_metadata") { - contextId?.let { - tag("context_id") { +it } - } - - taskId?.let { - tag("task_id") { +it } - } - - referenceTaskIds - ?.takeIf { it.isNotEmpty() } - ?.let { referenceIds -> - tag("reference_task_ids") { - referenceIds.forEach { - tag("id") { +it } - } - } - } - } - } + // Create metadata + val metadata = JsonObject(emptyMap()).withA2AMetadata( + MessageA2AMetadata( + messageId = messageId, + contextId = contextId, + taskId = taskId, + referenceTaskIds = referenceTaskIds, + metadata = metadata, + extensions = extensions, + ) + ) return when (role) { Role.User -> Message.User( - content = content.toString(), + content = content, metaInfo = RequestMetaInfo( - timestamp = clock.now() + timestamp = clock.now(), + metadata = metadata, ), attachments = attachments.toList(), ) Role.Agent -> Message.Assistant( - content = content.toString(), + content = content, metaInfo = ResponseMetaInfo( - timestamp = clock.now() + timestamp = clock.now(), + metadata = metadata, ), attachments = attachments, ) @@ -75,18 +63,18 @@ public fun A2AMessage.toKoogMessage( /** * Converts Koog's [Message] to [A2AMessage]. + * To fill A2A-specific fields, it will attempt to read [MessageA2AMetadata] from [ai.koog.prompt.message.MessageMetaInfo.metadata], + * but it also can be overridden with [a2aMetadata] * + * @param a2aMetadata The A2A-specific metadata to override exiting in this [Message]. * @see ai.koog.a2a.model.Message */ @OptIn(ExperimentalUuidApi::class) public fun Message.toA2AMessage( - messageId: String = Uuid.random().toString(), - contextId: String? = null, - taskId: String? = null, - referenceTaskIds: List? = null, - metadata: JsonObject? = null, - extensions: List? = null, + a2aMetadata: MessageA2AMetadata? = null, ): A2AMessage { + val actualMetadata = a2aMetadata ?: metaInfo.getA2AMetadata() + val role = when (this) { is Message.User -> Role.User is Message.Assistant -> Role.Agent @@ -98,13 +86,13 @@ public fun Message.toA2AMessage( .map { it.toA2APart() } return A2AMessage( - messageId = messageId, + messageId = actualMetadata?.messageId ?: Uuid.random().toString(), role = role, parts = parts, - extensions = extensions, - taskId = taskId, - referenceTaskIds = referenceTaskIds, - contextId = contextId, - metadata = metadata + extensions = actualMetadata?.extensions, + taskId = actualMetadata?.taskId, + referenceTaskIds = actualMetadata?.referenceTaskIds, + contextId = actualMetadata?.contextId, + metadata = actualMetadata?.metadata ) } diff --git a/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/PartConverters.kt b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/PartConverters.kt index 6401e14fe6..47e932bd5c 100644 --- a/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/PartConverters.kt +++ b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/PartConverters.kt @@ -9,7 +9,6 @@ import ai.koog.a2a.model.TextPart import ai.koog.prompt.message.Attachment import ai.koog.prompt.message.AttachmentContent import kotlinx.serialization.Serializable -import kotlinx.serialization.json.Json /** * Koog doesn't have proper support for message parts yet, but A2A operates with parts. @@ -31,10 +30,6 @@ public data class KoogContentPart(val content: String) : KoogPart @Serializable public data class KoogAttachmentPart(val attachment: Attachment) : KoogPart -private val json = Json { - prettyPrint = true -} - /** * Converts A2A [Part] to Koog [KoogPart]. */ @@ -42,7 +37,7 @@ public fun Part.toKoogPart(): KoogPart = when (this) { is TextPart -> KoogContentPart(this.text) // Koog doesn't support structured data as a separate type, treat it as a content part. - is DataPart -> KoogContentPart(json.encodeToString(this.data)) + is DataPart -> KoogContentPart(A2AFeatureJson.encodeToString(this.data)) is FilePart -> { val file = this.file // to enable smart cast diff --git a/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/Serialization.kt b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/Serialization.kt new file mode 100644 index 0000000000..96084a22e1 --- /dev/null +++ b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/Serialization.kt @@ -0,0 +1,7 @@ +package ai.koog.agents.a2a.core + +import kotlinx.serialization.json.Json + +internal val A2AFeatureJson = Json { + prettyPrint = true +} diff --git a/agents/agents-features/agents-features-a2a-core/src/commonTest/kotlin/ai/koog/agents/a2a/core/MessageConvertersTest.kt b/agents/agents-features/agents-features-a2a-core/src/commonTest/kotlin/ai/koog/agents/a2a/core/MessageConvertersTest.kt index 36d9823543..682325a881 100644 --- a/agents/agents-features/agents-features-a2a-core/src/commonTest/kotlin/ai/koog/agents/a2a/core/MessageConvertersTest.kt +++ b/agents/agents-features/agents-features-a2a-core/src/commonTest/kotlin/ai/koog/agents/a2a/core/MessageConvertersTest.kt @@ -11,11 +11,12 @@ import ai.koog.prompt.message.AttachmentContent import ai.koog.prompt.message.Message import ai.koog.prompt.message.RequestMetaInfo import ai.koog.prompt.message.ResponseMetaInfo -import ai.koog.prompt.xml.xml import kotlinx.datetime.Clock import kotlinx.datetime.Instant import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.buildJsonObject +import kotlinx.serialization.json.encodeToJsonElement import kotlinx.serialization.json.put import kotlin.test.Test import kotlin.test.assertEquals @@ -51,19 +52,7 @@ class MessageConvertersTest { val actual: Message = a2a.toKoogMessage(clock = fixedClock) - val expectedContent = xml { - tag("message_content") { - +("Hello\n" + prettyJson.encodeToString(json)) - } - tag("a2a_message_metadata") { - tag("context_id") { +"ctx-123" } - tag("task_id") { +"task-1" } - tag("reference_task_ids") { - tag("id") { +"ref-1" } - tag("id") { +"ref-2" } - } - } - } + val expectedContent = "Hello\n" + prettyJson.encodeToString(json) val expectedAttachments = listOf( Attachment.File( format = "", @@ -78,9 +67,23 @@ class MessageConvertersTest { content = AttachmentContent.URL("https://example.com/doc.txt") ) ) + val expectedMetadata = JsonObject( + mapOf( + MESSAGE_A2A_METADATA_KEY to Json.encodeToJsonElement( + MessageA2AMetadata( + messageId = "m1", + contextId = "ctx-123", + taskId = "task-1", + referenceTaskIds = listOf("ref-1", "ref-2"), + metadata = null, + extensions = listOf("ext:a"), + ) + ) + ) + ) val expected: Message = Message.User( content = expectedContent, - metaInfo = RequestMetaInfo(timestamp = fixedInstant), + metaInfo = RequestMetaInfo(timestamp = fixedInstant, metadata = expectedMetadata), attachments = expectedAttachments ) @@ -97,12 +100,23 @@ class MessageConvertersTest { val actual = a2a.toKoogMessage(clock = fixedClock) + val expectedMetadata = JsonObject( + mapOf( + MESSAGE_A2A_METADATA_KEY to Json.encodeToJsonElement( + MessageA2AMetadata( + messageId = "m2", + contextId = null, + taskId = null, + referenceTaskIds = null, + metadata = null, + extensions = null, + ) + ) + ) + ) val expected = Message.Assistant( - content = xml { - tag("message_content") { +"Agent says hi" } - tag("a2a_message_metadata") {} - }, - metaInfo = ResponseMetaInfo(timestamp = fixedInstant), + content = "Agent says hi", + metaInfo = ResponseMetaInfo(timestamp = fixedInstant, metadata = expectedMetadata), attachments = emptyList() ) @@ -138,10 +152,14 @@ class MessageConvertersTest { ) val actual = koog.toA2AMessage( - messageId = "mid", - contextId = "ctx", - taskId = "task", - referenceTaskIds = listOf("r1"), + a2aMetadata = MessageA2AMetadata( + messageId = "mid", + contextId = "ctx", + taskId = "task", + referenceTaskIds = listOf("r1"), + metadata = null, + extensions = null, + ) ) val expectedPlainBase64 = AttachmentContent.Binary.Bytes("abc".encodeToByteArray()).asBase64() @@ -177,7 +195,16 @@ class MessageConvertersTest { content = "Answer", metaInfo = ResponseMetaInfo(timestamp = fixedInstant), ) - val actual = koog.toA2AMessage(messageId = "m3") + val actual = koog.toA2AMessage( + a2aMetadata = MessageA2AMetadata( + messageId = "m3", + contextId = null, + taskId = null, + referenceTaskIds = null, + metadata = null, + extensions = null, + ) + ) val expected = A2AMessage( messageId = "m3", role = Role.Agent, @@ -198,7 +225,7 @@ class MessageConvertersTest { metaInfo = RequestMetaInfo(timestamp = fixedInstant) ) assertFailsWith { - sys.toA2AMessage(messageId = "m4") + sys.toA2AMessage() } } } diff --git a/prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/message/Message.kt b/prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/message/Message.kt index ce944055d9..18f133c9f5 100644 --- a/prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/message/Message.kt +++ b/prompt/prompt-model/src/commonMain/kotlin/ai/koog/prompt/message/Message.kt @@ -230,6 +230,12 @@ public sealed interface MessageMetaInfo { * to the current system time if not explicitly set. */ public val timestamp: Instant + + /** + * Free-form information associated with a message. + * Can be used to store custom metadata that doesn't fit into the standard fields. + */ + public val metadata: JsonObject? } /** @@ -243,7 +249,8 @@ public sealed interface MessageMetaInfo { */ @Serializable public data class RequestMetaInfo( - override val timestamp: Instant + override val timestamp: Instant, + override val metadata: JsonObject? = null ) : MessageMetaInfo { /** * Companion object for `RequestMetaInfo` that provides factory methods and utilities related to creating instances. @@ -291,7 +298,12 @@ public data class ResponseMetaInfo( public val totalTokensCount: Int? = null, public val inputTokensCount: Int? = null, public val outputTokensCount: Int? = null, + @Deprecated( + "additionalInfo is deprecated, use metadata instead", + ReplaceWith("metadata") + ) public val additionalInfo: Map = emptyMap(), + override val metadata: JsonObject? = null, ) : MessageMetaInfo { /** * Companion object for the ResponseMetaInfo class. @@ -302,7 +314,11 @@ public data class ResponseMetaInfo( * Creates a ResponseMetadata instance with a timestamp from the provided clock. * * @param clock The clock to use for generating the timestamp. - * @param tokensCount The number of tokens used in the response, or null if not available. + * @param totalTokensCount The total number of tokens involved in the response, including both input and output tokens. + * @param inputTokensCount The number of tokens used in the input. + * @param outputTokensCount The number of tokens generated in the output. + * @param additionalInfo Deprecated: use [metadata] instead. Additional metadata as a map of string keys to string values. + * @param metadata Additional metadata as a JSON object. * @return A new ResponseMetadata instance with the timestamp from the provided clock. */ public fun create( @@ -310,9 +326,10 @@ public data class ResponseMetaInfo( totalTokensCount: Int? = null, inputTokensCount: Int? = null, outputTokensCount: Int? = null, - additionalInfo: Map = emptyMap() + additionalInfo: Map = emptyMap(), + metadata: JsonObject? = null, ): ResponseMetaInfo = - ResponseMetaInfo(clock.now(), totalTokensCount, inputTokensCount, outputTokensCount, additionalInfo) + ResponseMetaInfo(clock.now(), totalTokensCount, inputTokensCount, outputTokensCount, additionalInfo, metadata) /** * An empty instance of the [ResponseMetaInfo] with the timestamp set to a distant past. From 98643aaf6db46109bb54de5f663e39da73be02f3 Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Tue, 30 Sep 2025 21:36:12 +0200 Subject: [PATCH 38/52] [a2a] Add interactive joke agent example --- .../client/A2AClientJsonRpcIntegrationTest.kt | 5 +- .../server/A2AServerJsonRpcIntegrationTest.kt | 3 + .../agents/a2a/core/MessageA2AMetadata.kt | 11 +-- examples/simple-examples/CLAUDE.md | 92 +++++++++++++++++++ examples/simple-examples/build.gradle.kts | 17 ++++ .../ai/koog/agents/example/a2a/.gitkeep | 0 .../ai/koog/agents/example/a2a/joke/Client.kt | 84 +++++++++++++++++ .../example/a2a/joke/JokeAgentExecutor.kt | 75 +++++++++++++++ .../ai/koog/agents/example/a2a/joke/Server.kt | 79 ++++++++++++++++ 9 files changed, 359 insertions(+), 7 deletions(-) create mode 100644 examples/simple-examples/CLAUDE.md create mode 100644 examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/.gitkeep create mode 100644 examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/Client.kt create mode 100644 examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/JokeAgentExecutor.kt create mode 100644 examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/Server.kt diff --git a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt index c0a850a882..68ee128fd2 100644 --- a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt +++ b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt @@ -11,6 +11,8 @@ import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.TestInstance import org.junit.jupiter.api.condition.EnabledOnOs import org.junit.jupiter.api.condition.OS +import org.junit.jupiter.api.parallel.Execution +import org.junit.jupiter.api.parallel.ExecutionMode import org.testcontainers.containers.GenericContainer import org.testcontainers.containers.wait.strategy.Wait import org.testcontainers.junit.jupiter.Container @@ -25,6 +27,7 @@ import kotlin.time.Duration.Companion.minutes @TestInstance(TestInstance.Lifecycle.PER_CLASS) @Testcontainers @EnabledOnOs(OS.LINUX) +@Execution(ExecutionMode.SAME_THREAD, reason = "Working with the same instance of test server.") class A2AClientJsonRpcIntegrationTest : BaseA2AProtocolTest() { companion object { @Container @@ -34,7 +37,7 @@ class A2AClientJsonRpcIntegrationTest : BaseA2AProtocolTest() { .waitingFor(Wait.forListeningPort()) } - override val testTimeout = 2.minutes + override val testTimeout = 1.minutes private val httpClient = HttpClient { install(Logging) { diff --git a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/A2AServerJsonRpcIntegrationTest.kt b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/A2AServerJsonRpcIntegrationTest.kt index ebd57917ec..4c0c82617b 100644 --- a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/A2AServerJsonRpcIntegrationTest.kt +++ b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/A2AServerJsonRpcIntegrationTest.kt @@ -45,6 +45,8 @@ import kotlinx.coroutines.withContext import org.junit.jupiter.api.AfterAll import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.TestInstance +import org.junit.jupiter.api.parallel.Execution +import org.junit.jupiter.api.parallel.ExecutionMode import java.net.ServerSocket import kotlin.test.BeforeTest import kotlin.test.Test @@ -59,6 +61,7 @@ import kotlin.uuid.Uuid */ @OptIn(ExperimentalUuidApi::class) @TestInstance(TestInstance.Lifecycle.PER_CLASS) +@Execution(ExecutionMode.SAME_THREAD, reason = "Working with the same instance of test server.") class A2AServerJsonRpcIntegrationTest : BaseA2AProtocolTest() { override val testTimeout = 2.minutes diff --git a/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageA2AMetadata.kt b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageA2AMetadata.kt index e3220aaa69..877e6ee2d4 100644 --- a/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageA2AMetadata.kt +++ b/agents/agents-features/agents-features-a2a-core/src/commonMain/kotlin/ai/koog/agents/a2a/core/MessageA2AMetadata.kt @@ -1,6 +1,5 @@ package ai.koog.agents.a2a.core -import ai.koog.prompt.message.Message import ai.koog.prompt.message.MessageMetaInfo import kotlinx.serialization.Serializable import kotlinx.serialization.json.JsonObject @@ -19,11 +18,11 @@ public const val MESSAGE_A2A_METADATA_KEY: String = "a2a_metadata" @Serializable public data class MessageA2AMetadata( val messageId: String, - val contextId: String?, - val taskId: String?, - val referenceTaskIds: List?, - val metadata: JsonObject?, - val extensions: List?, + val contextId: String? = null, + val taskId: String? = null, + val referenceTaskIds: List? = null, + val metadata: JsonObject? = null, + val extensions: List? = null, ) /** diff --git a/examples/simple-examples/CLAUDE.md b/examples/simple-examples/CLAUDE.md new file mode 100644 index 0000000000..d4d5a482c9 --- /dev/null +++ b/examples/simple-examples/CLAUDE.md @@ -0,0 +1,92 @@ +## Project Overview + +This is the **simple-examples** project within the Koog Framework repository. It contains executable examples demonstrating various AI agent patterns and capabilities, from basic calculator agents to advanced features like persistence, streaming, and agent-to-agent communication. + +## Environment Setup + +Examples require API keys for LLM providers. Configure them via: + +## Running Examples + +Each example has a dedicated Gradle task: + +```bash +# Run any example +./gradlew runExampleCalculator +./gradlew runExampleStreamingWithTools +./gradlew runExampleRoutingViaGraph + +# Build the project +./gradlew assemble + +# Run tests +./gradlew test +``` + +All available tasks follow the pattern `runExample*`. See `build.gradle.kts` for the complete list. + +## Project Architecture + +### Composite Build Setup + +This project uses Gradle composite build to depend on the parent Koog framework (`settings.gradle.kts:`). Changes in the main framework are immediately available without publishing. + +### Key Dependencies + +- `ai.koog:koog-agents` - Meta-module with core agent dependencies +- `ai.koog:koog-ktor` - Ktor server integration +- `ai.koog:agents-features-sql` - SQL persistence feature +- `ai.koog:agents-features-a2a-server` - Agent-to-agent server +- `ai.koog:agents-features-a2a-client` - Agent-to-agent client +- `ai.koog:agents-test` - Testing utilities (test scope) + +### Example Structure + +Examples are organized by capability: + +- **Core patterns**: `calculator/`, `guesser/`, `tone/` - Basic agent patterns with tools and strategies +- **Agent-to-agent (A2A)**: `a2a/` - Inter-agent communication examples +- **Advanced features**: `memory/`, `snapshot/`, `moderation/` - Memory, persistence, content filtering +- **External integration**: `websearch/`, `attachments/`, `client/` - External API integration +- **Structured output**: `structuredoutput/` - Schema-based output formatting and streaming +- **Banking routing**: `banking/` - Complex multi-agent routing patterns + +### Common Patterns + +1. **Tool Definition**: Use `@Tool` and `@LLMDescription` annotations on methods in classes extending `ToolSet` +2. **Strategy Creation**: Define agent behavior as state machine graphs using `strategy { }` DSL +3. **Agent Setup**: Combine `PromptExecutor`, strategy, `AIAgentConfig`, and `ToolRegistry` +4. **Event Handling**: Use `handleEvents { }` to observe tool calls, errors, and completion + +Example from `calculator/Calculator.kt`: +```kotlin +val toolRegistry = ToolRegistry { + tool(AskUser) + tools(CalculatorTools().asTools()) +} + +val agent = AIAgent( + promptExecutor = executor, + strategy = CalculatorStrategy.strategy, + agentConfig = agentConfig, + toolRegistry = toolRegistry +) { + handleEvents { /* ... */ } +} +``` + +## Common Development Tasks + +```bash +# Build without running tests +./gradlew assemble + +# Run a specific example +./gradlew runExampleCalculator + +# Run tests +./gradlew test + +# Clean build artifacts +./gradlew clean +``` diff --git a/examples/simple-examples/build.gradle.kts b/examples/simple-examples/build.gradle.kts index cc73f941c8..3e09e94640 100644 --- a/examples/simple-examples/build.gradle.kts +++ b/examples/simple-examples/build.gradle.kts @@ -17,6 +17,15 @@ dependencies { implementation("ai.koog:koog-ktor") //noinspection UseTomlInstead implementation("ai.koog:agents-features-sql") + //noinspection UseTomlInstead + implementation("ai.koog:agents-features-a2a-server") + //noinspection UseTomlInstead + implementation("ai.koog:agents-features-a2a-client") + //noinspection UseTomlInstead + implementation("ai.koog:a2a-transport-server-jsonrpc-http") + //noinspection UseTomlInstead + implementation("ai.koog:a2a-transport-client-jsonrpc-http") + //noinspection UseTomlInstead testImplementation("ai.koog:agents-test") @@ -106,3 +115,11 @@ registerRunExampleTask("runExampleFilePersistentAgent", "ai.koog.agents.example. registerRunExampleTask("runExampleSQLPersistentAgent", "ai.koog.agents.example.snapshot.sql.SQLPersistentAgentExample") registerRunExampleTask("runExampleWebSearchAgent", "ai.koog.agents.example.websearch.WebSearchAgentKt") registerRunExampleTask("runExampleStreamingWithTools", "ai.koog.agents.example.streaming.StreamingAgentWithToolsKt") + +/* + A2A examples +*/ + +// joke generation +registerRunExampleTask("runExampleJokeAgentServer", "ai.koog.agents.example.a2a.joke.ServerKt") +registerRunExampleTask("runExampleJokeAgentClient", "ai.koog.agents.example.a2a.joke.ClientKt") diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/.gitkeep b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/Client.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/Client.kt new file mode 100644 index 0000000000..702fb50f49 --- /dev/null +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/Client.kt @@ -0,0 +1,84 @@ +@file:OptIn(ExperimentalUuidApi::class) + +package ai.koog.agents.example.a2a.joke + +import ai.koog.a2a.client.A2AClient +import ai.koog.a2a.client.UrlAgentCardResolver +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.client.jsonrpc.http.HttpJSONRPCClientTransport +import ai.koog.agents.a2a.core.toKoogMessage +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +private const val BRIGHT_CYAN = "\u001B[1;36m" +private const val YELLOW = "\u001B[33m" +private const val BRIGHT_MAGENTA = "\u001B[1;35m" +private const val RED = "\u001B[31m" +private const val RESET = "\u001B[0m" + +@OptIn(ExperimentalUuidApi::class) +suspend fun main() { + println() + println("${YELLOW}Starting Joke Generator A2A Client$RESET\n") + + // Set up the HTTP JSON-RPC transport + val transport = HttpJSONRPCClientTransport( + url = "http://localhost:9998$JOKE_GENERATOR_AGENT_PATH" + ) + + // Set up the agent card resolver + val agentCardResolver = UrlAgentCardResolver( + baseUrl = "http://localhost:9998", + path = JOKE_GENERATOR_AGENT_CARD_PATH + ) + + // Create the A2A client + val client = A2AClient( + transport = transport, + agentCardResolver = agentCardResolver + ) + + // Connect and fetch agent card + client.connect() + val agentCard = client.cachedAgentCard() + println("${YELLOW}Connected to agent:$RESET\n${agentCard.name} (${agentCard.description})\n") + + // Read context ID + println("${BRIGHT_CYAN}Context ID (which chat to start/continue):$RESET") + val contextId = readln() + println() + + // Start chat loop + while (true) { + println("${BRIGHT_CYAN}Request (/q to quit):$RESET") + val request = readln() + println() + + if (request == "/q") { + break + } + + val message = Message( + messageId = Uuid.random().toString(), + role = Role.User, + parts = listOf(TextPart(request)), + contextId = contextId + ) + + val response = client.sendMessage( + Request(MessageSendParams(message = message)) + ) + + val reply = (response.data as Message).toKoogMessage().content + println("${BRIGHT_MAGENTA}Agent response:${RESET}\n$reply\n") + } + + println("${RED}Conversation complete!$RESET") + + // Clean up + transport.close() +} diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/JokeAgentExecutor.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/JokeAgentExecutor.kt new file mode 100644 index 0000000000..5e64cdf72f --- /dev/null +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/JokeAgentExecutor.kt @@ -0,0 +1,75 @@ +package ai.koog.agents.example.a2a.joke + +import ai.koog.a2a.exceptions.A2AUnsupportedOperationException +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.server.agent.AgentExecutor +import ai.koog.a2a.server.session.RequestContext +import ai.koog.a2a.server.session.SessionEventProcessor +import ai.koog.agents.a2a.core.MessageA2AMetadata +import ai.koog.agents.a2a.core.toA2AMessage +import ai.koog.agents.a2a.core.toKoogMessage +import ai.koog.agents.example.ApiKeyService +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient +import ai.koog.prompt.executor.clients.anthropic.AnthropicModels +import ai.koog.prompt.executor.clients.google.GoogleLLMClient +import ai.koog.prompt.executor.clients.openai.OpenAILLMClient +import ai.koog.prompt.executor.llms.MultiLLMPromptExecutor +import ai.koog.prompt.llm.LLMProvider +import ai.koog.prompt.message.Message +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +class JokeAgentExecutor : AgentExecutor { + private val promptExecutor = MultiLLMPromptExecutor( + LLMProvider.OpenAI to OpenAILLMClient(ApiKeyService.openAIApiKey), + LLMProvider.Anthropic to AnthropicLLMClient(ApiKeyService.anthropicApiKey), + LLMProvider.Google to GoogleLLMClient(ApiKeyService.googleApiKey), + ) + + @OptIn(ExperimentalUuidApi::class) + override suspend fun execute( + context: RequestContext, + eventProcessor: SessionEventProcessor + ) { + val userMessage = context.params.message + + if (context.task != null || !userMessage.referenceTaskIds.isNullOrEmpty()) { + throw A2AUnsupportedOperationException("This agent doesn't support tasks") + } + + // Save incoming message to the current context + context.messageStorage.save(userMessage) + + // Load all messages from the current context + val contextMessages = context.messageStorage.getAll().map { it.toKoogMessage() } + + val prompt = prompt("joke-generation") { + system { + +"You are an assistant helping user to generate jokes" + } + + // Append current message context + messages(contextMessages) + } + + // Get a response from the LLM + val responseMessage = promptExecutor.execute(prompt, AnthropicModels.Opus_4_1) + .single() + .let { message -> + message as? Message.Assistant ?: throw IllegalStateException("Unexpected message type: $message") + } + .toA2AMessage( + a2aMetadata = MessageA2AMetadata( + messageId = Uuid.random().toString(), + contextId = context.contextId, + ) + ) + + // Save the response to the current context + context.messageStorage.save(responseMessage) + + // Reply with message + eventProcessor.sendMessage(responseMessage) + } +} diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/Server.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/Server.kt new file mode 100644 index 0000000000..a4eeeb476a --- /dev/null +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/Server.kt @@ -0,0 +1,79 @@ +package ai.koog.agents.example.a2a.joke + +import ai.koog.a2a.model.AgentCapabilities +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.AgentInterface +import ai.koog.a2a.model.AgentSkill +import ai.koog.a2a.model.TransportProtocol +import ai.koog.a2a.server.A2AServer +import ai.koog.a2a.transport.server.jsonrpc.http.HttpJSONRPCServerTransport +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.server.cio.CIO + +private val logger = KotlinLogging.logger {} + +const val JOKE_GENERATOR_AGENT_PATH = "/joke-generator-agent" +const val JOKE_GENERATOR_AGENT_CARD_PATH = "$JOKE_GENERATOR_AGENT_PATH/agent-card.json" + +suspend fun main() { + logger.info { "Starting Joke A2A Agent on http://localhost:9998" } + + // Create agent card with capabilities + val agentCard = AgentCard( + protocolVersion = "0.3.0", + name = "Joke Generator 3000", + description = "A helpful AI agent that generates jokes based on user requests", + version = "1.0.0", + url = "http://localhost:9998$JOKE_GENERATOR_AGENT_PATH", + preferredTransport = TransportProtocol.JSONRPC, + additionalInterfaces = listOf( + AgentInterface( + url = "http://localhost:9998$JOKE_GENERATOR_AGENT_PATH", + transport = TransportProtocol.JSONRPC, + ) + ), + capabilities = AgentCapabilities( + streaming = false, + pushNotifications = false, + stateTransitionHistory = false, + ), + defaultInputModes = listOf("text"), + defaultOutputModes = listOf("text"), + skills = listOf( + AgentSkill( + id = "joke_generation", + name = "Joke Generation", + description = "Generates humorous jokes on various topics", + examples = listOf( + "Tell me a joke", + "Generate a funny joke about programming", + "Make me laugh with a dad joke" + ), + tags = listOf("humor", "jokes", "entertainment") + ) + ), + supportsAuthenticatedExtendedCard = false + ) + + // Create agent executor + val agentExecutor = JokeAgentExecutor() + + // Create A2A server + val a2aServer = A2AServer( + agentExecutor = agentExecutor, + agentCard = agentCard, + ) + + // Create and start server transport + val serverTransport = HttpJSONRPCServerTransport(a2aServer) + + logger.info { "Joke Generator Agent ready at http://localhost:9998/$JOKE_GENERATOR_AGENT_PATH" } + serverTransport.start( + engineFactory = CIO, + port = 9998, + path = JOKE_GENERATOR_AGENT_PATH, + wait = true, // Block until server stops + agentCard = agentCard, + agentCardPath = JOKE_GENERATOR_AGENT_CARD_PATH + ) +} From b7cd3b9faea48ac9762cda3803aee5cda99d0e1d Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Wed, 1 Oct 2025 16:30:23 +0200 Subject: [PATCH 39/52] [agents] Fix EventHandlerTest --- .../eventHandler/feature/EventHandlerTest.kt | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerTest.kt b/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerTest.kt index f41a6ea754..d4443ea9ef 100644 --- a/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerTest.kt +++ b/agents/agents-features/agents-features-event-handler/src/jvmTest/kotlin/ai/koog/agents/features/eventHandler/feature/EventHandlerTest.kt @@ -134,7 +134,7 @@ class EventHandlerTest { "role: ${Message.Role.User}, message: $testLLMResponse" + "}], temperature: $temperature, model: ${model.eventString}, tools: [], responses: [role: ${Message.Role.Assistant}, message: Default test response])", "OnNodeExecutionCompleted (run id: $runId, node: test LLM call, input: $testLLMResponse, output: " + - "Assistant(content=Default test response, metaInfo=ResponseMetaInfo(timestamp=$ts, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}), attachments=[], finishReason=null))", + "Assistant(content=Default test response, metaInfo=ResponseMetaInfo(timestamp=$ts, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}, metadata=null), attachments=[], finishReason=null))", "OnNodeExecutionStarting (run id: $runId, node: __finish__, input: $agentResult)", "OnNodeExecutionCompleted (run id: $runId, node: __finish__, input: $agentResult, output: $agentResult)", "OnStrategyCompleted (run id: $runId, strategy: $strategyName, result: $agentResult)", @@ -224,13 +224,13 @@ class EventHandlerTest { "role: ${Message.Role.User}, message: $userPrompt" + "}], temperature: $temperature, model: ${model.eventString}, tools: [${dummyTool.name}], responses: [role: ${Message.Role.Tool}, message: {\"dummy\":\"test\"}])", "OnNodeExecutionCompleted (run id: $runId, node: test-llm-call, input: $userPrompt, output: " + - "Call(id=null, tool=${dummyTool.name}, content={\"dummy\":\"test\"}, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={})))", + "Call(id=null, tool=${dummyTool.name}, content={\"dummy\":\"test\"}, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}, metadata=null)))", "OnNodeExecutionStarting (run id: $runId, node: test-tool-call, input: " + - "Call(id=null, tool=${dummyTool.name}, content={\"dummy\":\"test\"}, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={})))", + "Call(id=null, tool=${dummyTool.name}, content={\"dummy\":\"test\"}, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}, metadata=null)))", "OnToolCallStarting (run id: $runId, tool: ${dummyTool.name}, args: Args(dummy=test))", "OnToolCallCompleted (run id: $runId, tool: ${dummyTool.name}, args: Args(dummy=test), result: ${dummyTool.result})", "OnNodeExecutionCompleted (run id: $runId, node: test-tool-call, input: " + - "Call(id=null, tool=${dummyTool.name}, content={\"dummy\":\"test\"}, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={})), output: ReceivedToolResult(id=null, tool=${dummyTool.name}, content=${dummyTool.result}, result=${dummyTool.result}))", + "Call(id=null, tool=${dummyTool.name}, content={\"dummy\":\"test\"}, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}, metadata=null)), output: ReceivedToolResult(id=null, tool=${dummyTool.name}, content=${dummyTool.result}, result=${dummyTool.result}))", "OnNodeExecutionStarting (run id: $runId, node: test-node-llm-send-tool-result, input: ReceivedToolResult(id=null, tool=${dummyTool.name}, content=${dummyTool.result}, result=${dummyTool.result}))", "OnLLMCallStarting (run id: $runId, prompt: id: $promptId, messages: [{" + "role: ${Message.Role.System}, message: $systemPrompt, " + @@ -249,7 +249,7 @@ class EventHandlerTest { "role: ${Message.Role.Tool}, message: ${dummyTool.result}" + "}], temperature: $temperature, model: openai:gpt-4o, tools: [${dummyTool.name}], responses: [role: ${Message.Role.Assistant}, message: Return test result])", "OnNodeExecutionCompleted (run id: $runId, node: test-node-llm-send-tool-result, input: " + - "ReceivedToolResult(id=null, tool=${dummyTool.name}, content=${dummyTool.result}, result=${dummyTool.result}), output: Assistant(content=Return test result, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}), attachments=[], finishReason=null))", + "ReceivedToolResult(id=null, tool=${dummyTool.name}, content=${dummyTool.result}, result=${dummyTool.result}), output: Assistant(content=Return test result, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}, metadata=null), attachments=[], finishReason=null))", "OnNodeExecutionStarting (run id: $runId, node: __finish__, input: $mockResponse)", "OnNodeExecutionCompleted (run id: $runId, node: __finish__, input: $mockResponse, output: $mockResponse)", "OnStrategyCompleted (run id: $runId, strategy: $strategyName, result: $mockResponse)", @@ -327,7 +327,7 @@ class EventHandlerTest { "role: ${Message.Role.User}, message: $testLLMResponse" + "}], temperature: $temperature, model: ${model.eventString}, tools: [${toolRegistry.tools.joinToString { it.name }}], responses: [role: ${Message.Role.Assistant}, message: Default test response])", "OnNodeExecutionCompleted (run id: $runId, node: test LLM call, input: $testLLMResponse, output: " + - "Assistant(content=Default test response, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}), attachments=[], finishReason=null))", + "Assistant(content=Default test response, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}, metadata=null), attachments=[], finishReason=null))", "OnNodeExecutionStarting (run id: $runId, node: test LLM call with tools, input: $llmCallWithToolsResponse)", "OnLLMCallStarting (run id: $runId, prompt: id: $promptId, messages: [{" + "role: ${Message.Role.System}, message: $systemPrompt, " + @@ -346,7 +346,7 @@ class EventHandlerTest { "role: ${Message.Role.User}, message: $llmCallWithToolsResponse" + "}], temperature: $temperature, model: openai:gpt-4o, tools: [${toolRegistry.tools.joinToString { it.name }}], responses: [role: ${Message.Role.Assistant}, message: Default test response])", "OnNodeExecutionCompleted (run id: $runId, node: test LLM call with tools, input: $llmCallWithToolsResponse, output: " + - "Assistant(content=Default test response, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}), attachments=[], finishReason=null))", + "Assistant(content=Default test response, metaInfo=ResponseMetaInfo(timestamp=2023-01-01T00:00:00Z, totalTokensCount=null, inputTokensCount=null, outputTokensCount=null, additionalInfo={}, metadata=null), attachments=[], finishReason=null))", "OnNodeExecutionStarting (run id: $runId, node: __finish__, input: $agentResult)", "OnNodeExecutionCompleted (run id: $runId, node: __finish__, input: $agentResult, output: $agentResult)", "OnStrategyCompleted (run id: $runId, strategy: $strategyName, result: $agentResult)", From 9c4799fdf518f1e71766567aef54c92e6cd30d7d Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Wed, 1 Oct 2025 01:18:29 +0200 Subject: [PATCH 40/52] [agents] Update simple joke a2a example --- examples/simple-examples/build.gradle.kts | 4 ++-- .../agents/example/a2a/{joke => simplejoke}/Client.kt | 2 +- .../agents/example/a2a/{joke => simplejoke}/Server.kt | 4 ++-- .../SimpleJokeAgentExecutor.kt} | 9 ++++++--- 4 files changed, 11 insertions(+), 8 deletions(-) rename examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/{joke => simplejoke}/Client.kt (98%) rename examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/{joke => simplejoke}/Server.kt (96%) rename examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/{joke/JokeAgentExecutor.kt => simplejoke/SimpleJokeAgentExecutor.kt} (92%) diff --git a/examples/simple-examples/build.gradle.kts b/examples/simple-examples/build.gradle.kts index 3e09e94640..5d66734641 100644 --- a/examples/simple-examples/build.gradle.kts +++ b/examples/simple-examples/build.gradle.kts @@ -121,5 +121,5 @@ registerRunExampleTask("runExampleStreamingWithTools", "ai.koog.agents.example.s */ // joke generation -registerRunExampleTask("runExampleJokeAgentServer", "ai.koog.agents.example.a2a.joke.ServerKt") -registerRunExampleTask("runExampleJokeAgentClient", "ai.koog.agents.example.a2a.joke.ClientKt") +registerRunExampleTask("runExampleSimpleJokeAgentServer", "ai.koog.agents.example.a2a.simplejoke.ServerKt") +registerRunExampleTask("runExampleSimpleJokeAgentClient", "ai.koog.agents.example.a2a.simplejoke.ClientKt") diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/Client.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/simplejoke/Client.kt similarity index 98% rename from examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/Client.kt rename to examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/simplejoke/Client.kt index 702fb50f49..41e4b2022e 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/Client.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/simplejoke/Client.kt @@ -1,6 +1,6 @@ @file:OptIn(ExperimentalUuidApi::class) -package ai.koog.agents.example.a2a.joke +package ai.koog.agents.example.a2a.simplejoke import ai.koog.a2a.client.A2AClient import ai.koog.a2a.client.UrlAgentCardResolver diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/Server.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/simplejoke/Server.kt similarity index 96% rename from examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/Server.kt rename to examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/simplejoke/Server.kt index a4eeeb476a..24a7450703 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/Server.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/simplejoke/Server.kt @@ -1,4 +1,4 @@ -package ai.koog.agents.example.a2a.joke +package ai.koog.agents.example.a2a.simplejoke import ai.koog.a2a.model.AgentCapabilities import ai.koog.a2a.model.AgentCard @@ -56,7 +56,7 @@ suspend fun main() { ) // Create agent executor - val agentExecutor = JokeAgentExecutor() + val agentExecutor = SimpleJokeAgentExecutor() // Create A2A server val a2aServer = A2AServer( diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/JokeAgentExecutor.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/simplejoke/SimpleJokeAgentExecutor.kt similarity index 92% rename from examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/JokeAgentExecutor.kt rename to examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/simplejoke/SimpleJokeAgentExecutor.kt index 5e64cdf72f..30212ff03f 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/joke/JokeAgentExecutor.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/simplejoke/SimpleJokeAgentExecutor.kt @@ -1,4 +1,4 @@ -package ai.koog.agents.example.a2a.joke +package ai.koog.agents.example.a2a.simplejoke import ai.koog.a2a.exceptions.A2AUnsupportedOperationException import ai.koog.a2a.model.MessageSendParams @@ -20,7 +20,10 @@ import ai.koog.prompt.message.Message import kotlin.uuid.ExperimentalUuidApi import kotlin.uuid.Uuid -class JokeAgentExecutor : AgentExecutor { +/** + * This is a simple example of an agent executor that wraps LLM calls using prompt executor to generate jokes. + */ +class SimpleJokeAgentExecutor : AgentExecutor { private val promptExecutor = MultiLLMPromptExecutor( LLMProvider.OpenAI to OpenAILLMClient(ApiKeyService.openAIApiKey), LLMProvider.Anthropic to AnthropicLLMClient(ApiKeyService.anthropicApiKey), @@ -54,7 +57,7 @@ class JokeAgentExecutor : AgentExecutor { } // Get a response from the LLM - val responseMessage = promptExecutor.execute(prompt, AnthropicModels.Opus_4_1) + val responseMessage = promptExecutor.execute(prompt, AnthropicModels.Sonnet_4) .single() .let { message -> message as? Message.Assistant ?: throw IllegalStateException("Unexpected message type: $message") From d9ed9204f251493814c41106be4f7993cf109da5 Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Wed, 1 Oct 2025 03:47:15 +0200 Subject: [PATCH 41/52] [a2a] Add stress tests, fix race conditions, and refine testing approach --- a2a/a2a-client/build.gradle.kts | 1 + .../kotlin/ai/koog/a2a/client/A2AClient.kt | 15 +- .../client/A2AClientJsonRpcIntegrationTest.kt | 43 +++- .../kotlin/ai/koog/a2a/model/AgentCard.kt | 4 +- .../kotlin/ai/koog/a2a/server/A2AServer.kt | 42 +++- .../ai/koog/a2a/server/session/Session.kt | 20 +- .../server/session/SessionEventProcessor.kt | 29 ++- .../koog/a2a/server/session/SessionManager.kt | 20 +- .../kotlin/ai/koog/a2a/server/.gitkeep | 0 .../A2AServerJsonRpcIntegrationTest.kt | 216 ++++-------------- .../jsonrpc/BaseA2AServerJsonRpcTest.kt | 184 +++++++++++++++ .../StressA2AServerJsonRpcIntegrationTest.kt | 47 ++++ .../ai/koog/a2a/test/BaseA2AProtocolTest.kt | 25 +- .../agents-features-sql/build.gradle.kts | 1 + .../H2PersistencyStorageProviderTest.kt | 6 +- .../MySQLPersistencyStorageProviderTest.kt | 6 +- .../PostgresPersistencyStorageProviderTest.kt | 6 +- test-utils/build.gradle.kts | 1 + .../test/utils/DockerAvailableCondition.kt | 20 ++ 19 files changed, 440 insertions(+), 246 deletions(-) create mode 100644 a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/.gitkeep rename a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/{ => jsonrpc}/A2AServerJsonRpcIntegrationTest.kt (54%) create mode 100644 a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/jsonrpc/BaseA2AServerJsonRpcTest.kt create mode 100644 a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/jsonrpc/StressA2AServerJsonRpcIntegrationTest.kt create mode 100644 test-utils/src/jvmMain/kotlin/ai/koog/test/utils/DockerAvailableCondition.kt diff --git a/a2a/a2a-client/build.gradle.kts b/a2a/a2a-client/build.gradle.kts index efc7cd1a31..a0696d6b25 100644 --- a/a2a/a2a-client/build.gradle.kts +++ b/a2a/a2a-client/build.gradle.kts @@ -37,6 +37,7 @@ kotlin { implementation(kotlin("test-junit5")) implementation(project(":a2a:a2a-test")) implementation(project(":a2a:a2a-transport:a2a-transport-client-jsonrpc-http")) + implementation(project(":test-utils")) implementation(libs.ktor.client.cio) implementation(libs.ktor.client.logging) diff --git a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt index ce991559b8..7e4da3a6eb 100644 --- a/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt +++ b/a2a/a2a-client/src/commonMain/kotlin/ai/koog/a2a/client/A2AClient.kt @@ -15,17 +15,18 @@ import ai.koog.a2a.transport.ClientTransport import ai.koog.a2a.transport.Request import ai.koog.a2a.transport.Response import kotlinx.coroutines.flow.Flow -import kotlin.concurrent.Volatile +import kotlin.concurrent.atomics.AtomicReference +import kotlin.concurrent.atomics.ExperimentalAtomicApi /** * A2A client responsible for sending requests to A2A server. */ +@OptIn(ExperimentalAtomicApi::class) public open class A2AClient( private val transport: ClientTransport, private val agentCardResolver: AgentCardResolver, ) { - @Volatile - protected var agentCard: AgentCard? = null + protected var agentCard: AtomicReference = AtomicReference(null) /** * Performs initialization logic. @@ -41,7 +42,7 @@ public open class A2AClient( */ public open suspend fun getAgentCard(): AgentCard { return agentCardResolver.resolve().also { - agentCard = it + agentCard.exchange(it) } } @@ -51,7 +52,7 @@ public open class A2AClient( * @throws [IllegalStateException] if it's not initialized */ public open fun cachedAgentCard(): AgentCard { - return checkNotNull(agentCard) { "Agent card is not initialized." } + return checkNotNull(agentCard.load()) { "Agent card is not initialized." } } /** @@ -64,12 +65,12 @@ public open class A2AClient( request: Request, ctx: ClientCallContext = ClientCallContext.Default ): Response { - check(getAgentCard().supportsAuthenticatedExtendedCard == true) { + check(cachedAgentCard().supportsAuthenticatedExtendedCard == true) { "Agent card reports that authenticated extended agent card is not supported." } return transport.getAuthenticatedExtendedAgentCard(request, ctx).also { - agentCard = it.data + agentCard.exchange(it.data) } } diff --git a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt index 68ee128fd2..2764d58f4a 100644 --- a/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt +++ b/a2a/a2a-client/src/jvmTest/kotlin/ai/koog/a2a/client/A2AClientJsonRpcIntegrationTest.kt @@ -2,6 +2,7 @@ package ai.koog.a2a.client import ai.koog.a2a.test.BaseA2AProtocolTest import ai.koog.a2a.transport.client.jsonrpc.http.HttpJSONRPCClientTransport +import ai.koog.test.utils.DockerAvailableCondition import io.ktor.client.HttpClient import io.ktor.client.plugins.logging.LogLevel import io.ktor.client.plugins.logging.Logging @@ -9,15 +10,15 @@ import kotlinx.coroutines.test.runTest import org.junit.jupiter.api.AfterAll import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.TestInstance -import org.junit.jupiter.api.condition.EnabledOnOs -import org.junit.jupiter.api.condition.OS +import org.junit.jupiter.api.extension.ExtendWith import org.junit.jupiter.api.parallel.Execution import org.junit.jupiter.api.parallel.ExecutionMode import org.testcontainers.containers.GenericContainer import org.testcontainers.containers.wait.strategy.Wait import org.testcontainers.junit.jupiter.Container import org.testcontainers.junit.jupiter.Testcontainers -import kotlin.time.Duration.Companion.minutes +import kotlin.test.Test +import kotlin.time.Duration.Companion.seconds /** * Integration test class for testing the JSON-RPC HTTP communication in the A2A client context. @@ -26,7 +27,7 @@ import kotlin.time.Duration.Companion.minutes */ @TestInstance(TestInstance.Lifecycle.PER_CLASS) @Testcontainers -@EnabledOnOs(OS.LINUX) +@ExtendWith(DockerAvailableCondition::class) @Execution(ExecutionMode.SAME_THREAD, reason = "Working with the same instance of test server.") class A2AClientJsonRpcIntegrationTest : BaseA2AProtocolTest() { companion object { @@ -37,7 +38,7 @@ class A2AClientJsonRpcIntegrationTest : BaseA2AProtocolTest() { .waitingFor(Wait.forListeningPort()) } - override val testTimeout = 1.minutes + override val testTimeout = 10.seconds private val httpClient = HttpClient { install(Logging) { @@ -74,4 +75,36 @@ class A2AClientJsonRpcIntegrationTest : BaseA2AProtocolTest() { fun tearDown() = runTest { transport.close() } + + @Test + override fun `test get agent card`() = + super.`test get agent card`() + + @Test + override fun `test get authenticated extended agent card`() = + super.`test get authenticated extended agent card`() + + @Test + override fun `test send message`() = + super.`test send message`() + + @Test + override fun `test send message streaming`() = + super.`test send message streaming`() + + @Test + override fun `test get task`() = + super.`test get task`() + + @Test + override fun `test cancel task`() = + super.`test cancel task`() + + @Test + override fun `test resubscribe task`() = + super.`test resubscribe task`() + + @Test + override fun `test push notification configs`() = + super.`test push notification configs`() } diff --git a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt index 0f36f87973..18972d7a90 100644 --- a/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt +++ b/a2a/a2a-core/src/commonMain/kotlin/ai/koog/a2a/model/AgentCard.kt @@ -110,7 +110,9 @@ public data class AgentCard( @JvmInline @Serializable public value class TransportProtocol(public val value: String) { - @Suppress("MissingKDocForPublicAPI") + /** + * List of known transport protocols. + */ public companion object { /** * JSON-RPC protocol. diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt index f045bcad05..ea7e1b0265 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt @@ -39,16 +39,20 @@ import ai.koog.a2a.transport.Response import ai.koog.a2a.transport.ServerCallContext import ai.koog.a2a.utils.KeyedMutex import ai.koog.a2a.utils.withLock +import io.github.oshai.kotlinlogging.KotlinLogging import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CompletableJob import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Job import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancel import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.channelFlow -import kotlinx.coroutines.flow.first +import kotlinx.coroutines.flow.firstOrNull import kotlinx.coroutines.flow.flow -import kotlinx.coroutines.flow.last +import kotlinx.coroutines.flow.lastOrNull import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.onStart import kotlinx.coroutines.launch /** @@ -321,6 +325,10 @@ public open class A2AServer( protected val idGenerator: IdGenerator = UuidIdGenerator, protected val coroutineScope: CoroutineScope = CoroutineScope(SupervisorJob()), ) : RequestHandler { + private companion object { + private val logger = KotlinLogging.logger {} + } + /** * Mutex for locking specific tasks by their IDs. */ @@ -374,7 +382,7 @@ public open class A2AServer( val taskId = message.taskId ?: idGenerator.generateTaskId(message) - val session = tasksMutex.withLock(taskId) { + val (session, monitoringStarted) = tasksMutex.withLock(taskId) { // If there's a currently running session for the same task, wait for it to finish. sessionManager.getSession(taskId)?.join() @@ -412,22 +420,36 @@ public open class A2AServer( eventProcessor = eventProcessor, ) { agentExecutor.execute(requestContext, eventProcessor) - }.also { - sessionManager.addSession(it) + }.let { + it to sessionManager.addSession(it) } } + // Signal that event collection is setup + val collectionStarted: CompletableJob = Job() + // Subscribe to events stream and start emitting them. launch { session.events + .onStart { + collectionStarted.complete() + } .collect { event -> send(Response(data = event, id = request.id)) } } - // Start the session to execute the agent and wait for it to finish. - // Using await here to propagate any exceptions thrown by the agent execution. + // Ensure event collection is setup to stream events in response. + collectionStarted.join() + // Ensure monitoring is ready to monitor the session. + monitoringStarted.join() + + /* + Start the session to execute the agent and wait for it to finish. + Using await here to propagate any exceptions thrown by the agent execution. + */ session.agentJob.await() + session.join() } override suspend fun onSendMessage( @@ -440,10 +462,10 @@ public open class A2AServer( val event = if (messageConfiguration?.blocking == true) { // If blocking is requested, attempt to wait for the last event, until the current turn of the agent execution is finished. - eventStream.last() + eventStream.lastOrNull() } else { - eventStream.first() - } + eventStream.firstOrNull() + } ?: throw IllegalStateException("Can't get response from the agent: event stream is empty") return when (val eventData = event.data) { is Message -> Response(data = eventData, id = event.id) diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt index 305f18cf07..9dd8451f5e 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/Session.kt @@ -14,15 +14,25 @@ import kotlinx.coroutines.flow.collect * * @property eventProcessor The session event processor * @property agentJob The execution process associated with this session's execution - * @property events A stream of events generated during this session */ public class Session( public val eventProcessor: SessionEventProcessor, public val agentJob: Deferred ) { - public val contextId: String get() = eventProcessor.contextId - public val taskId: String get() = eventProcessor.taskId - public val events: Flow get() = eventProcessor.events + /** + * Context ID associated with this session. + */ + public val contextId: String = eventProcessor.contextId + + /** + * Task ID associated with this session. + */ + public val taskId: String = eventProcessor.taskId + + /** + * A stream of events associated with this session. + */ + public val events: Flow = eventProcessor.events /** * Starts the [agentJob], if it hasn't already been started. @@ -31,7 +41,7 @@ public class Session( agentJob.start() } - /* + /** * Suspends until the session, i.e., event stream and agent job, complete. * Waits for the event stream to finish first, to avoid triggering the agent job prematurely. * Assumes that by the time event stream is finished, agent job will already be completed or canceled. diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt index ce9148ec38..2e4fc7dbe3 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionEventProcessor.kt @@ -10,10 +10,9 @@ import ai.koog.a2a.server.exceptions.SessionNotActiveException import ai.koog.a2a.server.tasks.TaskStorage import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.MutableSharedFlow -import kotlinx.coroutines.flow.emptyFlow import kotlinx.coroutines.flow.filterIsInstance -import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.map +import kotlinx.coroutines.flow.onSubscription import kotlinx.coroutines.flow.takeWhile import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock @@ -42,8 +41,6 @@ import kotlin.jvm.JvmInline * from the incoming request or a newly generated ID that must be used if creating a new task. * Note: This taskId might not correspond to an actually existing task initially - it serves as the * identifier that will be validated against all [TaskEvent] in this session. - * @property isOpen Whether the session is open. - * @property events A hot flow of events in this session that can be subscribed to. */ @OptIn(ExperimentalAtomicApi::class) public class SessionEventProcessor( @@ -63,6 +60,10 @@ public class SessionEventProcessor( } private val _isOpen: AtomicBoolean = AtomicBoolean(true) + + /** + * Whether the session is open. + */ public val isOpen: Boolean get() = _isOpen.load() /** @@ -82,17 +83,15 @@ public class SessionEventProcessor( } private val _events = MutableSharedFlow() - public val events: Flow - get() = flow { - if (isOpen) { - _events - .takeWhile { it !is FlowEvent.Close } - .filterIsInstance() - .map { it.data } - } else { - emptyFlow() - }.collect(this) - } + + /** + * A hot flow of events in this session that can be subscribed to. + */ + public val events: Flow = _events + .onSubscription { if (!_isOpen.load()) emit(FlowEvent.Close) } + .takeWhile { it !is FlowEvent.Close } + .filterIsInstance() + .map { it.data } /** * Sends a [Message] to the session event processor. Validates the message against the session context and updates diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt index 77b2427f35..53b1550c26 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/session/SessionManager.kt @@ -8,8 +8,12 @@ import ai.koog.a2a.server.tasks.TaskStorage import ai.koog.a2a.utils.KeyedMutex import ai.koog.a2a.utils.RWLock import ai.koog.a2a.utils.withLock +import io.github.oshai.kotlinlogging.KotlinLogging +import kotlinx.coroutines.CompletableJob import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Job import kotlinx.coroutines.flow.firstOrNull +import kotlinx.coroutines.flow.onStart import kotlinx.coroutines.launch /** @@ -36,6 +40,9 @@ public class SessionManager( private val pushConfigStorage: PushNotificationConfigStorage? = null, private val pushSender: PushNotificationSender? = null, ) { + private companion object { + private val logger = KotlinLogging.logger {} + } /** * Map of task id to session. All sessions have task id associated with them, even if the task won't be created. @@ -49,9 +56,11 @@ public class SessionManager( * Sends push notifications if configured after each session completes. * * @param session The session to add. + * @return A [CompletableJob] indicating when the monitoring coroutine is started and ready to monitor the session. + * It is crucial to start agent execution only after this job completes, to ensure the monitoring won't skip any events. * @throws IllegalArgumentException if a session for the same task id already exists. */ - public suspend fun addSession(session: Session) { + public suspend fun addSession(session: Session): CompletableJob { sessionsRwLock.withWriteLock { check(session.taskId !in sessions) { "Session for taskId '${session.taskId}' already runs." @@ -60,9 +69,14 @@ public class SessionManager( sessions[session.taskId] = session } + // Signal to indicate the monitoring is started. + val monitoringStarted = Job() + // Monitor for agent job completion to send push notifications and remove session from the map. coroutineScope.launch { - val firstEvent = session.events.firstOrNull() + val firstEvent = session.events + .onStart { monitoringStarted.complete() } + .firstOrNull() // Wait for the agent job to finish session.agentJob.join() @@ -95,6 +109,8 @@ public class SessionManager( } } } + + return monitoringStarted } /** diff --git a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/.gitkeep b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/A2AServerJsonRpcIntegrationTest.kt b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/jsonrpc/A2AServerJsonRpcIntegrationTest.kt similarity index 54% rename from a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/A2AServerJsonRpcIntegrationTest.kt rename to a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/jsonrpc/A2AServerJsonRpcIntegrationTest.kt index 4c0c82617b..fb67bc6abf 100644 --- a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/A2AServerJsonRpcIntegrationTest.kt +++ b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/jsonrpc/A2AServerJsonRpcIntegrationTest.kt @@ -1,11 +1,5 @@ -package ai.koog.a2a.server +package ai.koog.a2a.server.jsonrpc -import ai.koog.a2a.client.A2AClient -import ai.koog.a2a.client.UrlAgentCardResolver -import ai.koog.a2a.consts.A2AConsts -import ai.koog.a2a.model.AgentCapabilities -import ai.koog.a2a.model.AgentCard -import ai.koog.a2a.model.AgentSkill import ai.koog.a2a.model.Message import ai.koog.a2a.model.MessageSendConfiguration import ai.koog.a2a.model.MessageSendParams @@ -15,12 +9,7 @@ import ai.koog.a2a.model.TaskIdParams import ai.koog.a2a.model.TaskState import ai.koog.a2a.model.TaskStatusUpdateEvent import ai.koog.a2a.model.TextPart -import ai.koog.a2a.model.TransportProtocol -import ai.koog.a2a.server.notifications.InMemoryPushNotificationConfigStorage -import ai.koog.a2a.test.BaseA2AProtocolTest import ai.koog.a2a.transport.Request -import ai.koog.a2a.transport.client.jsonrpc.http.HttpJSONRPCClientTransport -import ai.koog.a2a.transport.server.jsonrpc.http.HttpJSONRPCServerTransport import io.kotest.inspectors.shouldForAll import io.kotest.inspectors.shouldForAtLeastOne import io.kotest.matchers.nulls.shouldNotBeNull @@ -28,18 +17,11 @@ import io.kotest.matchers.should import io.kotest.matchers.shouldBe import io.kotest.matchers.string.shouldStartWith import io.kotest.matchers.types.shouldBeInstanceOf -import io.ktor.client.HttpClient -import io.ktor.client.engine.cio.CIO -import io.ktor.client.plugins.HttpTimeout -import io.ktor.client.plugins.logging.LogLevel -import io.ktor.client.plugins.logging.Logging -import io.ktor.server.netty.Netty import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.delay import kotlinx.coroutines.flow.toList import kotlinx.coroutines.joinAll import kotlinx.coroutines.launch -import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runTest import kotlinx.coroutines.withContext import org.junit.jupiter.api.AfterAll @@ -47,10 +29,9 @@ import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.TestInstance import org.junit.jupiter.api.parallel.Execution import org.junit.jupiter.api.parallel.ExecutionMode -import java.net.ServerSocket import kotlin.test.BeforeTest import kotlin.test.Test -import kotlin.time.Duration.Companion.minutes +import kotlin.time.Duration.Companion.seconds import kotlin.uuid.ExperimentalUuidApi import kotlin.uuid.Uuid @@ -62,170 +43,55 @@ import kotlin.uuid.Uuid @OptIn(ExperimentalUuidApi::class) @TestInstance(TestInstance.Lifecycle.PER_CLASS) @Execution(ExecutionMode.SAME_THREAD, reason = "Working with the same instance of test server.") -class A2AServerJsonRpcIntegrationTest : BaseA2AProtocolTest() { - override val testTimeout = 2.minutes - - private var testPort: Int? = null - private val testPath = "/a2a" - private lateinit var serverUrl: String - - private lateinit var serverTransport: HttpJSONRPCServerTransport - private lateinit var clientTransport: HttpJSONRPCClientTransport - private lateinit var httpClient: HttpClient - - override lateinit var client: A2AClient +class A2AServerJsonRpcIntegrationTest : BaseA2AServerJsonRpcTest() { + override val testTimeout = 10.seconds @BeforeAll - fun setup(): Unit = runBlocking { - // Discover and take any free port - testPort = ServerSocket(0).use { it.localPort } - serverUrl = "http://localhost:$testPort$testPath" - - // Create agent cards - val agentCard = createAgentCard() - val agentCardExtended = createExtendedAgentCard() - - // Create test agent executor - val testAgentExecutor = TestAgentExecutor() - - // Create A2A server - val a2aServer = A2AServer( - agentExecutor = testAgentExecutor, - agentCard = agentCard, - agentCardExtended = agentCardExtended, - pushConfigStorage = InMemoryPushNotificationConfigStorage() - ) + override fun setup() { + super.setup() + } - // Create server transport - serverTransport = HttpJSONRPCServerTransport(a2aServer) + @BeforeTest + override fun initClient() { + super.initClient() + } - // Start server - serverTransport.start( - engineFactory = Netty, - port = testPort!!, - path = testPath, - wait = false, - agentCard = agentCard, - agentCardPath = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH, - ) + @AfterAll + override fun tearDown() { + super.tearDown() + } - // Create client transport - httpClient = HttpClient(CIO) { - install(Logging) { - level = LogLevel.ALL - } + @Test + override fun `test get agent card`() = + super.`test get agent card`() - install(HttpTimeout) { - requestTimeoutMillis = testTimeout.inWholeMilliseconds - } - } + @Test + override fun `test get authenticated extended agent card`() = + super.`test get authenticated extended agent card`() - clientTransport = HttpJSONRPCClientTransport(serverUrl, httpClient) + @Test + override fun `test send message`() = + super.`test send message`() - client = A2AClient( - transport = clientTransport, - agentCardResolver = UrlAgentCardResolver( - baseUrl = serverUrl, - path = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH, - baseHttpClient = httpClient, - ) - ) - } + @Test + override fun `test send message streaming`() = + super.`test send message streaming`() - @BeforeTest - fun initClient(): Unit = runBlocking { - client.connect() - } + @Test + override fun `test get task`() = + super.`test get task`() - @AfterAll - fun tearDown(): Unit = runBlocking { - clientTransport.close() - serverTransport.stop() - } + @Test + override fun `test cancel task`() = + super.`test cancel task`() - private fun createAgentCard(): AgentCard = AgentCard( - protocolVersion = "0.3.0", - name = "Hello World Agent", - description = "Just a hello world agent", - url = "http://localhost:9999/", - preferredTransport = TransportProtocol.JSONRPC, - additionalInterfaces = null, - iconUrl = null, - provider = null, - version = "1.0.0", - documentationUrl = null, - capabilities = AgentCapabilities( - streaming = true, - pushNotifications = true, - stateTransitionHistory = null, - extensions = null - ), - securitySchemes = null, - security = null, - defaultInputModes = listOf("text"), - defaultOutputModes = listOf("text"), - skills = listOf( - AgentSkill( - id = "hello_world", - name = "Returns hello world", - description = "just returns hello world", - tags = listOf("hello world"), - examples = listOf("hi", "hello world"), - inputModes = null, - outputModes = null, - security = null - ) - ), - supportsAuthenticatedExtendedCard = true, - signatures = null - ) + @Test + override fun `test resubscribe task`() = + super.`test resubscribe task`() - private fun createExtendedAgentCard(): AgentCard = AgentCard( - protocolVersion = "0.3.0", - name = "Hello World Agent - Extended Edition", - description = "The full-featured hello world agent for authenticated users.", - url = "http://localhost:9999/", - preferredTransport = TransportProtocol.JSONRPC, - additionalInterfaces = null, - iconUrl = null, - provider = null, - version = "1.0.1", - documentationUrl = null, - capabilities = AgentCapabilities( - streaming = true, - pushNotifications = true, - stateTransitionHistory = null, - extensions = null - ), - securitySchemes = null, - security = null, - defaultInputModes = listOf("text"), - defaultOutputModes = listOf("text"), - skills = listOf( - AgentSkill( - id = "hello_world", - name = "Returns hello world", - description = "just returns hello world", - tags = listOf("hello world"), - examples = listOf("hi", "hello world"), - inputModes = null, - outputModes = null, - security = null - ), - AgentSkill( - id = "super_hello_world", - name = "Returns a SUPER Hello World", - description = "A more enthusiastic greeting, only for authenticated users.", - tags = listOf("hello world", "super", "extended"), - examples = listOf("super hi", "give me a super hello"), - inputModes = null, - outputModes = null, - security = null - ) - ), - supportsAuthenticatedExtendedCard = true, - signatures = null - ) + @Test + override fun `test push notification configs`() = + super.`test push notification configs`() /** * Extended test that wouldn't work with Python A2A SDK server, because their implementation has some problems. @@ -239,7 +105,7 @@ class A2AServerJsonRpcIntegrationTest : BaseA2AProtocolTest() { val createTaskRequest = Request( data = MessageSendParams( message = Message( - messageId = Uuid.random().toString(), + messageId = Uuid.Companion.random().toString(), role = Role.User, parts = listOf( TextPart("do long-running task"), @@ -339,7 +205,7 @@ class A2AServerJsonRpcIntegrationTest : BaseA2AProtocolTest() { ) = Request( data = MessageSendParams( message = Message( - messageId = Uuid.random().toString(), + messageId = Uuid.Companion.random().toString(), role = Role.User, parts = listOf( TextPart("do long-running task"), diff --git a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/jsonrpc/BaseA2AServerJsonRpcTest.kt b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/jsonrpc/BaseA2AServerJsonRpcTest.kt new file mode 100644 index 0000000000..d023683889 --- /dev/null +++ b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/jsonrpc/BaseA2AServerJsonRpcTest.kt @@ -0,0 +1,184 @@ +package ai.koog.a2a.server.jsonrpc + +import ai.koog.a2a.client.A2AClient +import ai.koog.a2a.client.UrlAgentCardResolver +import ai.koog.a2a.consts.A2AConsts +import ai.koog.a2a.model.AgentCapabilities +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.AgentSkill +import ai.koog.a2a.model.TransportProtocol +import ai.koog.a2a.server.A2AServer +import ai.koog.a2a.server.TestAgentExecutor +import ai.koog.a2a.server.notifications.InMemoryPushNotificationConfigStorage +import ai.koog.a2a.test.BaseA2AProtocolTest +import ai.koog.a2a.transport.client.jsonrpc.http.HttpJSONRPCClientTransport +import ai.koog.a2a.transport.server.jsonrpc.http.HttpJSONRPCServerTransport +import io.ktor.client.HttpClient +import io.ktor.client.engine.cio.CIO +import io.ktor.client.plugins.HttpTimeout +import io.ktor.client.plugins.logging.LogLevel +import io.ktor.client.plugins.logging.Logging +import io.ktor.server.netty.Netty +import kotlinx.coroutines.runBlocking +import java.net.ServerSocket + +abstract class BaseA2AServerJsonRpcTest : BaseA2AProtocolTest() { + protected var testPort: Int? = null + protected val testPath = "/a2a" + protected lateinit var serverUrl: String + + protected lateinit var serverTransport: HttpJSONRPCServerTransport + protected lateinit var clientTransport: HttpJSONRPCClientTransport + protected lateinit var httpClient: HttpClient + + override lateinit var client: A2AClient + + open fun setup(): Unit = runBlocking { + // Discover and take any free port + testPort = ServerSocket(0).use { it.localPort } + serverUrl = "http://localhost:$testPort$testPath" + + // Create agent cards + val agentCard = createAgentCard() + val agentCardExtended = createExtendedAgentCard() + + // Create test agent executor + val testAgentExecutor = TestAgentExecutor() + + // Create A2A server + val a2aServer = A2AServer( + agentExecutor = testAgentExecutor, + agentCard = agentCard, + agentCardExtended = agentCardExtended, + pushConfigStorage = InMemoryPushNotificationConfigStorage() + ) + + // Create server transport + serverTransport = HttpJSONRPCServerTransport(a2aServer) + + // Start server + serverTransport.start( + engineFactory = Netty, + port = testPort!!, + path = testPath, + wait = false, + agentCard = agentCard, + agentCardPath = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH, + ) + + // Create client transport + httpClient = HttpClient(CIO) { + install(Logging) { + level = LogLevel.ALL + } + + install(HttpTimeout) { + requestTimeoutMillis = testTimeout.inWholeMilliseconds + } + } + + clientTransport = HttpJSONRPCClientTransport(serverUrl, httpClient) + + client = A2AClient( + transport = clientTransport, + agentCardResolver = UrlAgentCardResolver( + baseUrl = serverUrl, + path = A2AConsts.AGENT_CARD_WELL_KNOWN_PATH, + baseHttpClient = httpClient, + ) + ) + } + + open fun initClient(): Unit = runBlocking { + client.connect() + } + + open fun tearDown() = runBlocking { + clientTransport.close() + serverTransport.stop() + } + + private fun createAgentCard(): AgentCard = AgentCard( + protocolVersion = "0.3.0", + name = "Hello World Agent", + description = "Just a hello world agent", + url = "http://localhost:9999/", + preferredTransport = TransportProtocol.JSONRPC, + additionalInterfaces = null, + iconUrl = null, + provider = null, + version = "1.0.0", + documentationUrl = null, + capabilities = AgentCapabilities( + streaming = true, + pushNotifications = true, + stateTransitionHistory = null, + extensions = null + ), + securitySchemes = null, + security = null, + defaultInputModes = listOf("text"), + defaultOutputModes = listOf("text"), + skills = listOf( + AgentSkill( + id = "hello_world", + name = "Returns hello world", + description = "just returns hello world", + tags = listOf("hello world"), + examples = listOf("hi", "hello world"), + inputModes = null, + outputModes = null, + security = null + ) + ), + supportsAuthenticatedExtendedCard = true, + signatures = null + ) + + private fun createExtendedAgentCard(): AgentCard = AgentCard( + protocolVersion = "0.3.0", + name = "Hello World Agent - Extended Edition", + description = "The full-featured hello world agent for authenticated users.", + url = "http://localhost:9999/", + preferredTransport = TransportProtocol.Companion.JSONRPC, + additionalInterfaces = null, + iconUrl = null, + provider = null, + version = "1.0.1", + documentationUrl = null, + capabilities = AgentCapabilities( + streaming = true, + pushNotifications = true, + stateTransitionHistory = null, + extensions = null + ), + securitySchemes = null, + security = null, + defaultInputModes = listOf("text"), + defaultOutputModes = listOf("text"), + skills = listOf( + AgentSkill( + id = "hello_world", + name = "Returns hello world", + description = "just returns hello world", + tags = listOf("hello world"), + examples = listOf("hi", "hello world"), + inputModes = null, + outputModes = null, + security = null + ), + AgentSkill( + id = "super_hello_world", + name = "Returns a SUPER Hello World", + description = "A more enthusiastic greeting, only for authenticated users.", + tags = listOf("hello world", "super", "extended"), + examples = listOf("super hi", "give me a super hello"), + inputModes = null, + outputModes = null, + security = null + ) + ), + supportsAuthenticatedExtendedCard = true, + signatures = null + ) +} diff --git a/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/jsonrpc/StressA2AServerJsonRpcIntegrationTest.kt b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/jsonrpc/StressA2AServerJsonRpcIntegrationTest.kt new file mode 100644 index 0000000000..dc9b1ea2c1 --- /dev/null +++ b/a2a/a2a-server/src/jvmTest/kotlin/ai/koog/a2a/server/jsonrpc/StressA2AServerJsonRpcIntegrationTest.kt @@ -0,0 +1,47 @@ +package ai.koog.a2a.server.jsonrpc + +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.Disabled +import org.junit.jupiter.api.RepeatedTest +import org.junit.jupiter.api.TestInstance +import org.junit.jupiter.api.parallel.Execution +import org.junit.jupiter.api.parallel.ExecutionMode +import kotlin.test.BeforeTest +import kotlin.time.Duration.Companion.seconds + +/** + * Stress-testing event-stream related requests to check that events are processed correctly under a high load. + * Also more samples help with finding some flaky behavior, e.g. race conditions. + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +@Execution(ExecutionMode.SAME_THREAD, reason = "Working with the same instance of test server.") +@Disabled("TODO add stress tests in heavy tests only") // TODO add in heavy tests +class StressA2AServerJsonRpcIntegrationTest : BaseA2AServerJsonRpcTest() { + override val testTimeout = 10.seconds + + @BeforeAll + override fun setup() { + super.setup() + } + + @BeforeTest + override fun initClient() { + super.initClient() + } + + @AfterAll + override fun tearDown() { + super.tearDown() + } + + @RepeatedTest(300, name = "{currentRepetition}/{totalRepetitions}") + fun `stress test cancel task`() = super.`test cancel task`() + + @RepeatedTest(300, name = "{currentRepetition}/{totalRepetitions}") + fun `stress test send message`() = super.`test send message`() + + // Long test, lower repetitions + @RepeatedTest(10, name = "{currentRepetition}/{totalRepetitions}") + fun `stress test resubscribe task`() = super.`test resubscribe task`() +} diff --git a/a2a/a2a-test/src/jvmMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt b/a2a/a2a-test/src/jvmMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt index d20014b792..ec04b6769f 100644 --- a/a2a/a2a-test/src/jvmMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt +++ b/a2a/a2a-test/src/jvmMain/kotlin/ai/koog/a2a/test/BaseA2AProtocolTest.kt @@ -32,7 +32,6 @@ import io.kotest.matchers.string.shouldStartWith import io.kotest.matchers.types.shouldBeInstanceOf import kotlinx.coroutines.flow.toList import kotlinx.coroutines.test.runTest -import kotlin.test.Test import kotlin.time.Duration import kotlin.uuid.ExperimentalUuidApi import kotlin.uuid.Uuid @@ -55,8 +54,7 @@ abstract class BaseA2AProtocolTest { */ protected abstract var client: A2AClient - @Test - fun `test get agent card`() = runTest(timeout = testTimeout) { + open fun `test get agent card`() = runTest(timeout = testTimeout) { val agentCard = client.getAgentCard() // Assert on the full AgentCard structure @@ -100,8 +98,7 @@ abstract class BaseA2AProtocolTest { agentCard shouldBe expectedAgentCard } - @Test - fun `test get authenticated extended agent card`() = runTest(timeout = testTimeout) { + open fun `test get authenticated extended agent card`() = runTest(timeout = testTimeout) { val request = Request(data = null) val response = client.getAuthenticatedExtendedAgentCard(request) @@ -157,8 +154,7 @@ abstract class BaseA2AProtocolTest { response.data shouldBe expectedExtendedAgentCard } - @Test - fun `test send message`() = runTest(timeout = testTimeout) { + open fun `test send message`() = runTest(timeout = testTimeout) { val request = Request( data = MessageSendParams( message = Message( @@ -185,8 +181,7 @@ abstract class BaseA2AProtocolTest { } } - @Test - fun `test send message streaming`() = runTest(timeout = testTimeout) { + open fun `test send message streaming`() = runTest(timeout = testTimeout) { val createTaskRequest = Request( data = MessageSendParams( message = Message( @@ -247,8 +242,7 @@ abstract class BaseA2AProtocolTest { } } - @Test - fun `test get task`() = runTest(timeout = testTimeout) { + open fun `test get task`() = runTest(timeout = testTimeout) { val createTaskRequest = Request( data = MessageSendParams( message = Message( @@ -282,8 +276,7 @@ abstract class BaseA2AProtocolTest { } } - @Test - fun `test cancel task`() = runTest(timeout = testTimeout) { + open fun `test cancel task`() = runTest(timeout = testTimeout) { val createTaskRequest = Request( data = MessageSendParams( message = Message( @@ -320,8 +313,7 @@ abstract class BaseA2AProtocolTest { } } - @Test - fun `test resubscribe task`() = runTest(timeout = testTimeout) { + open fun `test resubscribe task`() = runTest(timeout = testTimeout) { val createTaskRequest = Request( data = MessageSendParams( message = Message( @@ -374,8 +366,7 @@ abstract class BaseA2AProtocolTest { } } - @Test - fun `test push notification configs`() = runTest(timeout = testTimeout) { + open fun `test push notification configs`() = runTest(timeout = testTimeout) { val createTaskRequest = Request( data = MessageSendParams( message = Message( diff --git a/agents/agents-features/agents-features-sql/build.gradle.kts b/agents/agents-features/agents-features-sql/build.gradle.kts index 3fe859c67e..98e739ae36 100644 --- a/agents/agents-features/agents-features-sql/build.gradle.kts +++ b/agents/agents-features/agents-features-sql/build.gradle.kts @@ -47,6 +47,7 @@ kotlin { dependencies { implementation(kotlin("test-junit5")) implementation(project(":agents:agents-test")) + implementation(project(":test-utils")) implementation(libs.mockk) implementation(libs.testcontainers) implementation(libs.testcontainers.postgresql) diff --git a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProviderTest.kt b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProviderTest.kt index e6777eed2f..15cdaa40bb 100644 --- a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProviderTest.kt +++ b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProviderTest.kt @@ -4,6 +4,7 @@ import ai.koog.agents.snapshot.feature.AgentCheckpointData import ai.koog.prompt.message.Message import ai.koog.prompt.message.RequestMetaInfo import ai.koog.prompt.message.ResponseMetaInfo +import ai.koog.test.utils.DockerAvailableCondition import kotlinx.coroutines.delay import kotlinx.coroutines.runBlocking import kotlinx.datetime.Clock @@ -11,14 +12,13 @@ import kotlinx.serialization.json.JsonPrimitive import org.junit.jupiter.api.Test import org.junit.jupiter.api.TestInstance import org.junit.jupiter.api.TestInstance.Lifecycle -import org.junit.jupiter.api.condition.EnabledOnOs -import org.junit.jupiter.api.condition.OS +import org.junit.jupiter.api.extension.ExtendWith import kotlin.test.assertEquals import kotlin.test.assertNotNull import kotlin.test.assertNull @TestInstance(Lifecycle.PER_CLASS) -@EnabledOnOs(OS.LINUX) +@ExtendWith(DockerAvailableCondition::class) class H2PersistencyStorageProviderTest { private fun provider(ttlSeconds: Long? = null): H2PersistencyStorageProvider { diff --git a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProviderTest.kt b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProviderTest.kt index 893da5a1d6..d3c90fe1ab 100644 --- a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProviderTest.kt +++ b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProviderTest.kt @@ -4,6 +4,7 @@ import ai.koog.agents.snapshot.feature.AgentCheckpointData import ai.koog.prompt.message.Message import ai.koog.prompt.message.RequestMetaInfo import ai.koog.prompt.message.ResponseMetaInfo +import ai.koog.test.utils.DockerAvailableCondition import kotlinx.coroutines.delay import kotlinx.coroutines.runBlocking import kotlinx.datetime.Clock @@ -14,8 +15,7 @@ import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.Test import org.junit.jupiter.api.TestInstance import org.junit.jupiter.api.TestInstance.Lifecycle -import org.junit.jupiter.api.condition.EnabledOnOs -import org.junit.jupiter.api.condition.OS +import org.junit.jupiter.api.extension.ExtendWith import org.testcontainers.containers.MySQLContainer import org.testcontainers.utility.DockerImageName import kotlin.test.assertEquals @@ -23,7 +23,7 @@ import kotlin.test.assertNotNull import kotlin.test.assertNull @TestInstance(Lifecycle.PER_CLASS) -@EnabledOnOs(OS.LINUX) +@ExtendWith(DockerAvailableCondition::class) class MySQLPersistencyStorageProviderTest { private lateinit var mysql: MySQLContainer<*> diff --git a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProviderTest.kt b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProviderTest.kt index f189f84d9b..d93e849c66 100644 --- a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProviderTest.kt +++ b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProviderTest.kt @@ -4,6 +4,7 @@ import ai.koog.agents.snapshot.feature.AgentCheckpointData import ai.koog.prompt.message.Message import ai.koog.prompt.message.RequestMetaInfo import ai.koog.prompt.message.ResponseMetaInfo +import ai.koog.test.utils.DockerAvailableCondition import kotlinx.coroutines.runBlocking import kotlinx.datetime.Clock import kotlinx.serialization.json.JsonPrimitive @@ -13,8 +14,7 @@ import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.Test import org.junit.jupiter.api.TestInstance import org.junit.jupiter.api.TestInstance.Lifecycle -import org.junit.jupiter.api.condition.EnabledOnOs -import org.junit.jupiter.api.condition.OS +import org.junit.jupiter.api.extension.ExtendWith import org.testcontainers.containers.PostgreSQLContainer import org.testcontainers.utility.DockerImageName import kotlin.test.assertEquals @@ -22,7 +22,7 @@ import kotlin.test.assertNotNull import kotlin.test.assertNull @TestInstance(Lifecycle.PER_CLASS) -@EnabledOnOs(OS.LINUX) +@ExtendWith(DockerAvailableCondition::class) class PostgresPersistencyStorageProviderTest { private lateinit var postgres: PostgreSQLContainer<*> diff --git a/test-utils/build.gradle.kts b/test-utils/build.gradle.kts index 688ab4e94b..79ea57e818 100644 --- a/test-utils/build.gradle.kts +++ b/test-utils/build.gradle.kts @@ -22,6 +22,7 @@ kotlin { dependencies { api(kotlin("test-junit5")) api(libs.junit.jupiter.params) + api(libs.testcontainers) runtimeOnly(libs.slf4j.simple) } } diff --git a/test-utils/src/jvmMain/kotlin/ai/koog/test/utils/DockerAvailableCondition.kt b/test-utils/src/jvmMain/kotlin/ai/koog/test/utils/DockerAvailableCondition.kt new file mode 100644 index 0000000000..35fe37e977 --- /dev/null +++ b/test-utils/src/jvmMain/kotlin/ai/koog/test/utils/DockerAvailableCondition.kt @@ -0,0 +1,20 @@ +package ai.koog.test.utils + +import org.junit.jupiter.api.extension.ConditionEvaluationResult +import org.junit.jupiter.api.extension.ExecutionCondition +import org.junit.jupiter.api.extension.ExtensionContext +import org.testcontainers.DockerClientFactory + +/** + * Helper test condition method to skip test suite if Docker is not available. + */ +public class DockerAvailableCondition : ExecutionCondition { + override fun evaluateExecutionCondition(context: ExtensionContext): ConditionEvaluationResult { + return try { + DockerClientFactory.instance().client() + ConditionEvaluationResult.enabled("Docker is available") + } catch (e: Exception) { + ConditionEvaluationResult.disabled("Docker is not available, skipping this test") + } + } +} From 64dcd92ac0a78f219cc5387314d6ae6eab5acb78 Mon Sep 17 00:00:00 2001 From: Inna Teteniuk Date: Thu, 2 Oct 2025 09:40:06 +0200 Subject: [PATCH 42/52] Rename "persistency" to "persistence" in documentation (#896) ## Motivation and Context ## Breaking Changes --- #### Type of the changes - [ ] New feature (non-breaking change which adds functionality) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Tests improvement - [ ] Refactoring #### Checklist - [ ] The pull request has a description of the proposed change - [ ] I read the [Contributing Guidelines](https://github.com/JetBrains/koog/blob/main/CONTRIBUTING.md) before opening the pull request - [ ] The pull request uses **`develop`** as the base branch - [ ] Tests for the changes have been added - [ ] All new and existing tests passed ##### Additional steps for pull requests adding a new feature - [ ] An issue describing the proposed change exists - [ ] The pull request includes a link to the issue - [ ] The change was discussed and approved in the issue - [ ] Docs have been added / updated --- ...nt-persistency.md => agent-persistence.md} | 48 +++++++++---------- docs/mkdocs.yml | 6 +-- 2 files changed, 27 insertions(+), 27 deletions(-) rename docs/docs/{agent-persistency.md => agent-persistence.md} (89%) diff --git a/docs/docs/agent-persistency.md b/docs/docs/agent-persistence.md similarity index 89% rename from docs/docs/agent-persistency.md rename to docs/docs/agent-persistence.md index d50b75b4f1..79f31c03ca 100644 --- a/docs/docs/agent-persistency.md +++ b/docs/docs/agent-persistence.md @@ -1,11 +1,11 @@ -# Agent Persistency +# Agent Persistence -Agent Persistency is a feature that provides checkpoint functionality for AI agents in the Koog framework. +Agent persistence is a feature that provides checkpoint functionality for AI agents in the Koog framework. It lets you save and restore the state of an agent at specific points during execution, enabling capabilities such as: -- Resuming agent execution from a specific point -- Rolling back to previous states -- Persisting agent state across sessions +- Resuming agent execution from a specific point. +- Rolling back to previous states. +- Persisting agent state across sessions. ## Key concepts @@ -22,7 +22,7 @@ Checkpoints are identified by unique IDs and are associated with a specific agen ## Prerequisites -The Agent Persistency feature requires that all nodes in your agent's strategy have unique names. +The Agent Persistence feature requires that all nodes in your agent's strategy have unique names. This is enforced when the feature is installed: + Make sure to set unique names for nodes in your graph. ## Installation -To use the Agent Persistency feature, add it to your agent's configuration: +To use the Agent Persistence feature, add it to your agent's configuration: + ## Configuration options -The Agent Persistency feature has two main configuration options: +The Agent Persistence feature has two main configuration options: - **Storage provider**: the provider used to save and retrieve checkpoints. - **Continuous persistence**: automatic creation of checkpoints after each node is run. @@ -105,7 +105,7 @@ install(Persistency) { } ``` - + The framework includes the following built-in providers: @@ -144,7 +144,7 @@ install(Persistency) { } ``` - + When activated, the agent will automatically create a checkpoint after each node is executed, allowing for fine-grained recovery. @@ -180,7 +180,7 @@ suspend fun example(context: AIAgentContext) { } ``` - + ### Restoring from a checkpoint @@ -201,7 +201,7 @@ suspend fun example(context: AIAgentContext, checkpointId: String) { } ``` - + #### Rolling back all side-effects produced by tools @@ -222,7 +222,7 @@ And now you would like to roll back to a checkpoint. Restoring the agent's state be sufficient to achieve the exact state of the world before the checkpoint. You should also restore the side-effects produced by your tool calls. In our example, this would mean removing `Maria` and `Daniel` from the database. -With Koog Persistency you can achieve that by providing a `RollbackToolRegistry` to `Persistency` feature config: +With Koog Persistence you can achieve that by providing a `RollbackToolRegistry` to `Persistency` feature config: + ### Using extension functions -The Agent Persistency feature provides convenient extension functions for working with checkpoints: +The Agent Persistence feature provides convenient extension functions for working with checkpoints: @@ -291,7 +291,7 @@ suspend fun example(context: AIAgentContext) { } } ``` - + ## Advanced usage @@ -325,9 +325,9 @@ class MyCustomStorageProvider : PersistencyStorageProvider { } ``` - + -To use your custom provider in the feature configuration, set it as the storage when configuring the Agent Persistency +To use your custom provider in the feature configuration, set it as the storage when configuring the Agent Persistence feature in your agent. + ### Setting execution points @@ -395,6 +395,6 @@ fun example(context: AIAgentContext) { ``` - + This allows for more fine-grained control over the agent's state beyond just restoring from checkpoints. diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 7593012fa0..91fe5a986d 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -6,8 +6,8 @@ nav: - Key concepts: key-concepts.md - Getting started: - Single-run agents: single-run-agents.md - - Complex workflow agents: complex-workflow-agents.md - Functional agents: functional-agents.md + - Complex workflow agents: complex-workflow-agents.md - Prompt API: prompt-api.md - Tools: - Overview: tools-overview.md @@ -42,7 +42,7 @@ nav: - Overview: features-overview.md - Tracing: tracing.md - Memory: agent-memory.md - - Agent Persistency: agent-persistency.md + - Agent Persistence: agent-persistence.md - OpenTelemetry: - Overview: opentelemetry-support.md - Langfuse Exporter: opentelemetry-langfuse-exporter.md @@ -142,7 +142,7 @@ plugins: - features-overview.md: This page provides a short overview of features as a way to extend and enhance the functionality of AI agents. - tracing.md: This page includes details about the Tracing feature, which provides comprehensive tracing capabilities for AI agents. The page includes configuration and initialization details, examples and quickstart, details about error handling and FAQ and troubleshooting. - agent-memory.md: This page provides details about the AgentMemory feature which lets AI agents store, retrieve, and use information across conversations. The page includes configuration and initialization details, examples and quickstart, best practices, error handling and FAQ and troubleshooting. - - agent-persistency.md: This page describes Agent Persistency, which is a feature that lets you save and restore the state of an agent at specific points during execution. The page includes key concepts, prerequisites, installation instructions, configuration, and basic and advanced usage guides. + - agent-persistence.md: This page describes Agent Persistence, which is a feature that lets you save and restore the state of an agent at specific points during execution. The page includes key concepts, prerequisites, installation instructions, configuration, and basic and advanced usage guides. OpenTelemetry: - opentelemetry-support.md: This page provides details about the support for OpenTelemetry with the Koog agentic framework for tracing and monitoring AI agents. The page includes details about installation and configuration, span types and attributes, and common exporters. - opentelemetry-langfuse-exporter.md: This page provides information about Koog's built-in support for exporting agent traces to Langfuse, a platform for observability and analytics of AI applications. The page provides details about LangFuse integration configuration and an example of its use. From 1952377569be5cfb9022c772a099e6505fbcf36a Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Thu, 2 Oct 2025 03:22:42 +0200 Subject: [PATCH 43/52] [simple-examples] Add advanced A2A joke writer agent --- .../a2a/server/feature/A2AAgentServerNodes.kt | 19 - examples/simple-examples/build.gradle.kts | 6 +- .../ai/koog/agents/example/a2a/.gitkeep | 0 .../agents/example/a2a/advancedjoke/Client.kt | 154 +++++++ .../advancedjoke/JokeWriterAgentExecutor.kt | 406 ++++++++++++++++++ .../agents/example/a2a/advancedjoke/Server.kt | 79 ++++ 6 files changed, 644 insertions(+), 20 deletions(-) delete mode 100644 examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/.gitkeep create mode 100644 examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/Client.kt create mode 100644 examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/JokeWriterAgentExecutor.kt create mode 100644 examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/Server.kt diff --git a/agents/agents-features/agents-features-a2a-server/src/commonMain/kotlin/ai/koog/agents/a2a/server/feature/A2AAgentServerNodes.kt b/agents/agents-features/agents-features-a2a-server/src/commonMain/kotlin/ai/koog/agents/a2a/server/feature/A2AAgentServerNodes.kt index de2a5e5086..4dff75c52c 100644 --- a/agents/agents-features/agents-features-a2a-server/src/commonMain/kotlin/ai/koog/agents/a2a/server/feature/A2AAgentServerNodes.kt +++ b/agents/agents-features/agents-features-a2a-server/src/commonMain/kotlin/ai/koog/agents/a2a/server/feature/A2AAgentServerNodes.kt @@ -105,25 +105,6 @@ public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AMessageStorageReplace( } } -/** - * Creates a node that loads all messages from storage for the current context. - * - * This is an alias for [nodeA2AMessageStorageLoad]. - * - * @param name Optional node name for debugging and tracing - * @return A node that returns the list of all stored messages - * @see ai.koog.a2a.server.messages.MessageStorage.getByContext - */ -@AIAgentBuilderDslMarker -public fun AIAgentSubgraphBuilderBase<*, *>.nodeA2AMessagesContextLoad( - name: String? = null, -): AIAgentNodeDelegate> = - node(name) { - withA2AAgentServer { - context.messageStorage.getAll() - } - } - /** * Parameters for retrieving a single task from storage. * diff --git a/examples/simple-examples/build.gradle.kts b/examples/simple-examples/build.gradle.kts index 5d66734641..6e5e2d54ae 100644 --- a/examples/simple-examples/build.gradle.kts +++ b/examples/simple-examples/build.gradle.kts @@ -120,6 +120,10 @@ registerRunExampleTask("runExampleStreamingWithTools", "ai.koog.agents.example.s A2A examples */ -// joke generation +// Simple joke generation registerRunExampleTask("runExampleSimpleJokeAgentServer", "ai.koog.agents.example.a2a.simplejoke.ServerKt") registerRunExampleTask("runExampleSimpleJokeAgentClient", "ai.koog.agents.example.a2a.simplejoke.ClientKt") + +// Advanced joke generation +registerRunExampleTask("runExampleAdvancedJokeAgentServer", "ai.koog.agents.example.a2a.advancedjoke.ServerKt") +registerRunExampleTask("runExampleAdvancedJokeAgentClient", "ai.koog.agents.example.a2a.advancedjoke.ClientKt") diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/.gitkeep b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/Client.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/Client.kt new file mode 100644 index 0000000000..8c6464ed7c --- /dev/null +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/Client.kt @@ -0,0 +1,154 @@ +@file:OptIn(ExperimentalUuidApi::class) + +package ai.koog.agents.example.a2a.advancedjoke + +import ai.koog.a2a.client.A2AClient +import ai.koog.a2a.client.UrlAgentCardResolver +import ai.koog.a2a.model.Artifact +import ai.koog.a2a.model.Message +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskArtifactUpdateEvent +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.transport.Request +import ai.koog.a2a.transport.client.jsonrpc.http.HttpJSONRPCClientTransport +import kotlinx.serialization.encodeToString +import kotlinx.serialization.json.Json +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +private const val CYAN = "\u001B[36m" +private const val YELLOW = "\u001B[33m" +private const val MAGENTA = "\u001B[35m" +private const val GREEN = "\u001B[32m" +private const val RED = "\u001B[31m" +private const val BLUE = "\u001B[34m" +private const val RESET = "\u001B[0m" + +private val json = Json { prettyPrint = true } + +@OptIn(ExperimentalUuidApi::class) +suspend fun main() { + println("\n${YELLOW}Starting Advanced Joke Generator A2A Client$RESET\n") + + val transport = HttpJSONRPCClientTransport(url = "http://localhost:9999$ADVANCED_JOKE_AGENT_PATH") + val agentCardResolver = UrlAgentCardResolver(baseUrl = "http://localhost:9999", path = ADVANCED_JOKE_AGENT_CARD_PATH) + val client = A2AClient(transport = transport, agentCardResolver = agentCardResolver) + + client.connect() + val agentCard = client.cachedAgentCard() + println("${YELLOW}Connected: ${agentCard.name}$RESET\n") + + if (agentCard.capabilities.streaming != true) { + println("${RED}Error: Streaming not supported$RESET") + transport.close() + return + } + + println("${CYAN}Context ID:$RESET") + val contextId = readln() + println() + + var currentTaskId: String? = null + val artifacts = mutableMapOf() + + while (true) { + println("${CYAN}Request (/q to quit):$RESET") + val request = readln() + println() + + if (request == "/q") break + + val message = Message( + messageId = Uuid.random().toString(), + role = Role.User, + parts = listOf(TextPart(request)), + contextId = contextId, + taskId = currentTaskId + ) + + try { + client.sendMessageStreaming(Request(MessageSendParams(message = message))).collect { response -> + val event = response.data + println("${BLUE}[${event.kind}]$RESET") + println("${json.encodeToString(event)}\n") + + when (event) { + is Task -> { + currentTaskId = event.id + event.artifacts?.forEach { artifacts[it.artifactId] = it } + } + + is Message -> { + val textContent = event.parts.filterIsInstance().joinToString("\n") { it.text } + if (textContent.isNotBlank()) { + println("${MAGENTA}Message:$RESET\n$textContent\n") + } + } + + is TaskStatusUpdateEvent -> { + when (event.status.state) { + TaskState.InputRequired -> { + val question = event.status.message?.parts + ?.filterIsInstance() + ?.joinToString("\n") { it.text } + if (!question.isNullOrBlank()) { + println("${MAGENTA}Question:$RESET\n$question\n") + } + } + + TaskState.Completed -> { + if (artifacts.isNotEmpty()) { + println("${GREEN}=== Artifacts ===$RESET") + artifacts.values.forEach { artifact -> + val content = artifact.parts.filterIsInstance() + .joinToString("\n") { it.text } + if (content.isNotBlank()) { + println("${GREEN}[${artifact.artifactId}]$RESET\n$content\n") + } + } + } + if (event.final) { + currentTaskId = null + artifacts.clear() + } + } + + TaskState.Failed, TaskState.Canceled, TaskState.Rejected -> { + if (event.final) { + currentTaskId = null + artifacts.clear() + } + } + + else -> {} + } + } + + is TaskArtifactUpdateEvent -> { + if (event.append == true) { + val existing = artifacts[event.artifact.artifactId] + if (existing != null) { + artifacts[event.artifact.artifactId] = existing.copy( + parts = existing.parts + event.artifact.parts + ) + } else { + artifacts[event.artifact.artifactId] = event.artifact + } + } else { + artifacts[event.artifact.artifactId] = event.artifact + } + } + } + } + } catch (e: Exception) { + println("${RED}Error: ${e.message}$RESET\n") + } + } + + println("${YELLOW}Done$RESET") + transport.close() +} diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/JokeWriterAgentExecutor.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/JokeWriterAgentExecutor.kt new file mode 100644 index 0000000000..d74cf4ce38 --- /dev/null +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/JokeWriterAgentExecutor.kt @@ -0,0 +1,406 @@ +package ai.koog.agents.example.a2a.advancedjoke + +import ai.koog.a2a.exceptions.A2AUnsupportedOperationException +import ai.koog.a2a.model.Artifact +import ai.koog.a2a.model.MessageSendParams +import ai.koog.a2a.model.Role +import ai.koog.a2a.model.Task +import ai.koog.a2a.model.TaskArtifactUpdateEvent +import ai.koog.a2a.model.TaskState +import ai.koog.a2a.model.TaskStatus +import ai.koog.a2a.model.TaskStatusUpdateEvent +import ai.koog.a2a.model.TextPart +import ai.koog.a2a.server.agent.AgentExecutor +import ai.koog.a2a.server.session.RequestContext +import ai.koog.a2a.server.session.SessionEventProcessor +import ai.koog.agents.a2a.core.A2AMessage +import ai.koog.agents.a2a.core.toKoogMessage +import ai.koog.agents.a2a.server.feature.A2AAgentServer +import ai.koog.agents.a2a.server.feature.withA2AAgentServer +import ai.koog.agents.core.agent.GraphAIAgent +import ai.koog.agents.core.agent.config.AIAgentConfig +import ai.koog.agents.core.agent.context.agentInput +import ai.koog.agents.core.dsl.builder.forwardTo +import ai.koog.agents.core.dsl.builder.strategy +import ai.koog.agents.core.dsl.extension.nodeLLMRequestStructured +import ai.koog.agents.core.dsl.extension.onIsInstance +import ai.koog.agents.core.tools.ToolRegistry +import ai.koog.agents.core.tools.annotations.LLMDescription +import ai.koog.agents.example.ApiKeyService +import ai.koog.prompt.dsl.prompt +import ai.koog.prompt.executor.clients.anthropic.AnthropicLLMClient +import ai.koog.prompt.executor.clients.google.GoogleLLMClient +import ai.koog.prompt.executor.clients.google.GoogleModels +import ai.koog.prompt.executor.clients.openai.OpenAILLMClient +import ai.koog.prompt.executor.llms.MultiLLMPromptExecutor +import ai.koog.prompt.executor.model.PromptExecutor +import ai.koog.prompt.llm.LLMProvider +import ai.koog.prompt.message.Message +import ai.koog.prompt.text.text +import ai.koog.prompt.xml.xml +import kotlinx.datetime.Clock +import kotlinx.serialization.SerialName +import kotlinx.serialization.Serializable +import kotlin.reflect.typeOf +import kotlin.uuid.ExperimentalUuidApi +import kotlin.uuid.Uuid + +/** + * An advanced A2A agent that demonstrates: + * - Task-based conversation flow with state management + * - Interactive clarification questions (InputRequired state) + * - Structured output via sealed interfaces + * - Artifact delivery for final results + */ +class JokeWriterAgentExecutor : AgentExecutor { + private val promptExecutor = MultiLLMPromptExecutor( + LLMProvider.OpenAI to OpenAILLMClient(ApiKeyService.openAIApiKey), + LLMProvider.Anthropic to AnthropicLLMClient(ApiKeyService.anthropicApiKey), + LLMProvider.Google to GoogleLLMClient(ApiKeyService.googleApiKey), + ) + + @OptIn(ExperimentalUuidApi::class) + override suspend fun execute( + context: RequestContext, + eventProcessor: SessionEventProcessor + ) { + val agent = jokeWriterAgent(promptExecutor, context, eventProcessor) + agent.run(context.params.message) + } +} + +private fun jokeWriterAgent( + promptExecutor: PromptExecutor, + context: RequestContext, + eventProcessor: SessionEventProcessor +): GraphAIAgent { + val agentConfig = AIAgentConfig( + prompt = prompt("joke-generation") { + system { + +"You are a very funny sarcastic assistant. You must help users generate funny jokes." + +"When asked for something else, sarcastically decline the request because you can only assist with jokes." + } + }, + model = GoogleModels.Gemini2_5Flash, + maxAgentIterations = 20 + ) + + return GraphAIAgent( + inputType = typeOf(), + outputType = typeOf(), + promptExecutor = promptExecutor, + strategy = jokeWriterStrategy(), + agentConfig = agentConfig, + toolRegistry = ToolRegistry.EMPTY, + ) { + install(A2AAgentServer) { + this.context = context + this.eventProcessor = eventProcessor + } + } +} + +@OptIn(ExperimentalUuidApi::class) +private fun jokeWriterStrategy() = strategy("joke-writer") { + // Node: Load conversation history from message storage + val setupMessageContext by node { userInput -> + if (!userInput.referenceTaskIds.isNullOrEmpty()) { + throw A2AUnsupportedOperationException("This agent doesn't understand task references in referenceTaskIds yet.") + } + + // Load current context messages + val contextMessages: List = withA2AAgentServer { + context.messageStorage.getAll() + } + + // Update the prompt with the current context messages + llm.writeSession { + updatePrompt { + messages(contextMessages.map { it.toKoogMessage() }) + } + } + + userInput + } + + // Node: Load existing task (if continuing) or prepare for new task creation + val setupTaskContext by node { userInput -> + // Check if the message continues the task that already exists + val currentTask: Task? = withA2AAgentServer { + context.task?.id?.let { id -> + // Load task with full conversation history to continue working on it + context.taskStorage.get(id, historyLength = null) + } + } + + currentTask?.let { task -> + val currentTaskMessages = (task.history.orEmpty() + listOfNotNull(task.status.message) + userInput) + .map { it.toKoogMessage() } + + llm.writeSession { + updatePrompt { + user { + +"There's an ongoing task, the next messages contain conversation history for this task" + } + + messages(currentTaskMessages) + } + } + } + + /* + If task exists then the message belongs to the task, send event to update the task. + Otherwise, put it in general message storage for the current context. + */ + withA2AAgentServer { + if (currentTask != null) { + val updateEvent = TaskStatusUpdateEvent( + taskId = currentTask.id, + contextId = currentTask.contextId, + status = TaskStatus( + state = TaskState.Working, + message = userInput, + timestamp = Clock.System.now(), + ), + final = false, + ) + + eventProcessor.sendTaskEvent(updateEvent) + } else { + context.messageStorage.save(userInput) + } + } + + currentTask + } + + // Node: Ask LLM to classify if this is a joke request or something else + val classifyNewRequest by nodeLLMRequestStructured() + + // Node: Send a polite decline message if the request is not about jokes + val respondFallbackMessage by node { classification -> + withA2AAgentServer { + val message = A2AMessage( + messageId = Uuid.random().toString(), + role = Role.Agent, + parts = listOf( + TextPart(classification.response) + ), + contextId = context.contextId, + taskId = context.taskId, + ) + + // Store reply in message storage to preserve context + context.messageStorage.save(message) + // Reply with message + eventProcessor.sendMessage(message) + } + } + + // Node: Create a new task for the joke request + val createTask by node { + val userInput = agentInput() + + withA2AAgentServer { + val task = Task( + id = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = TaskState.Submitted, + message = userInput, + timestamp = Clock.System.now(), + ), + + ) + + eventProcessor.sendTaskEvent(task) + } + } + + // Node: Ask LLM to classify joke details (or request clarification) + val classifyJokeRequest by nodeLLMRequestStructured() + + // Node: Generate the actual joke based on classified parameters + val generateJoke by node { request -> + llm.writeSession { + updatePrompt { + user { + +text { + +"Generate a joke based on the following user request:" + xml { + tag("subject") { + +request.subject + } + tag("targetAudience") { + +request.targetAudience + } + tag("isSwearingAllowed") { + +request.isSwearingAllowed.toString() + } + } + } + } + } + + val message = requestLLMWithoutTools() + message as? Message.Assistant ?: throw IllegalStateException("Unexpected message type: $message") + } + } + + // Node: Send InputRequired event to ask the user for more information + val askMoreInfo by node { clarification -> + withA2AAgentServer { + val taskUpdate = TaskStatusUpdateEvent( + taskId = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = TaskState.InputRequired, + message = A2AMessage( + role = Role.User, + parts = listOf( + TextPart(clarification.question) + ), + messageId = Uuid.random().toString(), + taskId = context.taskId, + contextId = context.contextId + ), + timestamp = Clock.System.now(), + ), + final = true, + ) + + eventProcessor.sendTaskEvent(taskUpdate) + } + } + + // Node: Send the joke as an artifact and mark task as completed + val respondWithJoke by node { jokeMessage -> + withA2AAgentServer { + val artifactUpdate = TaskArtifactUpdateEvent( + taskId = context.taskId, + contextId = context.contextId, + artifact = Artifact( + artifactId = "joke", + parts = listOf( + TextPart(jokeMessage.content) + ) + ), + ) + + eventProcessor.sendTaskEvent(artifactUpdate) + + val taskStatusUpdate = TaskStatusUpdateEvent( + taskId = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = TaskState.Completed, + ), + final = true, + ) + + eventProcessor.sendTaskEvent(taskStatusUpdate) + } + } + + // --- Graph Flow Definition --- + + // Always start by loading context and checking for existing tasks + nodeStart then setupMessageContext then setupTaskContext + + // If no task exists, classify whether this is a joke request + edge( + setupTaskContext forwardTo classifyNewRequest + onCondition { task -> task == null } + transformed { agentInput().content() } + ) + // If task exists, continue processing the joke request + edge( + setupTaskContext forwardTo classifyJokeRequest + onCondition { task -> task != null } + transformed { agentInput().content() } + ) + + // New request classification: If not a joke request, decline politely + edge( + classifyNewRequest forwardTo respondFallbackMessage + transformed { it.getOrThrow().structure } + onCondition { !it.isJokeRequest } + ) + // New request classification: If joke request, create a task + edge( + classifyNewRequest forwardTo createTask + transformed { it.getOrThrow().structure } + onCondition { it.isJokeRequest } + ) + + edge(respondFallbackMessage forwardTo nodeFinish) + + // After creating task, classify the joke details + edge( + createTask forwardTo classifyJokeRequest + transformed { agentInput().content() } + ) + + // Joke classification: Ask for clarification if needed + edge( + classifyJokeRequest forwardTo askMoreInfo + transformed { it.getOrThrow().structure } + onIsInstance JokeRequestClassification.NeedsClarification::class + ) + // Joke classification: Generate joke if we have all details + edge( + classifyJokeRequest forwardTo generateJoke + transformed { it.getOrThrow().structure } + onIsInstance JokeRequestClassification.Ready::class + ) + + // After asking for info, wait for user response (finish this iteration) + edge(askMoreInfo forwardTo nodeFinish) + + // After generating joke, send it as an artifact + edge(generateJoke forwardTo respondWithJoke) + edge(respondWithJoke forwardTo nodeFinish) +} + +private fun A2AMessage.content(): String { + return parts.filterIsInstance().joinToString(separator = "\n") { it.text } +} + +// --- Structured Output Models --- + +@Serializable +@LLMDescription("Initial incoming user message classification, to determine if this is a joke request or not.") +private data class UserRequestClassification( + @property:LLMDescription("Whether the incoming message is a joke request or not") + val isJokeRequest: Boolean, + @property:LLMDescription( + "In case the message is not a joke request, polite reply to the user that the agent cannot assist." + + "Default is empty" + ) + val response: String = "", +) + +@LLMDescription("The classification of the joke request") +@Serializable +@SerialName("JokeRequestClassification") +private sealed interface JokeRequestClassification { + @Serializable + @SerialName("NeedsClarification") + @LLMDescription("The joke request needs clarification") + data class NeedsClarification( + @property:LLMDescription("The question that needs clarification") + val question: String + ) : JokeRequestClassification + + @LLMDescription("The joke request is ready to be processed") + @Serializable + @SerialName("Ready") + data class Ready( + @property:LLMDescription("The joke subject") + val subject: String, + @property:LLMDescription("The joke target audience") + val targetAudience: String, + @property:LLMDescription("Whether the swearing is allowed in the joke") + val isSwearingAllowed: Boolean, + ) : JokeRequestClassification +} diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/Server.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/Server.kt new file mode 100644 index 0000000000..6e2edcc8f8 --- /dev/null +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/advancedjoke/Server.kt @@ -0,0 +1,79 @@ +package ai.koog.agents.example.a2a.advancedjoke + +import ai.koog.a2a.model.AgentCapabilities +import ai.koog.a2a.model.AgentCard +import ai.koog.a2a.model.AgentInterface +import ai.koog.a2a.model.AgentSkill +import ai.koog.a2a.model.TransportProtocol +import ai.koog.a2a.server.A2AServer +import ai.koog.a2a.transport.server.jsonrpc.http.HttpJSONRPCServerTransport +import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.server.cio.CIO + +private val logger = KotlinLogging.logger {} + +const val ADVANCED_JOKE_AGENT_PATH = "/advanced-joke-agent" +const val ADVANCED_JOKE_AGENT_CARD_PATH = "$ADVANCED_JOKE_AGENT_PATH/agent-card.json" + +suspend fun main() { + logger.info { "Starting Advanced Joke A2A Agent on http://localhost:9999" } + + // Create agent card with capabilities - this agent supports streaming and tasks + val agentCard = AgentCard( + protocolVersion = "0.3.0", + name = "Advanced Joke Generator", + description = "A sophisticated AI agent that generates jokes with clarifying questions and structured task flow", + version = "1.0.0", + url = "http://localhost:9999$ADVANCED_JOKE_AGENT_PATH", + preferredTransport = TransportProtocol.JSONRPC, + additionalInterfaces = listOf( + AgentInterface( + url = "http://localhost:9999$ADVANCED_JOKE_AGENT_PATH", + transport = TransportProtocol.JSONRPC, + ) + ), + capabilities = AgentCapabilities( + streaming = true, // Supports streaming responses + pushNotifications = false, + stateTransitionHistory = false, + ), + defaultInputModes = listOf("text"), + defaultOutputModes = listOf("text"), + skills = listOf( + AgentSkill( + id = "advanced_joke_generation", + name = "Advanced Joke Generation", + description = "Generates humorous jokes with interactive clarification and customization options", + examples = listOf( + "Tell me a joke about programming", + "Generate a funny joke for teenagers", + "Make me laugh with a dad joke about cats" + ), + tags = listOf("humor", "jokes", "entertainment", "interactive") + ) + ), + supportsAuthenticatedExtendedCard = false + ) + + // Create agent executor + val agentExecutor = JokeWriterAgentExecutor() + + // Create A2A server + val a2aServer = A2AServer( + agentExecutor = agentExecutor, + agentCard = agentCard, + ) + + // Create and start server transport + val serverTransport = HttpJSONRPCServerTransport(a2aServer) + + logger.info { "Advanced Joke Generator Agent ready at http://localhost:9999/$ADVANCED_JOKE_AGENT_PATH" } + serverTransport.start( + engineFactory = CIO, + port = 9999, + path = ADVANCED_JOKE_AGENT_PATH, + wait = true, // Block until server stops + agentCard = agentCard, + agentCardPath = ADVANCED_JOKE_AGENT_CARD_PATH + ) +} From 929cf45d2b312dc04ab1d10f0c94a856d1dad70e Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Thu, 2 Oct 2025 04:27:40 +0200 Subject: [PATCH 44/52] [simple-examples] Update README with A2A --- examples/simple-examples/README.md | 9 ++++ .../ai/koog/agents/example/a2a/README.md | 44 +++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/README.md diff --git a/examples/simple-examples/README.md b/examples/simple-examples/README.md index e8d52be82d..b33f48a184 100644 --- a/examples/simple-examples/README.md +++ b/examples/simple-examples/README.md @@ -85,6 +85,15 @@ Welcome to the **Koog Framework Simple Examples** collection! This project showc | **Bedrock Agent** | AI agents using AWS Bedrock integration | `runExampleBedrockAgent` | [📓 BedrockAgent.ipynb](../notebooks/BedrockAgent.ipynb) | | **Web Search** | Agent with web search capabilities | `runExampleWebSearchAgent` | - | +### Agent-to-Agent (A2A) + +Examples demonstrating the A2A protocol for inter-agent communication. See the [A2A README](src/main/kotlin/ai/koog/agents/example/a2a/README.md) for details. + +| Example | Description | Files | +|------------------------|------------------------------------------------------------|---------------------------------------| +| **Simple Joke Agent** | Basic A2A agent with message-based joke generation | `simplejoke/` (Server + Client) | +| **Advanced Joke Agent** | Task-based agent with clarification flow and artifacts | `advancedjoke/` (Server + Client) | + ### Advanced Patterns | Feature | Description | Gradle Task | Notebook | diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/README.md b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/README.md new file mode 100644 index 0000000000..cde6a1dd6e --- /dev/null +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/a2a/README.md @@ -0,0 +1,44 @@ +# Agent-to-Agent (A2A) Examples + +Examples demonstrating the A2A protocol for inter-agent communication with standardized message/task workflows, streaming responses, and artifact delivery. + +## Examples + +### Simple Joke Agent (`simplejoke/`) + +Basic message-based communication without tasks. + +**Run:** +```bash +# Terminal 1: Start server (port 9998) +./gradlew runExampleSimpleJokeServer + +# Terminal 2: Run client +./gradlew runExampleSimpleJokeClient +``` + +### Advanced Joke Agent (`advancedjoke/`) + +Task-based workflow with: +- Interactive clarification (InputRequired state) +- Artifact delivery for results +- Graph-based agent strategy with documented nodes/edges +- Streaming response events + +**Run:** +```bash +# Terminal 1: Start server (port 9999) +./gradlew runExampleAdvancedJokeServer + +# Terminal 2: Run client +./gradlew runExampleAdvancedJokeClient +``` + +## Key Patterns + +**Simple Agent:** `sendMessage()` → single response +**Advanced Agent:** `sendMessageStreaming()` → Flow of events (Task, TaskStatusUpdateEvent, TaskArtifactUpdateEvent) + +**Task States:** Submitted → Working → InputRequired (optional) → Completed + +See code comments in `JokeWriterAgentExecutor.kt` for detailed flow documentation. From 9444d49e35b2c0332467911b2b1c1efdf00ac44a Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Thu, 2 Oct 2025 09:39:27 +0200 Subject: [PATCH 45/52] [a2a] Fix event collection race condition --- .../kotlin/ai/koog/a2a/server/A2AServer.kt | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt index ea7e1b0265..75bc2d49e9 100644 --- a/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt +++ b/a2a/a2a-server/src/commonMain/kotlin/ai/koog/a2a/server/A2AServer.kt @@ -425,22 +425,26 @@ public open class A2AServer( } } - // Signal that event collection is setup - val collectionStarted: CompletableJob = Job() + // Signal that event collection is started + val eventCollectinStarted: CompletableJob = Job() + // Signal that all events have been collected + val eventCollectionFinished: CompletableJob = Job() // Subscribe to events stream and start emitting them. launch { session.events .onStart { - collectionStarted.complete() + eventCollectinStarted.complete() } .collect { event -> send(Response(data = event, id = request.id)) } + + eventCollectionFinished.complete() } // Ensure event collection is setup to stream events in response. - collectionStarted.join() + eventCollectinStarted.join() // Ensure monitoring is ready to monitor the session. monitoringStarted.join() @@ -449,7 +453,8 @@ public open class A2AServer( Using await here to propagate any exceptions thrown by the agent execution. */ session.agentJob.await() - session.join() + // Make sure all events have been collected and sent + eventCollectionFinished.join() } override suspend fun onSendMessage( From ace223710fd9b0fd9d73e794ce64a28ff2cf63ac Mon Sep 17 00:00:00 2001 From: Andrey Bragin Date: Thu, 2 Oct 2025 09:46:44 +0200 Subject: [PATCH 46/52] Remove emtpy a2a rest transport modules, fix publishing --- .../kotlin/ai/koog/a2a/test/Stub.kt | 7 ++++ .../a2a-transport-client-rest/Module.md | 3 -- .../build.gradle.kts | 42 ------------------- .../koog/a2a/transport/client/rest/.gitkeep | 0 .../a2a-transport-core-rest/Module.md | 3 -- .../a2a-transport-core-rest/build.gradle.kts | 42 ------------------- .../ai/koog/a2a/transport/rest/.gitkeep | 0 .../a2a-transport-server-rest/Module.md | 3 -- .../build.gradle.kts | 42 ------------------- .../koog/a2a/transport/server/rest/.gitkeep | 0 koog-agents/build.gradle.kts | 3 -- settings.gradle.kts | 3 -- 12 files changed, 7 insertions(+), 141 deletions(-) create mode 100644 a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/Stub.kt delete mode 100644 a2a/a2a-transport/a2a-transport-client-rest/Module.md delete mode 100644 a2a/a2a-transport/a2a-transport-client-rest/build.gradle.kts delete mode 100644 a2a/a2a-transport/a2a-transport-client-rest/src/commonMain/kotlin/ai/koog/a2a/transport/client/rest/.gitkeep delete mode 100644 a2a/a2a-transport/a2a-transport-core-rest/Module.md delete mode 100644 a2a/a2a-transport/a2a-transport-core-rest/build.gradle.kts delete mode 100644 a2a/a2a-transport/a2a-transport-core-rest/src/commonMain/kotlin/ai/koog/a2a/transport/rest/.gitkeep delete mode 100644 a2a/a2a-transport/a2a-transport-server-rest/Module.md delete mode 100644 a2a/a2a-transport/a2a-transport-server-rest/build.gradle.kts delete mode 100644 a2a/a2a-transport/a2a-transport-server-rest/src/commonMain/kotlin/ai/koog/a2a/transport/server/rest/.gitkeep diff --git a/a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/Stub.kt b/a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/Stub.kt new file mode 100644 index 0000000000..60504608b9 --- /dev/null +++ b/a2a/a2a-test/src/commonMain/kotlin/ai/koog/a2a/test/Stub.kt @@ -0,0 +1,7 @@ +package ai.koog.a2a.test + +/** + * This class is required for publishing iOS target when there's no commonMain set. + */ +@Suppress("unused") +private class Stub diff --git a/a2a/a2a-transport/a2a-transport-client-rest/Module.md b/a2a/a2a-transport/a2a-transport-client-rest/Module.md deleted file mode 100644 index 9f93b94c62..0000000000 --- a/a2a/a2a-transport/a2a-transport-client-rest/Module.md +++ /dev/null @@ -1,3 +0,0 @@ -# Module a2a-transport-client-rest - -HTTP+JSON/REST client transport implementation for A2A protocol. Implements the ClientTransport interface using Ktor HTTP client to communicate with REST API servers. diff --git a/a2a/a2a-transport/a2a-transport-client-rest/build.gradle.kts b/a2a/a2a-transport/a2a-transport-client-rest/build.gradle.kts deleted file mode 100644 index a124396355..0000000000 --- a/a2a/a2a-transport/a2a-transport-client-rest/build.gradle.kts +++ /dev/null @@ -1,42 +0,0 @@ -import ai.koog.gradle.publish.maven.Publishing.publishToMaven - -group = rootProject.group -version = rootProject.version - -plugins { - id("ai.kotlin.multiplatform") - alias(libs.plugins.kotlin.serialization) -} - -kotlin { - sourceSets { - commonMain { - dependencies { - api(libs.kotlinx.serialization.json) - api(libs.kotlinx.coroutines.core) - } - } - - commonTest { - dependencies { - implementation(kotlin("test")) - } - } - - jvmTest { - dependencies { - implementation(kotlin("test-junit5")) - } - } - - jsTest { - dependencies { - implementation(kotlin("test-js")) - } - } - } - - explicitApi() -} - -publishToMaven() diff --git a/a2a/a2a-transport/a2a-transport-client-rest/src/commonMain/kotlin/ai/koog/a2a/transport/client/rest/.gitkeep b/a2a/a2a-transport/a2a-transport-client-rest/src/commonMain/kotlin/ai/koog/a2a/transport/client/rest/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/a2a/a2a-transport/a2a-transport-core-rest/Module.md b/a2a/a2a-transport/a2a-transport-core-rest/Module.md deleted file mode 100644 index f81a55a540..0000000000 --- a/a2a/a2a-transport/a2a-transport-core-rest/Module.md +++ /dev/null @@ -1,3 +0,0 @@ -# Module a2a-transport-core-rest - -Core HTTP+JSON/REST protocol implementation for A2A communication. Provides base classes and utilities for implementing REST-based transport layers, including request routing and response serialization. diff --git a/a2a/a2a-transport/a2a-transport-core-rest/build.gradle.kts b/a2a/a2a-transport/a2a-transport-core-rest/build.gradle.kts deleted file mode 100644 index a124396355..0000000000 --- a/a2a/a2a-transport/a2a-transport-core-rest/build.gradle.kts +++ /dev/null @@ -1,42 +0,0 @@ -import ai.koog.gradle.publish.maven.Publishing.publishToMaven - -group = rootProject.group -version = rootProject.version - -plugins { - id("ai.kotlin.multiplatform") - alias(libs.plugins.kotlin.serialization) -} - -kotlin { - sourceSets { - commonMain { - dependencies { - api(libs.kotlinx.serialization.json) - api(libs.kotlinx.coroutines.core) - } - } - - commonTest { - dependencies { - implementation(kotlin("test")) - } - } - - jvmTest { - dependencies { - implementation(kotlin("test-junit5")) - } - } - - jsTest { - dependencies { - implementation(kotlin("test-js")) - } - } - } - - explicitApi() -} - -publishToMaven() diff --git a/a2a/a2a-transport/a2a-transport-core-rest/src/commonMain/kotlin/ai/koog/a2a/transport/rest/.gitkeep b/a2a/a2a-transport/a2a-transport-core-rest/src/commonMain/kotlin/ai/koog/a2a/transport/rest/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/a2a/a2a-transport/a2a-transport-server-rest/Module.md b/a2a/a2a-transport/a2a-transport-server-rest/Module.md deleted file mode 100644 index 51df759724..0000000000 --- a/a2a/a2a-transport/a2a-transport-server-rest/Module.md +++ /dev/null @@ -1,3 +0,0 @@ -# Module a2a-transport-server-rest - -HTTP+JSON/REST server transport implementation for A2A protocol. Implements the ServerTransport interface using Ktor HTTP server to expose RESTful endpoints for agent communication. diff --git a/a2a/a2a-transport/a2a-transport-server-rest/build.gradle.kts b/a2a/a2a-transport/a2a-transport-server-rest/build.gradle.kts deleted file mode 100644 index a124396355..0000000000 --- a/a2a/a2a-transport/a2a-transport-server-rest/build.gradle.kts +++ /dev/null @@ -1,42 +0,0 @@ -import ai.koog.gradle.publish.maven.Publishing.publishToMaven - -group = rootProject.group -version = rootProject.version - -plugins { - id("ai.kotlin.multiplatform") - alias(libs.plugins.kotlin.serialization) -} - -kotlin { - sourceSets { - commonMain { - dependencies { - api(libs.kotlinx.serialization.json) - api(libs.kotlinx.coroutines.core) - } - } - - commonTest { - dependencies { - implementation(kotlin("test")) - } - } - - jvmTest { - dependencies { - implementation(kotlin("test-junit5")) - } - } - - jsTest { - dependencies { - implementation(kotlin("test-js")) - } - } - } - - explicitApi() -} - -publishToMaven() diff --git a/a2a/a2a-transport/a2a-transport-server-rest/src/commonMain/kotlin/ai/koog/a2a/transport/server/rest/.gitkeep b/a2a/a2a-transport/a2a-transport-server-rest/src/commonMain/kotlin/ai/koog/a2a/transport/server/rest/.gitkeep deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/koog-agents/build.gradle.kts b/koog-agents/build.gradle.kts index fb861d712a..01ccf62732 100644 --- a/koog-agents/build.gradle.kts +++ b/koog-agents/build.gradle.kts @@ -24,9 +24,6 @@ val excluded = setOf( ":a2a:a2a-transport:a2a-transport-core-jsonrpc", ":a2a:a2a-transport:a2a-transport-server-jsonrpc-http", ":a2a:a2a-transport:a2a-transport-client-jsonrpc-http", - ":a2a:a2a-transport:a2a-transport-core-rest", - ":a2a:a2a-transport:a2a-transport-server-rest", - ":a2a:a2a-transport:a2a-transport-client-rest", ":a2a:a2a-test", ":a2a:test-tck:a2a-test-server-tck", diff --git a/settings.gradle.kts b/settings.gradle.kts index 563ed90936..9e1fb87b82 100755 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -71,9 +71,6 @@ include(":a2a:a2a-test") include(":a2a:a2a-transport:a2a-transport-core-jsonrpc") include(":a2a:a2a-transport:a2a-transport-server-jsonrpc-http") include(":a2a:a2a-transport:a2a-transport-client-jsonrpc-http") -include(":a2a:a2a-transport:a2a-transport-core-rest") -include(":a2a:a2a-transport:a2a-transport-server-rest") -include(":a2a:a2a-transport:a2a-transport-client-rest") include(":a2a:test-tck:a2a-test-server-tck") include(":koog-spring-boot-starter") From 05b847252c48ec8093f63d21389735a1c95b224e Mon Sep 17 00:00:00 2001 From: Denis Domanskii Date: Thu, 2 Oct 2025 10:55:03 +0200 Subject: [PATCH 47/52] KG-227 Add a basic code-agent example (#808) A new code-agent example is added. --- .../code-agent/step-01-basic-agent/Module.md | 3 + .../code-agent/step-01-basic-agent/README.md | 0 .../step-01-basic-agent/build.gradle.kts | 22 ++ .../step-01-basic-agent/gradle.properties | 8 + .../gradle/libs.versions.toml | 15 ++ .../gradle/wrapper/gradle-wrapper.jar | Bin 0 -> 43764 bytes .../gradle/wrapper/gradle-wrapper.properties | 7 + .../code-agent/step-01-basic-agent/gradlew | 251 ++++++++++++++++++ .../step-01-basic-agent/gradlew.bat | 94 +++++++ .../step-01-basic-agent/settings.gradle.kts | 19 ++ .../src/main/kotlin/Main.kt | 50 ++++ .../src/main/resources/logback.xml | 11 + 12 files changed, 480 insertions(+) create mode 100644 examples/code-agent/step-01-basic-agent/Module.md create mode 100644 examples/code-agent/step-01-basic-agent/README.md create mode 100644 examples/code-agent/step-01-basic-agent/build.gradle.kts create mode 100644 examples/code-agent/step-01-basic-agent/gradle.properties create mode 100644 examples/code-agent/step-01-basic-agent/gradle/libs.versions.toml create mode 100644 examples/code-agent/step-01-basic-agent/gradle/wrapper/gradle-wrapper.jar create mode 100644 examples/code-agent/step-01-basic-agent/gradle/wrapper/gradle-wrapper.properties create mode 100755 examples/code-agent/step-01-basic-agent/gradlew create mode 100644 examples/code-agent/step-01-basic-agent/gradlew.bat create mode 100644 examples/code-agent/step-01-basic-agent/settings.gradle.kts create mode 100644 examples/code-agent/step-01-basic-agent/src/main/kotlin/Main.kt create mode 100644 examples/code-agent/step-01-basic-agent/src/main/resources/logback.xml diff --git a/examples/code-agent/step-01-basic-agent/Module.md b/examples/code-agent/step-01-basic-agent/Module.md new file mode 100644 index 0000000000..0962664427 --- /dev/null +++ b/examples/code-agent/step-01-basic-agent/Module.md @@ -0,0 +1,3 @@ +# Module Code Agent (Step 1) + +Code Agent example. diff --git a/examples/code-agent/step-01-basic-agent/README.md b/examples/code-agent/step-01-basic-agent/README.md new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/code-agent/step-01-basic-agent/build.gradle.kts b/examples/code-agent/step-01-basic-agent/build.gradle.kts new file mode 100644 index 0000000000..d3a6c96a8f --- /dev/null +++ b/examples/code-agent/step-01-basic-agent/build.gradle.kts @@ -0,0 +1,22 @@ +plugins { + alias(libs.plugins.kotlin.jvm) + alias(libs.plugins.shadow) + application +} + +application.mainClass.set("ai.koog.agents.examples.codeagent.step01.MainKt") + +dependencies { + implementation("ai.koog:koog-agents") + implementation(libs.kotlinx.coroutines.core) + implementation(libs.logback.classic) +} + +tasks.test { + useJUnitPlatform() +} + +tasks.shadowJar { + archiveBaseName.set("code-agent") + mergeServiceFiles() +} diff --git a/examples/code-agent/step-01-basic-agent/gradle.properties b/examples/code-agent/step-01-basic-agent/gradle.properties new file mode 100644 index 0000000000..e360d6a58e --- /dev/null +++ b/examples/code-agent/step-01-basic-agent/gradle.properties @@ -0,0 +1,8 @@ +#Kotlin +kotlin.code.style=official +kotlin.daemon.jvmargs=-Xmx4096M + +#Gradle +org.gradle.jvmargs=-Xmx4096M -Dfile.encoding=UTF-8 +org.gradle.parallel=true +org.gradle.caching=true diff --git a/examples/code-agent/step-01-basic-agent/gradle/libs.versions.toml b/examples/code-agent/step-01-basic-agent/gradle/libs.versions.toml new file mode 100644 index 0000000000..9ba9996847 --- /dev/null +++ b/examples/code-agent/step-01-basic-agent/gradle/libs.versions.toml @@ -0,0 +1,15 @@ +[versions] +kotlin = "2.2.20" +kotlinx-coroutines = "1.10.2" +kotlinx-serialization = "1.8.1" +logback = "1.5.13" +shadow = "9.1.0" + +[libraries] +kotlinx-coroutines-core = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref = "kotlinx-coroutines" } +logback-classic = { module = "ch.qos.logback:logback-classic", version.ref = "logback" } + +[plugins] +kotlin-jvm = { id = "org.jetbrains.kotlin.jvm", version.ref = "kotlin" } +kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" } +shadow = { id = "com.gradleup.shadow", version.ref = "shadow" } diff --git a/examples/code-agent/step-01-basic-agent/gradle/wrapper/gradle-wrapper.jar b/examples/code-agent/step-01-basic-agent/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000000000000000000000000000000000000..1b33c55baabb587c669f562ae36f953de2481846 GIT binary patch literal 43764 zcma&OWmKeVvL#I6?i3D%6z=Zs?ofE*?rw#G$eqJB ziT4y8-Y@s9rkH0Tz>ll(^xkcTl)CY?rS&9VNd66Yc)g^6)JcWaY(5$5gt z8gr3SBXUTN;~cBgz&})qX%#!Fxom2Yau_`&8)+6aSN7YY+pS410rRUU*>J}qL0TnJ zRxt*7QeUqTh8j)Q&iavh<}L+$Jqz))<`IfKussVk%%Ah-Ti?Eo0hQH!rK%K=#EAw0 zwq@@~XNUXRnv8$;zv<6rCRJ6fPD^hfrh;0K?n z=p!u^3xOgWZ%f3+?+>H)9+w^$Tn1e;?UpVMJb!!;f)`6f&4|8mr+g)^@x>_rvnL0< zvD0Hu_N>$(Li7|Jgu0mRh&MV+<}`~Wi*+avM01E)Jtg=)-vViQKax!GeDc!xv$^mL z{#OVBA$U{(Zr8~Xm|cP@odkHC*1R8z6hcLY#N@3E-A8XEvpt066+3t9L_6Zg6j@9Q zj$$%~yO-OS6PUVrM2s)(T4#6=JpI_@Uz+!6=GdyVU?`!F=d;8#ZB@(5g7$A0(`eqY z8_i@3w$0*es5mrSjhW*qzrl!_LQWs4?VfLmo1Sd@Ztt53+etwzAT^8ow_*7Jp`Y|l z*UgSEwvxq+FYO!O*aLf-PinZYne7Ib6ny3u>MjQz=((r3NTEeU4=-i0LBq3H-VJH< z^>1RE3_JwrclUn9vb7HcGUaFRA0QHcnE;6)hnkp%lY1UII#WPAv?-;c?YH}LWB8Nl z{sx-@Z;QxWh9fX8SxLZk8;kMFlGD3Jc^QZVL4nO)1I$zQwvwM&_!kW+LMf&lApv#< zur|EyC|U@5OQuph$TC_ZU`{!vJp`13e9alaR0Dbn5ikLFH7>eIz4QbV|C=%7)F=qo z_>M&5N)d)7G(A%c>}UCrW!Ql_6_A{?R7&CL`;!KOb3 z8Z=$YkV-IF;c7zs{3-WDEFJzuakFbd*4LWd<_kBE8~BFcv}js_2OowRNzWCtCQ6&k z{&~Me92$m*@e0ANcWKuz)?YjB*VoSTx??-3Cc0l2U!X^;Bv@m87eKHukAljrD54R+ zE;@_w4NPe1>3`i5Qy*3^E9x#VB6?}v=~qIprrrd5|DFkg;v5ixo0IsBmik8=Y;zv2 z%Bcf%NE$a44bk^`i4VwDLTbX=q@j9;JWT9JncQ!+Y%2&HHk@1~*L8-{ZpY?(-a9J-1~<1ltr9i~D9`P{XTIFWA6IG8c4;6bFw*lzU-{+?b&%OcIoCiw00n>A1ra zFPE$y@>ebbZlf(sN_iWBzQKDV zmmaLX#zK!@ZdvCANfwV}9@2O&w)!5gSgQzHdk2Q`jG6KD7S+1R5&F)j6QTD^=hq&7 zHUW+r^da^%V(h(wonR(j?BOiC!;y=%nJvz?*aW&5E87qq;2z`EI(f zBJNNSMFF9U{sR-af5{IY&AtoGcoG)Iq-S^v{7+t0>7N(KRoPj;+2N5;9o_nxIGjJ@ z7bYQK)bX)vEhy~VL%N6g^NE@D5VtV+Q8U2%{ji_=6+i^G%xeskEhH>Sqr194PJ$fB zu1y^){?9Vkg(FY2h)3ZHrw0Z<@;(gd_dtF#6y_;Iwi{yX$?asr?0N0_B*CifEi7<6 zq`?OdQjCYbhVcg+7MSgIM|pJRu~`g?g3x?Tl+V}#$It`iD1j+!x+!;wS0+2e>#g?Z z*EA^k7W{jO1r^K~cD#5pamp+o@8&yw6;%b|uiT?{Wa=4+9<}aXWUuL#ZwN1a;lQod zW{pxWCYGXdEq9qAmvAB904}?97=re$>!I%wxPV#|f#@A*Y=qa%zHlDv^yWbR03%V0 zprLP+b(#fBqxI%FiF*-n8HtH6$8f(P6!H3V^ysgd8de-N(@|K!A< z^qP}jp(RaM9kQ(^K(U8O84?D)aU(g?1S8iWwe)gqpHCaFlJxb*ilr{KTnu4_@5{K- z)n=CCeCrPHO0WHz)dDtkbZfUfVBd?53}K>C5*-wC4hpDN8cGk3lu-ypq+EYpb_2H; z%vP4@&+c2p;thaTs$dc^1CDGlPG@A;yGR5@$UEqk6p58qpw#7lc<+W(WR;(vr(D>W z#(K$vE#uBkT=*q&uaZwzz=P5mjiee6>!lV?c}QIX%ZdkO1dHg>Fa#xcGT6~}1*2m9 zkc7l3ItD6Ie~o_aFjI$Ri=C!8uF4!Ky7iG9QTrxVbsQroi|r)SAon#*B*{}TB-?=@ z8~jJs;_R2iDd!$+n$%X6FO&PYS{YhDAS+U2o4su9x~1+U3z7YN5o0qUK&|g^klZ6X zj_vrM5SUTnz5`*}Hyts9ADwLu#x_L=nv$Z0`HqN`Zo=V>OQI)fh01n~*a%01%cx%0 z4LTFVjmW+ipVQv5rYcn3;d2o4qunWUY!p+?s~X~(ost@WR@r@EuDOSs8*MT4fiP>! zkfo^!PWJJ1MHgKS2D_hc?Bs?isSDO61>ebl$U*9*QY(b=i&rp3@3GV@z>KzcZOxip z^dzA~44;R~cnhWz7s$$v?_8y-k!DZys}Q?4IkSyR!)C0j$(Gm|t#e3|QAOFaV2}36 z?dPNY;@I=FaCwylc_;~kXlZsk$_eLkNb~TIl8QQ`mmH&$*zwwR8zHU*sId)rxHu*K z;yZWa8UmCwju%aSNLwD5fBl^b0Ux1%q8YR*uG`53Mi<`5uA^Dc6Ync)J3N7;zQ*75)hf%a@{$H+%S?SGT)ks60)?6j$ zspl|4Ad6@%-r1t*$tT(en!gIXTUDcsj?28ZEzz)dH)SV3bZ+pjMaW0oc~rOPZP@g! zb9E+ndeVO_Ib9c_>{)`01^`ZS198 z)(t=+{Azi11$eu%aU7jbwuQrO`vLOixuh~%4z@mKr_Oc;F%Uq01fA)^W&y+g16e?rkLhTxV!EqC%2}sx_1u7IBq|}Be&7WI z4I<;1-9tJsI&pQIhj>FPkQV9{(m!wYYV@i5h?A0#BN2wqlEwNDIq06|^2oYVa7<~h zI_OLan0Do*4R5P=a3H9`s5*>xU}_PSztg`+2mv)|3nIy=5#Z$%+@tZnr> zLcTI!Mxa`PY7%{;KW~!=;*t)R_sl<^b>eNO@w#fEt(tPMg_jpJpW$q_DoUlkY|uo> z0-1{ouA#;t%spf*7VjkK&$QrvwUERKt^Sdo)5@?qAP)>}Y!h4(JQ!7{wIdkA+|)bv z&8hBwoX4v|+fie}iTslaBX^i*TjwO}f{V)8*!dMmRPi%XAWc8<_IqK1jUsApk)+~R zNFTCD-h>M5Y{qTQ&0#j@I@tmXGj%rzhTW5%Bkh&sSc=$Fv;M@1y!zvYG5P2(2|(&W zlcbR1{--rJ&s!rB{G-sX5^PaM@3EqWVz_y9cwLR9xMig&9gq(voeI)W&{d6j1jh&< zARXi&APWE1FQWh7eoZjuP z;vdgX>zep^{{2%hem;e*gDJhK1Hj12nBLIJoL<=0+8SVEBx7!4Ea+hBY;A1gBwvY<)tj~T=H`^?3>zeWWm|LAwo*S4Z%bDVUe z6r)CH1H!(>OH#MXFJ2V(U(qxD{4Px2`8qfFLG+=a;B^~Te_Z!r3RO%Oc#ZAHKQxV5 zRYXxZ9T2A%NVJIu5Pu7!Mj>t%YDO$T@M=RR(~mi%sv(YXVl`yMLD;+WZ{vG9(@P#e zMo}ZiK^7^h6TV%cG+;jhJ0s>h&VERs=tuZz^Tlu~%d{ZHtq6hX$V9h)Bw|jVCMudd zwZ5l7In8NT)qEPGF$VSKg&fb0%R2RnUnqa){)V(X(s0U zkCdVZe6wy{+_WhZh3qLp245Y2RR$@g-!9PjJ&4~0cFSHMUn=>dapv)hy}|y91ZWTV zCh=z*!S3_?`$&-eZ6xIXUq8RGl9oK0BJw*TdU6A`LJqX9eS3X@F)g$jLkBWFscPhR zpCv8#KeAc^y>>Y$k^=r|K(DTC}T$0#jQBOwB#@`P6~*IuW_8JxCG}J4va{ zsZzt}tt+cv7=l&CEuVtjD6G2~_Meh%p4RGuY?hSt?(sreO_F}8r7Kp$qQdvCdZnDQ zxzc*qchE*E2=WK)^oRNa>Ttj`fpvF-JZ5tu5>X1xw)J@1!IqWjq)ESBG?J|ez`-Tc zi5a}GZx|w-h%5lNDE_3ho0hEXMoaofo#Z;$8|2;EDF&*L+e$u}K=u?pb;dv$SXeQM zD-~7P0i_`Wk$#YP$=hw3UVU+=^@Kuy$>6?~gIXx636jh{PHly_a2xNYe1l60`|y!7 z(u%;ILuW0DDJ)2%y`Zc~hOALnj1~txJtcdD#o4BCT68+8gZe`=^te6H_egxY#nZH&P*)hgYaoJ^qtmpeea`35Fw)cy!w@c#v6E29co8&D9CTCl%^GV|X;SpneSXzV~LXyRn-@K0Df z{tK-nDWA!q38M1~`xUIt_(MO^R(yNY#9@es9RQbY@Ia*xHhD&=k^T+ zJi@j2I|WcgW=PuAc>hs`(&CvgjL2a9Rx zCbZyUpi8NWUOi@S%t+Su4|r&UoU|ze9SVe7p@f1GBkrjkkq)T}X%Qo1g!SQ{O{P?m z-OfGyyWta+UCXH+-+(D^%kw#A1-U;?9129at7MeCCzC{DNgO zeSqsV>W^NIfTO~4({c}KUiuoH8A*J!Cb0*sp*w-Bg@YfBIPZFH!M}C=S=S7PLLcIG zs7K77g~W)~^|+mx9onzMm0qh(f~OsDTzVmRtz=aZTllgR zGUn~_5hw_k&rll<4G=G+`^Xlnw;jNYDJz@bE?|r866F2hA9v0-8=JO3g}IHB#b`hy zA42a0>{0L7CcabSD+F7?pGbS1KMvT{@1_@k!_+Ki|5~EMGt7T%u=79F)8xEiL5!EJ zzuxQ`NBliCoJMJdwu|);zRCD<5Sf?Y>U$trQ-;xj6!s5&w=9E7)%pZ+1Nh&8nCCwM zv5>Ket%I?cxr3vVva`YeR?dGxbG@pi{H#8@kFEf0Jq6~K4>kt26*bxv=P&jyE#e$| zDJB_~imk^-z|o!2njF2hL*|7sHCnzluhJjwLQGDmC)Y9 zr9ZN`s)uCd^XDvn)VirMgW~qfn1~SaN^7vcX#K1G`==UGaDVVx$0BQnubhX|{e z^i0}>k-;BP#Szk{cFjO{2x~LjK{^Upqd&<+03_iMLp0$!6_$@TbX>8U-f*-w-ew1?`CtD_0y_Lo|PfKi52p?`5$Jzx0E8`M0 zNIb?#!K$mM4X%`Ry_yhG5k@*+n4||2!~*+&pYLh~{`~o(W|o64^NrjP?-1Lgu?iK^ zTX6u3?#$?R?N!{599vg>G8RGHw)Hx&=|g4599y}mXNpM{EPKKXB&+m?==R3GsIq?G zL5fH={=zawB(sMlDBJ+{dgb)Vx3pu>L=mDV0{r1Qs{0Pn%TpopH{m(By4;{FBvi{I z$}x!Iw~MJOL~&)p93SDIfP3x%ROjg}X{Sme#hiJ&Yk&a;iR}V|n%PriZBY8SX2*;6 z4hdb^&h;Xz%)BDACY5AUsV!($lib4>11UmcgXKWpzRL8r2Srl*9Y(1uBQsY&hO&uv znDNff0tpHlLISam?o(lOp#CmFdH<6HmA0{UwfU#Y{8M+7od8b8|B|7ZYR9f<#+V|ZSaCQvI$~es~g(Pv{2&m_rKSB2QQ zMvT}$?Ll>V+!9Xh5^iy3?UG;dF-zh~RL#++roOCsW^cZ&({6q|?Jt6`?S8=16Y{oH zp50I7r1AC1(#{b`Aq5cw>ypNggHKM9vBx!W$eYIzD!4KbLsZGr2o8>g<@inmS3*>J zx8oG((8f!ei|M@JZB`p7+n<Q}?>h249<`7xJ?u}_n;Gq(&km#1ULN87CeTO~FY zS_Ty}0TgQhV zOh3T7{{x&LSYGQfKR1PDIkP!WnfC1$l+fs@Di+d4O=eVKeF~2fq#1<8hEvpwuqcaH z4A8u~r^gnY3u6}zj*RHjk{AHhrrDqaj?|6GaVJbV%o-nATw}ASFr!f`Oz|u_QPkR# z0mDudY1dZRlk@TyQ?%Eti=$_WNFtLpSx9=S^be{wXINp%MU?a`F66LNU<c;0&ngifmP9i;bj6&hdGMW^Kf8e6ZDXbQD&$QAAMo;OQ)G zW(qlHh;}!ZP)JKEjm$VZjTs@hk&4{?@+NADuYrr!R^cJzU{kGc1yB?;7mIyAWwhbeA_l_lw-iDVi7wcFurf5 z#Uw)A@a9fOf{D}AWE%<`s1L_AwpZ?F!Vac$LYkp<#A!!`XKaDC{A%)~K#5z6>Hv@V zBEqF(D5?@6r3Pwj$^krpPDCjB+UOszqUS;b2n>&iAFcw<*im2(b3|5u6SK!n9Sg4I z0KLcwA6{Mq?p%t>aW0W!PQ>iUeYvNjdKYqII!CE7SsS&Rj)eIw-K4jtI?II+0IdGq z2WT|L3RL?;GtGgt1LWfI4Ka`9dbZXc$TMJ~8#Juv@K^1RJN@yzdLS8$AJ(>g!U9`# zx}qr7JWlU+&m)VG*Se;rGisutS%!6yybi%B`bv|9rjS(xOUIvbNz5qtvC$_JYY+c& za*3*2$RUH8p%pSq>48xR)4qsp!Q7BEiJ*`^>^6INRbC@>+2q9?x(h0bpc>GaNFi$K zPH$6!#(~{8@0QZk=)QnM#I=bDx5vTvjm$f4K}%*s+((H2>tUTf==$wqyoI`oxI7>C z&>5fe)Yg)SmT)eA(|j@JYR1M%KixxC-Eceknf-;N=jJTwKvk#@|J^&5H0c+%KxHUI z6dQbwwVx3p?X<_VRVb2fStH?HH zFR@Mp=qX%#L3XL)+$PXKV|o|#DpHAoqvj6uQKe@M-mnhCSou7Dj4YuO6^*V`m)1lf z;)@e%1!Qg$10w8uEmz{ENb$^%u}B;J7sDd zump}onoD#!l=agcBR)iG!3AF0-63%@`K9G(CzKrm$VJ{v7^O9Ps7Zej|3m= zVXlR&yW6=Y%mD30G@|tf=yC7-#L!16Q=dq&@beWgaIL40k0n% z)QHrp2Jck#evLMM1RGt3WvQ936ZC9vEje0nFMfvmOHVI+&okB_K|l-;|4vW;qk>n~ z+|kk8#`K?x`q>`(f6A${wfw9Cx(^)~tX7<#TpxR#zYG2P+FY~mG{tnEkv~d6oUQA+ z&hNTL=~Y@rF`v-RZlts$nb$3(OL1&@Y11hhL9+zUb6)SP!;CD)^GUtUpCHBE`j1te zAGud@miCVFLk$fjsrcpjsadP__yj9iEZUW{Ll7PPi<$R;m1o!&Xdl~R_v0;oDX2z^!&8}zNGA}iYG|k zmehMd1%?R)u6R#<)B)1oe9TgYH5-CqUT8N7K-A-dm3hbm_W21p%8)H{O)xUlBVb+iUR}-v5dFaCyfSd zC6Bd7=N4A@+Bna=!-l|*_(nWGDpoyU>nH=}IOrLfS+-d40&(Wo*dDB9nQiA2Tse$R z;uq{`X7LLzP)%Y9aHa4YQ%H?htkWd3Owv&UYbr5NUDAH^<l@Z0Cx%`N+B*i!!1u>D8%;Qt1$ zE5O0{-`9gdDxZ!`0m}ywH!;c{oBfL-(BH<&SQ~smbcobU!j49O^f4&IIYh~f+hK*M zZwTp%{ZSAhMFj1qFaOA+3)p^gnXH^=)`NTYgTu!CLpEV2NF=~-`(}7p^Eof=@VUbd z_9U|8qF7Rueg&$qpSSkN%%%DpbV?8E8ivu@ensI0toJ7Eas^jyFReQ1JeY9plb^{m z&eQO)qPLZQ6O;FTr*aJq=$cMN)QlQO@G&%z?BKUs1&I^`lq>=QLODwa`(mFGC`0H< zOlc*|N?B5&!U6BuJvkL?s1&nsi$*5cCv7^j_*l&$-sBmRS85UIrE--7eD8Gr3^+o? zqG-Yl4S&E;>H>k^a0GdUI(|n1`ws@)1%sq2XBdK`mqrNq_b4N{#VpouCXLzNvjoFv zo9wMQ6l0+FT+?%N(ka*;%m~(?338bu32v26!{r)|w8J`EL|t$}TA4q_FJRX5 zCPa{hc_I(7TGE#@rO-(!$1H3N-C0{R$J=yPCXCtGk{4>=*B56JdXU9cQVwB`6~cQZ zf^qK21x_d>X%dT!!)CJQ3mlHA@ z{Prkgfs6=Tz%63$6Zr8CO0Ak3A)Cv#@BVKr&aiKG7RYxY$Yx>Bj#3gJk*~Ps-jc1l z;4nltQwwT4@Z)}Pb!3xM?+EW0qEKA)sqzw~!C6wd^{03-9aGf3Jmt=}w-*!yXupLf z;)>-7uvWN4Unn8b4kfIza-X=x*e4n5pU`HtgpFFd))s$C@#d>aUl3helLom+RYb&g zI7A9GXLRZPl}iQS*d$Azxg-VgcUr*lpLnbPKUV{QI|bsG{8bLG<%CF( zMoS4pRDtLVYOWG^@ox^h8xL~afW_9DcE#^1eEC1SVSb1BfDi^@g?#f6e%v~Aw>@w- zIY0k+2lGWNV|aA*e#`U3=+oBDmGeInfcL)>*!w|*;mWiKNG6wP6AW4-4imN!W)!hE zA02~S1*@Q`fD*+qX@f3!2yJX&6FsEfPditB%TWo3=HA;T3o2IrjS@9SSxv%{{7&4_ zdS#r4OU41~GYMiib#z#O;zohNbhJknrPPZS6sN$%HB=jUnlCO_w5Gw5EeE@KV>soy z2EZ?Y|4RQDDjt5y!WBlZ(8M)|HP<0YyG|D%RqD+K#e7-##o3IZxS^wQ5{Kbzb6h(i z#(wZ|^ei>8`%ta*!2tJzwMv+IFHLF`zTU8E^Mu!R*45_=ccqI};Zbyxw@U%a#2}%f zF>q?SrUa_a4H9l+uW8JHh2Oob>NyUwG=QH~-^ZebU*R@67DcXdz2{HVB4#@edz?B< z5!rQH3O0>A&ylROO%G^fimV*LX7>!%re{_Sm6N>S{+GW1LCnGImHRoF@csnFzn@P0 zM=jld0z%oz;j=>c7mMwzq$B^2mae7NiG}%>(wtmsDXkWk{?BeMpTrIt3Mizq?vRsf zi_WjNp+61uV(%gEU-Vf0;>~vcDhe(dzWdaf#4mH3o^v{0EWhj?E?$5v02sV@xL0l4 zX0_IMFtQ44PfWBbPYN#}qxa%=J%dlR{O!KyZvk^g5s?sTNycWYPJ^FK(nl3k?z-5t z39#hKrdO7V(@!TU)LAPY&ngnZ1MzLEeEiZznn7e-jLCy8LO zu^7_#z*%I-BjS#Pg-;zKWWqX-+Ly$T!4`vTe5ZOV0j?TJVA*2?*=82^GVlZIuH%9s zXiV&(T(QGHHah=s&7e|6y?g+XxZGmK55`wGV>@1U)Th&=JTgJq>4mI&Av2C z)w+kRoj_dA!;SfTfkgMPO>7Dw6&1*Hi1q?54Yng`JO&q->^CX21^PrU^JU#CJ_qhV zSG>afB%>2fx<~g8p=P8Yzxqc}s@>>{g7}F!;lCXvF#RV)^fyYb_)iKVCz1xEq=fJ| z0a7DMCK*FuP=NM*5h;*D`R4y$6cpW-E&-i{v`x=Jbk_xSn@2T3q!3HoAOB`@5Vg6) z{PW|@9o!e;v1jZ2{=Uw6S6o{g82x6g=k!)cFSC*oemHaVjg?VpEmtUuD2_J^A~$4* z3O7HsbA6wxw{TP5Kk)(Vm?gKo+_}11vbo{Tp_5x79P~#F)ahQXT)tSH5;;14?s)On zel1J>1x>+7;g1Iz2FRpnYz;sD0wG9Q!vuzE9yKi3@4a9Nh1!GGN?hA)!mZEnnHh&i zf?#ZEN2sFbf~kV;>K3UNj1&vFhc^sxgj8FCL4v>EOYL?2uuT`0eDH}R zmtUJMxVrV5H{L53hu3#qaWLUa#5zY?f5ozIn|PkMWNP%n zWB5!B0LZB0kLw$k39=!akkE9Q>F4j+q434jB4VmslQ;$ zKiO#FZ`p|dKS716jpcvR{QJkSNfDVhr2%~eHrW;fU45>>snr*S8Vik-5eN5k*c2Mp zyxvX&_cFbB6lODXznHHT|rsURe2!swomtrqc~w5 zymTM8!w`1{04CBprR!_F{5LB+2_SOuZN{b*!J~1ZiPpP-M;);!ce!rOPDLtgR@Ie1 zPreuqm4!H)hYePcW1WZ0Fyaqe%l}F~Orr)~+;mkS&pOhP5Ebb`cnUt!X_QhP4_4p( z8YKQCDKGIy>?WIFm3-}Br2-N`T&FOi?t)$hjphB9wOhBXU#Hb+zm&We_-O)s(wc`2 z8?VsvU;J>Ju7n}uUb3s1yPx_F*|FlAi=Ge=-kN?1;`~6szP%$3B0|8Sqp%ebM)F8v zADFrbeT0cgE>M0DMV@_Ze*GHM>q}wWMzt|GYC%}r{OXRG3Ij&<+nx9;4jE${Fj_r* z`{z1AW_6Myd)i6e0E-h&m{{CvzH=Xg!&(bLYgRMO_YVd8JU7W+7MuGWNE=4@OvP9+ zxi^vqS@5%+#gf*Z@RVyU9N1sO-(rY$24LGsg1>w>s6ST^@)|D9>cT50maXLUD{Fzf zt~tp{OSTEKg3ZSQyQQ5r51){%=?xlZ54*t1;Ow)zLe3i?8tD8YyY^k%M)e`V*r+vL zPqUf&m)U+zxps+NprxMHF{QSxv}>lE{JZETNk1&F+R~bp{_T$dbXL2UGnB|hgh*p4h$clt#6;NO~>zuyY@C-MD@)JCc5XrYOt`wW7! z_ti2hhZBMJNbn0O-uTxl_b6Hm313^fG@e;RrhIUK9@# z+DHGv_Ow$%S8D%RB}`doJjJy*aOa5mGHVHz0e0>>O_%+^56?IkA5eN+L1BVCp4~m=1eeL zb;#G!#^5G%6Mw}r1KnaKsLvJB%HZL)!3OxT{k$Yo-XrJ?|7{s4!H+S2o?N|^Z z)+?IE9H7h~Vxn5hTis^3wHYuOU84+bWd)cUKuHapq=&}WV#OxHpLab`NpwHm8LmOo zjri+!k;7j_?FP##CpM+pOVx*0wExEex z@`#)K<-ZrGyArK;a%Km`^+We|eT+#MygHOT6lXBmz`8|lyZOwL1+b+?Z$0OhMEp3R z&J=iRERpv~TC=p2-BYLC*?4 zxvPs9V@g=JT0>zky5Poj=fW_M!c)Xxz1<=&_ZcL=LMZJqlnO1P^xwGGW*Z+yTBvbV z-IFe6;(k1@$1;tS>{%pXZ_7w+i?N4A2=TXnGf=YhePg8bH8M|Lk-->+w8Y+FjZ;L=wSGwxfA`gqSn)f(XNuSm>6Y z@|#e-)I(PQ^G@N`%|_DZSb4_pkaEF0!-nqY+t#pyA>{9^*I-zw4SYA1_z2Bs$XGUZbGA;VeMo%CezHK0lO={L%G)dI-+8w?r9iexdoB{?l zbJ}C?huIhWXBVs7oo{!$lOTlvCLZ_KN1N+XJGuG$rh<^eUQIqcI7^pmqhBSaOKNRq zrx~w^?9C?*&rNwP_SPYmo;J-#!G|{`$JZK7DxsM3N^8iR4vvn>E4MU&Oe1DKJvLc~ zCT>KLZ1;t@My zRj_2hI^61T&LIz)S!+AQIV23n1>ng+LUvzv;xu!4;wpqb#EZz;F)BLUzT;8UA1x*6vJ zicB!3Mj03s*kGV{g`fpC?V^s(=JG-k1EMHbkdP4P*1^8p_TqO|;!Zr%GuP$8KLxuf z=pv*H;kzd;P|2`JmBt~h6|GxdU~@weK5O=X&5~w$HpfO}@l-T7@vTCxVOwCkoPQv8 z@aV_)I5HQtfs7^X=C03zYmH4m0S!V@JINm6#(JmZRHBD?T!m^DdiZJrhKpBcur2u1 zf9e4%k$$vcFopK5!CC`;ww(CKL~}mlxK_Pv!cOsFgVkNIghA2Au@)t6;Y3*2gK=5d z?|@1a)-(sQ%uFOmJ7v2iG&l&m^u&^6DJM#XzCrF%r>{2XKyxLD2rgWBD;i(!e4InDQBDg==^z;AzT2z~OmV0!?Z z0S9pX$+E;w3WN;v&NYT=+G8hf=6w0E1$0AOr61}eOvE8W1jX%>&Mjo7&!ulawgzLH zbcb+IF(s^3aj12WSi#pzIpijJJzkP?JzRawnxmNDSUR#7!29vHULCE<3Aa#be}ie~d|!V+ z%l~s9Odo$G&fH!t!+`rUT0T9DulF!Yq&BfQWFZV1L9D($r4H(}Gnf6k3^wa7g5|Ws zj7%d`!3(0bb55yhC6@Q{?H|2os{_F%o=;-h{@Yyyn*V7?{s%Grvpe!H^kl6tF4Zf5 z{Jv1~yZ*iIWL_9C*8pBMQArfJJ0d9Df6Kl#wa}7Xa#Ef_5B7=X}DzbQXVPfCwTO@9+@;A^Ti6il_C>g?A-GFwA0#U;t4;wOm-4oS})h z5&on>NAu67O?YCQr%7XIzY%LS4bha9*e*4bU4{lGCUmO2UQ2U)QOqClLo61Kx~3dI zmV3*(P6F_Tr-oP%x!0kTnnT?Ep5j;_IQ^pTRp=e8dmJtI4YgWd0}+b2=ATkOhgpXe z;jmw+FBLE}UIs4!&HflFr4)vMFOJ19W4f2^W(=2)F%TAL)+=F>IE$=e=@j-*bFLSg z)wf|uFQu+!=N-UzSef62u0-C8Zc7 zo6@F)c+nZA{H|+~7i$DCU0pL{0Ye|fKLuV^w!0Y^tT$isu%i1Iw&N|tX3kwFKJN(M zXS`k9js66o$r)x?TWL}Kxl`wUDUpwFx(w4Yk%49;$sgVvT~n8AgfG~HUcDt1TRo^s zdla@6heJB@JV z!vK;BUMznhzGK6PVtj0)GB=zTv6)Q9Yt@l#fv7>wKovLobMV-+(8)NJmyF8R zcB|_K7=FJGGn^X@JdFaat0uhKjp3>k#^&xE_}6NYNG?kgTp>2Iu?ElUjt4~E-?`Du z?mDCS9wbuS%fU?5BU@Ijx>1HG*N?gIP+<~xE4u=>H`8o((cS5M6@_OK%jSjFHirQK zN9@~NXFx*jS{<|bgSpC|SAnA@I)+GB=2W|JJChLI_mx+-J(mSJ!b)uUom6nH0#2^(L@JBlV#t zLl?j54s`Y3vE^c_3^Hl0TGu*tw_n?@HyO@ZrENxA+^!)OvUX28gDSF*xFtQzM$A+O zCG=n#6~r|3zt=8%GuG} z<#VCZ%2?3Q(Ad#Y7GMJ~{U3>E{5e@z6+rgZLX{Cxk^p-7dip^d29;2N1_mm4QkASo z-L`GWWPCq$uCo;X_BmGIpJFBlhl<8~EG{vOD1o|X$aB9KPhWO_cKiU*$HWEgtf=fn zsO%9bp~D2c@?*K9jVN@_vhR03>M_8h!_~%aN!Cnr?s-!;U3SVfmhRwk11A^8Ns`@KeE}+ zN$H}a1U6E;*j5&~Og!xHdfK5M<~xka)x-0N)K_&e7AjMz`toDzasH+^1bZlC!n()crk9kg@$(Y{wdKvbuUd04N^8}t1iOgsKF zGa%%XWx@WoVaNC1!|&{5ZbkopFre-Lu(LCE5HWZBoE#W@er9W<>R=^oYxBvypN#x3 zq#LC8&q)GFP=5^-bpHj?LW=)-g+3_)Ylps!3^YQ{9~O9&K)xgy zMkCWaApU-MI~e^cV{Je75Qr7eF%&_H)BvfyKL=gIA>;OSq(y z052BFz3E(Prg~09>|_Z@!qj}@;8yxnw+#Ej0?Rk<y}4ghbD569B{9hSFr*^ygZ zr6j7P#gtZh6tMk6?4V$*Jgz+#&ug;yOr>=qdI#9U&^am2qoh4Jy}H2%a|#Fs{E(5r z%!ijh;VuGA6)W)cJZx+;9Bp1LMUzN~x_8lQ#D3+sL{be-Jyeo@@dv7XguJ&S5vrH` z>QxOMWn7N-T!D@1(@4>ZlL^y5>m#0!HKovs12GRav4z!>p(1~xok8+_{| z#Ae4{9#NLh#Vj2&JuIn5$d6t@__`o}umFo(n0QxUtd2GKCyE+erwXY?`cm*h&^9*8 zJ+8x6fRZI-e$CRygofIQN^dWysCxgkyr{(_oBwwSRxZora1(%(aC!5BTtj^+YuevI zx?)H#(xlALUp6QJ!=l9N__$cxBZ5p&7;qD3PsXRFVd<({Kh+mShFWJNpy`N@ab7?9 zv5=klvCJ4bx|-pvOO2-+G)6O?$&)ncA#Urze2rlBfp#htudhx-NeRnJ@u%^_bfw4o z4|{b8SkPV3b>Wera1W(+N@p9H>dc6{cnkh-sgr?e%(YkWvK+0YXVwk0=d`)}*47*B z5JGkEdVix!w7-<%r0JF~`ZMMPe;f0EQHuYHxya`puazyph*ZSb1mJAt^k4549BfS; zK7~T&lRb=W{s&t`DJ$B}s-eH1&&-wEOH1KWsKn0a(ZI+G!v&W4A*cl>qAvUv6pbUR z#(f#EKV8~hk&8oayBz4vaswc(?qw1vn`yC zZQDl2PCB-&Uu@g9ZQHhO+v(W0bNig{-k0;;`+wM@#@J)8r?qOYs#&vUna8ILxN7S{ zp1s41KnR8miQJtJtOr|+qk}wrLt+N*z#5o`TmD1)E&QD(Vh&pjZJ_J*0!8dy_ z>^=@v=J)C`x&gjqAYu`}t^S=DFCtc0MkBU2zf|69?xW`Ck~(6zLD)gSE{7n~6w8j_ zoH&~$ED2k5-yRa0!r8fMRy z;QjBYUaUnpd}mf%iVFPR%Dg9!d>g`01m~>2s))`W|5!kc+_&Y>wD@@C9%>-lE`WB0 zOIf%FVD^cj#2hCkFgi-fgzIfOi+ya)MZK@IZhHT5FVEaSbv-oDDs0W)pA0&^nM0TW zmgJmd7b1R7b0a`UwWJYZXp4AJPteYLH>@M|xZFKwm!t3D3&q~av?i)WvAKHE{RqpD{{%OhYkK?47}+}` zrR2(Iv9bhVa;cDzJ%6ntcSbx7v7J@Y4x&+eWSKZ*eR7_=CVIUSB$^lfYe@g+p|LD{ zPSpQmxx@b$%d!05|H}WzBT4_cq?@~dvy<7s&QWtieJ9)hd4)$SZz}#H2UTi$CkFWW|I)v_-NjuH!VypONC=1`A=rm_jfzQ8Fu~1r8i{q-+S_j$ z#u^t&Xnfi5tZtl@^!fUJhx@~Cg0*vXMK}D{>|$#T*+mj(J_@c{jXBF|rm4-8%Z2o! z2z0o(4%8KljCm^>6HDK!{jI7p+RAPcty_~GZ~R_+=+UzZ0qzOwD=;YeZt*?3%UGdr z`c|BPE;yUbnyARUl&XWSNJ<+uRt%!xPF&K;(l$^JcA_CMH6)FZt{>6ah$|(9$2fc~ z=CD00uHM{qv;{Zk9FR0~u|3|Eiqv9?z2#^GqylT5>6JNZwKqKBzzQpKU2_pmtD;CT zi%Ktau!Y2Tldfu&b0UgmF(SSBID)15*r08eoUe#bT_K-G4VecJL2Pa=6D1K6({zj6 za(2Z{r!FY5W^y{qZ}08+h9f>EKd&PN90f}Sc0ejf%kB4+f#T8Q1=Pj=~#pi$U zp#5rMR%W25>k?<$;$x72pkLibu1N|jX4cWjD3q^Pk3js!uK6h7!dlvw24crL|MZs_ zb%Y%?Fyp0bY0HkG^XyS76Ts*|Giw{31LR~+WU5NejqfPr73Rp!xQ1mLgq@mdWncLy z%8}|nzS4P&`^;zAR-&nm5f;D-%yNQPwq4N7&yULM8bkttkD)hVU>h>t47`{8?n2&4 zjEfL}UEagLUYwdx0sB2QXGeRmL?sZ%J!XM`$@ODc2!y|2#7hys=b$LrGbvvjx`Iqi z&RDDm3YBrlKhl`O@%%&rhLWZ*ABFz2nHu7k~3@e4)kO3%$=?GEFUcCF=6-1n!x^vmu+Ai*amgXH+Rknl6U>#9w;A} zn2xanZSDu`4%%x}+~FG{Wbi1jo@wqBc5(5Xl~d0KW(^Iu(U3>WB@-(&vn_PJt9{1`e9Iic@+{VPc`vP776L*viP{wYB2Iff8hB%E3|o zGMOu)tJX!`qJ}ZPzq7>=`*9TmETN7xwU;^AmFZ-ckZjV5B2T09pYliaqGFY|X#E-8 z20b>y?(r-Fn5*WZ-GsK}4WM>@TTqsxvSYWL6>18q8Q`~JO1{vLND2wg@58OaU!EvT z1|o+f1mVXz2EKAbL!Q=QWQKDZpV|jznuJ}@-)1&cdo z^&~b4Mx{*1gurlH;Vhk5g_cM&6LOHS2 zRkLfO#HabR1JD4Vc2t828dCUG#DL}f5QDSBg?o)IYYi@_xVwR2w_ntlpAW0NWk$F1 z$If?*lP&Ka1oWfl!)1c3fl`g*lMW3JOn#)R1+tfwrs`aiFUgz3;XIJ>{QFxLCkK30 zNS-)#DON3yb!7LBHQJ$)4y%TN82DC2-9tOIqzhZ27@WY^<6}vXCWcR5iN{LN8{0u9 zNXayqD=G|e?O^*ms*4P?G%o@J1tN9_76e}E#66mr89%W_&w4n66~R;X_vWD(oArwj z4CpY`)_mH2FvDuxgT+akffhX0b_slJJ*?Jn3O3~moqu2Fs1oL*>7m=oVek2bnprnW zixkaIFU%+3XhNA@@9hyhFwqsH2bM|`P?G>i<-gy>NflhrN{$9?LZ1ynSE_Mj0rADF zhOz4FnK}wpLmQuV zgO4_Oz9GBu_NN>cPLA=`SP^$gxAnj;WjJnBi%Q1zg`*^cG;Q)#3Gv@c^j6L{arv>- zAW%8WrSAVY1sj$=umcAf#ZgC8UGZGoamK}hR7j6}i8#np8ruUlvgQ$j+AQglFsQQq zOjyHf22pxh9+h#n$21&$h?2uq0>C9P?P=Juw0|;oE~c$H{#RGfa>| zj)Iv&uOnaf@foiBJ}_;zyPHcZt1U~nOcNB{)og8Btv+;f@PIT*xz$x!G?u0Di$lo7 zOugtQ$Wx|C($fyJTZE1JvR~i7LP{ zbdIwqYghQAJi9p}V&$=*2Azev$6K@pyblphgpv8^9bN!?V}{BkC!o#bl&AP!3DAjM zmWFsvn2fKWCfjcAQmE+=c3Y7j@#7|{;;0f~PIodmq*;W9Fiak|gil6$w3%b_Pr6K_ zJEG@&!J%DgBZJDCMn^7mk`JV0&l07Bt`1ymM|;a)MOWz*bh2#d{i?SDe9IcHs7 zjCrnyQ*Y5GzIt}>`bD91o#~5H?4_nckAgotN{2%!?wsSl|LVmJht$uhGa+HiH>;av z8c?mcMYM7;mvWr6noUR{)gE!=i7cZUY7e;HXa221KkRoc2UB>s$Y(k%NzTSEr>W(u z<(4mcc)4rB_&bPzX*1?*ra%VF}P1nwiP5cykJ&W{!OTlz&Td0pOkVp+wc z@k=-Hg=()hNg=Q!Ub%`BONH{ z_=ZFgetj@)NvppAK2>8r!KAgi>#%*7;O-o9MOOfQjV-n@BX6;Xw;I`%HBkk20v`qoVd0)}L6_49y1IhR z_OS}+eto}OPVRn*?UHC{eGyFU7JkPz!+gX4P>?h3QOwGS63fv4D1*no^6PveUeE5% zlehjv_3_^j^C({a2&RSoVlOn71D8WwMu9@Nb@=E_>1R*ve3`#TF(NA0?d9IR_tm=P zOP-x;gS*vtyE1Cm zG0L?2nRUFj#aLr-R1fX*$sXhad)~xdA*=hF3zPZhha<2O$Ps+F07w*3#MTe?)T8|A!P!v+a|ot{|^$q(TX`35O{WI0RbU zCj?hgOv=Z)xV?F`@HKI11IKtT^ocP78cqHU!YS@cHI@{fPD?YXL)?sD~9thOAv4JM|K8OlQhPXgnevF=F7GKD2#sZW*d za}ma31wLm81IZxX(W#A9mBvLZr|PoLnP>S4BhpK8{YV_}C|p<)4#yO{#ISbco92^3 zv&kCE(q9Wi;9%7>>PQ!zSkM%qqqLZW7O`VXvcj;WcJ`2~v?ZTYB@$Q&^CTfvy?1r^ z;Cdi+PTtmQwHX_7Kz?r#1>D zS5lWU(Mw_$B&`ZPmqxpIvK<~fbXq?x20k1~9az-Q!uR78mCgRj*eQ>zh3c$W}>^+w^dIr-u{@s30J=)1zF8?Wn|H`GS<=>Om|DjzC{}Jt?{!fSJe*@$H zg>wFnlT)k#T?LslW zu$^7Uy~$SQ21cE?3Ijl+bLfuH^U5P^$@~*UY#|_`uvAIe(+wD2eF}z_y!pvomuVO; zS^9fbdv)pcm-B@CW|Upm<7s|0+$@@<&*>$a{aW+oJ%f+VMO<#wa)7n|JL5egEgoBv zl$BY(NQjE0#*nv=!kMnp&{2Le#30b)Ql2e!VkPLK*+{jv77H7)xG7&=aPHL7LK9ER z5lfHxBI5O{-3S?GU4X6$yVk>lFn;ApnwZybdC-GAvaznGW-lScIls-P?Km2mF>%B2 zkcrXTk+__hj-3f48U%|jX9*|Ps41U_cd>2QW81Lz9}%`mTDIhE)jYI$q$ma7Y-`>% z8=u+Oftgcj%~TU}3nP8&h7k+}$D-CCgS~wtWvM|UU77r^pUw3YCV80Ou*+bH0!mf0 zxzUq4ed6y>oYFz7+l18PGGzhB^pqSt)si=9M>~0(Bx9*5r~W7sa#w+_1TSj3Jn9mW zMuG9BxN=}4645Cpa#SVKjFst;9UUY@O<|wpnZk$kE+to^4!?0@?Cwr3(>!NjYbu?x z1!U-?0_O?k!NdM^-rIQ8p)%?M+2xkhltt*|l=%z2WFJhme7*2xD~@zk#`dQR$6Lmd zb3LOD4fdt$Cq>?1<%&Y^wTWX=eHQ49Xl_lFUA(YQYHGHhd}@!VpYHHm=(1-O=yfK#kKe|2Xc*9}?BDFN zD7FJM-AjVi)T~OG)hpSWqH>vlb41V#^G2B_EvYlWhDB{Z;Q9-0)ja(O+By`31=biA zG&Fs#5!%_mHi|E4Nm$;vVQ!*>=_F;ZC=1DTPB#CICS5fL2T3XmzyHu?bI;m7D4@#; ztr~;dGYwb?m^VebuULtS4lkC_7>KCS)F@)0OdxZIFZp@FM_pHnJes8YOvwB|++#G( z&dm*OP^cz95Wi15vh`Q+yB>R{8zqEhz5of>Po$9LNE{xS<)lg2*roP*sQ}3r3t<}; zPbDl{lk{pox~2(XY5=qg0z!W-x^PJ`VVtz$git7?)!h>`91&&hESZy1KCJ2nS^yMH z!=Q$eTyRi68rKxdDsdt+%J_&lapa{ds^HV9Ngp^YDvtq&-Xp}60B_w@Ma>_1TTC;^ zpbe!#gH}#fFLkNo#|`jcn?5LeUYto%==XBk6Ik0kc4$6Z+L3x^4=M6OI1=z5u#M%0 z0E`kevJEpJjvvN>+g`?gtnbo$@p4VumliZV3Z%CfXXB&wPS^5C+7of2tyVkMwNWBiTE2 z8CdPu3i{*vR-I(NY5syRR}I1TJOV@DJy-Xmvxn^IInF>Tx2e)eE9jVSz69$6T`M9-&om!T+I znia!ZWJRB28o_srWlAxtz4VVft8)cYloIoVF=pL zugnk@vFLXQ_^7;%hn9x;Vq?lzg7%CQR^c#S)Oc-8d=q_!2ZVH764V z!wDKSgP}BrVV6SfCLZnYe-7f;igDs9t+K*rbMAKsp9L$Kh<6Z;e7;xxced zn=FGY<}CUz31a2G}$Q(`_r~75PzM4l_({Hg&b@d8&jC}B?2<+ed`f#qMEWi z`gm!STV9E4sLaQX+sp5Nu9*;9g12naf5?=P9p@H@f}dxYprH+3ju)uDFt^V{G0APn zS;16Dk{*fm6&BCg#2vo?7cbkkI4R`S9SSEJ=#KBk3rl69SxnCnS#{*$!^T9UUmO#&XXKjHKBqLdt^3yVvu8yn|{ zZ#%1CP)8t-PAz(+_g?xyq;C2<9<5Yy<~C74Iw(y>uUL$+$mp(DRcCWbCKiGCZw@?_ zdomfp+C5xt;j5L@VfhF*xvZdXwA5pcdsG>G<8II-|1dhAgzS&KArcb0BD4ZZ#WfiEY{hkCq5%z9@f|!EwTm;UEjKJsUo696V>h zy##eXYX}GUu%t{Gql8vVZKkNhQeQ4C%n|RmxL4ee5$cgwlU+?V7a?(jI#&3wid+Kz5+x^G!bb#$q>QpR#BZ}Xo5UW^ zD&I`;?(a}Oys7-`I^|AkN?{XLZNa{@27Dv^s4pGowuyhHuXc zuctKG2x0{WCvg_sGN^n9myJ}&FXyGmUQnW7fR$=bj$AHR88-q$D!*8MNB{YvTTEyS zn22f@WMdvg5~o_2wkjItJN@?mDZ9UUlat2zCh(zVE=dGi$rjXF7&}*sxac^%HFD`Y zTM5D3u5x**{bW!68DL1A!s&$2XG@ytB~dX-?BF9U@XZABO`a|LM1X3HWCllgl0+uL z04S*PX$%|^WAq%jkzp~%9HyYIF{Ym?k)j3nMwPZ=hlCg9!G+t>tf0o|J2%t1 ztC+`((dUplgm3`+0JN~}&FRRJ3?l*>Y&TfjS>!ShS`*MwO{WIbAZR#<%M|4c4^dY8 z{Rh;-!qhY=dz5JthbWoovLY~jNaw>%tS4gHVlt5epV8ekXm#==Po$)}mh^u*cE>q7*kvX&gq)(AHoItMYH6^s6f(deNw%}1=7O~bTHSj1rm2|Cq+3M z93djjdomWCTCYu!3Slx2bZVy#CWDozNedIHbqa|otsUl+ut?>a;}OqPfQA05Yim_2 zs@^BjPoFHOYNc6VbNaR5QZfSMh2S*`BGwcHMM(1@w{-4jVqE8Eu0Bi%d!E*^Rj?cR z7qgxkINXZR)K^=fh{pc0DCKtrydVbVILI>@Y0!Jm>x-xM!gu%dehm?cC6ok_msDVA*J#{75%4IZt}X|tIVPReZS#aCvuHkZxc zHVMtUhT(wp09+w9j9eRqz~LtuSNi2rQx_QgQ(}jBt7NqyT&ma61ldD(s9x%@q~PQl zp6N*?=N$BtvjQ_xIT{+vhb1>{pM0Arde0!X-y))A4znDrVx8yrP3B1(7bKPE5jR@5 zwpzwT4cu~_qUG#zYMZ_!2Tkl9zP>M%cy>9Y(@&VoB84#%>amTAH{(hL4cDYt!^{8L z645F>BWO6QaFJ-{C-i|-d%j7#&7)$X7pv#%9J6da#9FB5KyDhkA+~)G0^87!^}AP>XaCSScr;kL;Z%RSPD2CgoJ;gpYT5&6NUK$86$T?jRH=w8nI9Z534O?5fk{kd z`(-t$8W|#$3>xoMfXvV^-A(Q~$8SKDE^!T;J+rQXP71XZ(kCCbP%bAQ1|%$%Ov9_a zyC`QP3uPvFoBqr_+$HenHklqyIr>PU_Fk5$2C+0eYy^~7U&(!B&&P2%7#mBUhM!z> z_B$Ko?{Pf6?)gpYs~N*y%-3!1>o-4;@1Zz9VQHh)j5U1aL-Hyu@1d?X;jtDBNk*vMXPn@ z+u@wxHN*{uHR!*g*4Xo&w;5A+=Pf9w#PeZ^x@UD?iQ&${K2c}UQgLRik-rKM#Y5rdDphdcNTF~cCX&9ViRP}`>L)QA4zNXeG)KXFzSDa6 zd^St;inY6J_i=5mcGTx4_^Ys`M3l%Q==f>{8S1LEHn{y(kbxn5g1ezt4CELqy)~TV6{;VW>O9?5^ ztcoxHRa0jQY7>wwHWcxA-BCwzsP>63Kt&3fy*n#Cha687CQurXaRQnf5wc9o8v7Rw zNwGr2fac;Wr-Ldehn7tF^(-gPJwPt@VR1f;AmKgxN&YPL;j=0^xKM{!wuU|^mh3NE zy35quf}MeL!PU;|{OW_x$TBothLylT-J>_x6p}B_jW1L>k)ps6n%7Rh z96mPkJIM0QFNYUM2H}YF5bs%@Chs6#pEnloQhEl?J-)es!(SoJpEPoMTdgA14-#mC zghayD-DJWtUu`TD8?4mR)w5E`^EHbsz2EjH5aQLYRcF{l7_Q5?CEEvzDo(zjh|BKg z3aJl_n#j&eFHsUw4~lxqnr!6NL*se)6H=A+T1e3xUJGQrd}oSPwSy5+$tt{2t5J5@(lFxl43amsARG74iyNC}uuS zd2$=(r6RdamdGx^eatX@F2D8?U23tDpR+Os?0Gq2&^dF+$9wiWf?=mDWfjo4LfRwL zI#SRV9iSz>XCSgEj!cW&9H-njJopYiYuq|2w<5R2!nZ27DyvU4UDrHpoNQZiGPkp@ z1$h4H46Zn~eqdj$pWrv;*t!rTYTfZ1_bdkZmVVIRC21YeU$iS-*XMNK`#p8Z_DJx| zk3Jssf^XP7v0X?MWFO{rACltn$^~q(M9rMYoVxG$15N;nP)A98k^m3CJx8>6}NrUd@wp-E#$Q0uUDQT5GoiK_R{ z<{`g;8s>UFLpbga#DAf%qbfi`WN1J@6IA~R!YBT}qp%V-j!ybkR{uY0X|x)gmzE0J z&)=eHPjBxJvrZSOmt|)hC+kIMI;qgOnuL3mbNR0g^<%|>9x7>{}>a2qYSZAGPt4it?8 zNcLc!Gy0>$jaU?}ZWxK78hbhzE+etM`67*-*x4DN>1_&{@5t7_c*n(qz>&K{Y?10s zXsw2&nQev#SUSd|D8w7ZD2>E<%g^; zV{yE_O}gq?Q|zL|jdqB^zcx7vo(^})QW?QKacx$yR zhG|XH|8$vDZNIfuxr-sYFR{^csEI*IM#_gd;9*C+SysUFejP0{{z7@P?1+&_o6=7V|EJLQun^XEMS)w(=@eMi5&bbH*a0f;iC~2J74V2DZIlLUHD&>mlug5+v z6xBN~8-ovZylyH&gG#ptYsNlT?-tzOh%V#Y33zlsJ{AIju`CjIgf$@gr8}JugRq^c zAVQ3;&uGaVlVw}SUSWnTkH_6DISN&k2QLMBe9YU=sA+WiX@z)FoSYX`^k@B!j;ZeC zf&**P?HQG6Rk98hZ*ozn6iS-dG}V>jQhb3?4NJB*2F?6N7Nd;EOOo;xR7acylLaLy z9)^lykX39d@8@I~iEVar4jmjjLWhR0d=EB@%I;FZM$rykBNN~jf>#WbH4U{MqhhF6 zU??@fSO~4EbU4MaeQ_UXQcFyO*Rae|VAPLYMJEU`Q_Q_%s2*>$#S^)&7er+&`9L=1 z4q4ao07Z2Vsa%(nP!kJ590YmvrWg+YrgXYs_lv&B5EcoD`%uL79WyYA$0>>qi6ov7 z%`ia~J^_l{p39EY zv>>b}Qs8vxsu&WcXEt8B#FD%L%ZpcVtY!rqVTHe;$p9rbb5O{^rFMB>auLn-^;s+-&P1#h~mf~YLg$8M9 zZ4#87;e-Y6x6QO<{McUzhy(%*6| z)`D~A(TJ$>+0H+mct(jfgL4x%^oC^T#u(bL)`E2tBI#V1kSikAWmOOYrO~#-cc_8! zCe|@1&mN2{*ceeiBldHCdrURk4>V}79_*TVP3aCyV*5n@jiNbOm+~EQ_}1#->_tI@ zqXv+jj2#8xJtW508rzFrYcJxoek@iW6SR@1%a%Bux&;>25%`j3UI`0DaUr7l79`B1 zqqUARhW1^h6=)6?;@v>xrZNM;t}{yY3P@|L}ey@gG( z9r{}WoYN(9TW&dE2dEJIXkyHA4&pU6ki=rx&l2{DLGbVmg4%3Dlfvn!GB>EVaY_%3+Df{fBiqJV>~Xf8A0aqUjgpa} zoF8YXO&^_x*Ej}nw-$-F@(ddB>%RWoPUj?p8U{t0=n>gAI83y<9Ce@Q#3&(soJ{64 z37@Vij1}5fmzAuIUnXX`EYe;!H-yTVTmhAy;y8VZeB#vD{vw9~P#DiFiKQ|kWwGFZ z=jK;JX*A;Jr{#x?n8XUOLS;C%f|zj-7vXtlf_DtP7bpurBeX%Hjwr z4lI-2TdFpzkjgiv!8Vfv`=SP+s=^i3+N~1ELNWUbH|ytVu>EyPN_3(4TM^QE1swRo zoV7Y_g)a>28+hZG0e7g%@2^s>pzR4^fzR-El}ARTmtu!zjZLuX%>#OoU3}|rFjJg} zQ2TmaygxJ#sbHVyiA5KE+yH0LREWr%^C*yR|@gM$nK2P zo}M}PV0v))uJh&33N>#aU376@ZH79u(Yw`EQ2hM3SJs9f99+cO6_pNW$j$L-CtAfe zYfM)ccwD!P%LiBk!eCD?fHCGvgMQ%Q2oT_gmf?OY=A>&PaZQOq4eT=lwbaf}33LCH zFD|)lu{K7$8n9gX#w4~URjZxWm@wlH%oL#G|I~Fb-v^0L0TWu+`B+ZG!yII)w05DU z>GO?n(TN+B=>HdxVDSlIH76pta$_LhbBg;eZ`M7OGcqt||qi zogS72W1IN%=)5JCyOHWoFP7pOFK0L*OAh=i%&VW&4^LF@R;+K)t^S!96?}^+5QBIs zjJNTCh)?)4k^H^g1&jc>gysM`y^8Rm3qsvkr$9AeWwYpa$b22=yAd1t<*{ zaowSEFP+{y?Ob}8&cwfqoy4Pb9IA~VnM3u!trIK$&&0Op#Ql4j>(EW?UNUv#*iH1$ z^j>+W{afcd`{e&`-A{g}{JnIzYib)!T56IT@YEs{4|`sMpW3c8@UCoIJv`XsAw!XC z34|Il$LpW}CIHFC5e*)}00I5{%OL*WZRGzC0?_}-9{#ue?-ug^ zLE|uv-~6xnSs_2_&CN9{9vyc!Xgtn36_g^wI0C4s0s^;8+p?|mm;Odt3`2ZjwtK;l zfd6j)*Fr#53>C6Y8(N5?$H0ma;BCF3HCjUs7rpb2Kf*x3Xcj#O8mvs#&33i+McX zQpBxD8!O{5Y8D&0*QjD=Yhl9%M0)&_vk}bmN_Ud^BPN;H=U^bn&(csl-pkA+GyY0Z zKV7sU_4n;}uR78ouo8O%g*V;79KY?3d>k6%gpcmQsKk&@Vkw9yna_3asGt`0Hmj59 z%0yiF*`jXhByBI9QsD=+>big5{)BGe&+U2gAARGe3ID)xrid~QN_{I>k}@tzL!Md_ z&=7>TWciblF@EMC3t4-WX{?!m!G6$M$1S?NzF*2KHMP3Go4=#ZHkeIv{eEd;s-yD# z_jU^Ba06TZqvV|Yd;Z_sN%$X=!T+&?#p+OQIHS%!LO`Hx0q_Y0MyGYFNoM{W;&@0@ zLM^!X4KhdtsET5G<0+|q0oqVXMW~-7LW9Bg}=E$YtNh1#1D^6Mz(V9?2g~I1( zoz9Cz=8Hw98zVLwC2AQvp@pBeKyidn6Xu0-1SY1((^Hu*-!HxFUPs)yJ+i`^BC>PC zjwd0mygOVK#d2pRC9LxqGc6;Ui>f{YW9Bvb>33bp^NcnZoH~w9(lM5@JiIlfa-6|k ziy31UoMN%fvQfhi8^T+=yrP{QEyb-jK~>$A4SZT-N56NYEbpvO&yUme&pWKs3^94D zH{oXnUTb3T@H+RgzML*lejx`WAyw*?K7B-I(VJx($2!NXYm%3`=F~TbLv3H<{>D?A zJo-FDYdSA-(Y%;4KUP2SpHKAIcv9-ld(UEJE7=TKp|Gryn;72?0LHqAN^fk6%8PCW z{g_-t)G5uCIf0I`*F0ZNl)Z>))MaLMpXgqWgj-y;R+@A+AzDjsTqw2Mo9ULKA3c70 z!7SOkMtZb+MStH>9MnvNV0G;pwSW9HgP+`tg}e{ij0H6Zt5zJ7iw`hEnvye!XbA@!~#%vIkzowCOvq5I5@$3wtc*w2R$7!$*?}vg4;eDyJ_1=ixJuEp3pUS27W?qq(P^8$_lU!mRChT}ctvZz4p!X^ zOSp|JOAi~f?UkwH#9k{0smZ7-#=lK6X3OFEMl7%)WIcHb=#ZN$L=aD`#DZKOG4p4r zwlQ~XDZ`R-RbF&hZZhu3(67kggsM-F4Y_tI^PH8PMJRcs7NS9ogF+?bZB*fcpJ z=LTM4W=N9yepVvTj&Hu~0?*vR1HgtEvf8w%Q;U0^`2@e8{SwgX5d(cQ|1(!|i$km! zvY03MK}j`sff;*-%mN~ST>xU$6Bu?*Hm%l@0dk;j@%>}jsgDcQ)Hn*UfuThz9(ww_ zasV`rSrp_^bp-0sx>i35FzJwA!d6cZ5#5#nr@GcPEjNnFHIrtUYm1^Z$;{d&{hQV9 z6EfFHaIS}46p^5I-D_EcwwzUUuO}mqRh&T7r9sfw`)G^Q%oHxEs~+XoM?8e*{-&!7 z7$m$lg9t9KP9282eke608^Q2E%H-xm|oJ8=*SyEo} z@&;TQ3K)jgspgKHyGiKVMCz>xmC=H5Fy3!=TP)-R3|&1S-B)!6q50wfLHKM@7Bq6E z44CY%G;GY>tC`~yh!qv~YdXw! zSkquvYNs6k1r7>Eza?Vkkxo6XRS$W7EzL&A`o>=$HXgBp{L(i^$}t`NcnAxzbH8Ht z2!;`bhKIh`f1hIFcI5bHI=ueKdzmB9)!z$s-BT4ItyY|NaA_+o=jO%MU5as9 zc2)aLP>N%u>wlaXTK!p)r?+~)L+0eCGb5{8WIk7K52$nufnQ+m8YF+GQc&{^(zh-$ z#wyWV*Zh@d!b(WwXqvfhQX)^aoHTBkc;4ossV3&Ut*k>AI|m+{#kh4B!`3*<)EJVj zwrxK>99v^k4&Y&`Awm>|exo}NvewV%E+@vOc>5>%H#BK9uaE2$vje zWYM5fKuOTtn96B_2~~!xJPIcXF>E_;yO8AwpJ4)V`Hht#wbO3Ung~@c%%=FX4)q+9 z99#>VC2!4l`~0WHs9FI$Nz+abUq# zz`Of97})Su=^rGp2S$)7N3rQCj#0%2YO<R&p>$<#lgXcUj=4H_{oAYiT3 z44*xDn-$wEzRw7#@6aD)EGO$0{!C5Z^7#yl1o;k0PhN=aVUQu~eTQ^Xy{z8Ow6tk83 z4{5xe%(hx)%nD&|e*6sTWH`4W&U!Jae#U4TnICheJmsw{l|CH?UA{a6?2GNgpZLyzU2UlFu1ZVwlALmh_DOs03J^Cjh1im`E3?9&zvNmg(MuMw&0^Lu$(#CJ*q6DjlKsY-RMJ^8yIY|{SQZ*9~CH|u9L z`R78^r=EbbR*_>5?-)I+$6i}G)%mN(`!X72KaV(MNUP7Nv3MS9S|Pe!%N2AeOt5zG zVJ;jI4HZ$W->Ai_4X+`9c(~m=@ek*m`ZQbv3ryI-AD#AH=`x$~WeW~M{Js57(K7(v ze5`};LG|%C_tmd>bkufMWmAo&B+DT9ZV~h(4jg0>^aeAqL`PEUzJJtI8W1M!bQWpv zvN(d}E1@nlYa!L!!A*RN!(Q3F%J?5PvQ0udu?q-T)j3JKV~NL>KRb~w-lWc685uS6 z=S#aR&B8Sc8>cGJ!!--?kwsJTUUm`Jk?7`H z7PrO~xgBrSW2_tTlCq1LH8*!o?pj?qxy8}(=r_;G18POrFh#;buWR0qU24+XUaVZ0 z?(sXcr@-YqvkCmHr{U2oPogHL{r#3r49TeR<{SJX1pcUqyWPrkYz^X8#QW~?F)R5i z>p^!i<;qM8Nf{-fd6!_&V*e_9qP6q(s<--&1Ttj01j0w>bXY7y1W*%Auu&p|XSOH=)V7Bd4fUKh&T1)@cvqhuD-d=?w}O zjI%i(f|thk0Go*!d7D%0^ztBfE*V=(ZIN84f5HU}T9?ulmEYzT5usi=DeuI*d|;M~ zp_=Cx^!4k#=m_qSPBr5EK~E?3J{dWWPH&oCcNepYVqL?nh4D5ynfWip$m*YlZ8r^Z zuFEUL-nW!3qjRCLIWPT0x)FDL7>Yt7@8dA?R2kF@WE>ysMY+)lTsgNM#3VbXVGL}F z1O(>q>2a+_`6r5Xv$NZAnp=Kgnr3)cL(^=8ypEeOf3q8(HGe@7Tt59;yFl||w|mnO zHDxg2G3z8=(6wjj9kbcEY@Z0iOd7Gq5GiPS5% z*sF1J<#daxDV2Z8H>wxOF<;yKzMeTaSOp_|XkS9Sfn6Mpe9UBi1cSTieGG5$O;ZLIIJ60Y>SN4vC?=yE_CWlo(EEE$e4j?z&^FM%kNmRtlbEL^dPPgvs9sbK5fGw*r@ z+!EU@u$T8!nZh?Fdf_qk$VuHk^yVw`h`_#KoS*N%epIIOfQUy_&V}VWDGp3tplMbf z5Se1sJUC$7N0F1-9jdV2mmGK{-}fu|Nv;12jDy0<-kf^AmkDnu6j~TPWOgy1MT68|D z=4=50jVbUKdKaQgD`eWGr3I&^<6uhkjz$YwItY8%Yp9{z4-{6g{73<_b*@XJ4Nm3-3z z?BW3{aY_ccRjb@W1)i5nLg|7BnWS!B`_Uo9CWaE`Ij327QH?i)9A}4Ug4wmxVVa^b z-4+m%-wwOl7cKH7+=x&nrCrbEC)Q$fpg&V83#uEH;C=GNMz`ps@^RxK%T*8%OPnC` z{WO~J%nxYJ`x|N%?&i7?;{_8t^jM&=50HlaOQj8fS}_`moH$c;vI<|cruPFnpT8yU zS%rPOCUSd5Zdb(zwk`hqwTQn)*&n)uYsP*F_(~xEWq}C= zv30kFmZFwJZ@ELVX3?$dXQh|icO7UrL*_5G=I^xXjImz`ZPp>?g#tf(ej~KaIU0algsG!IS09;>?MvqGg#c{i+}qY|{P8W~O%#>|gFd z<1dr$-oxyRGN17yZo1OwLnzwYs0|;IS_nymNB0IlSzPQ%-r`?T=;_XQ^~&#}b|AB} zkNbN5uB?-sUB-T5QLlg%Uk3)uHB;>VIzGe9_J9 zaeISkQm!v(9d(0ML^b9fR^sfHFlH?7Mvddt37OuR{|O0{uv)(&-6<87W4 zyO>s!=cPgP3O&7xxU5DlIPw_o3O>6o6Qb?JWs3qw#p3sBc3g$?Dx zi(6D+DYgV;GrUis-CL%Qe{nvZnwaVXmbhH(|GFh|Q)k=1uvA$I@1DXI7bKlQ@8D6P zS?(*?><>)G49q0wr;NajpxP4W2G)kHl6^=Z>hrNEI4Mwd_$O6$1dXF;Q#hE(-eeW6 zz03GJF%Wl?HO=_ztv5*zRlcU~{+{k%#N59mgm~eK>P!QZ6E?#Cu^2)+K8m@ySvZ*5 z|HDT}BkF@3!l(0%75G=1u2hETXEj!^1Z$!)!lyGXlWD!_vqGE$Z)#cUVBqlORW>0^ zDjyVTxwKHKG|0}j-`;!R-p>}qQfBl(?($7pP<+Y8QE#M8SCDq~k<+>Q^Zf@cT_WdX3~BSe z+|KK|7OL5Hm5(NFP~j>Ct3*$wi0n0!xl=(C61`q&cec@mFlH(sy%+RH<=s)8aAPN`SfJdkAQjdv82G5iRdv8 zh{9wHUZaniSEpslXl^_ODh}mypC?b*9FzLjb~H@3DFSe;D(A-K3t3eOTB(m~I6C;(-lKAvit(70k`%@+O*Ztdz;}|_TS~B?Tpmi=QKC^m_ z2YpEaT3iiz*;T~ap1yiA)a`dKMwu`^UhIUeltNQ1Yjo=q@bI@&3zH?rVUg=IxLy-ni zyxDu%-Fr{H6owTjZU2O5>nDb=q&Jz_TjeSq%!2m40x&U6w~GQ({quPL73IsJS;f`$ zsuhioqCBj(gJ>2hoo)Gou7(WP*pX)f=Y=!=k!&1K?EYY%jJ~X&DnK{^saPQK<1BJ z_A`_{%ZozcB(3w$z^To^6d|XuT@=X~wtW!+{4ID@N{AB~J6AL5vuY>JwvWCNFKsKh zd}@>q@_WV#QZ&UJ0#?X(pXR!oyXOEG3rqzHbCzGLONDb042i$})fM@XF)uSP(DHUc z^&{|$*xe{cs?Gp8=B%RY3L7#$ve$?TWh>MZdxF1zH1v}1z+$Ov#G7?%D)bBCyDe*% zSeKSpETC2V1){II>@UwJi>4uBN+iAx+82E~gb|Cr&8E^i&)A!uv-g?jzH99wU}8+# z$nh>yvb;TwZmS@7LrvuCu_d0-WxFNI&C7%sWuTL%YU!l|I1{|->=dlOeHOCtUO#zkS3ESO8LHV4hTdQL5EdV zuWD33fFPH}HPrW^s$Qn1Xgp&AT6<-He{{4%eIu3rN=iK|9mURdKXfB&Q?qGok%!cs ze53UP{Z!TO-Y@q2;;k2avA3`lm4OoN4@S*k=UA)7H;qZ`d8`XaYFCv?Ba+uGW@r5v z&&{nf(24WSBOhc7!qF^@0cz;XcUynNaj6w2349;s!K{KVqs5yS{ z7VubS`2OzT^5#1~6Tt^RTvt9-J|D2F>y~>2;jeF>g`hx5l%B3H=aLExQihuYngzlnBTYOTHJQMzl>kwqN5JYs)Ej zblA@ntkUS~xi+}y6|(81helS}Q~&VB37qyV|S3Y=><^1wh%msQM?fz z<58MX(=|PSUKCF#)dbhR%D&xgCD?$aR0qen+wpp6 zst}vX18!Be96TD??j1HsHTUx(a&@F?=gT`Q$oJFFyrh^;zgz!(NlAHGn0cJy@us=w zNhC#l5G;H}+>49Nsh12=ZPO2r*2OBQe5kpb&1?*PIBFitK8}FUfb~S-#hKfF0o#&d z#3aPkB$9scYku&kA6{0xHnBV#&Wei5J>5T-XX-gUXEPo+9b7WL=*XESc(3BshL`aj zXp}QIp*40}oWJt*l043e8_5;H5PI5c)U&IEw5dF(4zjX0y_lk9 zAp@!mK>WUqHo)-jop=DoK>&no>kAD=^qIE7qis&_*4~ z6q^EF$D@R~3_xseCG>Ikb6Gfofb$g|75PPyyZN&tiRxqovo_k zO|HA|sgy#B<32gyU9x^&)H$1jvw@qp+1b(eGAb)O%O!&pyX@^nQd^9BQ4{(F8<}|A zhF&)xusQhtoXOOhic=8#Xtt5&slLia3c*a?dIeczyTbC#>FTfiLST57nc3@Y#v_Eg#VUv zT8cKH#f3=1PNj!Oroz_MAR*pow%Y0*6YCYmUy^7`^r|j23Q~^*TW#cU7CHf0eAD_0 zEWEVddxFgQ7=!nEBQ|ibaScslvhuUk^*%b#QUNrEB{3PG@uTxNwW}Bs4$nS9wc(~O zG7Iq>aMsYkcr!9#A;HNsJrwTDYkK8ikdj{M;N$sN6BqJ<8~z>T20{J8Z2rRUuH7~3 z=tgS`AgxbBOMg87UT4Lwge`*Y=01Dvk>)^{Iu+n6fuVX4%}>?3czOGR$0 zpp*wp>bsFFSV`V;r_m+TZns$ZprIi`OUMhe^cLE$2O+pP3nP!YB$ry}2THx2QJs3< za1;>d-AggCarrQ>&Z!d@;mW+!q6eXhb&`GbzUDSxpl8AJ#Cm#tuc)_xh(2NV=5XMs zrf_ozRYO$NkC=pKFX5OH8v1>0i9Z$ec`~Mf+_jQ68spn(CJwclDhEEkH2Qw;${J$clv__nUjn5jA0wCLEnu1j;v!0vB>Ri6m9`;R{JMS%^)4FC zU0Z44+u$I$w=Bj|iu4DT5h~sS`C*zbmX?@-crY}E+hy>}2~C0Nn(EKk@5^qO4@l@! z6O0lr%tzGC`D^)8xU3FnMZVm0kX1sBWhaQyzVoXFWwr%Ny?=2M{5s#5i7fTu3gEkG zc{(Pr$v=;`Y#&`y*J}#M9ux>0?xu!`$9cUKm#Bdd_&S#LPTS?ZPV6zN6>W6JTS~-LfjL{mB=b(KMk3 z2HjBSlJeyUVqDd=Mt!=hpYsvby2GL&3~zm;0{^nZJq+4vb?5HH4wufvr}IX42sHeK zm@x?HN$8TsTavXs)tLDFJtY9b)y~Tl@7z4^I8oUQq4JckH@~CVQ;FoK(+e0XAM>1O z(ei}h?)JQp>)d=6ng-BZF1Z5hsAKW@mXq+hU?r8I(*%`tnIIOXw7V6ZK(T9RFJJe@ zZS!aC+p)Gf2Ujc=a6hx4!A1Th%YH!Lb^xpI!Eu` zmJO{9rw){B1Ql18d%F%da+Tbu1()?o(zT7StYqK6_w`e+fjXq5L^y(0 z09QA6H4oFj59c2wR~{~>jUoDzDdKz}5#onYPJRwa`SUO)Pd4)?(ENBaFVLJr6Kvz= zhTtXqbx09C1z~~iZt;g^9_2nCZ{};-b4dQJbv8HsWHXPVg^@(*!@xycp#R?a|L!+` zY5w))JWV`Gls(=}shH0#r*;~>_+-P5Qc978+QUd>J%`fyn{*TsiG-dWMiJXNgwBaT zJ=wgYFt+1ACW)XwtNx)Q9tA2LPoB&DkL16P)ERWQlY4%Y`-5aM9mZ{eKPUgI!~J3Z zkMd5A_p&v?V-o-6TUa8BndiX?ooviev(DKw=*bBVOW|=zps9=Yl|-R5@yJe*BPzN}a0mUsLn{4LfjB_oxpv(mwq# zSY*%E{iB)sNvWfzg-B!R!|+x(Q|b@>{-~cFvdDHA{F2sFGA5QGiIWy#3?P2JIpPKg6ncI^)dvqe`_|N=8 '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd -P "${APP_HOME:-./}" > /dev/null && printf '%s\n' "$PWD" ) || exit + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD=maximum + +warn () { + echo "$*" +} >&2 + +die () { + echo + echo "$*" + echo + exit 1 +} >&2 + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; +esac + +CLASSPATH="\\\"\\\"" + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD=java + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +fi + +# Increase the maximum file descriptors if we can. +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac +fi + +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + -jar "$APP_HOME/gradle/wrapper/gradle-wrapper.jar" \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/examples/code-agent/step-01-basic-agent/gradlew.bat b/examples/code-agent/step-01-basic-agent/gradlew.bat new file mode 100644 index 0000000000..db3a6ac207 --- /dev/null +++ b/examples/code-agent/step-01-basic-agent/gradlew.bat @@ -0,0 +1,94 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem +@rem SPDX-License-Identifier: Apache-2.0 +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. 1>&2 +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2 +echo. 1>&2 +echo Please set the JAVA_HOME variable in your environment to match the 1>&2 +echo location of your Java installation. 1>&2 + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH= + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" -jar "%APP_HOME%\gradle\wrapper\gradle-wrapper.jar" %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/examples/code-agent/step-01-basic-agent/settings.gradle.kts b/examples/code-agent/step-01-basic-agent/settings.gradle.kts new file mode 100644 index 0000000000..7174be00b7 --- /dev/null +++ b/examples/code-agent/step-01-basic-agent/settings.gradle.kts @@ -0,0 +1,19 @@ +rootProject.name = "step-01-basic-agent" + +pluginManagement { + repositories { + gradlePluginPortal() + } +} + +dependencyResolutionManagement { + @Suppress("UnstableApiUsage") + repositories { + mavenCentral() + google() + } +} + +includeBuild("../../../.") { + name = "koog" +} diff --git a/examples/code-agent/step-01-basic-agent/src/main/kotlin/Main.kt b/examples/code-agent/step-01-basic-agent/src/main/kotlin/Main.kt new file mode 100644 index 0000000000..3467df8eba --- /dev/null +++ b/examples/code-agent/step-01-basic-agent/src/main/kotlin/Main.kt @@ -0,0 +1,50 @@ +package ai.koog.agents.examples.codeagent.step01 + +import ai.koog.agents.core.agent.AIAgent +import ai.koog.agents.core.agent.singleRunStrategy +import ai.koog.agents.core.tools.ToolRegistry +import ai.koog.agents.ext.tool.file.EditFileTool +import ai.koog.agents.ext.tool.file.ListDirectoryTool +import ai.koog.agents.ext.tool.file.ReadFileTool +import ai.koog.agents.ext.tool.file.WriteFileTool +import ai.koog.agents.features.eventHandler.feature.handleEvents +import ai.koog.prompt.executor.clients.openai.OpenAIModels +import ai.koog.prompt.executor.llms.all.simpleOpenAIExecutor +import ai.koog.rag.base.files.JVMFileSystemProvider +import kotlinx.coroutines.runBlocking + +val agent = AIAgent( + promptExecutor = simpleOpenAIExecutor(System.getenv("OPENAI_API_KEY")), + strategy = singleRunStrategy(), + systemPrompt = """ + You are a highly skilled programmer tasked with updating the provided codebase according to the given task. + Your goal is to deliver production-ready code changes that integrate seamlessly with the existing codebase and solve given task. + """.trimIndent(), + llmModel = OpenAIModels.Chat.GPT5, + toolRegistry = ToolRegistry { + tool(ListDirectoryTool(JVMFileSystemProvider.ReadOnly)) + tool(ReadFileTool(JVMFileSystemProvider.ReadOnly)) + tool(WriteFileTool(JVMFileSystemProvider.ReadWrite)) + tool(EditFileTool(JVMFileSystemProvider.ReadWrite)) + }, + maxIterations = 100 +) { + handleEvents { + onToolCallStarting { ctx -> + println("Tool called: ${ctx.tool.name}") + } + } +} + +fun main(args: Array) = runBlocking { + if (args.size < 2) { + println("Error: Please provide the project absolute path and a task as arguments") + println("Usage: ") + return@runBlocking + } + + val (path, task) = args + val input = "Project path: $path\n\n$task" + val result = agent.run(input) + println(result) +} diff --git a/examples/code-agent/step-01-basic-agent/src/main/resources/logback.xml b/examples/code-agent/step-01-basic-agent/src/main/resources/logback.xml new file mode 100644 index 0000000000..d624456018 --- /dev/null +++ b/examples/code-agent/step-01-basic-agent/src/main/resources/logback.xml @@ -0,0 +1,11 @@ + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + From f4230cedd8c0ebc424361ebbd97b584d34b1241d Mon Sep 17 00:00:00 2001 From: Inna Teteniuk Date: Thu, 2 Oct 2025 12:33:02 +0200 Subject: [PATCH 48/52] Update Overview documentation (#912) - Added information about functional agents. - JetBrains is now explicitly mentioned in the framework description. --- ## Motivation and Context ## Breaking Changes --- #### Type of the changes - [ ] New feature (non-breaking change which adds functionality) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [x] Documentation update - [ ] Tests improvement - [ ] Refactoring #### Checklist - [ ] The pull request has a description of the proposed change - [ ] I read the [Contributing Guidelines](https://github.com/JetBrains/koog/blob/main/CONTRIBUTING.md) before opening the pull request - [ ] The pull request uses **`develop`** as the base branch - [ ] Tests for the changes have been added - [ ] All new and existing tests passed ##### Additional steps for pull requests adding a new feature - [ ] An issue describing the proposed change exists - [ ] The pull request includes a link to the issue - [ ] The change was discussed and approved in the issue - [ ] Docs have been added / updated --- docs/docs/index.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/docs/index.md b/docs/docs/index.md index 3762feb4c5..5b0fb7a647 100644 --- a/docs/docs/index.md +++ b/docs/docs/index.md @@ -1,12 +1,13 @@ # Overview -Koog is a Kotlin-based framework designed to build and run AI agents entirely in idiomatic Kotlin. +Koog is an open-source JetBrains framework designed to build and run AI agents entirely in idiomatic Kotlin. It lets you create agents that can interact with tools, handle complex workflows, and communicate with users. The framework supports the following types of agents: * Single-run agents with minimal configuration that process a single input and provide a response. An agent of this type operates within a single cycle of tool-calling to complete its task and provide a response. +* Functional agents with lightweight, customizable logic defined by a lambda function to handle user input, interact with an LLM, call tools, and produce a final output. * Complex workflow agents with advanced capabilities that support custom strategies and configurations. ## Key features From 54e3a8069eca1bb6277852bb6ecf95dce79468e0 Mon Sep 17 00:00:00 2001 From: Briliantov Vadim Date: Thu, 2 Oct 2025 14:26:25 +0200 Subject: [PATCH 49/52] Rename Persistency to Persistence everywhere, provide replacements and deprecations (#910) Rename Persistency to Persistence everywhere, provide replacements and deprecations #### Type of the changes - [ ] New feature (non-breaking change which adds functionality) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Tests improvement - [x] Refactoring #### Checklist - [ ] The pull request has a description of the proposed change - [ ] I read the [Contributing Guidelines](https://github.com/JetBrains/koog/blob/main/CONTRIBUTING.md) before opening the pull request - [ ] The pull request uses **`develop`** as the base branch - [ ] Tests for the changes have been added - [ ] All new and existing tests passed ##### Additional steps for pull requests adding a new feature - [ ] An issue describing the proposed change exists - [ ] The pull request includes a link to the issue - [ ] The change was discussed and approved in the issue - [ ] Docs have been added / updated --- CHANGELOG.md | 4 +- .../agents-features-snapshot/Module.md | 12 +-- .../snapshot/feature/AgentCheckpointData.kt | 8 +- .../agents/snapshot/feature/Persistency.kt | 83 +++++++++-------- .../feature/PersistencyFeatureConfig.kt | 31 ++++--- .../InMemoryPersistencyStorageProvider.kt | 4 +- .../providers/NoPersistencyStorageProvider.kt | 4 +- ...ersistencyUtils.kt => PersistenceUtils.kt} | 13 ++- .../providers/PersistencyStorageProvider.kt | 11 ++- .../file/FilePersistencyStorageProvider.kt | 21 +++-- .../file/JVMFilePersistencyStorageProvider.kt | 21 +++-- .../kotlin/CheckpointSerializationTest.kt | 8 +- .../src/jvmTest/kotlin/CheckpointsTests.kt | 38 ++++---- .../kotlin/NodeUniquenessCheckpointTest.kt | 8 +- .../kotlin/PersistencyRestoreStrategyTests.kt | 20 ++--- .../kotlin/PersistencyRunsTwiceTest.kt | 18 ++-- .../kotlin/SimpleGraphCheckpointTest.kt | 16 ++-- .../jvmTest/kotlin/SubgraphCheckpointsTest.kt | 20 ++--- .../kotlin/SubgraphSetExecutionPointTest.kt | 28 +++--- .../src/jvmTest/kotlin/TestStrategies.kt | 14 +-- .../FileAgentCheckpointStorageProviderTest.kt | 4 +- .../providers/file/FileCheckpointsTests.kt | 20 ++--- .../agents-features-sql/Module.md | 12 +-- .../SQLPersistencyStorageProvider.kt | 8 +- .../ExposedPersistencyStorageProvider.kt | 22 ++--- .../providers/H2PersistencyStorageProvider.kt | 18 ++-- .../MySQLPersistencyStorageProvider.kt | 12 +-- .../PostgresPersistencyStorageProvider.kt | 10 +-- .../H2PersistencyStorageProviderTest.kt | 6 +- .../MySQLPersistencyStorageProviderTest.kt | 6 +- .../PostgresPersistencyStorageProviderTest.kt | 6 +- .../providers/SQLPersistenceProvidersTest.kt | 16 ++-- docs/docs/agent-persistence.md | 90 +++++++++---------- docs/docs/features-overview.md | 4 +- .../example/snapshot/CheckpointExample.kt | 14 +-- .../snapshot/FilePersistentAgentExample.kt | 14 +-- .../example/snapshot/SnapshotExample.kt | 8 +- .../example/snapshot/SnapshotStrategy.kt | 6 +- .../agents/example/snapshot/sql/README.md | 10 +-- .../snapshot/sql/SQLPersistentAgentExample.kt | 16 ++-- .../tests/agent/AIAgentIntegrationTest.kt | 36 ++++---- 41 files changed, 387 insertions(+), 333 deletions(-) rename agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/{PersistencyUtils.kt => PersistenceUtils.kt} (77%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c36a97231..e2bc628ebe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -81,7 +81,7 @@ Fixed iOS target publication ## Major Features -- **Agent Persistency and Checkpoints**: Save and restore agent state to local disk, memory, or easily integrate with +- **Agent Persistence and Checkpoints**: Save and restore agent state to local disk, memory, or easily integrate with any cloud storages or databases. Agents can now roll back to any prior state on demand or automatically restore from the latest checkpoint (#305) - **Vector Document Storage**: Store embeddings and documents in persistent storage for retrieval-augmented generation ( @@ -147,7 +147,7 @@ Fixed iOS target publication - Langfuse Tracing example - Moderation example: Moderating iterative joke-generation conversation - Parallel Nodes Execution example: Generating jokes using 3 different LLMs in parallel, and choosing the funniest one -- Snapshot and Persistency example: Taking agent snapshots and restoring its state example +- Snapshot and Persistence example: Taking agent snapshots and restoring its state example # 0.2.1 diff --git a/agents/agents-features/agents-features-snapshot/Module.md b/agents/agents-features/agents-features-snapshot/Module.md index 296feabfe2..3aa13ecaba 100644 --- a/agents/agents-features/agents-features-snapshot/Module.md +++ b/agents/agents-features/agents-features-snapshot/Module.md @@ -27,18 +27,18 @@ dependencies { } ``` -Then, install the Persistency feature when creating your agent: +Then, install the Persistence feature when creating your agent: ```kotlin val agent = AIAgent( // other configuration parameters ) { - install(Persistency) { + install(Persistence) { // Configure the storage provider - storage = InMemoryPersistencyStorageProvider("agent-persistence-id") + storage = InMemoryPersistenceStorageProvider("agent-persistence-id") // Optional: enable automatic checkpoint creation after each node - enableAutomaticPersistency = true + enableAutomaticPersistence = true // Use `RollbackStrategy.Default` if you want to checkpoint the whole state machine and continue from the same node in your strategy graph, // or `RollbackStrategy.MessageHistoryOnly` if you only want to checkpoint messages @@ -64,9 +64,9 @@ val agent = AIAgent( llmModel = OllamaModels.Meta.LLAMA_3_2, strategy = singleRunStrategy(ToolCalls.SEQUENTIAL), ) { - install(Persistency) { + install(Persistence) { storage = snapshotProvider - enableAutomaticPersistency = true + enableAutomaticPersistence = true } } diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/AgentCheckpointData.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/AgentCheckpointData.kt index cadfb1ee78..95785109ab 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/AgentCheckpointData.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/AgentCheckpointData.kt @@ -6,7 +6,7 @@ import ai.koog.agents.core.agent.context.AIAgentContext import ai.koog.agents.core.agent.context.AgentContextData import ai.koog.agents.core.agent.context.RollbackStrategy import ai.koog.agents.core.annotation.InternalAgentsApi -import ai.koog.agents.snapshot.providers.PersistencyUtils +import ai.koog.agents.snapshot.providers.PersistenceUtils import ai.koog.prompt.message.Message import kotlinx.datetime.Instant import kotlinx.serialization.Serializable @@ -47,10 +47,10 @@ public fun tombstoneCheckpoint(time: Instant): AgentCheckpointData { return AgentCheckpointData( checkpointId = Uuid.random().toString(), createdAt = time, - nodeId = PersistencyUtils.TOMBSTONE_CHECKPOINT_NAME, + nodeId = PersistenceUtils.TOMBSTONE_CHECKPOINT_NAME, lastInput = JsonNull, messageHistory = emptyList(), - properties = mapOf(PersistencyUtils.TOMBSTONE_CHECKPOINT_NAME to JsonPrimitive(true)) + properties = mapOf(PersistenceUtils.TOMBSTONE_CHECKPOINT_NAME to JsonPrimitive(true)) ) } @@ -85,4 +85,4 @@ public fun AgentCheckpointData.toAgentContextData( * and the value is a JSON primitive set to `true`, otherwise `false`. */ public fun AgentCheckpointData.isTombstone(): Boolean = - properties?.get(PersistencyUtils.TOMBSTONE_CHECKPOINT_NAME) == JsonPrimitive(true) + properties?.get(PersistenceUtils.TOMBSTONE_CHECKPOINT_NAME) == JsonPrimitive(true) diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistency.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistency.kt index 062f329ac5..57b9512013 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistency.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistency.kt @@ -16,7 +16,7 @@ import ai.koog.agents.core.feature.AIAgentGraphPipeline import ai.koog.agents.core.feature.InterceptContext import ai.koog.agents.core.tools.DirectToolCallsEnabler import ai.koog.agents.core.tools.annotations.InternalAgentToolsApi -import ai.koog.agents.snapshot.providers.PersistencyStorageProvider +import ai.koog.agents.snapshot.providers.PersistenceStorageProvider import ai.koog.prompt.message.Message import io.github.oshai.kotlinlogging.KotlinLogging import kotlinx.datetime.Clock @@ -30,6 +30,15 @@ import kotlin.time.ExperimentalTime import kotlin.uuid.ExperimentalUuidApi import kotlin.uuid.Uuid +@Deprecated( + "`Persistency` has been renamed to `Persistence`", + replaceWith = ReplaceWith( + expression = "Persistence", + "ai.koog.agents.snapshot.feature.Persistence" + ) +) +public typealias Persistency = Persistence + /** * A feature that provides checkpoint functionality for AI agents. * @@ -40,14 +49,14 @@ import kotlin.uuid.Uuid * - Persisting agent state across sessions * * The feature can be configured to automatically create checkpoints after each node execution - * using the [PersistencyFeatureConfig.enableAutomaticPersistency] option. + * using the [PersistenceFeatureConfig.enableAutomaticPersistence] option. * - * @property persistencyStorageProvider The provider responsible for storing and retrieving checkpoints + * @property persistenceStorageProvider The provider responsible for storing and retrieving checkpoints * @property currentNodeId The ID of the node currently being executed */ @OptIn(ExperimentalUuidApi::class, ExperimentalTime::class, InternalAgentsApi::class) -public class Persistency( - private val persistencyStorageProvider: PersistencyStorageProvider, +public class Persistence( + private val persistenceStorageProvider: PersistenceStorageProvider, internal val clock: Clock = Clock.System, ) { /** @@ -68,7 +77,7 @@ public class Persistency( * A registry for managing rollback tools within the persistence system. * * The `rollbackToolRegistry` plays a key role in supporting the rollback mechanism in the - * persistency operations, allowing seamless state restoration for tools **with side-effects** to specified or latest + * persistence operations, allowing seamless state restoration for tools **with side-effects** to specified or latest * checkpoints as needed. * */ @@ -91,7 +100,7 @@ public class Persistency( /** * Feature companion object that implements [AIAgentFeature] for the checkpoint functionality. */ - public companion object Feature : AIAgentGraphFeature { + public companion object Feature : AIAgentGraphFeature { private val logger = KotlinLogging.logger { } private val json = Json { @@ -101,14 +110,14 @@ public class Persistency( /** * The storage key used to identify this feature in the agent's feature registry. */ - override val key: AIAgentStorageKey = AIAgentStorageKey("agents-features-snapshot") + override val key: AIAgentStorageKey = AIAgentStorageKey("agents-features-snapshot") /** * Creates the default configuration for this feature. * - * @return A new instance of [PersistencyFeatureConfig] with default settings + * @return A new instance of [PersistenceFeatureConfig] with default settings */ - override fun createInitialConfig(): PersistencyFeatureConfig = PersistencyFeatureConfig() + override fun createInitialConfig(): PersistenceFeatureConfig = PersistenceFeatureConfig() /** * Installs the checkpoint feature into the agent pipeline. @@ -122,10 +131,10 @@ public class Persistency( * @param pipeline The agent pipeline to install the feature into */ override fun install( - config: PersistencyFeatureConfig, + config: PersistenceFeatureConfig, pipeline: AIAgentGraphPipeline ) { - val featureImpl = Persistency(config.storage) + val featureImpl = Persistence(config.storage) featureImpl.rollbackStrategy = config.rollbackStrategy featureImpl.rollbackToolRegistry = config.rollbackToolRegistry val interceptContext = InterceptContext(this, featureImpl) @@ -155,7 +164,7 @@ public class Persistency( return@interceptNodeExecutionCompleted } - if (config.enableAutomaticPersistency) { + if (config.enableAutomaticPersistence) { createCheckpoint( agentContext = eventCtx.context, nodeId = eventCtx.node.id, @@ -170,7 +179,7 @@ public class Persistency( } pipeline.interceptStrategyCompleted(interceptContext) { ctx -> - if (config.enableAutomaticPersistency && config.rollbackStrategy == RollbackStrategy.Default) { + if (config.enableAutomaticPersistence && config.rollbackStrategy == RollbackStrategy.Default) { ctx.feature.createTombstoneCheckpoint(ctx.feature.clock.now()) } } @@ -253,7 +262,7 @@ public class Persistency( * @param checkpointData The checkpoint data to save */ public suspend fun saveCheckpoint(checkpointData: AgentCheckpointData) { - persistencyStorageProvider.saveCheckpoint(checkpointData) + persistenceStorageProvider.saveCheckpoint(checkpointData) } /** @@ -262,7 +271,7 @@ public class Persistency( * @return The latest checkpoint data, or null if no checkpoint exists */ public suspend fun getLatestCheckpoint(): AgentCheckpointData? = - persistencyStorageProvider.getLatestCheckpoint() + persistenceStorageProvider.getLatestCheckpoint() /** * Retrieves a specific checkpoint by ID for the specified agent. @@ -271,7 +280,7 @@ public class Persistency( * @return The checkpoint data with the specified ID, or null if not found */ public suspend fun getCheckpointById(checkpointId: String): AgentCheckpointData? = - persistencyStorageProvider.getCheckpoints().firstOrNull { it.checkpointId == checkpointId } + persistenceStorageProvider.getCheckpoints().firstOrNull { it.checkpointId == checkpointId } /** * Sets the execution point of an agent to a specific state. @@ -381,50 +390,50 @@ public class Persistency( /** * Extension function to access the checkpoint feature from an agent context. * - * @return The [Persistency] feature instance for this agent + * @return The [Persistence] feature instance for this agent * @throws IllegalStateException if the checkpoint feature is not installed */ -public fun AIAgentContext.persistency(): Persistency = agent.persistency() +public fun AIAgentContext.persistence(): Persistence = agent.persistence() /** - * Retrieves the persistency feature for the AI agent. + * Retrieves the persistence feature for the AI agent. * - * @return The persistency feature associated with the AI agent. - * @throws IllegalStateException if the persistency feature is not available. + * @return The persistence feature associated with the AI agent. + * @throws IllegalStateException if the persistence feature is not available. */ -public fun AIAgent<*, *>.persistency(): Persistency = featureOrThrow(Persistency.Feature) +public fun AIAgent<*, *>.persistence(): Persistence = featureOrThrow(Persistence.Feature) /** - * Executes the provided action within the context of the AI agent's persistency layer. + * Executes the provided action within the context of the AI agent's persistence layer. * - * This function enhances agents with persistent state management capabilities by leveraging the `Persistency` component - * within the current `AIAgentContext`. The supplied action is executed with the persistency layer, enabling operations + * This function enhances agents with persistent state management capabilities by leveraging the `Persistence` component + * within the current `AIAgentContext`. The supplied action is executed with the persistence layer, enabling operations * that require consistent and reliable state management across the lifecycle of the agent. * - * @param action A suspendable lambda function that receives the `Persistency` instance and the current `AIAgentContext` - * as its parameters. This allows custom logic that interacts with the persistency layer to be executed. + * @param action A suspendable lambda function that receives the `Persistence` instance and the current `AIAgentContext` + * as its parameters. This allows custom logic that interacts with the persistence layer to be executed. * @return A result of type [T] produced by the execution of the provided action. */ -public suspend fun AIAgentContext.withPersistency( - action: suspend Persistency.(AIAgentContext) -> T -): T = this.persistency().action(this) +public suspend fun AIAgentContext.withPersistence( + action: suspend Persistence.(AIAgentContext) -> T +): T = this.persistence().action(this) /** - * Executes the provided action within the context of the agent's persistency layer if the agent is in a running state. + * Executes the provided action within the context of the agent's persistence layer if the agent is in a running state. * - * This function allows interaction with the persistency mechanism associated with the agent, ensuring that + * This function allows interaction with the persistence mechanism associated with the agent, ensuring that * the operation is carried out in the correct execution context. * - * @param action A suspending function defining operations to perform using the agent's persistency mechanism + * @param action A suspending function defining operations to perform using the agent's persistence mechanism * and the current agent context. * @return The result of the execution of the provided action. * @throws IllegalStateException If the agent is not in a running state when this function is called. */ @OptIn(InternalAgentsApi::class) -public suspend fun AIAgent<*, *>.withPersistency( - action: suspend Persistency.(AIAgentContext) -> T +public suspend fun AIAgent<*, *>.withPersistence( + action: suspend Persistence.(AIAgentContext) -> T ): T = when (val state = getState()) { - is Running<*> -> this.persistency().action(state.rootContext) + is Running<*> -> this.persistence().action(state.rootContext) else -> throw IllegalStateException("Agent is not running. Current agents's state: $state") } diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/PersistencyFeatureConfig.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/PersistencyFeatureConfig.kt index cd0d5ec068..e2a8d4f81e 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/PersistencyFeatureConfig.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/PersistencyFeatureConfig.kt @@ -3,35 +3,44 @@ package ai.koog.agents.snapshot.feature import ai.koog.agents.core.agent.context.RollbackStrategy import ai.koog.agents.core.feature.config.FeatureConfig import ai.koog.agents.snapshot.feature.RollbackToolRegistry -import ai.koog.agents.snapshot.providers.NoPersistencyStorageProvider -import ai.koog.agents.snapshot.providers.PersistencyStorageProvider +import ai.koog.agents.snapshot.providers.NoPersistenceStorageProvider +import ai.koog.agents.snapshot.providers.PersistenceStorageProvider + +@Deprecated( + "`PersistencyFeatureConfig` has been renamed to `PersistenceFeatureConfig`", + replaceWith = ReplaceWith( + expression = "PersistenceFeatureConfig", + "ai.koog.agents.snapshot.feature.PersistenceFeatureConfig" + ) +) +public typealias PersistencyFeatureConfig = PersistenceFeatureConfig /** * Configuration class for the Snapshot feature. */ -public class PersistencyFeatureConfig : FeatureConfig() { +public class PersistenceFeatureConfig : FeatureConfig() { /** * Defines the storage mechanism for persisting snapshots in the feature. - * This property accepts implementations of [PersistencyStorageProvider], + * This property accepts implementations of [PersistenceStorageProvider], * which manage how snapshots are stored and retrieved. * - * By default, the storage is set to [NoPersistencyStorageProvider], a no-op + * By default, the storage is set to [NoPersistenceStorageProvider], a no-op * implementation that does not persist any data. To enable actual state - * persistence, assign a custom implementation of [PersistencyStorageProvider] + * persistence, assign a custom implementation of [PersistenceStorageProvider] * to this property. */ - public var storage: PersistencyStorageProvider = NoPersistencyStorageProvider() + public var storage: PersistenceStorageProvider = NoPersistenceStorageProvider() /** * Controls whether the feature's state should be automatically persisted. * When enabled, changes to the checkpoint are saved after each node execution through the assigned - * [PersistencyStorageProvider], ensuring the state can be restored later. + * [PersistenceStorageProvider], ensuring the state can be restored later. * - * Set this property to `true` to turn on automatic state persistency, + * Set this property to `true` to turn on automatic state persistence, * or `false` to disable it. */ - public var enableAutomaticPersistency: Boolean = false + public var enableAutomaticPersistence: Boolean = false /** * Determines the strategy to be used for rolling back the agent's state to a previously saved checkpoint. @@ -46,7 +55,7 @@ public class PersistencyFeatureConfig : FeatureConfig() { /** * Registry for rollback tools used when rolling back to checkpoints. - * Configure it during Persistency installation. Do not mutate later in withPersistency. + * Configure it during Persistence installation. Do not mutate later in withPersistence. */ public var rollbackToolRegistry: RollbackToolRegistry = RollbackToolRegistry.EMPTY } diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/InMemoryPersistencyStorageProvider.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/InMemoryPersistencyStorageProvider.kt index 10dbd0b468..7c6b8c5ae5 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/InMemoryPersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/InMemoryPersistencyStorageProvider.kt @@ -5,10 +5,10 @@ import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock /** - * In-memory implementation of [PersistencyStorageProvider]. + * In-memory implementation of [PersistenceStorageProvider]. * This provider stores snapshots in a mutable map. */ -public class InMemoryPersistencyStorageProvider(private val persistenceId: String) : PersistencyStorageProvider { +public class InMemoryPersistenceStorageProvider(private val persistenceId: String) : PersistenceStorageProvider { private val mutex = Mutex() private val snapshotMap = mutableMapOf>() diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/NoPersistencyStorageProvider.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/NoPersistencyStorageProvider.kt index 8187c302ac..08670abf2a 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/NoPersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/NoPersistencyStorageProvider.kt @@ -4,9 +4,9 @@ import ai.koog.agents.snapshot.feature.AgentCheckpointData import io.github.oshai.kotlinlogging.KotlinLogging /** - * No-op implementation of [PersistencyStorageProvider]. + * No-op implementation of [PersistenceStorageProvider]. */ -public class NoPersistencyStorageProvider : PersistencyStorageProvider { +public class NoPersistenceStorageProvider : PersistenceStorageProvider { private val logger = KotlinLogging.logger { } override suspend fun getCheckpoints(): List { diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistencyUtils.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistenceUtils.kt similarity index 77% rename from agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistencyUtils.kt rename to agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistenceUtils.kt index ecb6c60bdc..087b4c1642 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistencyUtils.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistenceUtils.kt @@ -2,10 +2,19 @@ package ai.koog.agents.snapshot.providers import kotlinx.serialization.json.Json +@Deprecated( + "`PersistencyUtils` has been renamed to `PersistenceUtils`", + replaceWith = ReplaceWith( + expression = "PersistenceUtils", + "ai.koog.agents.snapshot.providers.PersistenceUtils" + ) +) +public typealias PersistencyUtils = PersistenceUtils + /** - * Utility object containing configurations and utilities for handling persistency-related operations. + * Utility object containing configurations and utilities for handling persistence-related operations. */ -public object PersistencyUtils { +public object PersistenceUtils { /** * A preconfigured JSON instance for handling serialization and deserialization of checkpoint data. * diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistencyStorageProvider.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistencyStorageProvider.kt index 3b0c873eee..c91ef7f527 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistencyStorageProvider.kt @@ -4,7 +4,16 @@ package ai.koog.agents.snapshot.providers import ai.koog.agents.snapshot.feature.AgentCheckpointData -public interface PersistencyStorageProvider { +@Deprecated( + "`PersistencyStorageProvider` has been renamed to `PersistenceStorageProvider`", + replaceWith = ReplaceWith( + expression = "PersistenceStorageProvider", + "ai.koog.agents.snapshot.feature.PersistenceStorageProvider" + ) +) +public typealias PersistencyStorageProvider = PersistenceStorageProvider + +public interface PersistenceStorageProvider { public suspend fun getCheckpoints(): List public suspend fun saveCheckpoint(agentCheckpointData: AgentCheckpointData) public suspend fun getLatestCheckpoint(): AgentCheckpointData? diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/file/FilePersistencyStorageProvider.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/file/FilePersistencyStorageProvider.kt index acdc2114d7..a802a47327 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/file/FilePersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/file/FilePersistencyStorageProvider.kt @@ -1,16 +1,25 @@ package ai.koog.agents.snapshot.providers.file import ai.koog.agents.snapshot.feature.AgentCheckpointData -import ai.koog.agents.snapshot.providers.PersistencyStorageProvider -import ai.koog.agents.snapshot.providers.PersistencyUtils +import ai.koog.agents.snapshot.providers.PersistenceStorageProvider +import ai.koog.agents.snapshot.providers.PersistenceUtils import ai.koog.rag.base.files.FileSystemProvider import ai.koog.rag.base.files.createDirectory import ai.koog.rag.base.files.readText import ai.koog.rag.base.files.writeText import kotlinx.serialization.json.Json +@Deprecated( + "`FilePersistencyStorageProvider` has been renamed to `FilePersistenceStorageProvider`", + replaceWith = ReplaceWith( + expression = "FilePersistenceStorageProvider", + "ai.koog.agents.snapshot.providers.file.FilePersistenceStorageProvider" + ) +) +public typealias FilePersistencyStorageProvider = FilePersistenceStorageProvider + /** - * A file-based implementation of [PersistencyStorageProvider] that stores agent checkpoints in a file system. + * A file-based implementation of [PersistenceStorageProvider] that stores agent checkpoints in a file system. * * This implementation organizes checkpoints by agent ID and uses JSON serialization for storing and retrieving * checkpoint data. It relies on [FileSystemProvider.ReadWrite] for file system operations. @@ -19,12 +28,12 @@ import kotlinx.serialization.json.Json * @param fs A file system provider enabling read and write operations for file storage. * @param root Root file path where the checkpoint storage will organize data. */ -public open class FilePersistencyStorageProvider( +public open class FilePersistenceStorageProvider( private val persistenceId: String, private val fs: FileSystemProvider.ReadWrite, private val root: Path, - private val json: Json = PersistencyUtils.defaultCheckpointJson -) : PersistencyStorageProvider { + private val json: Json = PersistenceUtils.defaultCheckpointJson +) : PersistenceStorageProvider { /** * Directory where agent checkpoints are stored diff --git a/agents/agents-features/agents-features-snapshot/src/jvmMain/kotlin/ai/koog/agents/snapshot/providers/file/JVMFilePersistencyStorageProvider.kt b/agents/agents-features/agents-features-snapshot/src/jvmMain/kotlin/ai/koog/agents/snapshot/providers/file/JVMFilePersistencyStorageProvider.kt index 2f8e96d365..c0cdd08f8d 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmMain/kotlin/ai/koog/agents/snapshot/providers/file/JVMFilePersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmMain/kotlin/ai/koog/agents/snapshot/providers/file/JVMFilePersistencyStorageProvider.kt @@ -1,12 +1,21 @@ package ai.koog.agents.snapshot.providers.file -import ai.koog.agents.snapshot.providers.PersistencyUtils +import ai.koog.agents.snapshot.providers.PersistenceUtils import ai.koog.rag.base.files.JVMFileSystemProvider import kotlinx.serialization.json.Json import java.nio.file.Path +@Deprecated( + "`JVMFilePersistencyStorageProvider` has been renamed to `JVMFilePersistenceStorageProvider`", + replaceWith = ReplaceWith( + expression = "JVMFilePersistenceStorageProvider", + "ai.koog.agents.snapshot.providers.file.JVMFilePersistenceStorageProvider" + ) +) +public typealias JVMFilePersistencyStorageProvider = JVMFilePersistenceStorageProvider + /** - * A JVM-specific implementation of [FilePersistencyStorageProvider] for managing agent checkpoints + * A JVM-specific implementation of [FilePersistenceStorageProvider] for managing agent checkpoints * in a file system. * * This class utilizes JVM's [Path] for file system operations and [JVMFileSystemProvider.ReadWrite] @@ -16,14 +25,14 @@ import java.nio.file.Path * Use this class to persistently store and retrieve agent checkpoints to and from a file-based system * in JVM environments. * - * @constructor Initializes the [JVMFilePersistencyStorageProvider] with a specified root directory [root]. + * @constructor Initializes the [JVMFilePersistenceStorageProvider] with a specified root directory [root]. * @param root The root directory where all agent checkpoints will be stored. */ -public class JVMFilePersistencyStorageProvider( +public class JVMFilePersistenceStorageProvider( root: Path, persistenceId: String, - json: Json = PersistencyUtils.defaultCheckpointJson -) : FilePersistencyStorageProvider( + json: Json = PersistenceUtils.defaultCheckpointJson +) : FilePersistenceStorageProvider( fs = JVMFileSystemProvider.ReadWrite, root = root, persistenceId = persistenceId, diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/CheckpointSerializationTest.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/CheckpointSerializationTest.kt index dc1f478335..76c035fab9 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/CheckpointSerializationTest.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/CheckpointSerializationTest.kt @@ -1,6 +1,6 @@ import ai.koog.agents.snapshot.feature.AgentCheckpointData import ai.koog.agents.snapshot.feature.tombstoneCheckpoint -import ai.koog.agents.snapshot.providers.PersistencyUtils +import ai.koog.agents.snapshot.providers.PersistenceUtils import ai.koog.prompt.message.Message import ai.koog.prompt.message.RequestMetaInfo import ai.koog.prompt.message.ResponseMetaInfo @@ -35,7 +35,7 @@ class CheckpointSerializationTest { properties = null ) - val json = PersistencyUtils.defaultCheckpointJson + val json = PersistenceUtils.defaultCheckpointJson val serialized = json.encodeToString(AgentCheckpointData.serializer(), checkpoint) // properties should be omitted due to explicitNulls = false @@ -93,7 +93,7 @@ class CheckpointSerializationTest { properties = properties ) - val json = PersistencyUtils.defaultCheckpointJson + val json = PersistenceUtils.defaultCheckpointJson val serialized = json.encodeToString(AgentCheckpointData.serializer(), checkpoint) val restored = json.decodeFromString(AgentCheckpointData.serializer(), serialized) @@ -104,7 +104,7 @@ class CheckpointSerializationTest { @Test fun `serialize and deserialize tombstone checkpoint`() { val checkpoint = tombstoneCheckpoint(Clock.System.now()) - val json = PersistencyUtils.defaultCheckpointJson + val json = PersistenceUtils.defaultCheckpointJson val serialized = json.encodeToString(AgentCheckpointData.serializer(), checkpoint) val restored = json.decodeFromString(AgentCheckpointData.serializer(), serialized) diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/CheckpointsTests.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/CheckpointsTests.kt index 033fc9d49a..1ce94d6fe8 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/CheckpointsTests.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/CheckpointsTests.kt @@ -11,10 +11,10 @@ import ai.koog.agents.core.tools.Tool import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.ext.tool.SayToUser import ai.koog.agents.snapshot.feature.AgentCheckpointData -import ai.koog.agents.snapshot.feature.Persistency +import ai.koog.agents.snapshot.feature.Persistence import ai.koog.agents.snapshot.feature.RollbackToolRegistry -import ai.koog.agents.snapshot.feature.withPersistency -import ai.koog.agents.snapshot.providers.InMemoryPersistencyStorageProvider +import ai.koog.agents.snapshot.feature.withPersistence +import ai.koog.agents.snapshot.providers.InMemoryPersistenceStorageProvider import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.prompt.dsl.prompt import ai.koog.prompt.llm.OllamaModels @@ -69,7 +69,7 @@ class CheckpointsTests { } val checkpoint by node { input -> println("checkpoint save") - withPersistency { ctx -> + withPersistence { ctx -> createCheckpoint( ctx, currentNodeId ?: error("currentNodeId not set"), @@ -88,7 +88,7 @@ class CheckpointsTests { println("checkpoint load") if (!loaded) { loaded = true - withPersistency { ctx -> + withPersistence { ctx -> rollbackToCheckpoint("cpt-100500", ctx) } } @@ -104,8 +104,8 @@ class CheckpointsTests { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider("testAgentId") } } @@ -124,8 +124,8 @@ class CheckpointsTests { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider("testAgentId") } } @@ -209,7 +209,7 @@ class CheckpointsTests { // Node that creates a checkpoint val saveCheckpoint by node { input -> - withPersistency { ctx -> + withPersistence { ctx -> createCheckpoint( ctx, currentNodeId ?: error("currentNodeId not set"), @@ -261,8 +261,8 @@ class CheckpointsTests { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider("testAgentId") } } @@ -294,8 +294,8 @@ class CheckpointsTests { agentConfig = agentConfig, toolRegistry = localToolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("agent-tools-rollback-1") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider("agent-tools-rollback-1") rollbackToolRegistry = RollbackToolRegistry { registerRollback(WriteKVTool, DeleteKVTool) } @@ -322,7 +322,7 @@ class CheckpointsTests { assertContains(databaseMap, "user-2") assertContains(databaseMap, "user-3") - agent.withPersistency { ctx -> + agent.withPersistence { ctx -> println("ctx outside: $this") println("ctx outside [hash]: ${this.hashCode()}") rollbackToCheckpoint("ckpt-1", ctx) @@ -347,7 +347,7 @@ class CheckpointsTests { @Test fun testRestoreFromSingleCheckpoint() = runTest { - val checkpointStorageProvider = InMemoryPersistencyStorageProvider("testAgentId") + val checkpointStorageProvider = InMemoryPersistenceStorageProvider("testAgentId") val time = Clock.System.now() val agentId = "testAgentId" @@ -371,7 +371,7 @@ class CheckpointsTests { toolRegistry = toolRegistry, id = agentId ) { - install(Persistency) { + install(Persistence) { storage = checkpointStorageProvider } } @@ -388,7 +388,7 @@ class CheckpointsTests { @Test fun testRestoreFromLatestCheckpoint() = runTest { - val checkpointStorageProvider = InMemoryPersistencyStorageProvider("testAgentId") + val checkpointStorageProvider = InMemoryPersistenceStorageProvider("testAgentId") val time = Clock.System.now() val agentId = "testAgentId" @@ -424,7 +424,7 @@ class CheckpointsTests { toolRegistry = toolRegistry, id = agentId ) { - install(Persistency) { + install(Persistence) { storage = checkpointStorageProvider } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/NodeUniquenessCheckpointTest.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/NodeUniquenessCheckpointTest.kt index f0bbef7b18..a1cd44756e 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/NodeUniquenessCheckpointTest.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/NodeUniquenessCheckpointTest.kt @@ -7,8 +7,8 @@ import ai.koog.agents.core.dsl.builder.forwardTo import ai.koog.agents.core.dsl.builder.strategy import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.ext.tool.SayToUser -import ai.koog.agents.snapshot.feature.Persistency -import ai.koog.agents.snapshot.providers.InMemoryPersistencyStorageProvider +import ai.koog.agents.snapshot.feature.Persistence +import ai.koog.agents.snapshot.providers.InMemoryPersistenceStorageProvider import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.prompt.dsl.prompt import ai.koog.prompt.executor.model.PromptExecutor @@ -104,8 +104,8 @@ class NodeUniquenessCheckpointTest { toolRegistry = toolRegistry ) { // Install the AgentCheckpoint feature - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider("testAgentId") } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRestoreStrategyTests.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRestoreStrategyTests.kt index 8b81618e2b..beb716361f 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRestoreStrategyTests.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRestoreStrategyTests.kt @@ -3,8 +3,8 @@ import ai.koog.agents.core.agent.AIAgentService import ai.koog.agents.core.agent.config.AIAgentConfig import ai.koog.agents.core.agent.context.RollbackStrategy import ai.koog.agents.snapshot.feature.AgentCheckpointData -import ai.koog.agents.snapshot.feature.Persistency -import ai.koog.agents.snapshot.providers.InMemoryPersistencyStorageProvider +import ai.koog.agents.snapshot.feature.Persistence +import ai.koog.agents.snapshot.providers.InMemoryPersistenceStorageProvider import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.prompt.dsl.prompt import ai.koog.prompt.llm.OllamaModels @@ -16,10 +16,10 @@ import kotlinx.serialization.json.JsonPrimitive import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Test -class PersistencyRestoreStrategyTests { +class PersistenceRestoreStrategyTests { @Test fun `rollback Default resumes from checkpoint node`() = runTest { - val provider = InMemoryPersistencyStorageProvider("persistency-restore-default") + val provider = InMemoryPersistenceStorageProvider("persistence-restore-default") val checkpoint = AgentCheckpointData( checkpointId = "chk-1", @@ -40,10 +40,10 @@ class PersistencyRestoreStrategyTests { maxAgentIterations = 10 ), ) { - install(Persistency) { + install(Persistence) { storage = provider - // We only need restore on start; automatic persistency doesn't matter here - enableAutomaticPersistency = true + // We only need restore on start; automatic persistence doesn't matter here + enableAutomaticPersistence = true rollbackStrategy = RollbackStrategy.Default } } @@ -59,7 +59,7 @@ class PersistencyRestoreStrategyTests { @Test fun `rollback MessageHistoryOnly starts from beginning`() = runTest { - val provider = InMemoryPersistencyStorageProvider("persistency-restore-history-only") + val provider = InMemoryPersistenceStorageProvider("persistence-restore-history-only") val agentService = AIAgentService( promptExecutor = getMockExecutor { }, @@ -70,9 +70,9 @@ class PersistencyRestoreStrategyTests { maxAgentIterations = 10 ), ) { - install(Persistency) { + install(Persistence) { storage = provider - enableAutomaticPersistency = true + enableAutomaticPersistence = true rollbackStrategy = RollbackStrategy.MessageHistoryOnly } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRunsTwiceTest.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRunsTwiceTest.kt index 2f791514f8..6c71b0e209 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRunsTwiceTest.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRunsTwiceTest.kt @@ -1,8 +1,8 @@ import ai.koog.agents.core.agent.AIAgentService import ai.koog.agents.core.agent.config.AIAgentConfig -import ai.koog.agents.snapshot.feature.Persistency +import ai.koog.agents.snapshot.feature.Persistence import ai.koog.agents.snapshot.feature.isTombstone -import ai.koog.agents.snapshot.providers.InMemoryPersistencyStorageProvider +import ai.koog.agents.snapshot.providers.InMemoryPersistenceStorageProvider import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.prompt.dsl.prompt import ai.koog.prompt.llm.OllamaModels @@ -12,12 +12,12 @@ import kotlinx.coroutines.test.runTest import org.awaitility.kotlin.await import org.junit.jupiter.api.Test -class PersistencyRunsTwiceTest { +class PersistenceRunsTwiceTest { @Test fun `agent runs to end and on second run starts from beginning again`() = runTest { // Arrange - val provider = InMemoryPersistencyStorageProvider("persistency-test-agent") + val provider = InMemoryPersistenceStorageProvider("persistence-test-agent") val testCollector = TestAgentLogsCollector() @@ -32,9 +32,9 @@ class PersistencyRunsTwiceTest { maxAgentIterations = 10 ), ) { - install(Persistency) { + install(Persistence) { storage = provider - enableAutomaticPersistency = true + enableAutomaticPersistence = true } } @@ -76,7 +76,7 @@ class PersistencyRunsTwiceTest { @Test fun `agent fails on the first run and second run running successfully`() = runTest { - val provider = InMemoryPersistencyStorageProvider("persistency-test-agent") + val provider = InMemoryPersistenceStorageProvider("persistence-test-agent") val testCollector = TestAgentLogsCollector() @@ -91,9 +91,9 @@ class PersistencyRunsTwiceTest { maxAgentIterations = 10 ), ) { - install(Persistency) { + install(Persistence) { storage = provider - enableAutomaticPersistency = true + enableAutomaticPersistence = true } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SimpleGraphCheckpointTest.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SimpleGraphCheckpointTest.kt index 30287da353..4de2dfce04 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SimpleGraphCheckpointTest.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SimpleGraphCheckpointTest.kt @@ -2,8 +2,8 @@ import ai.koog.agents.core.agent.AIAgent import ai.koog.agents.core.agent.config.AIAgentConfig import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.ext.tool.SayToUser -import ai.koog.agents.snapshot.feature.Persistency -import ai.koog.agents.snapshot.providers.InMemoryPersistencyStorageProvider +import ai.koog.agents.snapshot.feature.Persistence +import ai.koog.agents.snapshot.providers.InMemoryPersistenceStorageProvider import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.prompt.dsl.prompt import ai.koog.prompt.executor.model.PromptExecutor @@ -51,8 +51,8 @@ class SimpleGraphCheckpointTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider("testAgentId") } } @@ -78,7 +78,7 @@ class SimpleGraphCheckpointTest { @Test fun `test agent creates and saves checkpoints`() = runTest { // Create a snapshot provider to store checkpoints - val checkpointStorageProvider = InMemoryPersistencyStorageProvider("testAgentId") + val checkpointStorageProvider = InMemoryPersistenceStorageProvider("testAgentId") // Create a mock executor for testing val mockExecutor: PromptExecutor = getMockExecutor { @@ -106,7 +106,7 @@ class SimpleGraphCheckpointTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { + install(Persistence) { storage = checkpointStorageProvider } } @@ -122,7 +122,7 @@ class SimpleGraphCheckpointTest { @Test fun test_checkpoint_persists_history() = runTest { - val checkpointStorageProvider = InMemoryPersistencyStorageProvider("testAgentId") + val checkpointStorageProvider = InMemoryPersistenceStorageProvider("testAgentId") val mockExecutor: PromptExecutor = getMockExecutor { // No specific mock responses needed for this test @@ -148,7 +148,7 @@ class SimpleGraphCheckpointTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { + install(Persistence) { storage = checkpointStorageProvider } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphCheckpointsTest.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphCheckpointsTest.kt index 0c3a5da999..54dbf84b75 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphCheckpointsTest.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphCheckpointsTest.kt @@ -2,8 +2,8 @@ import ai.koog.agents.core.agent.AIAgent import ai.koog.agents.core.agent.config.AIAgentConfig import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.ext.tool.SayToUser -import ai.koog.agents.snapshot.feature.Persistency -import ai.koog.agents.snapshot.providers.InMemoryPersistencyStorageProvider +import ai.koog.agents.snapshot.feature.Persistence +import ai.koog.agents.snapshot.providers.InMemoryPersistenceStorageProvider import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.prompt.dsl.prompt import ai.koog.prompt.llm.OllamaModels @@ -36,8 +36,8 @@ class SubgraphCheckpointsTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider("testAgentId") } } @@ -63,8 +63,8 @@ class SubgraphCheckpointsTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider("testAgentId") } } @@ -90,8 +90,8 @@ class SubgraphCheckpointsTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider("testAgentId") } } @@ -118,8 +118,8 @@ class SubgraphCheckpointsTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider("testAgentId") } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphSetExecutionPointTest.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphSetExecutionPointTest.kt index 0a723550ad..c8eb5b6017 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphSetExecutionPointTest.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphSetExecutionPointTest.kt @@ -2,8 +2,8 @@ import ai.koog.agents.core.agent.AIAgent import ai.koog.agents.core.agent.config.AIAgentConfig import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.ext.tool.SayToUser -import ai.koog.agents.snapshot.feature.Persistency -import ai.koog.agents.snapshot.providers.InMemoryPersistencyStorageProvider +import ai.koog.agents.snapshot.feature.Persistence +import ai.koog.agents.snapshot.providers.InMemoryPersistenceStorageProvider import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.prompt.dsl.prompt import ai.koog.prompt.llm.OllamaModels @@ -32,8 +32,8 @@ class SubgraphSetExecutionPointTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider("testAgentId") } } @@ -56,8 +56,8 @@ class SubgraphSetExecutionPointTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider("testAgentId") } } @@ -84,8 +84,8 @@ class SubgraphSetExecutionPointTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider("testAgentId") } } @@ -109,8 +109,8 @@ class SubgraphSetExecutionPointTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider("testAgentId") } } @@ -136,8 +136,8 @@ class SubgraphSetExecutionPointTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider("testAgentId") } } @@ -163,8 +163,8 @@ class SubgraphSetExecutionPointTest { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { - storage = InMemoryPersistencyStorageProvider("testAgentId") + install(Persistence) { + storage = InMemoryPersistenceStorageProvider("testAgentId") } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/TestStrategies.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/TestStrategies.kt index 9393ce0401..53f3ac6f84 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/TestStrategies.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/TestStrategies.kt @@ -2,7 +2,7 @@ import ai.koog.agents.core.dsl.builder.AIAgentNodeDelegate import ai.koog.agents.core.dsl.builder.AIAgentSubgraphBuilderBase import ai.koog.agents.core.dsl.builder.forwardTo import ai.koog.agents.core.dsl.builder.strategy -import ai.koog.agents.snapshot.feature.withPersistency +import ai.koog.agents.snapshot.feature.withPersistence import kotlinx.serialization.json.JsonPrimitive import kotlin.reflect.typeOf @@ -103,10 +103,10 @@ private fun AIAgentSubgraphBuilderBase<*, *>.teleportOnceNode( ): AIAgentNodeDelegate = node(name) { if (!teleportState.teleported) { teleportState.teleported = true - withPersistency { ctx -> + withPersistence { ctx -> val history = llm.readSession { this.prompt.messages } setExecutionPoint(ctx, teleportToId, history, JsonPrimitive("$it\nTeleported")) - return@withPersistency "Teleported" + return@withPersistence "Teleported" } } else { // If we've already teleported, just return the input @@ -131,7 +131,7 @@ private fun AIAgentSubgraphBuilderBase<*, *>.nodeForSecondTry( private fun AIAgentSubgraphBuilderBase<*, *>.createCheckpointNode(name: String? = null, checkpointId: String) = node(name) { val input = it - withPersistency { ctx -> + withPersistence { ctx -> createCheckpoint(ctx, name!!, input, typeOf(), checkpointId) llm.writeSession { updatePrompt { @@ -157,7 +157,7 @@ private fun AIAgentSubgraphBuilderBase<*, *>.nodeRollbackToCheckpoint( return@node "Skipping rollback" } - withPersistency { + withPersistence { val checkpoint = rollbackToCheckpoint(checkpointId, it)!! teleportState.teleported = true llm.writeSession { @@ -174,7 +174,7 @@ private fun AIAgentSubgraphBuilderBase<*, *>.nodeCreateCheckpoint( name: String? = null, ): AIAgentNodeDelegate = node(name) { val input = it - withPersistency { ctx -> + withPersistence { ctx -> val checkpoint = createCheckpoint( ctx, currentNodeId ?: error("currentNodeId not set"), @@ -185,7 +185,7 @@ private fun AIAgentSubgraphBuilderBase<*, *>.nodeCreateCheckpoint( saveCheckpoint(checkpoint ?: error("Checkpoint creation failed")) - return@withPersistency "$input\nSnapshot created" + return@withPersistence "$input\nSnapshot created" } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileAgentCheckpointStorageProviderTest.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileAgentCheckpointStorageProviderTest.kt index b9f5318e1b..3c5f87d3e3 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileAgentCheckpointStorageProviderTest.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileAgentCheckpointStorageProviderTest.kt @@ -17,12 +17,12 @@ import kotlin.test.assertTrue class FileAgentCheckpointStorageProviderTest { private lateinit var tempDir: java.nio.file.Path - private lateinit var provider: JVMFilePersistencyStorageProvider + private lateinit var provider: JVMFilePersistenceStorageProvider @BeforeTest fun setup() { tempDir = Files.createTempDirectory("checkpoint-test") - provider = JVMFilePersistencyStorageProvider(tempDir, "testAgentId") + provider = JVMFilePersistenceStorageProvider(tempDir, "testAgentId") } @AfterTest diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileCheckpointsTests.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileCheckpointsTests.kt index 99cfa8a1b5..7c7ef48b3d 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileCheckpointsTests.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileCheckpointsTests.kt @@ -3,9 +3,9 @@ import ai.koog.agents.core.agent.config.AIAgentConfig import ai.koog.agents.core.tools.ToolRegistry import ai.koog.agents.ext.tool.SayToUser import ai.koog.agents.snapshot.feature.AgentCheckpointData -import ai.koog.agents.snapshot.feature.Persistency +import ai.koog.agents.snapshot.feature.Persistence import ai.koog.agents.snapshot.feature.isTombstone -import ai.koog.agents.snapshot.providers.file.JVMFilePersistencyStorageProvider +import ai.koog.agents.snapshot.providers.file.JVMFilePersistenceStorageProvider import ai.koog.agents.testing.tools.getMockExecutor import ai.koog.prompt.dsl.prompt import ai.koog.prompt.llm.OllamaModels @@ -32,7 +32,7 @@ import kotlin.time.Duration.Companion.seconds */ class FileCheckpointsTests { private lateinit var tempDir: Path - private lateinit var provider: JVMFilePersistencyStorageProvider + private lateinit var provider: JVMFilePersistenceStorageProvider val systemPrompt = "You are a test agent." val agentConfig = AIAgentConfig( @@ -49,7 +49,7 @@ class FileCheckpointsTests { @BeforeTest fun setup() { tempDir = Files.createTempDirectory("agent-checkpoint-test") - provider = JVMFilePersistencyStorageProvider(tempDir, "testAgentId") + provider = JVMFilePersistenceStorageProvider(tempDir, "testAgentId") } @AfterTest @@ -70,7 +70,7 @@ class FileCheckpointsTests { toolRegistry = toolRegistry, id = agentId ) { - install(Persistency) { + install(Persistence) { storage = provider } } @@ -99,7 +99,7 @@ class FileCheckpointsTests { agentConfig = agentConfig, toolRegistry = toolRegistry ) { - install(Persistency) { + install(Persistence) { storage = provider } } @@ -138,7 +138,7 @@ class FileCheckpointsTests { toolRegistry = toolRegistry, id = agentId ) { - install(Persistency) { + install(Persistence) { storage = provider } } @@ -190,7 +190,7 @@ class FileCheckpointsTests { toolRegistry = toolRegistry, id = agentId ) { - install(Persistency) { + install(Persistence) { storage = provider } } @@ -216,10 +216,10 @@ class FileCheckpointsTests { toolRegistry = toolRegistry, id = agentId ) { - install(Persistency) { + install(Persistence) { storage = provider - enableAutomaticPersistency = true + enableAutomaticPersistence = true } } diff --git a/agents/agents-features/agents-features-sql/Module.md b/agents/agents-features/agents-features-sql/Module.md index 350d7ab79e..e96acfa55b 100644 --- a/agents/agents-features/agents-features-sql/Module.md +++ b/agents/agents-features/agents-features-sql/Module.md @@ -19,19 +19,19 @@ Provides SQL-based persistence providers for agent checkpoints using JetBrains E ## Providers -### ExposedPersistencyStorageProvider +### ExposedPersistenceStorageProvider Base provider using Exposed ORM with configurable cleanup behavior. -### PostgresPersistencyStorageProvider +### PostgresPersistenceStorageProvider Production-ready provider with JSONB support and HikariCP pooling. -### MySQLPersistencyStorageProvider +### MySQLPersistenceStorageProvider Enterprise provider with JSON column support (MySQL 5.7+). -### H2PersistencyStorageProvider +### H2PersistenceStorageProvider Perfect for testing and embedded applications. -### SQLitePersistencyStorageProvider +### SQLitePersistenceStorageProvider Zero-configuration provider for desktop and mobile applications. -All providers implement `AutoCloseable` for proper resource management and support configurable TTL cleanup. \ No newline at end of file +All providers implement `AutoCloseable` for proper resource management and support configurable TTL cleanup. diff --git a/agents/agents-features/agents-features-sql/src/commonMain/kotlin/ai/koog/agents/features/sql/providers/SQLPersistencyStorageProvider.kt b/agents/agents-features/agents-features-sql/src/commonMain/kotlin/ai/koog/agents/features/sql/providers/SQLPersistencyStorageProvider.kt index a5dadc1630..b8eb3e3db6 100644 --- a/agents/agents-features/agents-features-sql/src/commonMain/kotlin/ai/koog/agents/features/sql/providers/SQLPersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-sql/src/commonMain/kotlin/ai/koog/agents/features/sql/providers/SQLPersistencyStorageProvider.kt @@ -1,10 +1,10 @@ package ai.koog.agents.features.sql.providers -import ai.koog.agents.snapshot.providers.PersistencyStorageProvider +import ai.koog.agents.snapshot.providers.PersistenceStorageProvider import kotlinx.datetime.Instant /** - * Abstract base class for SQL-based implementations of [PersistencyStorageProvider]. + * Abstract base class for SQL-based implementations of [PersistenceStorageProvider]. * * This provider offers a generic SQL abstraction for persisting agent checkpoints * to relational databases. Concrete implementations should handle specific SQL @@ -32,12 +32,12 @@ import kotlinx.datetime.Instant * @param tableName Name of the table to store checkpoints (default: "agent_checkpoints") * @param ttlSeconds Optional TTL for checkpoint entries in seconds (null = no expiration) */ -public abstract class SQLPersistencyStorageProvider( +public abstract class SQLPersistenceStorageProvider( protected val persistenceId: String, protected val tableName: String = "agent_checkpoints", protected val ttlSeconds: Long? = null, protected val migrator: SQLPersistenceSchemaMigrator -) : PersistencyStorageProvider { +) : PersistenceStorageProvider { /** * Initializes the database schema if it doesn't exist. diff --git a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/ExposedPersistencyStorageProvider.kt b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/ExposedPersistencyStorageProvider.kt index e4041e3af7..d1d5c38315 100644 --- a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/ExposedPersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/ExposedPersistencyStorageProvider.kt @@ -1,7 +1,7 @@ package ai.koog.agents.features.sql.providers import ai.koog.agents.snapshot.feature.AgentCheckpointData -import ai.koog.agents.snapshot.providers.PersistencyUtils +import ai.koog.agents.snapshot.providers.PersistenceUtils import kotlinx.datetime.Clock import kotlinx.serialization.json.Json import org.jetbrains.exposed.sql.Column @@ -73,7 +73,7 @@ public open class CheckpointsTable(tableName: String) : Table(tableName) { } /** - * An abstract Exposed-based implementation of [SQLPersistencyStorageProvider] for managing + * An abstract Exposed-based implementation of [SQLPersistenceStorageProvider] for managing * agent checkpoints in SQL databases using JetBrains Exposed ORM. * * This class provides a generic SQL implementation that works with any database supported @@ -113,14 +113,14 @@ public open class CheckpointsTable(tableName: String) : Table(tableName) { * @param ttlSeconds Optional TTL for checkpoint entries in seconds (null = no expiration) */ @Suppress("MissingKDocForPublicAPI") -public abstract class ExposedPersistencyStorageProvider( +public abstract class ExposedPersistenceStorageProvider( persistenceId: String, protected val database: Database, tableName: String = "agent_checkpoints", ttlSeconds: Long? = null, migrator: SQLPersistenceSchemaMigrator, - private val json: Json = PersistencyUtils.defaultCheckpointJson -) : SQLPersistencyStorageProvider( + private val json: Json = PersistenceUtils.defaultCheckpointJson +) : SQLPersistenceStorageProvider( persistenceId = persistenceId, tableName = tableName, ttlSeconds = ttlSeconds, @@ -184,7 +184,7 @@ public abstract class ExposedPersistencyStorageProvider( checkpointsTable .select(checkpointsTable.checkpointJson) .where { - checkpointsTable.persistenceId eq this@ExposedPersistencyStorageProvider.persistenceId + checkpointsTable.persistenceId eq this@ExposedPersistenceStorageProvider.persistenceId } .orderBy(checkpointsTable.createdAt to SortOrder.ASC) .mapNotNull { row -> @@ -202,7 +202,7 @@ public abstract class ExposedPersistencyStorageProvider( transaction { // Use upsert for idempotent saves checkpointsTable.upsert { - it[checkpointsTable.persistenceId] = this@ExposedPersistencyStorageProvider.persistenceId + it[checkpointsTable.persistenceId] = this@ExposedPersistenceStorageProvider.persistenceId it[checkpointsTable.checkpointId] = agentCheckpointData.checkpointId it[checkpointsTable.createdAt] = agentCheckpointData.createdAt.toEpochMilliseconds() it[checkpointsTable.checkpointJson] = checkpointJson @@ -216,7 +216,7 @@ public abstract class ExposedPersistencyStorageProvider( checkpointsTable .select(checkpointsTable.checkpointJson) .where { - checkpointsTable.persistenceId eq this@ExposedPersistencyStorageProvider.persistenceId + checkpointsTable.persistenceId eq this@ExposedPersistenceStorageProvider.persistenceId } .orderBy(checkpointsTable.createdAt to SortOrder.DESC) .limit(1) @@ -231,7 +231,7 @@ public abstract class ExposedPersistencyStorageProvider( override suspend fun deleteCheckpoint(checkpointId: String) { transaction { checkpointsTable.deleteWhere { - (checkpointsTable.persistenceId eq this@ExposedPersistencyStorageProvider.persistenceId) and + (checkpointsTable.persistenceId eq this@ExposedPersistenceStorageProvider.persistenceId) and (checkpointsTable.checkpointId eq checkpointId) } } @@ -240,7 +240,7 @@ public abstract class ExposedPersistencyStorageProvider( override suspend fun deleteAllCheckpoints() { transaction { checkpointsTable.deleteWhere { - checkpointsTable.persistenceId eq this@ExposedPersistencyStorageProvider.persistenceId + checkpointsTable.persistenceId eq this@ExposedPersistenceStorageProvider.persistenceId } } } @@ -248,7 +248,7 @@ public abstract class ExposedPersistencyStorageProvider( override suspend fun getCheckpointCount(): Long { return transaction { checkpointsTable.selectAll().where { - checkpointsTable.persistenceId eq this@ExposedPersistencyStorageProvider.persistenceId + checkpointsTable.persistenceId eq this@ExposedPersistenceStorageProvider.persistenceId }.count() } } diff --git a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProvider.kt b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProvider.kt index 978715dc55..30833681e5 100644 --- a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProvider.kt @@ -1,6 +1,6 @@ package ai.koog.agents.features.sql.providers -import ai.koog.agents.snapshot.providers.PersistencyUtils +import ai.koog.agents.snapshot.providers.PersistenceUtils import kotlinx.coroutines.Dispatchers import kotlinx.serialization.json.Json import org.jetbrains.exposed.sql.Database @@ -8,17 +8,17 @@ import org.jetbrains.exposed.sql.transactions.experimental.newSuspendedTransacti import org.jetbrains.exposed.sql.transactions.transaction /** - * H2 Database-specific implementation of [ExposedPersistencyStorageProvider] for managing + * H2 Database-specific implementation of [ExposedPersistenceStorageProvider] for managing * agent checkpoints in H2 databases. */ -public class H2PersistencyStorageProvider( +public class H2PersistenceStorageProvider( persistenceId: String, database: Database, tableName: String = "agent_checkpoints", ttlSeconds: Long? = null, migrator: SQLPersistenceSchemaMigrator = H2PersistenceSchemaMigrator(database, tableName), - json: Json = PersistencyUtils.defaultCheckpointJson -) : ExposedPersistencyStorageProvider(persistenceId, database, tableName, ttlSeconds, migrator, json) { + json: Json = PersistenceUtils.defaultCheckpointJson +) : ExposedPersistenceStorageProvider(persistenceId, database, tableName, ttlSeconds, migrator, json) { public override suspend fun transaction(block: suspend () -> T): T = newSuspendedTransaction(Dispatchers.IO, database) { @@ -44,7 +44,7 @@ public class H2PersistencyStorageProvider( options: String = "DB_CLOSE_DELAY=-1", tableName: String = "agent_checkpoints", ttlSeconds: Long? = null - ): H2PersistencyStorageProvider = H2PersistencyStorageProvider( + ): H2PersistenceStorageProvider = H2PersistenceStorageProvider( persistenceId = persistenceId, database = Database.connect("jdbc:h2:mem:$databaseName;$options"), tableName = tableName, @@ -68,7 +68,7 @@ public class H2PersistencyStorageProvider( options: String = "", tableName: String = "agent_checkpoints", ttlSeconds: Long? = null - ): H2PersistencyStorageProvider = H2PersistencyStorageProvider( + ): H2PersistenceStorageProvider = H2PersistenceStorageProvider( persistenceId = persistenceId, database = Database.connect( if (options.isNotEmpty()) { @@ -95,8 +95,8 @@ public class H2PersistencyStorageProvider( databasePath: String = "mem:test", tableName: String = "agent_checkpoints", ttlSeconds: Long? = null - ): H2PersistencyStorageProvider { - return H2PersistencyStorageProvider( + ): H2PersistenceStorageProvider { + return H2PersistenceStorageProvider( persistenceId = persistenceId, database = Database.connect("jdbc:h2:$databasePath;MODE=PostgreSQL;DATABASE_TO_LOWER=TRUE"), tableName = tableName, diff --git a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProvider.kt b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProvider.kt index 3077fe86fa..9a39443b7b 100644 --- a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProvider.kt @@ -1,6 +1,6 @@ package ai.koog.agents.features.sql.providers -import ai.koog.agents.snapshot.providers.PersistencyUtils +import ai.koog.agents.snapshot.providers.PersistenceUtils import kotlinx.coroutines.Dispatchers import kotlinx.serialization.json.Json import org.jetbrains.exposed.sql.Database @@ -8,7 +8,7 @@ import org.jetbrains.exposed.sql.transactions.experimental.newSuspendedTransacti import org.jetbrains.exposed.sql.transactions.transaction /** - * MySQL-specific implementation of [ExposedPersistencyStorageProvider] for managing + * MySQL-specific implementation of [ExposedPersistenceStorageProvider] for managing * agent checkpoints in MySQL databases. * * This provider is optimized for MySQL 5.7+ and MariaDB 10.2+, leveraging their @@ -22,7 +22,7 @@ import org.jetbrains.exposed.sql.transactions.transaction * * ## Example Usage: * ```kotlin - * val provider = MySQLPersistencyStorageProvider( + * val provider = MySQLPersistenceStorageProvider( * persistenceId = "my-agent", * database = Database.connect( * url = "jdbc:mysql://localhost:3306/mydb?useSSL=false&serverTimezone=UTC", @@ -36,14 +36,14 @@ import org.jetbrains.exposed.sql.transactions.transaction * * @constructor Initializes the MySQL persistence provider with an Exposed Database instance. */ -public class MySQLPersistencyStorageProvider( +public class MySQLPersistenceStorageProvider( persistenceId: String, database: Database, tableName: String = "agent_checkpoints", ttlSeconds: Long? = null, migrator: SQLPersistenceSchemaMigrator = MySqlPersistenceSchemaMigrator(database, tableName), - json: Json = PersistencyUtils.defaultCheckpointJson -) : ExposedPersistencyStorageProvider(persistenceId, database, tableName, ttlSeconds, migrator, json) { + json: Json = PersistenceUtils.defaultCheckpointJson +) : ExposedPersistenceStorageProvider(persistenceId, database, tableName, ttlSeconds, migrator, json) { override suspend fun transaction(block: suspend () -> T): T = newSuspendedTransaction(Dispatchers.IO, database) { diff --git a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProvider.kt b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProvider.kt index e4e662ee34..e743e78176 100644 --- a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProvider.kt @@ -1,6 +1,6 @@ package ai.koog.agents.features.sql.providers -import ai.koog.agents.snapshot.providers.PersistencyUtils +import ai.koog.agents.snapshot.providers.PersistenceUtils import kotlinx.coroutines.Dispatchers import kotlinx.serialization.json.Json import org.jetbrains.exposed.sql.Database @@ -8,19 +8,19 @@ import org.jetbrains.exposed.sql.transactions.experimental.newSuspendedTransacti import org.jetbrains.exposed.sql.transactions.transaction /** - * PostgreSQL-specific implementation of [ExposedPersistencyStorageProvider] for managing + * PostgreSQL-specific implementation of [ExposedPersistenceStorageProvider] for managing * agent checkpoints in PostgreSQL databases. * * @constructor Initializes the PostgreSQL persistence provider with connection details. */ -public class PostgresPersistencyStorageProvider( +public class PostgresPersistenceStorageProvider( persistenceId: String, database: Database, tableName: String = "agent_checkpoints", ttlSeconds: Long? = null, migrator: SQLPersistenceSchemaMigrator = PostgresPersistenceSchemaMigrator(database, tableName), - json: Json = PersistencyUtils.defaultCheckpointJson -) : ExposedPersistencyStorageProvider(persistenceId, database, tableName, ttlSeconds, migrator, json) { + json: Json = PersistenceUtils.defaultCheckpointJson +) : ExposedPersistenceStorageProvider(persistenceId, database, tableName, ttlSeconds, migrator, json) { override suspend fun transaction(block: suspend () -> T): T = newSuspendedTransaction(Dispatchers.IO, database) { block() } diff --git a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProviderTest.kt b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProviderTest.kt index 15cdaa40bb..10c24d578e 100644 --- a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProviderTest.kt +++ b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProviderTest.kt @@ -19,10 +19,10 @@ import kotlin.test.assertNull @TestInstance(Lifecycle.PER_CLASS) @ExtendWith(DockerAvailableCondition::class) -class H2PersistencyStorageProviderTest { +class H2PersistenceStorageProviderTest { - private fun provider(ttlSeconds: Long? = null): H2PersistencyStorageProvider { - return H2PersistencyStorageProvider.inMemory( + private fun provider(ttlSeconds: Long? = null): H2PersistenceStorageProvider { + return H2PersistenceStorageProvider.inMemory( persistenceId = "h2-agent", databaseName = "h2_test_db", tableName = "agent_checkpoints_test", diff --git a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProviderTest.kt b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProviderTest.kt index d3c90fe1ab..8b8afeb4c2 100644 --- a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProviderTest.kt +++ b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProviderTest.kt @@ -24,7 +24,7 @@ import kotlin.test.assertNull @TestInstance(Lifecycle.PER_CLASS) @ExtendWith(DockerAvailableCondition::class) -class MySQLPersistencyStorageProviderTest { +class MySQLPersistenceStorageProviderTest { private lateinit var mysql: MySQLContainer<*> @@ -42,14 +42,14 @@ class MySQLPersistencyStorageProviderTest { mysql.stop() } - private fun provider(ttlSeconds: Long? = null): MySQLPersistencyStorageProvider { + private fun provider(ttlSeconds: Long? = null): MySQLPersistenceStorageProvider { val db: Database = Database.connect( url = mysql.jdbcUrl + "?useSSL=false&allowPublicKeyRetrieval=true&serverTimezone=UTC", driver = "com.mysql.cj.jdbc.Driver", user = mysql.username, password = mysql.password ) - return MySQLPersistencyStorageProvider( + return MySQLPersistenceStorageProvider( persistenceId = "mysql-agent", database = db, tableName = "agent_checkpoints_test", diff --git a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProviderTest.kt b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProviderTest.kt index d93e849c66..f9ab8099ba 100644 --- a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProviderTest.kt +++ b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProviderTest.kt @@ -23,7 +23,7 @@ import kotlin.test.assertNull @TestInstance(Lifecycle.PER_CLASS) @ExtendWith(DockerAvailableCondition::class) -class PostgresPersistencyStorageProviderTest { +class PostgresPersistenceStorageProviderTest { private lateinit var postgres: PostgreSQLContainer<*> @@ -41,14 +41,14 @@ class PostgresPersistencyStorageProviderTest { postgres.stop() } - private fun provider(ttlSeconds: Long? = null): PostgresPersistencyStorageProvider { + private fun provider(ttlSeconds: Long? = null): PostgresPersistenceStorageProvider { val db: Database = Database.connect( url = postgres.jdbcUrl, driver = "org.postgresql.Driver", user = postgres.username, password = postgres.password ) - return PostgresPersistencyStorageProvider( + return PostgresPersistenceStorageProvider( persistenceId = "pg-agent", database = db, tableName = "agent_checkpoints_test", diff --git a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/SQLPersistenceProvidersTest.kt b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/SQLPersistenceProvidersTest.kt index 1a505d8a9f..29c13467f6 100644 --- a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/SQLPersistenceProvidersTest.kt +++ b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/SQLPersistenceProvidersTest.kt @@ -24,7 +24,7 @@ class SQLPersistenceProvidersTest { @Test fun `test save and retrieve checkpoint`() = runBlocking { - val provider = H2PersistencyStorageProvider.inMemory( + val provider = H2PersistenceStorageProvider.inMemory( persistenceId = "test-agent", databaseName = "test_db" ) @@ -45,7 +45,7 @@ class SQLPersistenceProvidersTest { @Test fun `test multiple checkpoints and ordering`() = runBlocking { - val provider = H2PersistencyStorageProvider.inMemory( + val provider = H2PersistenceStorageProvider.inMemory( persistenceId = "test-agent", databaseName = "test_ordering" ) @@ -70,11 +70,11 @@ class SQLPersistenceProvidersTest { @Test fun `test persistence ID isolation`() = runBlocking { - val provider1 = H2PersistencyStorageProvider.inMemory( + val provider1 = H2PersistenceStorageProvider.inMemory( persistenceId = "agent-1", databaseName = "shared_db" ) - val provider2 = H2PersistencyStorageProvider.inMemory( + val provider2 = H2PersistenceStorageProvider.inMemory( persistenceId = "agent-2", databaseName = "shared_db" // Same database ) @@ -98,7 +98,7 @@ class SQLPersistenceProvidersTest { @Test fun `test TTL expiration`() = runBlocking { - val provider = H2PersistencyStorageProvider.inMemory( + val provider = H2PersistenceStorageProvider.inMemory( persistenceId = "ttl-test", databaseName = "ttl_db", ttlSeconds = 1 // 1 second TTL @@ -122,11 +122,11 @@ class SQLPersistenceProvidersTest { @Test fun `verify all providers can be instantiated`() { // H2 - assertNotNull(H2PersistencyStorageProvider.inMemory("test", "test_db")) + assertNotNull(H2PersistenceStorageProvider.inMemory("test", "test_db")) // PostgreSQL assertNotNull( - PostgresPersistencyStorageProvider( + PostgresPersistenceStorageProvider( persistenceId = "test", database = Database.connect( url = "jdbc:postgresql://localhost:5432/test", @@ -139,7 +139,7 @@ class SQLPersistenceProvidersTest { // MySQL assertNotNull( - MySQLPersistencyStorageProvider( + MySQLPersistenceStorageProvider( persistenceId = "test", database = Database.connect( url = "jdbc:mysql://localhost:3306/test", diff --git a/docs/docs/agent-persistence.md b/docs/docs/agent-persistence.md index 79f31c03ca..6638b4c2d3 100644 --- a/docs/docs/agent-persistence.md +++ b/docs/docs/agent-persistence.md @@ -1,11 +1,11 @@ # Agent Persistence -Agent persistence is a feature that provides checkpoint functionality for AI agents in the Koog framework. +Agent Persistence is a feature that provides checkpoint functionality for AI agents in the Koog framework. It lets you save and restore the state of an agent at specific points during execution, enabling capabilities such as: -- Resuming agent execution from a specific point. -- Rolling back to previous states. -- Persisting agent state across sessions. +- Resuming agent execution from a specific point +- Rolling back to previous states +- Persisting agent state across sessions ## Key concepts @@ -48,8 +48,8 @@ To use the Agent Persistence feature, add it to your agent's configuration: ```kotlin -install(Persistency) { - storage = InMemoryPersistencyStorageProvider("in-memory-storage") +install(Persistence) { + storage = InMemoryPersistenceStorageProvider("in-memory-storage") } ``` @@ -110,11 +110,11 @@ install(Persistency) { The framework includes the following built-in providers: -- `InMemoryPersistencyStorageProvider`: stores checkpoints in memory (lost when the application restarts). -- `FilePersistencyStorageProvider`: persists checkpoints to the file system. -- `NoPersistencyStorageProvider`: a no-op implementation that does not store checkpoints. This is the default provider. +- `InMemoryPersistenceStorageProvider`: stores checkpoints in memory (lost when the application restarts). +- `FilePersistenceStorageProvider`: persists checkpoints to the file system. +- `NoPersistenceStorageProvider`: a no-op implementation that does not store checkpoints. This is the default provider. -You can also implement custom storage providers by implementing the `PersistencyStorageProvider` interface. +You can also implement custom storage providers by implementing the `PersistenceStorageProvider` interface. For more information, see [Custom storage providers](#custom-storage-providers). ### Continuous persistence @@ -124,8 +124,8 @@ To activate continuous persistence, use the code below: ```kotlin -install(Persistency) { - enableAutomaticPersistency = true +install(Persistence) { + enableAutomaticPersistence = true } ``` @@ -157,7 +157,7 @@ To learn how to create a checkpoint at a specific point in your agent's executio ```kotlin suspend fun example(context: AIAgentContext, checkpointId: String) { // Roll back to a specific checkpoint - context.persistency().rollbackToCheckpoint(checkpointId, context) + context.persistence().rollbackToCheckpoint(checkpointId, context) // Or roll back to the latest checkpoint - context.persistency().rollbackToLatestCheckpoint(context) + context.persistence().rollbackToLatestCheckpoint(context) } ``` @@ -222,12 +222,12 @@ And now you would like to roll back to a checkpoint. Restoring the agent's state be sufficient to achieve the exact state of the world before the checkpoint. You should also restore the side-effects produced by your tool calls. In our example, this would mean removing `Maria` and `Daniel` from the database. -With Koog Persistence you can achieve that by providing a `RollbackToolRegistry` to `Persistency` feature config: +With Koog Persistence you can achieve that by providing a `RollbackToolRegistry` to `Persistence` feature config: ```kotlin -install(Persistency) { - enableAutomaticPersistency = true +install(Persistence) { + enableAutomaticPersistence = true rollbackToolRegistry = RollbackToolRegistry { // For every `createUser` tool call there will be a `removeUser` invocation in the reverse order // when rolling back to the desired execution point. @@ -269,17 +269,17 @@ The Agent Persistence feature provides convenient extension functions for workin import ai.koog.agents.core.agent.context.AIAgentContext import ai.koog.agents.example.exampleAgentPersistence05.inputData import ai.koog.agents.example.exampleAgentPersistence05.inputType -import ai.koog.agents.snapshot.feature.persistency -import ai.koog.agents.snapshot.feature.withPersistency +import ai.koog.agents.snapshot.feature.persistence +import ai.koog.agents.snapshot.feature.withPersistence --> ```kotlin suspend fun example(context: AIAgentContext) { // Access the checkpoint feature - val checkpointFeature = context.persistency() + val checkpointFeature = context.persistence() // Or perform an action with the checkpoint feature - context.withPersistency { ctx -> + context.withPersistence { ctx -> // 'this' is the checkpoint feature createCheckpoint( agentContext = ctx, @@ -297,11 +297,11 @@ suspend fun example(context: AIAgentContext) { ### Custom storage providers -You can implement custom storage providers by implementing the `PersistencyStorageProvider` interface: +You can implement custom storage providers by implementing the `PersistenceStorageProvider` interface: ```kotlin -class MyCustomStorageProvider : PersistencyStorageProvider { +class MyCustomStorageProvider : PersistenceStorageProvider { override suspend fun getCheckpoints(agentId: String): List { // Implementation } @@ -333,12 +333,12 @@ feature in your agent. ```kotlin -install(Persistency) { +install(Persistence) { storage = MyCustomStorageProvider() } ``` @@ -375,7 +375,7 @@ For advanced control, you can directly set the execution point of an agent: ## Motivation and Context ## Breaking Changes --- #### Type of the changes - [ ] New feature (non-breaking change which adds functionality) - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update - [ ] Tests improvement - [x] Refactoring #### Checklist - [ ] The pull request has a description of the proposed change - [ ] I read the [Contributing Guidelines](https://github.com/JetBrains/koog/blob/main/CONTRIBUTING.md) before opening the pull request - [ ] The pull request uses **`develop`** as the base branch - [ ] Tests for the changes have been added - [ ] All new and existing tests passed ##### Additional steps for pull requests adding a new feature - [ ] An issue describing the proposed change exists - [ ] The pull request includes a link to the issue - [ ] The change was discussed and approved in the issue - [ ] Docs have been added / updated --------- Co-authored-by: Konstantin Pavlov <1517853+kpavlov@users.noreply.github.com> Co-authored-by: Vadim Briliantov --- .../core/agent/context/AIAgentLLMContext.kt | 2 +- .../agents/core/feature/AIAgentPipeline.kt | 3 +- .../handler/strategy/StrategyEventContext.kt | 1 + .../agents-features-snapshot/Module.md | 2 +- .../agents-features-snapshot/README.md | 2 +- .../agents/snapshot/feature/Persistency.kt | 24 +++++----- .../InMemoryPersistencyStorageProvider.kt | 13 +++--- .../providers/NoPersistencyStorageProvider.kt | 5 ++- .../providers/PersistencyStorageProvider.kt | 6 +-- .../file/FilePersistencyStorageProvider.kt | 21 +++++---- .../file/JVMFilePersistencyStorageProvider.kt | 2 - .../src/jvmTest/kotlin/CheckpointsTests.kt | 18 ++++---- .../kotlin/NodeUniquenessCheckpointTest.kt | 2 +- .../kotlin/PersistencyRestoreStrategyTests.kt | 9 ++-- .../kotlin/PersistencyRunsTwiceTest.kt | 15 ++++--- .../kotlin/SimpleGraphCheckpointTest.kt | 10 ++--- .../jvmTest/kotlin/SubgraphCheckpointsTest.kt | 8 ++-- .../kotlin/SubgraphSetExecutionPointTest.kt | 12 ++--- .../src/jvmTest/kotlin/TestStrategies.kt | 2 +- .../FileAgentCheckpointStorageProviderTest.kt | 15 ++++--- .../providers/file/FileCheckpointsTests.kt | 12 ++--- .../SQLPersistencyStorageProvider.kt | 10 ++--- .../ExposedPersistencyStorageProvider.kt | 27 +++++------- .../providers/H2PersistencyStorageProvider.kt | 12 +---- .../MySQLPersistencyStorageProvider.kt | 3 +- .../PostgresPersistencyStorageProvider.kt | 3 +- .../H2PersistencyStorageProviderTest.kt | 41 ++++++++--------- .../MySQLPersistencyStorageProviderTest.kt | 41 ++++++++--------- .../PostgresPersistencyStorageProviderTest.kt | 41 ++++++++--------- .../providers/SQLPersistenceProvidersTest.kt | 44 +++++++++---------- docs/docs/agent-persistence.md | 12 ++--- examples/simple-examples/gradle.properties | 2 + .../example/snapshot/CheckpointExample.kt | 11 +++-- .../snapshot/FilePersistentAgentExample.kt | 6 +-- .../example/snapshot/SnapshotExample.kt | 2 +- .../snapshot/sql/SQLPersistentAgentExample.kt | 29 ++++++------ .../tests/agent/AIAgentIntegrationTest.kt | 15 +++---- 37 files changed, 237 insertions(+), 246 deletions(-) diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/context/AIAgentLLMContext.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/context/AIAgentLLMContext.kt index 3215e57931..07f3501c56 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/context/AIAgentLLMContext.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/agent/context/AIAgentLLMContext.kt @@ -50,7 +50,7 @@ public annotation class DetachedPromptExecutorAPI */ public class AIAgentLLMContext( tools: List, - public val toolRegistry: ToolRegistry = ToolRegistry.Companion.EMPTY, + public val toolRegistry: ToolRegistry = ToolRegistry.EMPTY, prompt: Prompt, model: LLModel, @property:DetachedPromptExecutorAPI diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/AIAgentPipeline.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/AIAgentPipeline.kt index 365e6d8d93..e12770e18c 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/AIAgentPipeline.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/AIAgentPipeline.kt @@ -313,7 +313,8 @@ public abstract class AIAgentPipeline(public val clock: Clock) { strategy = strategy, feature = handler.feature, result = result, - resultType = resultType + resultType = resultType, + agentId = context.agentId ) handler.handleStrategyCompletedUnsafe(eventContext) } diff --git a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/strategy/StrategyEventContext.kt b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/strategy/StrategyEventContext.kt index 92dbeb19b8..a9b557925f 100644 --- a/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/strategy/StrategyEventContext.kt +++ b/agents/agents-core/src/commonMain/kotlin/ai/koog/agents/core/feature/handler/strategy/StrategyEventContext.kt @@ -46,6 +46,7 @@ public class StrategyCompletedContext( public val feature: TFeature, public val result: Any?, public val resultType: KType, + public val agentId: String ) : StrategyEventContext { override val eventType: AgentLifecycleEventType = AgentLifecycleEventType.StrategyCompleted } diff --git a/agents/agents-features/agents-features-snapshot/Module.md b/agents/agents-features/agents-features-snapshot/Module.md index 3aa13ecaba..e8ed8d8637 100644 --- a/agents/agents-features/agents-features-snapshot/Module.md +++ b/agents/agents-features/agents-features-snapshot/Module.md @@ -35,7 +35,7 @@ val agent = AIAgent( ) { install(Persistence) { // Configure the storage provider - storage = InMemoryPersistenceStorageProvider("agent-persistence-id") + storage = InMemoryPersistenceStorageProvider() // Optional: enable automatic checkpoint creation after each node enableAutomaticPersistence = true diff --git a/agents/agents-features/agents-features-snapshot/README.md b/agents/agents-features/agents-features-snapshot/README.md index f06662292c..5a3db5d998 100644 --- a/agents/agents-features/agents-features-snapshot/README.md +++ b/agents/agents-features/agents-features-snapshot/README.md @@ -146,7 +146,7 @@ class MyCustomStorageProvider : AgentCheckpointStorageProvider { // Implementation } - override suspend fun saveCheckpoint(agentCheckpointData: AgentCheckpointData) { + override suspend fun saveCheckpoint(agentId: String, agentCheckpointData: AgentCheckpointData) { // Implementation } diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistency.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistency.kt index 57b9512013..73b3b237bf 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistency.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/feature/Persistency.kt @@ -180,7 +180,7 @@ public class Persistence( pipeline.interceptStrategyCompleted(interceptContext) { ctx -> if (config.enableAutomaticPersistence && config.rollbackStrategy == RollbackStrategy.Default) { - ctx.feature.createTombstoneCheckpoint(ctx.feature.clock.now()) + ctx.feature.createTombstoneCheckpoint(ctx.agentId, ctx.feature.clock.now()) } } } @@ -228,7 +228,7 @@ public class Persistence( ) } - saveCheckpoint(checkpoint) + saveCheckpoint(agentContext.agentId, checkpoint) return checkpoint } @@ -242,9 +242,9 @@ public class Persistence( * @return The created tombstone checkpoint data. */ @InternalAgentsApi - public suspend fun createTombstoneCheckpoint(time: Instant): AgentCheckpointData { + public suspend fun createTombstoneCheckpoint(agentId: String, time: Instant): AgentCheckpointData { val checkpoint = tombstoneCheckpoint(time) - saveCheckpoint(checkpoint) + saveCheckpoint(agentId, checkpoint) return checkpoint } @@ -261,8 +261,8 @@ public class Persistence( * * @param checkpointData The checkpoint data to save */ - public suspend fun saveCheckpoint(checkpointData: AgentCheckpointData) { - persistenceStorageProvider.saveCheckpoint(checkpointData) + public suspend fun saveCheckpoint(agentId: String, checkpointData: AgentCheckpointData) { + persistenceStorageProvider.saveCheckpoint(agentId, checkpointData) } /** @@ -270,8 +270,8 @@ public class Persistence( * * @return The latest checkpoint data, or null if no checkpoint exists */ - public suspend fun getLatestCheckpoint(): AgentCheckpointData? = - persistenceStorageProvider.getLatestCheckpoint() + public suspend fun getLatestCheckpoint(agentId: String): AgentCheckpointData? = + persistenceStorageProvider.getLatestCheckpoint(agentId) /** * Retrieves a specific checkpoint by ID for the specified agent. @@ -279,8 +279,8 @@ public class Persistence( * @param checkpointId The ID of the checkpoint to retrieve * @return The checkpoint data with the specified ID, or null if not found */ - public suspend fun getCheckpointById(checkpointId: String): AgentCheckpointData? = - persistenceStorageProvider.getCheckpoints().firstOrNull { it.checkpointId == checkpointId } + public suspend fun getCheckpointById(agentId: String, checkpointId: String): AgentCheckpointData? = + persistenceStorageProvider.getCheckpoints(agentId).firstOrNull { it.checkpointId == checkpointId } /** * Sets the execution point of an agent to a specific state. @@ -320,7 +320,7 @@ public class Persistence( checkpointId: String, agentContext: AIAgentContext ): AgentCheckpointData? { - val checkpoint: AgentCheckpointData? = getCheckpointById(checkpointId) + val checkpoint: AgentCheckpointData? = getCheckpointById(agentContext.agentId, checkpointId) if (checkpoint != null) { agentContext.store( checkpoint.toAgentContextData(rollbackStrategy) { context -> @@ -377,7 +377,7 @@ public class Persistence( public suspend fun rollbackToLatestCheckpoint( agentContext: AIAgentContext ): AgentCheckpointData? { - val checkpoint: AgentCheckpointData? = getLatestCheckpoint() + val checkpoint: AgentCheckpointData? = getLatestCheckpoint(agentContext.agentId) if (checkpoint?.isTombstone() ?: true) { return null } diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/InMemoryPersistencyStorageProvider.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/InMemoryPersistencyStorageProvider.kt index 7c6b8c5ae5..40095ad466 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/InMemoryPersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/InMemoryPersistencyStorageProvider.kt @@ -8,26 +8,25 @@ import kotlinx.coroutines.sync.withLock * In-memory implementation of [PersistenceStorageProvider]. * This provider stores snapshots in a mutable map. */ -public class InMemoryPersistenceStorageProvider(private val persistenceId: String) : PersistenceStorageProvider { +public class InMemoryPersistenceStorageProvider() : PersistenceStorageProvider { private val mutex = Mutex() private val snapshotMap = mutableMapOf>() - override suspend fun getCheckpoints(): List { + override suspend fun getCheckpoints(agentId: String): List { mutex.withLock { - return snapshotMap[persistenceId] ?: emptyList() + return snapshotMap[agentId] ?: emptyList() } } - override suspend fun saveCheckpoint(agentCheckpointData: AgentCheckpointData) { + override suspend fun saveCheckpoint(agentId: String, agentCheckpointData: AgentCheckpointData) { mutex.withLock { - val agentId = persistenceId snapshotMap[agentId] = (snapshotMap[agentId] ?: emptyList()) + agentCheckpointData } } - override suspend fun getLatestCheckpoint(): AgentCheckpointData? { + override suspend fun getLatestCheckpoint(agentId: String): AgentCheckpointData? { mutex.withLock { - return snapshotMap[persistenceId]?.maxBy { it.createdAt } + return snapshotMap[agentId]?.maxBy { it.createdAt } } } } diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/NoPersistencyStorageProvider.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/NoPersistencyStorageProvider.kt index 08670abf2a..898f350564 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/NoPersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/NoPersistencyStorageProvider.kt @@ -9,17 +9,18 @@ import io.github.oshai.kotlinlogging.KotlinLogging public class NoPersistenceStorageProvider : PersistenceStorageProvider { private val logger = KotlinLogging.logger { } - override suspend fun getCheckpoints(): List { + override suspend fun getCheckpoints(agentId: String): List { return emptyList() } override suspend fun saveCheckpoint( + agentId: String, agentCheckpointData: AgentCheckpointData ) { logger.info { "Snapshot feature is not enabled in the agent. Snapshot will not be saved: $agentCheckpointData" } } - override suspend fun getLatestCheckpoint(): AgentCheckpointData? { + override suspend fun getLatestCheckpoint(agentId: String): AgentCheckpointData? { return null } } diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistencyStorageProvider.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistencyStorageProvider.kt index c91ef7f527..64d8815c46 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/PersistencyStorageProvider.kt @@ -14,7 +14,7 @@ import ai.koog.agents.snapshot.feature.AgentCheckpointData public typealias PersistencyStorageProvider = PersistenceStorageProvider public interface PersistenceStorageProvider { - public suspend fun getCheckpoints(): List - public suspend fun saveCheckpoint(agentCheckpointData: AgentCheckpointData) - public suspend fun getLatestCheckpoint(): AgentCheckpointData? + public suspend fun getCheckpoints(agentId: String): List + public suspend fun saveCheckpoint(agentId: String, agentCheckpointData: AgentCheckpointData) + public suspend fun getLatestCheckpoint(agentId: String): AgentCheckpointData? } diff --git a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/file/FilePersistencyStorageProvider.kt b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/file/FilePersistencyStorageProvider.kt index a802a47327..a8c0d099ea 100644 --- a/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/file/FilePersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-snapshot/src/commonMain/kotlin/ai/koog/agents/snapshot/providers/file/FilePersistencyStorageProvider.kt @@ -29,7 +29,6 @@ public typealias FilePersistencyStorageProvider = FilePersistenceStoragePr * @param root Root file path where the checkpoint storage will organize data. */ public open class FilePersistenceStorageProvider( - private val persistenceId: String, private val fs: FileSystemProvider.ReadWrite, private val root: Path, private val json: Json = PersistenceUtils.defaultCheckpointJson @@ -49,9 +48,9 @@ public open class FilePersistenceStorageProvider( /** * Directory for a specific agent's checkpoints */ - private suspend fun agentCheckpointsDir(): Path { + private suspend fun agentCheckpointsDir(agentId: String): Path { val checkpointsDir = checkpointsDir() - val agentDir = fs.joinPath(checkpointsDir, persistenceId) + val agentDir = fs.joinPath(checkpointsDir, agentId) if (!fs.exists(agentDir)) { fs.createDirectory(agentDir) } @@ -61,13 +60,13 @@ public open class FilePersistenceStorageProvider( /** * Get the path to a specific checkpoint file */ - private suspend fun checkpointPath(checkpointId: String): Path { - val agentDir = agentCheckpointsDir() + private suspend fun checkpointPath(agentId: String, checkpointId: String): Path { + val agentDir = agentCheckpointsDir(agentId) return fs.joinPath(agentDir, checkpointId) } - override suspend fun getCheckpoints(): List { - val agentDir = agentCheckpointsDir() + override suspend fun getCheckpoints(agentId: String): List { + val agentDir = agentCheckpointsDir(agentId) if (!fs.exists(agentDir)) { return emptyList() @@ -83,14 +82,14 @@ public open class FilePersistenceStorageProvider( } } - override suspend fun saveCheckpoint(agentCheckpointData: AgentCheckpointData) { - val checkpointPath = checkpointPath(agentCheckpointData.checkpointId) + override suspend fun saveCheckpoint(agentId: String, agentCheckpointData: AgentCheckpointData) { + val checkpointPath = checkpointPath(agentId, agentCheckpointData.checkpointId) val serialized = json.encodeToString(AgentCheckpointData.serializer(), agentCheckpointData) fs.writeText(checkpointPath, serialized) } - override suspend fun getLatestCheckpoint(): AgentCheckpointData? { - return getCheckpoints() + override suspend fun getLatestCheckpoint(agentId: String): AgentCheckpointData? { + return getCheckpoints(agentId) .maxByOrNull { it.createdAt } } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmMain/kotlin/ai/koog/agents/snapshot/providers/file/JVMFilePersistencyStorageProvider.kt b/agents/agents-features/agents-features-snapshot/src/jvmMain/kotlin/ai/koog/agents/snapshot/providers/file/JVMFilePersistencyStorageProvider.kt index c0cdd08f8d..bd00b4815c 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmMain/kotlin/ai/koog/agents/snapshot/providers/file/JVMFilePersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmMain/kotlin/ai/koog/agents/snapshot/providers/file/JVMFilePersistencyStorageProvider.kt @@ -30,11 +30,9 @@ public typealias JVMFilePersistencyStorageProvider = JVMFilePersistenceStoragePr */ public class JVMFilePersistenceStorageProvider( root: Path, - persistenceId: String, json: Json = PersistenceUtils.defaultCheckpointJson ) : FilePersistenceStorageProvider( fs = JVMFileSystemProvider.ReadWrite, root = root, - persistenceId = persistenceId, json = json ) diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/CheckpointsTests.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/CheckpointsTests.kt index 1ce94d6fe8..8c2c7a7e00 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/CheckpointsTests.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/CheckpointsTests.kt @@ -105,7 +105,7 @@ class CheckpointsTests { toolRegistry = toolRegistry ) { install(Persistence) { - storage = InMemoryPersistenceStorageProvider("testAgentId") + storage = InMemoryPersistenceStorageProvider() } } @@ -125,7 +125,7 @@ class CheckpointsTests { toolRegistry = toolRegistry ) { install(Persistence) { - storage = InMemoryPersistenceStorageProvider("testAgentId") + storage = InMemoryPersistenceStorageProvider() } } @@ -262,7 +262,7 @@ class CheckpointsTests { toolRegistry = toolRegistry ) { install(Persistence) { - storage = InMemoryPersistenceStorageProvider("testAgentId") + storage = InMemoryPersistenceStorageProvider() } } @@ -295,7 +295,7 @@ class CheckpointsTests { toolRegistry = localToolRegistry ) { install(Persistence) { - storage = InMemoryPersistenceStorageProvider("agent-tools-rollback-1") + storage = InMemoryPersistenceStorageProvider() rollbackToolRegistry = RollbackToolRegistry { registerRollback(WriteKVTool, DeleteKVTool) } @@ -347,7 +347,7 @@ class CheckpointsTests { @Test fun testRestoreFromSingleCheckpoint() = runTest { - val checkpointStorageProvider = InMemoryPersistenceStorageProvider("testAgentId") + val checkpointStorageProvider = InMemoryPersistenceStorageProvider() val time = Clock.System.now() val agentId = "testAgentId" @@ -362,7 +362,7 @@ class CheckpointsTests { ) ) - checkpointStorageProvider.saveCheckpoint(testCheckpoint) + checkpointStorageProvider.saveCheckpoint(agentId, testCheckpoint) val agent = AIAgent( promptExecutor = getMockExecutor { }, @@ -388,7 +388,7 @@ class CheckpointsTests { @Test fun testRestoreFromLatestCheckpoint() = runTest { - val checkpointStorageProvider = InMemoryPersistenceStorageProvider("testAgentId") + val checkpointStorageProvider = InMemoryPersistenceStorageProvider() val time = Clock.System.now() val agentId = "testAgentId" @@ -414,8 +414,8 @@ class CheckpointsTests { ) ) - checkpointStorageProvider.saveCheckpoint(testCheckpoint) - checkpointStorageProvider.saveCheckpoint(testCheckpoint2) + checkpointStorageProvider.saveCheckpoint(agentId, testCheckpoint) + checkpointStorageProvider.saveCheckpoint(agentId, testCheckpoint2) val agent = AIAgent( promptExecutor = getMockExecutor { }, diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/NodeUniquenessCheckpointTest.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/NodeUniquenessCheckpointTest.kt index a1cd44756e..8bd3803cd1 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/NodeUniquenessCheckpointTest.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/NodeUniquenessCheckpointTest.kt @@ -105,7 +105,7 @@ class NodeUniquenessCheckpointTest { ) { // Install the AgentCheckpoint feature install(Persistence) { - storage = InMemoryPersistenceStorageProvider("testAgentId") + storage = InMemoryPersistenceStorageProvider() } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRestoreStrategyTests.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRestoreStrategyTests.kt index beb716361f..2dde64cdbe 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRestoreStrategyTests.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRestoreStrategyTests.kt @@ -19,7 +19,9 @@ import org.junit.jupiter.api.Test class PersistenceRestoreStrategyTests { @Test fun `rollback Default resumes from checkpoint node`() = runTest { - val provider = InMemoryPersistenceStorageProvider("persistence-restore-default") + val provider = InMemoryPersistenceStorageProvider() + + val agentId = "persistency-restore-default" val checkpoint = AgentCheckpointData( checkpointId = "chk-1", @@ -29,7 +31,7 @@ class PersistenceRestoreStrategyTests { messageHistory = listOf(Message.Assistant("History Before", ResponseMetaInfo(Clock.System.now()))), ) - provider.saveCheckpoint(checkpoint) + provider.saveCheckpoint(agentId, checkpoint) val agent = AIAgent( promptExecutor = getMockExecutor { }, @@ -39,6 +41,7 @@ class PersistenceRestoreStrategyTests { model = OllamaModels.Meta.LLAMA_3_2, maxAgentIterations = 10 ), + id = agentId ) { install(Persistence) { storage = provider @@ -59,7 +62,7 @@ class PersistenceRestoreStrategyTests { @Test fun `rollback MessageHistoryOnly starts from beginning`() = runTest { - val provider = InMemoryPersistenceStorageProvider("persistence-restore-history-only") + val provider = InMemoryPersistenceStorageProvider() val agentService = AIAgentService( promptExecutor = getMockExecutor { }, diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRunsTwiceTest.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRunsTwiceTest.kt index 6c71b0e209..1c5a120399 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRunsTwiceTest.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/PersistencyRunsTwiceTest.kt @@ -17,7 +17,7 @@ class PersistenceRunsTwiceTest { @Test fun `agent runs to end and on second run starts from beginning again`() = runTest { // Arrange - val provider = InMemoryPersistenceStorageProvider("persistence-test-agent") + val provider = InMemoryPersistenceStorageProvider() val testCollector = TestAgentLogsCollector() @@ -39,6 +39,7 @@ class PersistenceRunsTwiceTest { } val firstAgent = agentService.createAgent(id = "SAME_ID") + val agentId1 = "SAME_ID" // Act: first run firstAgent.run("Start the test") @@ -53,11 +54,11 @@ class PersistenceRunsTwiceTest { await.until { runBlocking { - provider.getLatestCheckpoint()?.isTombstone() == true + provider.getLatestCheckpoint(agentId1)?.isTombstone() == true } } - val firstCheckpoint = provider.getLatestCheckpoint() + val firstCheckpoint = provider.getLatestCheckpoint(agentId1) val secondAgent = agentService.createAgent(id = "SAME_ID") @@ -67,7 +68,7 @@ class PersistenceRunsTwiceTest { // And still ends with a tombstone as the latest checkpoint await.until { runBlocking { - val latest2 = provider.getLatestCheckpoint() + val latest2 = provider.getLatestCheckpoint(agentId1) latest2?.isTombstone() == true latest2 != firstCheckpoint } @@ -76,7 +77,7 @@ class PersistenceRunsTwiceTest { @Test fun `agent fails on the first run and second run running successfully`() = runTest { - val provider = InMemoryPersistenceStorageProvider("persistence-test-agent") + val provider = InMemoryPersistenceStorageProvider() val testCollector = TestAgentLogsCollector() @@ -112,7 +113,7 @@ class PersistenceRunsTwiceTest { await.until { runBlocking { - provider.getCheckpoints().size == 2 + provider.getCheckpoints(agentId).size == 2 } } @@ -131,7 +132,7 @@ class PersistenceRunsTwiceTest { await.until { runBlocking { - provider.getCheckpoints().filter { !it.isTombstone() }.size == 4 + provider.getCheckpoints(agentId).filter { !it.isTombstone() }.size == 4 } } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SimpleGraphCheckpointTest.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SimpleGraphCheckpointTest.kt index 4de2dfce04..70e09df945 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SimpleGraphCheckpointTest.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SimpleGraphCheckpointTest.kt @@ -52,7 +52,7 @@ class SimpleGraphCheckpointTest { toolRegistry = toolRegistry ) { install(Persistence) { - storage = InMemoryPersistenceStorageProvider("testAgentId") + storage = InMemoryPersistenceStorageProvider() } } @@ -78,7 +78,7 @@ class SimpleGraphCheckpointTest { @Test fun `test agent creates and saves checkpoints`() = runTest { // Create a snapshot provider to store checkpoints - val checkpointStorageProvider = InMemoryPersistenceStorageProvider("testAgentId") + val checkpointStorageProvider = InMemoryPersistenceStorageProvider() // Create a mock executor for testing val mockExecutor: PromptExecutor = getMockExecutor { @@ -115,14 +115,14 @@ class SimpleGraphCheckpointTest { agent.run("Start the test") // Verify that a checkpoint was created and saved - val checkpoint = checkpointStorageProvider.getCheckpoints().firstOrNull() + val checkpoint = checkpointStorageProvider.getCheckpoints(agent.id).firstOrNull() assertNotNull(checkpoint, "No checkpoint was created") assertEquals("checkpointNode", checkpoint?.nodeId, "Checkpoint has incorrect node ID") } @Test fun test_checkpoint_persists_history() = runTest { - val checkpointStorageProvider = InMemoryPersistenceStorageProvider("testAgentId") + val checkpointStorageProvider = InMemoryPersistenceStorageProvider() val mockExecutor: PromptExecutor = getMockExecutor { // No specific mock responses needed for this test @@ -157,7 +157,7 @@ class SimpleGraphCheckpointTest { agent.run("Start the test") // Verify that a checkpoint was created and saved - val checkpoint = checkpointStorageProvider.getCheckpoints().firstOrNull() + val checkpoint = checkpointStorageProvider.getCheckpoints(agent.id).firstOrNull() if (checkpoint == null) { error("checkpoint is null") } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphCheckpointsTest.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphCheckpointsTest.kt index 54dbf84b75..173ac34671 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphCheckpointsTest.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphCheckpointsTest.kt @@ -37,7 +37,7 @@ class SubgraphCheckpointsTest { toolRegistry = toolRegistry ) { install(Persistence) { - storage = InMemoryPersistenceStorageProvider("testAgentId") + storage = InMemoryPersistenceStorageProvider() } } @@ -64,7 +64,7 @@ class SubgraphCheckpointsTest { toolRegistry = toolRegistry ) { install(Persistence) { - storage = InMemoryPersistenceStorageProvider("testAgentId") + storage = InMemoryPersistenceStorageProvider() } } @@ -91,7 +91,7 @@ class SubgraphCheckpointsTest { toolRegistry = toolRegistry ) { install(Persistence) { - storage = InMemoryPersistenceStorageProvider("testAgentId") + storage = InMemoryPersistenceStorageProvider() } } @@ -119,7 +119,7 @@ class SubgraphCheckpointsTest { toolRegistry = toolRegistry ) { install(Persistence) { - storage = InMemoryPersistenceStorageProvider("testAgentId") + storage = InMemoryPersistenceStorageProvider() } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphSetExecutionPointTest.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphSetExecutionPointTest.kt index c8eb5b6017..5ffdb66376 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphSetExecutionPointTest.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/SubgraphSetExecutionPointTest.kt @@ -33,7 +33,7 @@ class SubgraphSetExecutionPointTest { toolRegistry = toolRegistry ) { install(Persistence) { - storage = InMemoryPersistenceStorageProvider("testAgentId") + storage = InMemoryPersistenceStorageProvider() } } @@ -57,7 +57,7 @@ class SubgraphSetExecutionPointTest { toolRegistry = toolRegistry ) { install(Persistence) { - storage = InMemoryPersistenceStorageProvider("testAgentId") + storage = InMemoryPersistenceStorageProvider() } } @@ -85,7 +85,7 @@ class SubgraphSetExecutionPointTest { toolRegistry = toolRegistry ) { install(Persistence) { - storage = InMemoryPersistenceStorageProvider("testAgentId") + storage = InMemoryPersistenceStorageProvider() } } @@ -110,7 +110,7 @@ class SubgraphSetExecutionPointTest { toolRegistry = toolRegistry ) { install(Persistence) { - storage = InMemoryPersistenceStorageProvider("testAgentId") + storage = InMemoryPersistenceStorageProvider() } } @@ -137,7 +137,7 @@ class SubgraphSetExecutionPointTest { toolRegistry = toolRegistry ) { install(Persistence) { - storage = InMemoryPersistenceStorageProvider("testAgentId") + storage = InMemoryPersistenceStorageProvider() } } @@ -164,7 +164,7 @@ class SubgraphSetExecutionPointTest { toolRegistry = toolRegistry ) { install(Persistence) { - storage = InMemoryPersistenceStorageProvider("testAgentId") + storage = InMemoryPersistenceStorageProvider() } } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/TestStrategies.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/TestStrategies.kt index 53f3ac6f84..c28aff1ebc 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/TestStrategies.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/TestStrategies.kt @@ -183,7 +183,7 @@ private fun AIAgentSubgraphBuilderBase<*, *>.nodeCreateCheckpoint( "snapshot-id" ) - saveCheckpoint(checkpoint ?: error("Checkpoint creation failed")) + saveCheckpoint(ctx.agentId, checkpoint ?: error("Checkpoint creation failed")) return@withPersistence "$input\nSnapshot created" } diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileAgentCheckpointStorageProviderTest.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileAgentCheckpointStorageProviderTest.kt index 3c5f87d3e3..9574974308 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileAgentCheckpointStorageProviderTest.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileAgentCheckpointStorageProviderTest.kt @@ -22,7 +22,7 @@ class FileAgentCheckpointStorageProviderTest { @BeforeTest fun setup() { tempDir = Files.createTempDirectory("checkpoint-test") - provider = JVMFilePersistenceStorageProvider(tempDir, "testAgentId") + provider = JVMFilePersistenceStorageProvider(tempDir) } @AfterTest @@ -54,11 +54,12 @@ class FileAgentCheckpointStorageProviderTest { messageHistory = messageHistory ) + val agentId = "testAgentId" // Save the checkpoint - provider.saveCheckpoint(checkpoint) + provider.saveCheckpoint(agentId, checkpoint) // Retrieve all checkpoints for the agent - val checkpoints = provider.getCheckpoints() + val checkpoints = provider.getCheckpoints(agentId) assertEquals(1, checkpoints.size, "Should have one checkpoint") // Verify the retrieved checkpoint @@ -80,7 +81,7 @@ class FileAgentCheckpointStorageProviderTest { assertEquals(originalAssistantMsg.content, retrievedAssistantMsg.content) // Test getLatestCheckpoint - val latestCheckpoint = provider.getLatestCheckpoint() + val latestCheckpoint = provider.getLatestCheckpoint(agentId) assertNotNull(latestCheckpoint, "Latest checkpoint should not be null") assertEquals(checkpointId, latestCheckpoint.checkpointId) @@ -96,15 +97,15 @@ class FileAgentCheckpointStorageProviderTest { ) // Save the later checkpoint - provider.saveCheckpoint(laterCheckpoint) + provider.saveCheckpoint(agentId, laterCheckpoint) // Verify that getLatestCheckpoint returns the later checkpoint - val newLatestCheckpoint = provider.getLatestCheckpoint() + val newLatestCheckpoint = provider.getLatestCheckpoint(agentId) assertNotNull(newLatestCheckpoint, "New latest checkpoint should not be null") assertEquals(laterCheckpointId, newLatestCheckpoint.checkpointId) // Verify that getCheckpoints returns both checkpoints - val allCheckpoints = provider.getCheckpoints() + val allCheckpoints = provider.getCheckpoints(agentId) assertEquals(2, allCheckpoints.size, "Should have two checkpoints") assertTrue(allCheckpoints.any { it.checkpointId == checkpointId }) assertTrue(allCheckpoints.any { it.checkpointId == laterCheckpointId }) diff --git a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileCheckpointsTests.kt b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileCheckpointsTests.kt index 7c7ef48b3d..fc531a529d 100644 --- a/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileCheckpointsTests.kt +++ b/agents/agents-features/agents-features-snapshot/src/jvmTest/kotlin/ai/koog/agents/snapshot/providers/file/FileCheckpointsTests.kt @@ -49,7 +49,7 @@ class FileCheckpointsTests { @BeforeTest fun setup() { tempDir = Files.createTempDirectory("agent-checkpoint-test") - provider = JVMFilePersistenceStorageProvider(tempDir, "testAgentId") + provider = JVMFilePersistenceStorageProvider(tempDir) } @AfterTest @@ -86,7 +86,7 @@ class FileCheckpointsTests { ) // Verify that the checkpoint was saved to the file system - val checkpoints = provider.getCheckpoints().filter { !it.isTombstone() } + val checkpoints = provider.getCheckpoints(agentId).filter { !it.isTombstone() } assertEquals(1, checkpoints.size, "Should have one checkpoint") assertEquals("checkpointId", checkpoints.first().checkpointId) } @@ -129,7 +129,7 @@ class FileCheckpointsTests { ) ) - provider.saveCheckpoint(testCheckpoint) + provider.saveCheckpoint(agentId, testCheckpoint) val agent = AIAgent( promptExecutor = getMockExecutor { }, @@ -180,8 +180,8 @@ class FileCheckpointsTests { ) ) - provider.saveCheckpoint(testCheckpoint) - provider.saveCheckpoint(testCheckpoint2) + provider.saveCheckpoint(agentId, testCheckpoint) + provider.saveCheckpoint(agentId, testCheckpoint2) val agent = AIAgent( promptExecutor = getMockExecutor { }, @@ -226,7 +226,7 @@ class FileCheckpointsTests { agent.run("Start the test") // Verify that checkpoints were automatically created - val checkpoints = provider.getCheckpoints() + val checkpoints = provider.getCheckpoints(agentId) assertTrue(checkpoints.isNotEmpty(), "Should have automatically created checkpoints") } } diff --git a/agents/agents-features/agents-features-sql/src/commonMain/kotlin/ai/koog/agents/features/sql/providers/SQLPersistencyStorageProvider.kt b/agents/agents-features/agents-features-sql/src/commonMain/kotlin/ai/koog/agents/features/sql/providers/SQLPersistencyStorageProvider.kt index b8eb3e3db6..1d2cb862c5 100644 --- a/agents/agents-features/agents-features-sql/src/commonMain/kotlin/ai/koog/agents/features/sql/providers/SQLPersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-sql/src/commonMain/kotlin/ai/koog/agents/features/sql/providers/SQLPersistencyStorageProvider.kt @@ -28,12 +28,10 @@ import kotlinx.datetime.Instant * Implementations must ensure thread-safe database access, typically through connection pooling. * * @constructor Initializes the SQL persistence provider. - * @param persistenceId Unique identifier for this agent's persistence data * @param tableName Name of the table to store checkpoints (default: "agent_checkpoints") * @param ttlSeconds Optional TTL for checkpoint entries in seconds (null = no expiration) */ public abstract class SQLPersistenceStorageProvider( - protected val persistenceId: String, protected val tableName: String = "agent_checkpoints", protected val ttlSeconds: Long? = null, protected val migrator: SQLPersistenceSchemaMigrator @@ -71,15 +69,15 @@ public abstract class SQLPersistenceStorageProvider( /** * Deletes a specific checkpoint by ID */ - public abstract suspend fun deleteCheckpoint(checkpointId: String) + public abstract suspend fun deleteCheckpoint(agentId: String, checkpointId: String) /** - * Deletes all checkpoints for this persistence ID + * Deletes all checkpoints for this agent ID */ - public abstract suspend fun deleteAllCheckpoints() + public abstract suspend fun deleteAllCheckpoints(agentId: String) /** * Gets the total number of checkpoints stored */ - public abstract suspend fun getCheckpointCount(): Long + public abstract suspend fun getCheckpointCount(agentId: String): Long } diff --git a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/ExposedPersistencyStorageProvider.kt b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/ExposedPersistencyStorageProvider.kt index d1d5c38315..c80134c0e1 100644 --- a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/ExposedPersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/ExposedPersistencyStorageProvider.kt @@ -107,21 +107,18 @@ public open class CheckpointsTable(tableName: String) : Table(tableName) { * - When TTL is not configured (ttlSeconds = null), no TTL processing occurs * * @constructor Initializes the Exposed persistence provider. - * @param persistenceId Unique identifier for this agent's persistence data * @param database The Exposed Database instance to use * @param tableName Name of the table to store checkpoints (default: "agent_checkpoints") * @param ttlSeconds Optional TTL for checkpoint entries in seconds (null = no expiration) */ @Suppress("MissingKDocForPublicAPI") public abstract class ExposedPersistenceStorageProvider( - persistenceId: String, protected val database: Database, tableName: String = "agent_checkpoints", ttlSeconds: Long? = null, migrator: SQLPersistenceSchemaMigrator, private val json: Json = PersistenceUtils.defaultCheckpointJson ) : SQLPersistenceStorageProvider( - persistenceId = persistenceId, tableName = tableName, ttlSeconds = ttlSeconds, migrator @@ -179,12 +176,12 @@ public abstract class ExposedPersistenceStorageProvider( } } - override suspend fun getCheckpoints(): List { + override suspend fun getCheckpoints(agentId: String): List { return transaction { checkpointsTable .select(checkpointsTable.checkpointJson) .where { - checkpointsTable.persistenceId eq this@ExposedPersistenceStorageProvider.persistenceId + checkpointsTable.persistenceId eq agentId } .orderBy(checkpointsTable.createdAt to SortOrder.ASC) .mapNotNull { row -> @@ -195,14 +192,14 @@ public abstract class ExposedPersistenceStorageProvider( } } - override suspend fun saveCheckpoint(agentCheckpointData: AgentCheckpointData) { + override suspend fun saveCheckpoint(agentId: String, agentCheckpointData: AgentCheckpointData) { val checkpointJson = json.encodeToString(agentCheckpointData) val ttlTimestamp = calculateTtlTimestamp(agentCheckpointData.createdAt) transaction { // Use upsert for idempotent saves checkpointsTable.upsert { - it[checkpointsTable.persistenceId] = this@ExposedPersistenceStorageProvider.persistenceId + it[checkpointsTable.persistenceId] = agentId it[checkpointsTable.checkpointId] = agentCheckpointData.checkpointId it[checkpointsTable.createdAt] = agentCheckpointData.createdAt.toEpochMilliseconds() it[checkpointsTable.checkpointJson] = checkpointJson @@ -211,12 +208,12 @@ public abstract class ExposedPersistenceStorageProvider( } } - override suspend fun getLatestCheckpoint(): AgentCheckpointData? { + override suspend fun getLatestCheckpoint(agentId: String): AgentCheckpointData? { return transaction { checkpointsTable .select(checkpointsTable.checkpointJson) .where { - checkpointsTable.persistenceId eq this@ExposedPersistenceStorageProvider.persistenceId + checkpointsTable.persistenceId eq agentId } .orderBy(checkpointsTable.createdAt to SortOrder.DESC) .limit(1) @@ -228,27 +225,27 @@ public abstract class ExposedPersistenceStorageProvider( } } - override suspend fun deleteCheckpoint(checkpointId: String) { + override suspend fun deleteCheckpoint(agentId: String, checkpointId: String) { transaction { checkpointsTable.deleteWhere { - (checkpointsTable.persistenceId eq this@ExposedPersistenceStorageProvider.persistenceId) and + (checkpointsTable.persistenceId eq agentId) and (checkpointsTable.checkpointId eq checkpointId) } } } - override suspend fun deleteAllCheckpoints() { + override suspend fun deleteAllCheckpoints(agentId: String) { transaction { checkpointsTable.deleteWhere { - checkpointsTable.persistenceId eq this@ExposedPersistenceStorageProvider.persistenceId + checkpointsTable.persistenceId eq agentId } } } - override suspend fun getCheckpointCount(): Long { + override suspend fun getCheckpointCount(agentId: String): Long { return transaction { checkpointsTable.selectAll().where { - checkpointsTable.persistenceId eq this@ExposedPersistenceStorageProvider.persistenceId + checkpointsTable.persistenceId eq agentId }.count() } } diff --git a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProvider.kt b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProvider.kt index 30833681e5..647ac9b9ec 100644 --- a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProvider.kt @@ -12,13 +12,12 @@ import org.jetbrains.exposed.sql.transactions.transaction * agent checkpoints in H2 databases. */ public class H2PersistenceStorageProvider( - persistenceId: String, database: Database, tableName: String = "agent_checkpoints", ttlSeconds: Long? = null, migrator: SQLPersistenceSchemaMigrator = H2PersistenceSchemaMigrator(database, tableName), json: Json = PersistenceUtils.defaultCheckpointJson -) : ExposedPersistenceStorageProvider(persistenceId, database, tableName, ttlSeconds, migrator, json) { +) : ExposedPersistenceStorageProvider(database, tableName, ttlSeconds, migrator, json) { public override suspend fun transaction(block: suspend () -> T): T = newSuspendedTransaction(Dispatchers.IO, database) { @@ -32,20 +31,17 @@ public class H2PersistenceStorageProvider( * Data is lost when the JVM shuts down. * Perfect for testing and temporary caching. * - * @param persistenceId Unique identifier for this agent's persistence data * @param databaseName Name of the in-memory database * @param options Additional H2 options (e.g., "DB_CLOSE_DELAY=-1") * @param tableName Name of the table to store checkpoints * @param ttlSeconds Optional TTL for checkpoint entries in seconds */ public fun inMemory( - persistenceId: String, databaseName: String = "test", options: String = "DB_CLOSE_DELAY=-1", tableName: String = "agent_checkpoints", ttlSeconds: Long? = null ): H2PersistenceStorageProvider = H2PersistenceStorageProvider( - persistenceId = persistenceId, database = Database.connect("jdbc:h2:mem:$databaseName;$options"), tableName = tableName, ttlSeconds = ttlSeconds @@ -56,20 +52,17 @@ public class H2PersistenceStorageProvider( * Data is persisted to a file on disk. * Good balance between performance and persistence. * - * @param persistenceId Unique identifier for this agent's persistence data * @param filePath Path to the database file (without .mv.db extension) * @param options Additional H2 options * @param tableName Name of the table to store checkpoints * @param ttlSeconds Optional TTL for checkpoint entries in seconds */ public fun fileBased( - persistenceId: String, filePath: String, options: String = "", tableName: String = "agent_checkpoints", ttlSeconds: Long? = null ): H2PersistenceStorageProvider = H2PersistenceStorageProvider( - persistenceId = persistenceId, database = Database.connect( if (options.isNotEmpty()) { "jdbc:h2:file:$filePath;$options" @@ -85,19 +78,16 @@ public class H2PersistenceStorageProvider( * Creates an H2 provider with PostgreSQL compatibility mode. * Useful when migrating from PostgreSQL or for compatibility testing. * - * @param persistenceId Unique identifier for this agent's persistence data * @param databasePath Path to database (memory or file) * @param tableName Name of the table to store checkpoints * @param ttlSeconds Optional TTL for checkpoint entries in seconds */ public fun postgresCompatible( - persistenceId: String, databasePath: String = "mem:test", tableName: String = "agent_checkpoints", ttlSeconds: Long? = null ): H2PersistenceStorageProvider { return H2PersistenceStorageProvider( - persistenceId = persistenceId, database = Database.connect("jdbc:h2:$databasePath;MODE=PostgreSQL;DATABASE_TO_LOWER=TRUE"), tableName = tableName, ttlSeconds = ttlSeconds diff --git a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProvider.kt b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProvider.kt index 9a39443b7b..fe076a92e0 100644 --- a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProvider.kt @@ -37,13 +37,12 @@ import org.jetbrains.exposed.sql.transactions.transaction * @constructor Initializes the MySQL persistence provider with an Exposed Database instance. */ public class MySQLPersistenceStorageProvider( - persistenceId: String, database: Database, tableName: String = "agent_checkpoints", ttlSeconds: Long? = null, migrator: SQLPersistenceSchemaMigrator = MySqlPersistenceSchemaMigrator(database, tableName), json: Json = PersistenceUtils.defaultCheckpointJson -) : ExposedPersistenceStorageProvider(persistenceId, database, tableName, ttlSeconds, migrator, json) { +) : ExposedPersistenceStorageProvider(database, tableName, ttlSeconds, migrator, json) { override suspend fun transaction(block: suspend () -> T): T = newSuspendedTransaction(Dispatchers.IO, database) { diff --git a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProvider.kt b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProvider.kt index e743e78176..40190957c6 100644 --- a/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProvider.kt +++ b/agents/agents-features/agents-features-sql/src/jvmMain/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProvider.kt @@ -14,13 +14,12 @@ import org.jetbrains.exposed.sql.transactions.transaction * @constructor Initializes the PostgreSQL persistence provider with connection details. */ public class PostgresPersistenceStorageProvider( - persistenceId: String, database: Database, tableName: String = "agent_checkpoints", ttlSeconds: Long? = null, migrator: SQLPersistenceSchemaMigrator = PostgresPersistenceSchemaMigrator(database, tableName), json: Json = PersistenceUtils.defaultCheckpointJson -) : ExposedPersistenceStorageProvider(persistenceId, database, tableName, ttlSeconds, migrator, json) { +) : ExposedPersistenceStorageProvider(database, tableName, ttlSeconds, migrator, json) { override suspend fun transaction(block: suspend () -> T): T = newSuspendedTransaction(Dispatchers.IO, database) { block() } diff --git a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProviderTest.kt b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProviderTest.kt index 10c24d578e..1b22167eb1 100644 --- a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProviderTest.kt +++ b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/H2PersistencyStorageProviderTest.kt @@ -21,9 +21,10 @@ import kotlin.test.assertNull @ExtendWith(DockerAvailableCondition::class) class H2PersistenceStorageProviderTest { + private val agentId = "h2-agent" + private fun provider(ttlSeconds: Long? = null): H2PersistenceStorageProvider { return H2PersistenceStorageProvider.inMemory( - persistenceId = "h2-agent", databaseName = "h2_test_db", tableName = "agent_checkpoints_test", ttlSeconds = ttlSeconds @@ -36,38 +37,38 @@ class H2PersistenceStorageProviderTest { p.migrate() // empty - assertNull(p.getLatestCheckpoint()) - assertEquals(0, p.getCheckpointCount()) + assertNull(p.getLatestCheckpoint(agentId)) + assertEquals(0, p.getCheckpointCount(agentId)) // save val cp1 = createTestCheckpoint("cp-1") - p.saveCheckpoint(cp1) + p.saveCheckpoint(agentId, cp1) // read - val latest1 = p.getLatestCheckpoint() + val latest1 = p.getLatestCheckpoint(agentId) assertNotNull(latest1) assertEquals("cp-1", latest1.checkpointId) - assertEquals(1, p.getCheckpoints().size) - assertEquals(1, p.getCheckpointCount()) + assertEquals(1, p.getCheckpoints(agentId).size) + assertEquals(1, p.getCheckpointCount(agentId)) // upsert same id should be idempotent (no duplicates due PK) - p.saveCheckpoint(cp1) - assertEquals(1, p.getCheckpoints().size) + p.saveCheckpoint(agentId, cp1) + assertEquals(1, p.getCheckpoints(agentId).size) // insert second val cp2 = createTestCheckpoint("cp-2") - p.saveCheckpoint(cp2) - val all = p.getCheckpoints() + p.saveCheckpoint(agentId, cp2) + val all = p.getCheckpoints(agentId) assertEquals(listOf("cp-1", "cp-2"), all.map { it.checkpointId }) - assertEquals("cp-2", p.getLatestCheckpoint()!!.checkpointId) + assertEquals("cp-2", p.getLatestCheckpoint(agentId)!!.checkpointId) // delete single - p.deleteCheckpoint("cp-1") - assertEquals(listOf("cp-2"), p.getCheckpoints().map { it.checkpointId }) + p.deleteCheckpoint(agentId, "cp-1") + assertEquals(listOf("cp-2"), p.getCheckpoints(agentId).map { it.checkpointId }) // delete all - p.deleteAllCheckpoints() - assertEquals(0, p.getCheckpointCount()) + p.deleteAllCheckpoints(agentId) + assertEquals(0, p.getCheckpointCount(agentId)) } @Test @@ -75,16 +76,16 @@ class H2PersistenceStorageProviderTest { val p = provider(ttlSeconds = 1) p.migrate() - p.saveCheckpoint(createTestCheckpoint("will-expire")) - assertEquals(1, p.getCheckpointCount()) + p.saveCheckpoint(agentId, createTestCheckpoint("will-expire")) + assertEquals(1, p.getCheckpointCount(agentId)) // Wait slightly over 1s to ensure ttl passes delay(1100) // Force cleanup directly to avoid interval throttling p.cleanupExpired() - assertEquals(0, p.getCheckpointCount()) - assertNull(p.getLatestCheckpoint()) + assertEquals(0, p.getCheckpointCount(agentId)) + assertNull(p.getLatestCheckpoint(agentId)) } private fun createTestCheckpoint(id: String): AgentCheckpointData { diff --git a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProviderTest.kt b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProviderTest.kt index 8b8afeb4c2..b38623760c 100644 --- a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProviderTest.kt +++ b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/MySQLPersistencyStorageProviderTest.kt @@ -26,6 +26,8 @@ import kotlin.test.assertNull @ExtendWith(DockerAvailableCondition::class) class MySQLPersistenceStorageProviderTest { + private val agentId = "mysql-agent" + private lateinit var mysql: MySQLContainer<*> @BeforeAll @@ -50,7 +52,6 @@ class MySQLPersistenceStorageProviderTest { password = mysql.password ) return MySQLPersistenceStorageProvider( - persistenceId = "mysql-agent", database = db, tableName = "agent_checkpoints_test", ttlSeconds = ttlSeconds @@ -63,38 +64,38 @@ class MySQLPersistenceStorageProviderTest { p.migrate() // empty - assertNull(p.getLatestCheckpoint()) - assertEquals(0, p.getCheckpointCount()) + assertNull(p.getLatestCheckpoint(agentId)) + assertEquals(0, p.getCheckpointCount(agentId)) // save val cp1 = createTestCheckpoint("cp-1") - p.saveCheckpoint(cp1) + p.saveCheckpoint(agentId, cp1) // read - val latest1 = p.getLatestCheckpoint() + val latest1 = p.getLatestCheckpoint(agentId) assertNotNull(latest1) assertEquals("cp-1", latest1.checkpointId) - assertEquals(1, p.getCheckpoints().size) - assertEquals(1, p.getCheckpointCount()) + assertEquals(1, p.getCheckpoints(agentId).size) + assertEquals(1, p.getCheckpointCount(agentId)) // upsert same id should be idempotent (no duplicates due PK) - p.saveCheckpoint(cp1) - assertEquals(1, p.getCheckpoints().size) + p.saveCheckpoint(agentId, cp1) + assertEquals(1, p.getCheckpoints(agentId).size) // insert second val cp2 = createTestCheckpoint("cp-2") - p.saveCheckpoint(cp2) - val all = p.getCheckpoints() + p.saveCheckpoint(agentId, cp2) + val all = p.getCheckpoints(agentId) assertEquals(listOf("cp-1", "cp-2"), all.map { it.checkpointId }) - assertEquals("cp-2", p.getLatestCheckpoint()!!.checkpointId) + assertEquals("cp-2", p.getLatestCheckpoint(agentId)!!.checkpointId) // delete single - p.deleteCheckpoint("cp-1") - assertEquals(listOf("cp-2"), p.getCheckpoints().map { it.checkpointId }) + p.deleteCheckpoint(agentId, "cp-1") + assertEquals(listOf("cp-2"), p.getCheckpoints(agentId).map { it.checkpointId }) // delete all - p.deleteAllCheckpoints() - assertEquals(0, p.getCheckpointCount()) + p.deleteAllCheckpoints(agentId) + assertEquals(0, p.getCheckpointCount(agentId)) } @Test @@ -102,16 +103,16 @@ class MySQLPersistenceStorageProviderTest { val p = provider(ttlSeconds = 1) p.migrate() - p.saveCheckpoint(createTestCheckpoint("will-expire")) - assertEquals(1, p.getCheckpointCount()) + p.saveCheckpoint(agentId, createTestCheckpoint("will-expire")) + assertEquals(1, p.getCheckpointCount(agentId)) // Sleep slightly over 1s to ensure ttl passes delay(1100) // force cleanup p.cleanupExpired() - assertEquals(0, p.getCheckpointCount()) - assertNull(p.getLatestCheckpoint()) + assertEquals(0, p.getCheckpointCount(agentId)) + assertNull(p.getLatestCheckpoint(agentId)) } private fun createTestCheckpoint(id: String): AgentCheckpointData { diff --git a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProviderTest.kt b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProviderTest.kt index f9ab8099ba..a9f58713f6 100644 --- a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProviderTest.kt +++ b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/PostgresPersistencyStorageProviderTest.kt @@ -25,6 +25,8 @@ import kotlin.test.assertNull @ExtendWith(DockerAvailableCondition::class) class PostgresPersistenceStorageProviderTest { + private val agentId = "pg-agent" + private lateinit var postgres: PostgreSQLContainer<*> @BeforeAll @@ -49,7 +51,6 @@ class PostgresPersistenceStorageProviderTest { password = postgres.password ) return PostgresPersistenceStorageProvider( - persistenceId = "pg-agent", database = db, tableName = "agent_checkpoints_test", ttlSeconds = ttlSeconds @@ -62,38 +63,38 @@ class PostgresPersistenceStorageProviderTest { p.migrate() // empty - assertNull(p.getLatestCheckpoint()) - assertEquals(0, p.getCheckpointCount()) + assertNull(p.getLatestCheckpoint(agentId)) + assertEquals(0, p.getCheckpointCount(agentId)) // save val cp1 = createTestCheckpoint("cp-1") - p.saveCheckpoint(cp1) + p.saveCheckpoint(agentId, cp1) // read - val latest1 = p.getLatestCheckpoint() + val latest1 = p.getLatestCheckpoint(agentId) assertNotNull(latest1) assertEquals("cp-1", latest1.checkpointId) - assertEquals(1, p.getCheckpoints().size) - assertEquals(1, p.getCheckpointCount()) + assertEquals(1, p.getCheckpoints(agentId).size) + assertEquals(1, p.getCheckpointCount(agentId)) // upsert same id should be idempotent (no duplicates due PK) - p.saveCheckpoint(cp1) - assertEquals(1, p.getCheckpoints().size) + p.saveCheckpoint(agentId, cp1) + assertEquals(1, p.getCheckpoints(agentId).size) // insert second val cp2 = createTestCheckpoint("cp-2") - p.saveCheckpoint(cp2) - val all = p.getCheckpoints() + p.saveCheckpoint(agentId, cp2) + val all = p.getCheckpoints(agentId) assertEquals(listOf("cp-1", "cp-2"), all.map { it.checkpointId }) - assertEquals("cp-2", p.getLatestCheckpoint()!!.checkpointId) + assertEquals("cp-2", p.getLatestCheckpoint(agentId)!!.checkpointId) // delete single - p.deleteCheckpoint("cp-1") - assertEquals(listOf("cp-2"), p.getCheckpoints().map { it.checkpointId }) + p.deleteCheckpoint(agentId, "cp-1") + assertEquals(listOf("cp-2"), p.getCheckpoints(agentId).map { it.checkpointId }) // delete all - p.deleteAllCheckpoints() - assertEquals(0, p.getCheckpointCount()) + p.deleteAllCheckpoints(agentId) + assertEquals(0, p.getCheckpointCount(agentId)) } @Test @@ -101,16 +102,16 @@ class PostgresPersistenceStorageProviderTest { val p = provider(ttlSeconds = 1) p.migrate() - p.saveCheckpoint(createTestCheckpoint("will-expire")) - assertEquals(1, p.getCheckpointCount()) + p.saveCheckpoint(agentId, createTestCheckpoint("will-expire")) + assertEquals(1, p.getCheckpointCount(agentId)) // Force cleanup by calling cleanupExpired directly to avoid time-based throttle // Sleep slightly over 1s to ensure ttl passes kotlinx.coroutines.delay(1100) p.cleanupExpired() - assertEquals(0, p.getCheckpointCount()) - assertNull(p.getLatestCheckpoint()) + assertEquals(0, p.getCheckpointCount(agentId)) + assertNull(p.getLatestCheckpoint(agentId)) } private fun createTestCheckpoint(id: String): AgentCheckpointData { diff --git a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/SQLPersistenceProvidersTest.kt b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/SQLPersistenceProvidersTest.kt index 29c13467f6..8c753b454c 100644 --- a/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/SQLPersistenceProvidersTest.kt +++ b/agents/agents-features/agents-features-sql/src/jvmTest/kotlin/ai/koog/agents/features/sql/providers/SQLPersistenceProvidersTest.kt @@ -24,8 +24,8 @@ class SQLPersistenceProvidersTest { @Test fun `test save and retrieve checkpoint`() = runBlocking { + val agentId = "test-agent" val provider = H2PersistenceStorageProvider.inMemory( - persistenceId = "test-agent", databaseName = "test_db" ) @@ -33,10 +33,10 @@ class SQLPersistenceProvidersTest { // Create and save checkpoint val checkpoint = createTestCheckpoint("test-1") - provider.saveCheckpoint(checkpoint) + provider.saveCheckpoint(agentId, checkpoint) // Retrieve and verify - val retrieved = provider.getLatestCheckpoint() + val retrieved = provider.getLatestCheckpoint(agentId) assertNotNull(retrieved) assertEquals(checkpoint.checkpointId, retrieved.checkpointId) assertEquals(checkpoint.nodeId, retrieved.nodeId) @@ -45,37 +45,38 @@ class SQLPersistenceProvidersTest { @Test fun `test multiple checkpoints and ordering`() = runBlocking { + val agentId = "test-agent" val provider = H2PersistenceStorageProvider.inMemory( - persistenceId = "test-agent", databaseName = "test_ordering" ) provider.migrate() // Save multiple checkpoints - provider.saveCheckpoint(createTestCheckpoint("checkpoint-1")) - provider.saveCheckpoint(createTestCheckpoint("checkpoint-2")) - provider.saveCheckpoint(createTestCheckpoint("checkpoint-3")) + provider.saveCheckpoint(agentId, createTestCheckpoint("checkpoint-1")) + provider.saveCheckpoint(agentId, createTestCheckpoint("checkpoint-2")) + provider.saveCheckpoint(agentId, createTestCheckpoint("checkpoint-3")) // Verify count and ordering - val allCheckpoints = provider.getCheckpoints() + val allCheckpoints = provider.getCheckpoints(agentId) assertEquals(3, allCheckpoints.size) assertEquals("checkpoint-1", allCheckpoints[0].checkpointId) assertEquals("checkpoint-3", allCheckpoints[2].checkpointId) // Verify latest - val latest = provider.getLatestCheckpoint() + val latest = provider.getLatestCheckpoint(agentId) assertEquals("checkpoint-3", latest?.checkpointId) } @Test fun `test persistence ID isolation`() = runBlocking { + val agentId = "test-agent" + val agentId2 = "test-agent2" + val provider1 = H2PersistenceStorageProvider.inMemory( - persistenceId = "agent-1", databaseName = "shared_db" ) val provider2 = H2PersistenceStorageProvider.inMemory( - persistenceId = "agent-2", databaseName = "shared_db" // Same database ) @@ -83,12 +84,12 @@ class SQLPersistenceProvidersTest { provider2.migrate() // Save to different agents - provider1.saveCheckpoint(createTestCheckpoint("agent1-data")) - provider2.saveCheckpoint(createTestCheckpoint("agent2-data")) + provider1.saveCheckpoint(agentId, createTestCheckpoint("agent1-data")) + provider2.saveCheckpoint(agentId2, createTestCheckpoint("agent2-data")) // Verify isolation - val agent1Checkpoints = provider1.getCheckpoints() - val agent2Checkpoints = provider2.getCheckpoints() + val agent1Checkpoints = provider1.getCheckpoints(agentId) + val agent2Checkpoints = provider2.getCheckpoints(agentId2) assertEquals(1, agent1Checkpoints.size) assertEquals(1, agent2Checkpoints.size) @@ -98,8 +99,9 @@ class SQLPersistenceProvidersTest { @Test fun `test TTL expiration`() = runBlocking { + val agentId = "test-agent" + val provider = H2PersistenceStorageProvider.inMemory( - persistenceId = "ttl-test", databaseName = "ttl_db", ttlSeconds = 1 // 1 second TTL ) @@ -107,16 +109,16 @@ class SQLPersistenceProvidersTest { provider.migrate() // Save checkpoint - provider.saveCheckpoint(createTestCheckpoint("expire-soon")) - assertEquals(1, provider.getCheckpointCount()) + provider.saveCheckpoint(agentId, createTestCheckpoint("expire-soon")) + assertEquals(1, provider.getCheckpointCount(agentId)) // Wait for expiration delay(1500) provider.conditionalCleanup() // Should be cleaned up on next operation - val afterExpiry = provider.getLatestCheckpoint() + val afterExpiry = provider.getLatestCheckpoint(agentId) assertNull(afterExpiry) - assertEquals(0, provider.getCheckpointCount()) + assertEquals(0, provider.getCheckpointCount(agentId)) } @Test @@ -127,7 +129,6 @@ class SQLPersistenceProvidersTest { // PostgreSQL assertNotNull( PostgresPersistenceStorageProvider( - persistenceId = "test", database = Database.connect( url = "jdbc:postgresql://localhost:5432/test", driver = "org.postgresql.Driver", @@ -140,7 +141,6 @@ class SQLPersistenceProvidersTest { // MySQL assertNotNull( MySQLPersistenceStorageProvider( - persistenceId = "test", database = Database.connect( url = "jdbc:mysql://localhost:3306/test", driver = "com.mysql.cj.jdbc.Driver", diff --git a/docs/docs/agent-persistence.md b/docs/docs/agent-persistence.md index 6638b4c2d3..184b2a5bd2 100644 --- a/docs/docs/agent-persistence.md +++ b/docs/docs/agent-persistence.md @@ -63,7 +63,7 @@ val agent = AIAgent( ) { install(Persistence) { // Use in-memory storage for snapshots - storage = InMemoryPersistenceStorageProvider("in-memory-storage") + storage = InMemoryPersistenceStorageProvider() // Enable automatic persistence enableAutomaticPersistence = true } @@ -101,7 +101,7 @@ val agent = AIAgent( ```kotlin install(Persistence) { - storage = InMemoryPersistenceStorageProvider("in-memory-storage") + storage = InMemoryPersistenceStorageProvider() } ``` @@ -315,7 +315,7 @@ class MyCustomStorageProvider : PersistenceStorageProvider { // Implementation } - override suspend fun saveCheckpoint(agentCheckpointData: AgentCheckpointData) { + override suspend fun saveCheckpoint(agentId: String, agentCheckpointData: AgentCheckpointData) { // Implementation } @@ -339,15 +339,15 @@ import ai.koog.prompt.executor.llms.all.simpleOllamaAIExecutor import ai.koog.prompt.llm.OllamaModels class MyCustomStorageProvider : PersistenceStorageProvider { - override suspend fun getCheckpoints(): List { + override suspend fun getCheckpoints(agentId: String): List { TODO("Not yet implemented") } - override suspend fun saveCheckpoint(agentCheckpointData: AgentCheckpointData) { + override suspend fun saveCheckpoint(agentId: String, agentCheckpointData: AgentCheckpointData) { TODO("Not yet implemented") } - override suspend fun getLatestCheckpoint(): AgentCheckpointData? { + override suspend fun getLatestCheckpoint(agentId: String): AgentCheckpointData? { TODO("Not yet implemented") } } diff --git a/examples/simple-examples/gradle.properties b/examples/simple-examples/gradle.properties index e360d6a58e..ab47dc19ae 100644 --- a/examples/simple-examples/gradle.properties +++ b/examples/simple-examples/gradle.properties @@ -1,8 +1,10 @@ #Kotlin kotlin.code.style=official kotlin.daemon.jvmargs=-Xmx4096M +kotlin.native.ignoreDisabledTargets=true #Gradle org.gradle.jvmargs=-Xmx4096M -Dfile.encoding=UTF-8 org.gradle.parallel=true org.gradle.caching=true + diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/CheckpointExample.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/CheckpointExample.kt index 46e43882d2..81191c4f9a 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/CheckpointExample.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/CheckpointExample.kt @@ -27,12 +27,11 @@ fun main() = runBlocking { tools(CalculatorTools().asTools()) } - val persistenceId = "snapshot-agent-example" + val agentId = "agent.1" - val snapshotProvider = InMemoryPersistenceStorageProvider( - persistenceId = persistenceId - ) + val snapshotProvider = InMemoryPersistenceStorageProvider() val agent = AIAgent( + id = agentId, promptExecutor = executor, llmModel = OllamaModels.Meta.LLAMA_3_2, strategy = singleRunStrategy(ToolCalls.SEQUENTIAL), @@ -61,17 +60,17 @@ fun main() = runBlocking { } } - val checkpoints = snapshotProvider.getCheckpoints() + val checkpoints = snapshotProvider.getCheckpoints(agentId) println("Snapshot provider state after first run: $checkpoints") val agent2 = AIAgent( + id = agent.id, promptExecutor = executor, llmModel = OllamaModels.Meta.LLAMA_3_2, toolRegistry = correctToolRegistry, strategy = singleRunStrategy(ToolCalls.SEQUENTIAL), systemPrompt = "You are a calculator. Use tools to calculate asked to result.", temperature = 0.0, - id = agent.id ) { install(Persistence) { storage = snapshotProvider diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/FilePersistentAgentExample.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/FilePersistentAgentExample.kt index bc1c749e92..ec6821bd74 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/FilePersistentAgentExample.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/FilePersistentAgentExample.kt @@ -38,7 +38,7 @@ fun main() = runBlocking { println("Checkpoint directory: $checkpointDir") // Create the file-based checkpoint provider - val provider = JVMFilePersistenceStorageProvider(checkpointDir, "persistent-agent-example") + val provider = JVMFilePersistenceStorageProvider(checkpointDir) // Create a unique agent ID to identify this agent's checkpoints val agentId = "persistent-agent-example" @@ -79,7 +79,7 @@ fun main() = runBlocking { println("Agent result: $result") // Retrieve all checkpoints created during the agent's execution - val checkpoints = provider.getCheckpoints() + val checkpoints = provider.getCheckpoints(agentId) println("\nRetrieved ${checkpoints.size} checkpoints for agent $agentId") // Print checkpoint details @@ -120,7 +120,7 @@ fun main() = runBlocking { println("Restored agent result: $restoredResult") // Get the latest checkpoint after the second run - val latestCheckpoint = provider.getLatestCheckpoint() + val latestCheckpoint = provider.getLatestCheckpoint(agentId) println("\nLatest checkpoint after restoration:") println(" ID: ${latestCheckpoint?.checkpointId}") println(" Created at: ${latestCheckpoint?.createdAt}") diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/SnapshotExample.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/SnapshotExample.kt index 6e66ef3f4f..9dd907c684 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/SnapshotExample.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/SnapshotExample.kt @@ -37,7 +37,7 @@ fun main() = runBlocking { maxAgentIterations = 50 ) - val snapshotProvider = InMemoryPersistenceStorageProvider("persistent-agent-example") + val snapshotProvider = InMemoryPersistenceStorageProvider() val agent = AIAgent( promptExecutor = executor, strategy = SnapshotStrategy.strategy, diff --git a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/sql/SQLPersistentAgentExample.kt b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/sql/SQLPersistentAgentExample.kt index 8d96d9b4e7..5dd319d3d5 100644 --- a/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/sql/SQLPersistentAgentExample.kt +++ b/examples/simple-examples/src/main/kotlin/ai/koog/agents/example/snapshot/sql/SQLPersistentAgentExample.kt @@ -42,9 +42,9 @@ object SQLPersistentAgentExample { private suspend fun postgresqlExample() { println("PostgreSQL Persistence Example") println("------------------------------") + val agentId = "postgres-agent" val provider = PostgresPersistenceStorageProvider( - persistenceId = "postgres-agent", database = Database.connect( url = "jdbc:postgresql://localhost:5432/agents", driver = "org.postgresql.Driver", @@ -59,11 +59,11 @@ object SQLPersistentAgentExample { // Create and save checkpoint val checkpoint = createSampleCheckpoint("postgres-checkpoint-1") - provider.saveCheckpoint(checkpoint) + provider.saveCheckpoint(agentId = agentId, agentCheckpointData = checkpoint) println("Saved checkpoint: ${checkpoint.checkpointId}") // Retrieve checkpoint - val retrieved = provider.getLatestCheckpoint() + val retrieved = provider.getLatestCheckpoint(agentId) println("Retrieved latest checkpoint: ${retrieved?.checkpointId}") } @@ -73,9 +73,9 @@ object SQLPersistentAgentExample { private suspend fun mysqlExample() { println("MySQL Persistence Example") println("-------------------------") + val agentId = "postgres-agent" val provider = MySQLPersistenceStorageProvider( - persistenceId = "mysql-agent", database = Database.connect( url = "jdbc:mysql://localhost:3306/agents?useSSL=false&serverTimezone=UTC", driver = "com.mysql.cj.jdbc.Driver", @@ -96,16 +96,16 @@ object SQLPersistentAgentExample { ) checkpoints.forEach { checkpoint -> - provider.saveCheckpoint(checkpoint) + provider.saveCheckpoint(agentId, checkpoint) println("Saved: ${checkpoint.checkpointId}") } // Get all checkpoints - val allCheckpoints = provider.getCheckpoints() + val allCheckpoints = provider.getCheckpoints(agentId) println("\nTotal checkpoints: ${allCheckpoints.size}") // Get checkpoint count - val count = provider.getCheckpointCount() + val count = provider.getCheckpointCount(agentId) println("Checkpoint count: $count") } @@ -115,37 +115,38 @@ object SQLPersistentAgentExample { private suspend fun h2Example() { println("H2 Database Persistence Examples") println("--------------------------------") - + val agentId = "h2-test-agent" // Example 1: In-memory database (for testing) println("\n1. In-Memory H2:") val inMemoryProvider = H2PersistenceStorageProvider.inMemory( - persistenceId = "h2-test-agent", databaseName = "test_agents" ) inMemoryProvider.migrate() val testCheckpoint = createSampleCheckpoint("h2-memory-checkpoint") - inMemoryProvider.saveCheckpoint(testCheckpoint) + inMemoryProvider.saveCheckpoint(agentId, testCheckpoint) println(" Saved to in-memory: ${testCheckpoint.checkpointId}") + val h2AgentId = "h2-file-agent" + // Example 2: File-based database (for persistence) println("\n2. File-Based H2:") val fileProvider = H2PersistenceStorageProvider.fileBased( - persistenceId = "h2-file-agent", filePath = "./data/h2/agent_checkpoints", ttlSeconds = 86400 // 24 hours ) fileProvider.migrate() val fileCheckpoint = createSampleCheckpoint("h2-file-checkpoint") - fileProvider.saveCheckpoint(fileCheckpoint) + fileProvider.saveCheckpoint(h2AgentId, fileCheckpoint) println(" Saved to file: ${fileCheckpoint.checkpointId}") // Example 3: PostgreSQL compatibility mode println("\n3. PostgreSQL Compatible Mode:") + val postgresAgentId = "postgres-agent" + val pgCompatProvider = H2PersistenceStorageProvider( - persistenceId = "postgres-agent", database = Database.connect( url = "jdbc:postgresql://localhost:5432/agents", driver = "org.postgresql.Driver", @@ -158,7 +159,7 @@ object SQLPersistentAgentExample { pgCompatProvider.migrate() val pgCheckpoint = createSampleCheckpoint("h2-pgcompat-checkpoint") - pgCompatProvider.saveCheckpoint(pgCheckpoint) + pgCompatProvider.saveCheckpoint(postgresAgentId, pgCheckpoint) println(" Saved with PG compatibility: ${pgCheckpoint.checkpointId}") } diff --git a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/AIAgentIntegrationTest.kt b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/AIAgentIntegrationTest.kt index 9866c45779..571197f632 100644 --- a/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/AIAgentIntegrationTest.kt +++ b/integration-tests/src/jvmTest/kotlin/ai/koog/integration/tests/agent/AIAgentIntegrationTest.kt @@ -720,7 +720,7 @@ class AIAgentIntegrationTest { @ParameterizedTest @MethodSource("openAIModels", "anthropicModels", "googleModels", "bedrockModels") fun integration_AgentCreateAndRestoreTest(model: LLModel) = runTest(timeout = 180.seconds) { - val checkpointStorageProvider = InMemoryPersistenceStorageProvider("integration_AgentCreateAndRestoreTest") + val checkpointStorageProvider = InMemoryPersistenceStorageProvider() val sayHello = "Hello World!" val hello = "Hello" val savedMessage = "Saved the state – the agent is ready to work!" @@ -776,7 +776,7 @@ class AIAgentIntegrationTest { agent.run("Start the test") - val checkpoints = checkpointStorageProvider.getCheckpoints() + val checkpoints = checkpointStorageProvider.getCheckpoints(agent.id) assertTrue(checkpoints.isNotEmpty(), "No checkpoints were created") assertEquals(save, checkpoints.first().nodeId, "Checkpoint has incorrect node ID") @@ -808,7 +808,7 @@ class AIAgentIntegrationTest { @ParameterizedTest @MethodSource("openAIModels", "anthropicModels", "googleModels", "bedrockModels") fun integration_AgentCheckpointRollbackTest(model: LLModel) = runTest(timeout = 180.seconds) { - val checkpointStorageProvider = InMemoryPersistenceStorageProvider("integration_AgentCheckpointRollbackTest") + val checkpointStorageProvider = InMemoryPersistenceStorageProvider() val hello = "Hello" val save = "Save" @@ -924,7 +924,7 @@ class AIAgentIntegrationTest { @MethodSource("openAIModels", "anthropicModels", "googleModels", "bedrockModels") fun integration_AgentCheckpointContinuousPersistenceTest(model: LLModel) = runTest(timeout = 180.seconds) { val checkpointStorageProvider = - InMemoryPersistenceStorageProvider("integration_AgentCheckpointContinuousPersistenceTest") + InMemoryPersistenceStorageProvider() val strategyName = "continuous-persistence-strategy" @@ -985,7 +985,7 @@ class AIAgentIntegrationTest { agent.run(testInput) - val checkpoints = checkpointStorageProvider.getCheckpoints() + val checkpoints = checkpointStorageProvider.getCheckpoints(agent.id) assertTrue(checkpoints.size >= 3, notEnoughCheckpointsError) val nodeIds = checkpoints.map { it.nodeId }.toSet() @@ -1012,8 +1012,7 @@ class AIAgentIntegrationTest { val noCheckpointsError = "No checkpoints were created" val incorrectNodeIdError = "Checkpoint has incorrect node ID" - val fileStorageProvider = - JVMFilePersistenceStorageProvider(tempDir, "integration_AgentCheckpointStorageProvidersTest") + val fileStorageProvider = JVMFilePersistenceStorageProvider(tempDir) val simpleStrategy = strategy(strategyName) { val nodeHello by node(hello) { @@ -1057,7 +1056,7 @@ class AIAgentIntegrationTest { agent.run(testInput) - val checkpoints = fileStorageProvider.getCheckpoints().filter { it.nodeId != "tombstone" } + val checkpoints = fileStorageProvider.getCheckpoints(agent.id).filter { it.nodeId != "tombstone" } assertTrue(checkpoints.isNotEmpty(), noCheckpointsError) assertEquals(bye, checkpoints.first().nodeId, incorrectNodeIdError) } From 669c56d63b4c47783ef04a33b9c86c70784889df Mon Sep 17 00:00:00 2001 From: Marko Marinkovic Date: Thu, 2 Oct 2025 15:43:07 +0200 Subject: [PATCH 51/52] A2A docs (#901) Add documentation for A2A support in Koog. --- docs/docs/a2a-client.md | 202 +++++++++++++++++ docs/docs/a2a-koog-integration.md | 319 ++++++++++++++++++++++++++ docs/docs/a2a-protocol-overview.md | 16 ++ docs/docs/a2a-server.md | 351 +++++++++++++++++++++++++++++ docs/mkdocs.yml | 10 + 5 files changed, 898 insertions(+) create mode 100644 docs/docs/a2a-client.md create mode 100644 docs/docs/a2a-koog-integration.md create mode 100644 docs/docs/a2a-protocol-overview.md create mode 100644 docs/docs/a2a-server.md diff --git a/docs/docs/a2a-client.md b/docs/docs/a2a-client.md new file mode 100644 index 0000000000..27df9fcf4e --- /dev/null +++ b/docs/docs/a2a-client.md @@ -0,0 +1,202 @@ +# A2A Client + +The A2A client enables you to communicate with A2A-compliant agents over the network. +It provides a complete implementation of +the [A2A protocol specification](https://a2a-protocol.org/latest/specification/), handling agent discovery, message +exchange, task management, and real-time streaming responses. + +## Overview + +The A2A client acts as a bridge between your application and A2A-compliant agents. +It orchestrates the entire communication lifecycle while maintaining protocol compliance and providing robust session +management. + +## Core components + +### A2AClient + +The main client class implementing the complete A2A protocol. It serves as the central coordinator that: + +- **Manages** connections and agent discovery through pluggable resolvers +- **Orchestrates** message exchange and task operations with automatic protocol compliance +- **Handles** streaming responses and real-time communication when supported by agents +- **Provides** comprehensive error handling and fallback mechanisms for robust applications + +The `A2AClient` accepts two required parameters: + +* `ClientTransport` which handles network communication layer +* `AgentCardResolver` which handles agent discovery and metadata retrieval + +The `A2AClient` interface provides several key methods for interacting with A2A agents: + +* `connect` method - To connect to the agent and retrieve its capabilities, which discovers what the agent can do and + caches the AgentCard +* `sendMessage` method - To send a message to the agent and receive a single response for simple request-response + patterns +* `sendMessageStreaming` method - To send a message with streaming support for real-time responses, which returns a Flow + of events including partial messages and task updates +* `getTask` method - To query the status and details of a specific task +* `cancelTask` method - To cancel a running task if the agent supports cancellation +* `cachedAgentCard` method - To get the cached agent card without making a network request, which returns null if + connect hasn't been called yet + +### ClientTransport + +The `ClientTransport` interface handles the low-level network communication while the A2A client manages the protocol +logic. +It abstracts away transport-specific details, allowing you to use different protocols seamlessly. + +#### HTTP JSON-RPC Transport + +The most common transport for A2A agents: + +```kotlin +val transport = HttpJSONRPCClientTransport( + url = "https://agent.example.com/a2a", // Agent endpoint URL + httpClient = HttpClient(CIO) { // Optional: custom HTTP client + install(ContentNegotiation) { + json() + } + install(HttpTimeout) { + requestTimeoutMillis = 30000 + } + } +) +``` + +### AgentCardResolver + +The `AgentCardResolver` interface retrieves agent metadata and capabilities. +It enables agent discovery from various sources and supports caching strategies for optimal performance. + +#### URL Agent Card Resolver + +Fetch agent cards from HTTP endpoints following A2A conventions: + +```kotlin +val agentCardResolver = UrlAgentCardResolver( + baseUrl = "https://agent.example.com", // Base URL of the agent service + path = "/.well-known/agent-card.json", // Standard agent card location + httpClient = HttpClient(CIO), // Optional: custom HTTP client +) +``` + +## Quickstart + +### 1. Create the Client + +Define the transport and agent card resolver and create the client. + +```kotlin +// HTTP JSON-RPC transport +val transport = HttpJSONRPCClientTransport( + url = "https://agent.example.com/a2a" +) + +// Agent card resolver +val agentCardResolver = UrlAgentCardResolver( + baseUrl = "https://agent.example.com", + path = "/.well-known/agent-card.json" +) + +// Create client +val client = A2AClient(transport, agentCardResolver) +``` + +### 2. Connect and Discover + +Connect to the agent and retrieve its card. +Having agent's card enables you to query its capabilities and perform other operations, for example, check if it +supports streaming. + +```kotlin +// Connect and retrieve agent capabilities +client.connect() +val agentCard = client.cachedAgentCard() + +println("Connected to: ${agentCard.name}") +println("Supports streaming: ${agentCard.capabilities.streaming}") +``` + +### 3. Send Messages + +Send a message to the agent and receive a single response. +The response can be either the message if the agent responded directly, or a task event if the agent is performing a +task. + +```kotlin +val message = Message( + messageId = UUID.randomUUID().toString(), + role = Role.User, + parts = listOf(TextPart("Hello, agent!")), + contextId = "conversation-1" +) + +val request = Request(data = MessageSendParams(message)) +val response = client.sendMessage(request) + +// Handle response +when (val event = response.data) { + is Message -> { + val text = event.parts + .filterIsInstance() + .joinToString { it.text } + print(text) // Stream partial responses + } + is TaskEvent -> { + if (event.final) { + println("\nTask completed") + } + } +} +``` + +### 4. Send Messages Streaming + +The A2A client supports streaming responses for real-time communication. +Instead of receiving a single response, it returns a `Flow` of events including messages and task updates. + +```kotlin +// Check if agent supports streaming +if (client.cachedAgentCard()?.capabilities?.streaming == true) { + client.sendMessageStreaming(request).collect { response -> + when (val event = response.data) { + is Message -> { + val text = event.parts + .filterIsInstance() + .joinToString { it.text } + print(text) // Stream partial responses + } + is TaskStatusUpdateEvent -> { + if (event.final) { + println("\nTask completed") + } + } + } + } +} else { + // Fallback to non-streaming + val response = client.sendMessage(request) + // Handle single response +} +``` + +### 5. Manage Tasks + +A2A Client provides methods to control server tasks by asking for their status and cancelling them. + +```kotlin +// Query task status +val taskRequest = Request(data = TaskQueryParams(taskId = "task-123")) +val taskResponse = client.getTask(taskRequest) +val task = taskResponse.data + +println("Task state: ${task.status.state}") + +// Cancel running task +if (task.status.state == TaskState.Working) { + val cancelRequest = Request(data = TaskIdParams(taskId = "task-123")) + val cancelledTask = client.cancelTask(cancelRequest).data + println("Task cancelled: ${cancelledTask.status.state}") +} +``` diff --git a/docs/docs/a2a-koog-integration.md b/docs/docs/a2a-koog-integration.md new file mode 100644 index 0000000000..89932902f6 --- /dev/null +++ b/docs/docs/a2a-koog-integration.md @@ -0,0 +1,319 @@ +# A2A and Koog Integration + +Koog provides seamless integration with the A2A protocol, allowing you to expose Koog agents as A2A servers and connect +Koog agents to other A2A-compliant agents. + +## Overview + +The integration enables two main patterns: + +1. **Expose Koog agents as A2A servers** - Make your Koog agents discoverable and accessible via the A2A protocol +2. **Connect Koog agents to A2A agents** - Let your Koog agents communicate with other A2A-compliant agents + +## Exposing Koog Agents as A2A Servers + +### Define Koog Agent with A2A feature + +Let's define a Koog agent first. The logic of the agent can vary, but here's an example basic single run agent with +tools. +The agent resaves a message from the user, forwards it to the llm. +If the llm response contains a tool call, the agent executes the tool and forwards the result to the llm. +If the llm response contains an assistant message, the agent sends the assistant message to the user and finishes. + +On input resize, the agent sends a task submitted event to the A2A client with the input message. +On each tool call, the agent sends a task working event to the A2A client with the tool call and result. +On assistant message, the agent sends a task complete event to the A2A client with the assistant message. + +```kotlin +/** + * Create a Koog agent with A2A feature + */ +@OptIn(ExperimentalUuidApi::class) +private fun createAgent( + context: RequestContext, + eventProcessor: SessionEventProcessor, +) = AIAgent( + promptExecutor = MultiLLMPromptExecutor( + LLMProvider.Google to GoogleLLMClient("api-key") + ), + toolRegistry = ToolRegistry { + // Declare tools here + }, + strategy = strategy("test") { + val nodeSetup by node { inputMessage -> + // Convenience function to transform A2A message into Koog message + val input = inputMessage.toKoogMessage() + llm.writeSession { + updatePrompt { + message(input) + } + } + // Send update event to A2A client + withA2AAgentServer { + sendTaskUpdate("Request submitted: ${input.content}", TaskState.Submitted) + } + } + + // Calling llm + val nodeLLMRequest by node { + llm.writeSession { + requestLLM() + } + } + + // Executing tool + val nodeProcessTool by node { toolCall -> + withA2AAgentServer { + sendTaskUpdate("Executing tool: ${toolCall.content}", TaskState.Working) + } + + val toolResult = environment.executeTool(toolCall) + + llm.writeSession { + updatePrompt { + tool { + result(toolResult) + } + } + } + withA2AAgentServer { + sendTaskUpdate("Tool result: ${toolResult.content}", TaskState.Working) + } + } + + // Sending assistant message + val nodeProcessAssistant by node { assistantMessage -> + withA2AAgentServer { + sendTaskUpdate(assistantMessage, TaskState.Completed) + } + } + + edge(nodeStart forwardTo nodeSetup) + edge(nodeSetup forwardTo nodeLLMRequest) + + // If a tool call is returned from llm, forward to the tool processing node and then back to llm + edge(nodeLLMRequest forwardTo nodeProcessTool onToolCall { true }) + edge(nodeProcessTool forwardTo nodeLLMRequest) + + // If an assistant message is returned from llm, forward to the assistant processing node and then to finish + edge(nodeLLMRequest forwardTo nodeProcessAssistant onAssistantMessage { true }) + edge(nodeProcessAssistant forwardTo nodeFinish) + }, + agentConfig = AIAgentConfig( + prompt = prompt("agent") { system("You are a helpful assistant.") }, + model = GoogleModels.Gemini2_5Pro, + maxAgentIterations = 10 + ), +) { + install(A2AAgentServer) { + this.context = context + this.eventProcessor = eventProcessor + } +} + +/** + * Convenience function to send task update event to A2A client + * @param content The message content + * @param state The task state + */ +@OptIn(ExperimentalUuidApi::class) +private suspend fun A2AAgentServer.sendTaskUpdate( + content: String, + state: TaskState, +) { + val message = A2AMessage( + messageId = Uuid.random().toString(), + role = Role.Agent, + parts = listOf( + TextPart(content) + ), + contextId = context.contextId, + taskId = context.taskId, + ) + + val task = Task( + id = context.taskId, + contextId = context.contextId, + status = TaskStatus( + state = state, + message = message, + timestamp = Clock.System.now(), + ) + ) + eventProcessor.sendTaskEvent(task) +} +``` + +## A2AAgentServer Feature Mechanism + +The `A2AAgentServer` is a Koog agent feature that enables seamless integration between Koog agents and the A2A protocol. +The `A2AAgentServer` feature provides access to the `RequestContext` and `SessionEventProcessor` entities, which are used to +communicate with the A2A client inside the Koog agent. + +To install the feature, call the `install` function on the agent and pass the `A2AAgentServer` feature along with the `RequestContext` and `SessionEventProcessor`: +```kotlin +// Install the feature +agent.install(A2AAgentServer) { + this.context = context + this.eventProcessor = eventProcessor +} +``` + +To access these entities from Koog agent strategy, the feature provides a `withA2AAgentServer` function that allows agent nodes to access A2A server capabilities within their execution context. +It retrieves the installed `A2AAgentServer` feature and provides it as the receiver for the action block. + +```kotlin +// Usage within agent nodes +withA2AAgentServer { + // 'this' is now A2AAgentServer instance + sendTaskUpdate("Processing your request...", TaskState.Working) +} +``` + +### Start A2A Server +After running the server Koog agent will be discoverable and accessible via the A2A protocol. + +```kotlin +val agentCard = AgentCard( + name = "Koog Agent", + url = "http://localhost:9999/koog", + description = "Simple universal agent powered by Koog", + version = "1.0.0", + protocolVersion = "0.3.0", + preferredTransport = TransportProtocol.JSONRPC, + capabilities = AgentCapabilities(streaming = true), + defaultInputModes = listOf("text"), + defaultOutputModes = listOf("text"), + skills = listOf( + AgentSkill( + id = "koog", + name = "Koog Agent", + description = "Universal agent powered by Koog. Supports tool calling.", + tags = listOf("chat", "tool"), + ) + ) +) +// Server setup +val server = A2AServer(agentExecutor = KoogAgentExecutor(), agentCard = agentCard) +val transport = HttpJSONRPCServerTransport(server) +transport.start(engineFactory = CIO, port = 8080, path = "/chat", wait = true) +``` + +## Connecting Koog Agents to A2A Agents + +### Create A2A Client and connect to the A2A Server + +```kotlin +val transport = HttpJSONRPCClientTransport(url = "http://localhost:9999/koog") +val agentCardResolver = + UrlAgentCardResolver(baseUrl = "http://localhost:9999", path = "/koog") +val client = A2AClient(transport = transport, agentCardResolver = agentCardResolver) + +val agentId = "koog" +client.connect() +``` + +### Create Koog Agent and add A2A Client to A2AAgentClient Feature +To connect to A2A agent from your Koog Agent, you can use the A2AAgentClient feature, which provides a client API for connecting to A2A agents. +The principle of the client is the same as the server: you install the feature and pass the `A2AAgentClient` feature along with the `RequestContext` and `SessionEventProcessor`. + +```kotlin +val agent = AIAgent( + promptExecutor = MultiLLMPromptExecutor( + LLMProvider.Google to GoogleLLMClient("api-key") + ), + toolRegistry = ToolRegistry { + // declare tools here + }, + strategy = strategy("test") { + + val nodeCheckStreaming by nodeA2AClientGetAgentCard().transform { it.capabilities.streaming } + + val nodeA2ASendMessageStreaming by nodeA2AClientSendMessageStreaming() + val nodeA2ASendMessage by nodeA2AClientSendMessage() + + val nodeProcessStreaming by node>, Unit> { + it.collect { response -> + when (response.data) { + is Task -> { + // Process task + } + + is A2AMessage -> { + // Process message + } + + is TaskStatusUpdateEvent -> { + // Process task status update + } + + is TaskArtifactUpdateEvent -> { + // Process task artifact update + } + } + } + } + + val nodeProcessEvent by node { event -> + when (event) { + is Task -> { + // Process task + } + + is A2AMessage -> { + // Process message + } + } + } + + // If streaming is supported, send a message, process response and finish + edge(nodeStart forwardTo nodeCheckStreaming transformed { agentId }) + edge( + nodeCheckStreaming forwardTo nodeA2ASendMessageStreaming + onCondition { it == true } transformed { buildA2ARequest(agentId) } + ) + edge(nodeA2ASendMessageStreaming forwardTo nodeProcessStreaming) + edge(nodeProcessStreaming forwardTo nodeFinish) + + // If streaming is not supported, send a message, process response and finish + edge( + nodeCheckStreaming forwardTo nodeA2ASendMessage + onCondition { it == false } transformed { buildA2ARequest(agentId) } + ) + edge(nodeA2ASendMessage forwardTo nodeProcessEvent) + edge(nodeProcessEvent forwardTo nodeFinish) + + // If streaming is not supported, send a message, process response and finish + edge(nodeCheckStreaming forwardTo nodeFinish onCondition { it == null } + transformed { println("Failed to get agents card") } + ) + + }, + agentConfig = AIAgentConfig( + prompt = prompt("agent") { system("You are a helpful assistant.") }, + model = GoogleModels.Gemini2_5Pro, + maxAgentIterations = 10 + ), +) { + install(A2AAgentClient) { + this.a2aClients = mapOf(agentId to client) + } +} + + +@OptIn(ExperimentalUuidApi::class) +private fun AIAgentGraphContextBase.buildA2ARequest(agentId: String): A2AClientRequest = + A2AClientRequest( + agentId = agentId, + callContext = ClientCallContext.Default, + params = MessageSendParams( + message = A2AMessage( + messageId = Uuid.random().toString(), + role = Role.User, + parts = listOf( + TextPart(agentInput as String) + ) + ) + ) + ) +``` diff --git a/docs/docs/a2a-protocol-overview.md b/docs/docs/a2a-protocol-overview.md new file mode 100644 index 0000000000..947b25ffae --- /dev/null +++ b/docs/docs/a2a-protocol-overview.md @@ -0,0 +1,16 @@ +# A2A protocol + +This page provides an overview of the A2A (Agent-to-Agent) protocol implementation in the Koog agentic framework. + +## What is the A2A protocol? + +The A2A (Agent-to-Agent) protocol is a standardized communication protocol that enables AI agents to interact with each other and with client applications. +It defines a set of methods, message formats, and behaviors that allow for consistent and interoperable agent communication. +For more information and a detailed specification of the A2A protocol, see the official [A2A Protocol website](https://a2a-protocol.org/latest/). + +## Key A2A components + +Koog provides full implementation of A2A protocol v0.3.0 for both client and server, as well as integration with the Koog agent framework: +- [A2A Server](a2a-server.md) is an agent or agentic system that exposes an endpoint implementing the A2A protocol. It receives requests from clients, processes tasks, and returns results or status updates. It can also be used independently of Koog agents. +- [A2A Client](a2a-client.md) is a client application or agent that initiates communication with an A2A server using the A2A protocol. It can also be used independently of Koog agents. +- [A2A Koog Integration](a2a-koog-integration.md) is a set of classes and utilities that simplify the integration of A2A with Koog Agents. It contains components (A2A features and nodes) for seamless A2A agent connections and communication within the Koog framework. diff --git a/docs/docs/a2a-server.md b/docs/docs/a2a-server.md new file mode 100644 index 0000000000..7e8ad7d7c5 --- /dev/null +++ b/docs/docs/a2a-server.md @@ -0,0 +1,351 @@ +# A2A Server + +The A2A server enables you to expose AI agents through the standardized A2A (Agent-to-Agent) protocol. It provides a complete implementation of the [A2A protocol specification](https://a2a-protocol.org/latest/specification/), handling client requests, executing agent logic, managing complex task lifecycles, and supporting real-time streaming responses. + +## Overview + +The A2A server acts as a bridge between the A2A protocol transport layer and your custom agent logic. +It orchestrates the entire request lifecycle while maintaining protocol compliance and providing robust session management. + +## Core components + +### A2AServer + +The main server class implementing the complete A2A protocol. It serves as the central coordinator that: + +- **Validates** incoming requests against protocol specifications +- **Manages** concurrent sessions and task lifecycles +- **Orchestrates** communication between transport, storage, and business logic layers +- **Handles** all protocol operations: message sending, task querying, cancellation, push notifications + +The `A2AServer` accepts two required parameters: +* `AgentExecutor` which defines business logic implementation of the agent +* `AgentCard` which defines agent capabilities and metadata + +And a number of optional parameters that can be used to customize its storage and transport behavior. + +### AgentExecutor + +The `AgentExecutor` interface is where you implement your agent's core business logic. +It acts as the bridge between the A2A protocol and your specific AI agent capabilities. +To start the execution of your agent, you must implement the `execute` method where define your agent's logic. +To cancel the agent, you must implement the `cancel` method. + +```kotlin +class MyAgentExecutor : AgentExecutor { + override suspend fun execute( + context: RequestContext, + eventProcessor: SessionEventProcessor + ) { + // Agent logic here + } + + override suspend fun cancel( + context: RequestContext, + eventProcessor: SessionEventProcessor, + agentJob: Deferred? + ) { + // Cancel agent here, optional + } +} +``` + +The `RequestContext` provides rich information about the current request, +including the `contextId` and `taskId` of the current session, the `message` sent, and the `params` of the request. + +The `SessionEventProcessor` communicates with clients: +- **`sendMessage(message)`**: Send immediate responses (chat-style interactions) +- **`sendTaskEvent(event)`**: Send task-related updates (long-running operations) + +```kotlin +// For immediate responses (like chatbots) +eventProcessor.sendMessage( + Message( + messageId = generateId(), + role = Role.Agent, + parts = listOf(TextPart("Here's your answer!")), + contextId = context.contextId + ) +) + +// For task-based operations +eventProcessor.sendTaskEvent( + TaskStatusUpdateEvent( + contextId = context.contextId, + taskId = context.taskId, + status = TaskStatus( + state = TaskState.Working, + message = Message(/* progress update */), + timestamp = Clock.System.now() + ), + final = false // More updates to come + ) +) +``` + +### AgentCard + +The `AgentCard` serves as your agent's self-describing manifest. It tells clients what your agent can do, how to communicate with it, and what security requirements it has. + +```kotlin +val agentCard = AgentCard( + // Basic Identity + name = "Advanced Recipe Assistant", + description = "AI agent specialized in cooking advice, recipe generation, and meal planning", + version = "2.1.0", + protocolVersion = "0.3.0", + + // Communication Settings + url = "https://api.example.com/a2a", + preferredTransport = TransportProtocol.JSONRPC, + + // Optional: Multiple transport support + additionalInterfaces = listOf( + AgentInterface("https://api.example.com/a2a", TransportProtocol.JSONRPC), + ), + + // Capabilities Declaration + capabilities = AgentCapabilities( + streaming = true, // Support real-time responses + pushNotifications = true, // Send async notifications + stateTransitionHistory = true // Maintain task history + ), + + // Content Type Support + defaultInputModes = listOf("text/plain", "text/markdown", "image/jpeg"), + defaultOutputModes = listOf("text/plain", "text/markdown", "application/json"), + + // Define available security schemes + securitySchemes = mapOf( + "bearer" to HTTPAuthSecurityScheme( + scheme = "Bearer", + bearerFormat = "JWT", + description = "JWT token authentication" + ), + "api-key" to APIKeySecurityScheme( + `in` = In.Header, + name = "X-API-Key", + description = "API key for service authentication" + ) + ), + + // Specify security requirements (logical OR of requirements) + security = listOf( + mapOf("bearer" to listOf("read", "write")), // Option 1: JWT with read/write scopes + mapOf("api-key" to emptyList()) // Option 2: API key + ), + + // Enable extended card for authenticated users + supportsAuthenticatedExtendedCard = true, + + // Skills/Capabilities + skills = listOf( + AgentSkill( + id = "recipe-generation", + name = "Recipe Generation", + description = "Generate custom recipes based on ingredients, dietary restrictions, and preferences", + tags = listOf("cooking", "recipes", "nutrition"), + examples = listOf( + "Create a vegan pasta recipe with mushrooms", + "I have chicken, rice, and vegetables. What can I make?" + ) + ), + AgentSkill( + id = "meal-planning", + name = "Meal Planning", + description = "Plan weekly meals and generate shopping lists", + tags = listOf("meal-planning", "nutrition", "shopping") + ) + ), + + // Optional: Branding + iconUrl = "https://example.com/agent-icon.png", + documentationUrl = "https://docs.example.com/recipe-agent", + provider = AgentProvider( + organization = "CookingAI Inc.", + url = "https://cookingai.com" + ) +) +``` + +### Transport Layer + +The A2A itself supports multiple transport protocols for communicating with clients. +Currently, Koog provides implementations for JSON-RPC server transport over HTTP. + +#### HTTP JSON-RPC Transport + +```kotlin +val transport = HttpJSONRPCServerTransport(server) +transport.start( + engineFactory = CIO, // Ktor engine (CIO, Netty, Jetty) + port = 8080, // Server port + path = "/a2a", // API endpoint path + wait = true // Block until server stops +) +``` + +### Storage + +The A2A server uses a pluggable storage architecture that separates different types of data. +All storage implementations are optional and default to in-memory variants for development. + +- **TaskStorage**: Task lifecycle management - stores and manages task states, history, and artifacts +- **MessageStorage**: Conversation history - manages message history within conversation contexts +- **PushNotificationConfigStorage**: Webhook management - manages webhook configurations for asynchronous notifications + +## Quickstart + +### 1. Create AgentCard +Define your agent's capabilities and metadata. +```kotlin +val agentCard = AgentCard( + name = "IO Assistant", + description = "AI agent specialized in input modification", + version = "2.1.0", + protocolVersion = "0.3.0", + + // Communication Settings + url = "https://api.example.com/a2a", + preferredTransport = TransportProtocol.JSONRPC, + + // Capabilities Declaration + capabilities = + AgentCapabilities( + streaming = true, // Support real-time responses + pushNotifications = true, // Send async notifications + stateTransitionHistory = true // Maintain task history + ), + + // Content Type Support + defaultInputModes = listOf("text/plain", "text/markdown", "image/jpeg"), + defaultOutputModes = listOf("text/plain", "text/markdown", "application/json"), + + // Skills/Capabilities + skills = listOf( + AgentSkill( + id = "echo", + name = "echo", + description = "Echoes back user messages", + tags = listOf("io"), + ) + ) +) +``` + + +### 2. Create an AgentExecutor +In executor manages implement agent logic, handles incoming requests and sends responses. + +```kotlin +class EchoAgentExecutor : AgentExecutor { + override suspend fun execute( + context: RequestContext, + eventProcessor: SessionEventProcessor + ) { + val userMessage = context.params.message + val userText = userMessage.parts + .filterIsInstance() + .joinToString(" ") { it.text } + + // Echo the user's message back + val response = Message( + messageId = UUID.randomUUID().toString(), + role = Role.Agent, + parts = listOf(TextPart("You said: $userText")), + contextId = context.contextId, + taskId = context.taskId + ) + + eventProcessor.sendMessage(response) + } +} +``` + +### 2. Create the Server +Pass the agent executor and agent card to the server. + +```kotlin +val server = A2AServer( + agentExecutor = EchoAgentExecutor(), + agentCard = agentCard +) +``` + +### 3. Add Transport Layer +Create a transport layer and start the server. +```kotlin +// HTTP JSON-RPC transport +val transport = HttpJSONRPCServerTransport(server) +transport.start( + engineFactory = CIO, + port = 8080, + path = "/agent", + wait = true +) +``` + +## Agent Implementation Patterns + +### Simple Response Agent +If your agent only needs to respond to a single message, you can implement it as a simple agent. +It can be also used if agent execution logic is not complex and time-consuming. + +```kotlin +class SimpleAgentExecutor : AgentExecutor { + override suspend fun execute( + context: RequestContext, + eventProcessor: SessionEventProcessor + ) { + val response = Message( + messageId = UUID.randomUUID().toString(), + role = Role.Agent, + parts = listOf(TextPart("Hello from agent!")), + contextId = context.contextId, + taskId = context.taskId + ) + + eventProcessor.sendMessage(response) + } +} +``` + +### Task-Based Agent +If the execution logic of your agent is complex and requires multiple steps, you can implement it as a task-based agent. +It can be also used if agent execution logic is time-consuming and suspending. +```kotlin +class TaskAgentExecutor : AgentExecutor { + override suspend fun execute( + context: RequestContext, + eventProcessor: SessionEventProcessor + ) { + // Send working status + eventProcessor.sendTaskEvent( + TaskStatusUpdateEvent( + contextId = context.contextId, + taskId = context.taskId, + status = TaskStatus( + state = TaskState.Working, + timestamp = Clock.System.now() + ), + final = false + ) + ) + + // Do work... + + // Send completion + eventProcessor.sendTaskEvent( + TaskStatusUpdateEvent( + contextId = context.contextId, + taskId = context.taskId, + status = TaskStatus( + state = TaskState.Completed, + timestamp = Clock.System.now() + ), + final = true + ) + ) + } +} +``` diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 91fe5a986d..3ad90dbe22 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -23,6 +23,11 @@ nav: - Data transfer between nodes: data-transfer-between-nodes.md - History compression: history-compression.md - Model Context Protocol: model-context-protocol.md + - A2A Protocol: + - Overview: a2a-protocol-overview.md + - A2A server implementation: a2a-server.md + - A2A client implementation: a2a-client.md + - A2A and Koog integration: a2a-koog-integration.md - Model capabilities: model-capabilities.md - Content moderation: content-moderation.md - Backend framework integrations: @@ -119,6 +124,11 @@ plugins: - history-compression.md: This page provides details about different implementations and different types of history compression strategies to reduce history size and token usage. Model Context Protocol: - model-context-protocol.md: This page provides details about Koog integration with MCP servers, which allows you to incorporate MCP tools into your Koog agents. + A2A Protocol: + - a2a-protocol-overview.md: This page provides an overview of the A2A (Agent-to-Agent) protocol implementation in the Koog agentic framework. + - a2a-server.md: This page provides details about the A2A server implementation in the Koog agentic framework. The A2A server is responsible for handling requests from A2A clients according to the A2A protocol specification. + - a2a-client.md: This page provides details about the A2A client implementation in the Koog agentic framework. The A2A client is responsible for sending requests to A2A servers according to the A2A protocol specification. + - a2a-koog-integration.md: This page provides details about the integration of A2A with Koog agents. It includes information about how to define your agent as A2A server and connect it to the Koog agent system. Model capabilities: - model-capabilities.md: This page provides an overview of different features or parameters that the models support, referred to as model capabilities. Content moderation: From 6a600781b8de1931b0ac0a685341a670eb32bb5e Mon Sep 17 00:00:00 2001 From: Briliantov Vadim Date: Thu, 2 Oct 2025 15:44:54 +0200 Subject: [PATCH 52/52] 0.5.0 release with change log (#914) 0.5.0 release with change log --- CHANGELOG.md | 195 ++++++++++++++++++++++++++++++++++++++++------- README.md | 6 +- build.gradle.kts | 2 +- 3 files changed, 171 insertions(+), 32 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e2bc628ebe..0e5fb59525 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,121 @@ +# 0.5.0 + +> Published 2 Oct 2025 + +## Major Features + +- **Full Agent-to-Agent (A2A) Protocol Support**: + - **Multiplatform Kotlin A2A SDK**: Including server and client with JSON-RPC HTTP support. + - **A2A Agent Feature**: seamlessly integrate A2A in your Koog agents +- **Non-Graph API for Strategies**: Introduced non-graph API for creating AI Agent strategies as Kotlin extension + functions with most of Koog's features supported (#560) +- **Agent Persistence and Checkpointing**: + - **Roll back Tool Side-Effects**: Add `RollbackToolRegistry` in the `Persistence` feature in order to roll back + tool calls with side effects when checkpointing. + - **State-Machine Persistence / Message History Switch**: Support switching between full state-machine persistence + and message history persistence (#856) +- **Tool API Improvements**: + - Make `ToolDescriptor` auto-generated for class-based tools (#791) + - Get rid of `ToolArgs` and `ToolResult` limitations for `Tool<*, *>` class (#791) +- **`subgraphWithTask` Simplification**: Get rid of required `finishTool` and support tools as functions in + `subgraphWithTask`, deduce final step automatically by data class (#791) +- **`AIAgentService` Introduced**: Make `AIAgent` state-manageable and single-run explicitly, introduce `AIAgentService` + to manage multiple uniform running agents. +- **New components**: + - Add LLM as a Judge component (#866) + - Tool Calling loop with Structured Output strategy (#829) + +## Improvements + +- Make Koog-based tools exportable via MCP server (KG-388) +- Add `additionalProperties` to LLM clients in order to support custom LLM configurations (#836) +- Allow adjusting context window sizes for Ollama dynamically (#883) +- Refactor streaming api to support tool calls (#747) +- Provide an ability to collect and send a list of nodes and edges out of `AIAgentStrategy` to the client when running + an agent (KG-160) +- Add `excludedProperties` to inline `createJsonStructure` too, update KDocs (#826) +- Refactor binary attachment handling and introduce Base64 serializer (#838) +- In `JsonStructuredData.defaultJson` instance rename class discriminator + from `#type` to `kind` to align with common practices (#772, KG-384) +- Make standard json generator default when creating `JsonStructuredData` + (it was basic before) (#772, KG-384) +- Add default audio configuration and modalities (#817) +- Add `GptAudio` model in OpenAI client (#818) +- Allow re-running of finished agents that have `Persistence` feature installed (#828, KG-193) +- Allow ideomatic node transformations with `.transform { ...}` lambda function (#684) +- Add ability to filter messages for every agent feature (KG-376) +- Add support for trace-level attributes in Langfuse integration (#860, KG-427) +- Keep all system messages when compressing message history of the agent(#857) +- Add support for Anthropic's Sonnet 4.5 model in Anthropic/Bedrock providers (#885) +- Refactored LLM client auto-configuration in Spring Boot integration, to modular provider-specific classes with + improved validation and security (#886) +- Add LLM Streaming agent events (KG-148) + +## Bug Fixes + +- Fix broken Anthropic models support via Amazon Bedrock (#789) +- Make `AIAgentStorageKey` in agent storage actually unique by removing `data` modifier (#825) +- Fix rerun for agents with Persistence (#828, KG-193) +- Update mcp version to `0.7.2` with fix for Android target (#835) +- Do not include an empty system message in Anthropic request (#887, KG-317) +- Use `maxTokens` from params in Google models (#734) +- Fix finishReason nullability (#771) + +## Deprecations + +- Rename agent interceptors in `EventHandler` and related feature events (KG-376) +- Deprecate concurrent unsafe `AIAgent.asTool` in favor of `AIAgentService.createAgentTool` (#873) +- Rename `Persistency` to `Persistence` everywhere (#896) +- Add `agentId` argument to all `Persistence` methods instead of `persistencyId` class field (#904) + +## Examples + +- Add a basic code-agent example (#808, KG-227) +- Add iOS and Web targets for demo-compose-app (#779, #780) + +# 0.4.2 + +> Published 15 Sep 2025 + +## Improvements + +- Make agents‑mcp support KMP targets to run across more platforms (#756). +- Add LLM client retry support to Spring Boot auto‑configuration to improve resilience on transient failures (#748). +- Add Claude Opus 4.1 model support to Anthropic client to unlock latest reasoning capabilities (#730). +- Add Gemini 2.5 Flash Lite model support to Google client to enable lower‑latency, cost‑efficient generations (#769). +- Add Java‑compatible non‑streaming Prompt Executor so Java apps can call Koog without + coroutines ([KG-312](https://youtrack.jetbrains.com/issue/KG-312), #715). +- Support excluding properties in JSON Schema generation to fine‑tune structured outputs (#638). +- Update AWS SDK to latest compatible version for Bedrock integrations. +- Introduce Postgres persistence provider to store agent state and artifacts (#705). +- Update Kotlin to 2.2.10 in dependency configuration for improved performance and language features (#764). +- Refactor executeStreaming to remove suspend for simpler interop and better call sites (#720). +- Add Java‑compatible prompt executor (non‑streaming) wiring and polish across + modules ([KG-312](https://youtrack.jetbrains.com/issue/KG-312), #715). +- Decouple FileSystemEntry from FileSystemProvider to simplify testing and enable alternative providers (#664). + +## Bug Fixes + +- Add missing tool calling support for Bedrock Nova models so agents can invoke functions when using + Nova ([KG-239](https://youtrack.jetbrains.com/issue/KG-239)). +- Add Android target support and migrate Android app to Kotlin Multiplatform to widen KMP + coverage ([KG-315](https://youtrack.jetbrains.com/issue/KG-315), #728, #767). +- Add Spring Boot Java example to jump‑start integration (#739). +- Add Java Spring auto‑config fixes: correct property binding and make Koog starter work out of the box (#698). +- Fix split package issues in OpenAI LLM clients to avoid classpath/load + errors ([KG-305](https://youtrack.jetbrains.com/issue/KG-305), #694). +- Ensure Anthropic tool schemas include the required "type" field in serialized request bodies to prevent validation + errors during tool calling (#582). +- Fix AbstractOpenAILLMClient to correctly handle plain‑text responses in capabilities flow; add integration tests to + prevent regressions (#564). +- Fix GraalVM native image build failure so projects can compile native binaries again (#774). +- Fix usages in OpenAI‑based data model to align with recent API changes (#688). +- Fix SpringBootStarters initialization and improve `RetryingClient` (#894) + +## CI and Build + +- Nightly build configuration and dependency submission workflow added (#695, #737). + # 0.4.1 > Published 28 Aug 2025 @@ -13,42 +131,58 @@ Fixed iOS target publication ## Major Features - **Integration with Observability Tools**: - - **Langfuse Integration**: Span adapters for Langfuse client, including open telemetry and graph visualisation ([KG-217](https://youtrack.jetbrains.com/issue/KG-217), [KG-223](https://youtrack.jetbrains.com/issue/KG-223)) - - **W&B Weave Integration**: Span adapters for W&B Weave open telemetry and observability ([KG-217](https://youtrack.jetbrains.com/issue/KG-217), [KG-218](https://youtrack.jetbrains.com/issue/KG-218)) -- **Ktor Integration**: First-class Ktor support via the "Koog" Ktor plugin to register and run agents in Ktor applications (#422). -- **iOS Target Support**: Multiplatform expanded with native iOS targets, enabling agents to run on Apple platforms (#512). -- **Upgraded Structured Output**: Refactored structured output API to be more flexible and add built-in/native provider support for OpenAI and Google, reducing prompt boilerplate and improving validation (#443). -- **GPT5 and Custom LLM Parameters Support**: Now GPT5 is available together with custom additional LLM parameters for OpenAI-compatible clients (#631, #517) + - **Langfuse Integration**: Span adapters for Langfuse client, including open telemetry and graph + visualisation ([KG-217](https://youtrack.jetbrains.com/issue/KG-217), [KG-223](https://youtrack.jetbrains.com/issue/KG-223)) + - **W&B Weave Integration**: Span adapters for W&B Weave open telemetry and + observability ([KG-217](https://youtrack.jetbrains.com/issue/KG-217), [KG-218](https://youtrack.jetbrains.com/issue/KG-218)) +- **Ktor Integration**: First-class Ktor support via the "Koog" Ktor plugin to register and run agents in Ktor + applications (#422). +- **iOS Target Support**: Multiplatform expanded with native iOS targets, enabling agents to run on Apple platforms ( + #512). +- **Upgraded Structured Output**: Refactored structured output API to be more flexible and add built-in/native provider + support for OpenAI and Google, reducing prompt boilerplate and improving validation (#443). +- **GPT5 and Custom LLM Parameters Support**: Now GPT5 is available together with custom additional LLM parameters for + OpenAI-compatible clients (#631, #517) - **Resilience and Retries**: - - **Retryable LLM Clients**: Introduce retry logic for LLM clients with sensible defaults to reduce transient failures (#592) - - **Retry Anything with LLM Feedback**: Add a feedback mechanism to the retry component (`subgraphWithRetry`) to observe and tune behavior (#459). + - **Retryable LLM Clients**: Introduce retry logic for LLM clients with sensible defaults to reduce transient + failures (#592) + - **Retry Anything with LLM Feedback**: Add a feedback mechanism to the retry component (`subgraphWithRetry`) to + observe and tune behavior (#459). ## Improvements - **OpenTelemetry and Observability**: - - Finish reason and unified attributes for inference/tool/message spans and events; extract event body fields to attributes for better querying ([KG-218](https://youtrack.jetbrains.com/issue/KG-218)). - - Mask sensitive data in events/attributes and introduce a “hidden-by-default” string type to keep secrets safe in logs ([KG-259](https://youtrack.jetbrains.com/issue/KG-259)). - - Include all messages into the inference span and add an index for ChoiceEvent to simplify analysis ([KG-172](https://youtrack.jetbrains.com/issue/KG-172)). - - Add tool arguments to `gen_ai.choice` and `gen_ai.assistant.message` events (#462). - - Allow setting a custom OpenTelemetry SDK instance in Koog ([KG-169](https://youtrack.jetbrains.com/issue/KG-169)). + - Finish reason and unified attributes for inference/tool/message spans and events; extract event body fields to + attributes for better querying ([KG-218](https://youtrack.jetbrains.com/issue/KG-218)). + - Mask sensitive data in events/attributes and introduce a “hidden-by-default” string type to keep secrets safe in + logs ([KG-259](https://youtrack.jetbrains.com/issue/KG-259)). + - Include all messages into the inference span and add an index for ChoiceEvent to simplify + analysis ([KG-172](https://youtrack.jetbrains.com/issue/KG-172)). + - Add tool arguments to `gen_ai.choice` and `gen_ai.assistant.message` events (#462). + - Allow setting a custom OpenTelemetry SDK instance in Koog ([KG-169](https://youtrack.jetbrains.com/issue/KG-169)). - **LLM and Providers**: - - Support Google’s “thinking” mode in generation config to improve reasoning quality (#414). - - Add responses API support for OpenAI (#645) - - AWS Bedrock: support Inference Profiles for simpler, consistent configuration (#506) and accept `AWS_SESSION_TOKEN` (#456). - - Add `maxTokens` as prompt parameters for finer control over generation length (#579). - - Add `contextLength` and `maxOutputTokens` to `LLModel` (#438, [KG-134](https://youtrack.jetbrains.com/issue/KG-134)) + - Support Google’s “thinking” mode in generation config to improve reasoning quality (#414). + - Add responses API support for OpenAI (#645) + - AWS Bedrock: support Inference Profiles for simpler, consistent configuration (#506) and accept + `AWS_SESSION_TOKEN` (#456). + - Add `maxTokens` as prompt parameters for finer control over generation length (#579). + - Add `contextLength` and `maxOutputTokens` to `LLModel` ( + #438, [KG-134](https://youtrack.jetbrains.com/issue/KG-134)) - **Agent Engine**: - - Add AIAgentPipeline interceptors to uniformly handle node errors; propagate `NodeExecutionError` across features ([KG-170](https://youtrack.jetbrains.com/issue/KG-170)). - - Include finish node processing in the pipeline to ensure finalizers run reliably (#598). + - Add AIAgentPipeline interceptors to uniformly handle node errors; propagate `NodeExecutionError` across + features ([KG-170](https://youtrack.jetbrains.com/issue/KG-170)). + - Include finish node processing in the pipeline to ensure finalizers run reliably (#598). - **File Tools and RAG**: - - Reworked FileSystemProvider with API cleanups and better ergonomics; moved blocking/suspendable operations to `Dispatchers.IO` for improved performance and responsiveness (#557, “Move suspendable operations to Dispatchers.IO”). - - Introduce `filterByRoot` helpers and allow custom path filters in `FilteredFileSystemProvider` for safer agent sandboxes (#494, #508). + - Reworked FileSystemProvider with API cleanups and better ergonomics; moved blocking/suspendable operations to + `Dispatchers.IO` for improved performance and responsiveness (#557, “Move suspendable operations to + Dispatchers.IO”). + - Introduce `filterByRoot` helpers and allow custom path filters in `FilteredFileSystemProvider` for safer agent + sandboxes (#494, #508). - Rename `PathFilter` to `TraversalFilter` and make its methods suspendable to support async checks. - Rename `fromAbsoluteString` to `fromAbsolutePathString` for clarity (#567). - Add `ReadFileTool` for reading local file contents where appropriate (#628). - Update kotlin-mcp dependency to v0.6.0 (#523) - ## Bug Fixes - Make `parts` field nullable in Google responses to handle missing content from Gemini models (#652). @@ -56,7 +190,8 @@ Fixed iOS target publication - Fix function calling for `gemini-2.5-flash` models to correctly route tool invocations (#586). - Restore OpenAI `responseFormat` option support in requests (#643). - Correct `o4-mini` vs `gpt-4o-mini` model mix-up in configuration (#573). -- Ensure event body for function calls is valid JSON for telemetry ingestion ([KG-268](https://youtrack.jetbrains.com/issue/KG-268)). +- Ensure event body for function calls is valid JSON for telemetry + ingestion ([KG-268](https://youtrack.jetbrains.com/issue/KG-268)). - Fix duplicated tool names resolution in `AIAgentSubgraphExt` to prevent conflicts (#493). - Fix Azure OpenAI client settings to generate valid endpoint URLs (#478). - Restore `llama3.2:latest` as the default for LLAMA_3_2 to match the provider expectations (#522). @@ -65,14 +200,17 @@ Fixed iOS target publication ## Removals / Breaking Changes -- Remove Google Gemini 1.5 Flash/Pro variants from the catalog ([KG-216](https://youtrack.jetbrains.com/issue/KG-216), #574). +- Remove Google Gemini 1.5 Flash/Pro variants from the catalog ([KG-216](https://youtrack.jetbrains.com/issue/KG-216), + #574). - Drop `execute` extensions for `PromptExecutor` in favor of the unified API (#591). -- File system API cleanup: removed deprecated FSProvider interfaces and methods; `PathFilter` renamed to `TraversalFilter` with suspendable operations; `fromAbsoluteString` renamed to `fromAbsolutePathString`. +- File system API cleanup: removed deprecated FSProvider interfaces and methods; `PathFilter` renamed to + `TraversalFilter` with suspendable operations; `fromAbsoluteString` renamed to `fromAbsolutePathString`. ## Examples - Add a web search agent (from Koog live stream 1) showcasing retrieval + summarization (#575). -- Add a trip planning agent example (from Koog live stream 2) demonstrating tools + planning + composite strategy (#595). +- Add a trip planning agent example (from Koog live stream 2) demonstrating tools + planning + composite strategy ( + #595). - Improve BestJokeAgent sample and fix NumberGuessingAgent example (#503, #445). # 0.3.0 @@ -86,7 +224,8 @@ Fixed iOS target publication the latest checkpoint (#305) - **Vector Document Storage**: Store embeddings and documents in persistent storage for retrieval-augmented generation ( RAG), with in-memory and local file implementations (#272) -- **OpenTelemetry Support**: Native integration with OpenTelemetry for unified tracing logs across AI agents (#369, #401, +- **OpenTelemetry Support**: Native integration with OpenTelemetry for unified tracing logs across AI agents (#369, + #401, #423, #426) - **Content Moderation**: Built-in support for moderating models, enabling AI agents to automatically review and filter outputs for safety and compliance (#395) diff --git a/README.md b/README.md index 009c63129c..845efce6b1 100644 --- a/README.md +++ b/README.md @@ -87,7 +87,7 @@ Currently, the framework supports the JVM, JS, WasmJS and iOS targets. ``` dependencies { - implementation("ai.koog:koog-agents:0.4.2") + implementation("ai.koog:koog-agents:0.5.0") } ``` 2. Make sure that you have `mavenCentral()` in the list of repositories. @@ -97,7 +97,7 @@ Currently, the framework supports the JVM, JS, WasmJS and iOS targets. ``` dependencies { - implementation 'ai.koog:koog-agents:0.4.2' + implementation 'ai.koog:koog-agents:0.5.0' } ``` 2. Make sure that you have `mavenCentral()` in the list of repositories. @@ -109,7 +109,7 @@ Currently, the framework supports the JVM, JS, WasmJS and iOS targets. ai.koog koog-agents-jvm - 0.4.2 + 0.5.0 ``` 2. Make sure that you have `mavenCentral` in the list of repositories. diff --git a/build.gradle.kts b/build.gradle.kts index 0f513eadb9..a4252f26b3 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -18,7 +18,7 @@ group = "ai.koog" version = run { // our version follows the semver specification - val main = "0.4.3" + val main = "0.5.0" val feat = run { val releaseBuild = !System.getenv("BRANCH_KOOG_IS_RELEASING_FROM").isNullOrBlank()