diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/RegistryAgentStringSerializer.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/RegistryAgentStringSerializer.kt index 6919e045..d8712f46 100644 --- a/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/RegistryAgentStringSerializer.kt +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/registry/RegistryAgentStringSerializer.kt @@ -2,21 +2,18 @@ package org.coralprotocol.coralserver.agent.registry +import dev.eav.tomlkt.* import io.ktor.client.request.* import io.ktor.client.statement.* -import io.ktor.util.* import kotlinx.coroutines.runBlocking -import kotlinx.serialization.ExperimentalSerializationApi -import kotlinx.serialization.KSerializer -import kotlinx.serialization.SerialName -import kotlinx.serialization.Serializable +import kotlinx.serialization.* import kotlinx.serialization.builtins.ListSerializer import kotlinx.serialization.descriptors.PrimitiveKind import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor import kotlinx.serialization.descriptors.SerialDescriptor import kotlinx.serialization.encoding.Decoder import kotlinx.serialization.encoding.Encoder -import kotlinx.serialization.json.JsonClassDiscriminator +import kotlinx.serialization.json.* import org.coralprotocol.coralserver.agent.runtime.prototype.DEFAULT_LOOP_FOLLOWUP_PROMPT import org.coralprotocol.coralserver.agent.runtime.prototype.DEFAULT_LOOP_INITIAL_BASE_PROMPT import org.coralprotocol.coralserver.agent.runtime.prototype.DEFAULT_SYSTEM_PROMPT @@ -24,6 +21,8 @@ import org.coralprotocol.coralserver.mcp.McpResourceName import org.koin.core.component.KoinComponent import java.io.File import java.nio.charset.Charset +import kotlin.io.encoding.Base64 +import kotlin.reflect.full.findAnnotation /* NOTE: This list is used in tests, resources/constants/coral-agent.toml must be updated to include any new constants @@ -39,6 +38,7 @@ val stringReferenceConstants = buildMap { @Serializable @JsonClassDiscriminator("type") +@TomlClassDiscriminator("type") sealed interface PotentialStringReference { val base64: Boolean? @@ -76,6 +76,24 @@ sealed interface PotentialStringReference { open class RegistryAgentStringSerializer : KSerializer, KoinComponent { open val base64Default: Boolean = false + private val stringSerializer = PotentialStringReference.String.serializer() + private val fileSerializer = PotentialStringReference.File.serializer() + private val urlSerializer = PotentialStringReference.Url.serializer() + private val constantSerializer = PotentialStringReference.Constant.serializer() + + private val potentialStringSerializerDiscriminator = run { + val tomlDiscriminator = PotentialStringReference::class + .findAnnotation()?.discriminator + ?: "type" + + val jsonDiscriminator = PotentialStringReference::class + .findAnnotation()?.discriminator + ?: "type" + + require(tomlDiscriminator == jsonDiscriminator) + tomlDiscriminator + } + override val descriptor: SerialDescriptor = PrimitiveSerialDescriptor("String", PrimitiveKind.STRING) @@ -84,47 +102,139 @@ open class RegistryAgentStringSerializer : KSerializer, KoinComponent { } override fun deserialize(decoder: Decoder): String { - val context = registryAgentSerializationContext.get() - ?: return decoder.decodeString() - - return try { - val reference = decoder.decodeSerializableValue(PotentialStringReference.serializer()) - val text = when (reference) { - is PotentialStringReference.File -> { - if (!context.enableFileReferences) - throw IllegalStateException("File references are not enabled") - - val file = File(reference.path) - if (file.isAbsolute || context.agentFilePath == null) { - file.readText(Charset.forName(reference.encoding)) - } else { - context.agentFilePath.toFile().resolve(file).readText(Charset.forName(reference.encoding)) + val reference = when (decoder) { + is TomlDecoder -> { + when (val element = decoder.decodeTomlElement()) { + is TomlLiteral if element.type == TomlLiteral.Type.String -> { + PotentialStringReference.String(element.content) + } + + is TomlTable -> { + val type = element[potentialStringSerializerDiscriminator]?.asTomlLiteral()?.content + ?: throw SerializationException("Missing discriminator \"$potentialStringSerializerDiscriminator\" in string reference") + + val element = TomlTable(element.filterKeys { it != potentialStringSerializerDiscriminator }) + when (type) { + stringSerializer.descriptor.serialName -> decoder.toml.decodeFromTomlElement( + stringSerializer, + element + ) + + fileSerializer.descriptor.serialName -> decoder.toml.decodeFromTomlElement( + fileSerializer, + element + ) + + urlSerializer.descriptor.serialName -> decoder.toml.decodeFromTomlElement( + urlSerializer, + element + ) + + constantSerializer.descriptor.serialName -> decoder.toml.decodeFromTomlElement( + constantSerializer, + element + ) + + else -> { + throw SerializationException("Unknown string reference type: $type") + } + } + + } + + else -> { + throw SerializationException("Unsupported string type: ${element::class.simpleName}") } } + } + + is JsonDecoder -> { + when (val element = decoder.decodeJsonElement()) { + is JsonPrimitive if element.isString -> { + PotentialStringReference.String(element.content) + } - is PotentialStringReference.String -> reference.value - is PotentialStringReference.Url -> { - if (!context.enableUrlReferences) - throw IllegalStateException("Url references are not enabled") + is JsonObject -> { + val type = element[potentialStringSerializerDiscriminator]?.jsonPrimitive?.content + ?: throw SerializationException("Missing discriminator \"$potentialStringSerializerDiscriminator\" in string reference") + + val element = JsonObject(element.filterKeys { it != potentialStringSerializerDiscriminator }) + when (type) { + stringSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement( + stringSerializer, + element + ) + + fileSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement( + fileSerializer, + element + ) + + urlSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement( + urlSerializer, + element + ) + + constantSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement( + constantSerializer, + element + ) + + else -> { + throw SerializationException("Unknown string reference type: $type") + } + } + } - runBlocking { - context.httpClient.get(reference.url).bodyAsText(Charset.forName(reference.encoding)) + else -> { + throw SerializationException("Unsupported string type: ${element::class.simpleName}") } } + } + + else -> throw SerializationException("Unsupported decoder type: ${decoder::class.simpleName}") + } - is PotentialStringReference.Constant -> { - stringReferenceConstants[reference.name] ?: throw IllegalStateException("Constant ${reference.name} not found") + val text = when (reference) { + is PotentialStringReference.File -> { + val context = registryAgentSerializationContext.get() + ?: throw SerializationException("File references require a serialization context") + + if (!context.enableFileReferences) + throw SerializationException("File references are not enabled") + + val file = File(reference.path) + if (file.isAbsolute || context.agentFilePath == null) { + file.readText(Charset.forName(reference.encoding)) + } else { + context.agentFilePath.toFile().resolve(file).readText(Charset.forName(reference.encoding)) } } - val base64 = reference.base64 ?: base64Default - if (base64) { - text.encodeBase64() - } else { - text + is PotentialStringReference.String -> reference.value + is PotentialStringReference.Url -> { + val context = registryAgentSerializationContext.get() + ?: throw SerializationException("URL references require a serialization context") + + if (!context.enableUrlReferences) + throw SerializationException("Url references are not enabled") + + runBlocking { + context.httpClient.get(reference.url).bodyAsText(Charset.forName(reference.encoding)) + } + } + + is PotentialStringReference.Constant -> { + stringReferenceConstants[reference.name] + ?: throw SerializationException("Constant ${reference.name} not found") } - } catch (_: IllegalArgumentException) { - decoder.decodeString() + } + + val base64 = reference.base64 ?: base64Default + return if (base64) { + Base64.encode(text.encodeToByteArray()) + } else { + text } } diff --git a/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/prototype/PrototypeString.kt b/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/prototype/PrototypeString.kt index 8cbab0aa..c5417309 100644 --- a/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/prototype/PrototypeString.kt +++ b/src/main/kotlin/org/coralprotocol/coralserver/agent/runtime/prototype/PrototypeString.kt @@ -68,7 +68,10 @@ sealed class PrototypeString { for (part in parts) { when (part) { is PrototypeUrlPart.Path -> builder.appendPathSegments(part.value.resolve(agentOptions)) - is PrototypeUrlPart.QueryParameter -> builder.parameters.append(part.name, part.value.resolve(agentOptions)) + is PrototypeUrlPart.QueryParameter -> builder.parameters.append( + part.name, + part.value.resolve(agentOptions) + ) } } @@ -190,40 +193,60 @@ object PrototypeStringSerializer : KSerializer { } } - override fun deserialize(decoder: Decoder): PrototypeString { + override fun deserialize(decoder: Decoder + ): PrototypeString { + // deserialization should allow inline strings to represent as string literals and should also + // support PotentialStringReference deserialization return when (decoder) { - - // json should only support plain deserialization of discriminated option/inline subtypes is JsonDecoder -> { - val jsonObject = decoder.decodeJsonElement() as JsonObject - - when (val type = jsonObject[prototypeStringDiscriminator]?.jsonPrimitive?.content) { - inlineSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement( - inlineSerializer, - JsonObject(jsonObject.filterKeys { it != prototypeStringDiscriminator }) - ) - - optionSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement( - optionSerializer, - JsonObject(jsonObject.filterKeys { it != prototypeStringDiscriminator }) - ) - - composedStringSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement( - composedStringSerializer, - JsonObject(jsonObject.filterKeys { it != prototypeStringDiscriminator }) - ) + when (val element = decoder.decodeJsonElement()) { + is JsonPrimitive if element.isString -> { + try { + PrototypeString.Inline(element.jsonPrimitive.content) + } catch (_: IllegalArgumentException) { + throw SerializationException("Unsupported json literal: $element") + } + } - composedUrlSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement( - composedUrlSerializer, - JsonObject(jsonObject.filterKeys { it != prototypeStringDiscriminator }) - ) + is JsonObject -> { + when (val type = element[prototypeStringDiscriminator]?.jsonPrimitive?.content) { + inlineSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement( + inlineSerializer, + JsonObject(element.filterKeys { it != prototypeStringDiscriminator }) + ) + + optionSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement( + optionSerializer, + JsonObject(element.filterKeys { it != prototypeStringDiscriminator }) + ) + + composedStringSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement( + composedStringSerializer, + JsonObject(element.filterKeys { it != prototypeStringDiscriminator }) + ) + + composedUrlSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement( + composedUrlSerializer, + JsonObject(element.filterKeys { it != prototypeStringDiscriminator }) + ) + + else -> { + PrototypeString.Inline( + decoder.json.decodeFromJsonElement( + RegistryAgentStringSerializer(), + element + ) + ) + } + } + } - else -> throw SerializationException("Unknown type: $type") + else -> { + throw SerializationException("Unsupported json element: $element") + } } } - // TOML deserialization should allow inline strings to represent as string literals and should also - // support PotentialStringReference deserialization is TomlDecoder -> { val tomlElement = decoder.decodeTomlElement() try { diff --git a/src/test/kotlin/org/coralprotocol/coralserver/registry/PrototypeStringSerializerTest.kt b/src/test/kotlin/org/coralprotocol/coralserver/registry/PrototypeStringSerializerTest.kt index 322ce49c..5896e392 100644 --- a/src/test/kotlin/org/coralprotocol/coralserver/registry/PrototypeStringSerializerTest.kt +++ b/src/test/kotlin/org/coralprotocol/coralserver/registry/PrototypeStringSerializerTest.kt @@ -1,15 +1,20 @@ package org.coralprotocol.coralserver.registry +import io.kotest.assertions.throwables.shouldThrow +import io.kotest.engine.spec.tempfile import io.kotest.matchers.equals.shouldBeEqual import io.kotest.matchers.nulls.shouldNotBeNull import io.kotest.matchers.types.shouldBeInstanceOf import io.ktor.server.application.* import io.ktor.server.response.* import io.ktor.server.routing.* +import kotlinx.serialization.SerializationException +import kotlinx.serialization.json.Json import org.coralprotocol.coralserver.CoralTest import org.coralprotocol.coralserver.agent.registry.AGENT_LLM_PROXY_NAME_LENGTH import org.coralprotocol.coralserver.agent.registry.MAXIMUM_SUPPORTED_AGENT_VERSION import org.coralprotocol.coralserver.agent.registry.UnresolvedRegistryAgent +import org.coralprotocol.coralserver.agent.registry.stringReferenceConstants import org.coralprotocol.coralserver.agent.runtime.prototype.PrototypeString import org.koin.test.inject import java.io.File @@ -68,4 +73,77 @@ class PrototypeStringSerializerTest : CoralTest({ agent.runtimes.prototypeRuntime.shouldNotBeNull() .proxyName.shouldBeInstanceOf().value.shouldBeEqual(uuid) } + + test("testJsonPrototypeStrings") { + val json by inject() + + // discriminated inline + var value = UUID.randomUUID().toString() + json.decodeFromString( + PrototypeString.serializer(), + """ + { + "type": "inline", + "value": "$value" + } + """.trimIndent() + ).shouldBeInstanceOf().value.shouldBeEqual(value) + + // discriminated option + value = UUID.randomUUID().toString() + json.decodeFromString( + PrototypeString.serializer(), + """ + { + "type": "option", + "name": "$value" + } + """.trimIndent() + ).shouldBeInstanceOf().name.shouldBeEqual(value) + + // string literal + value = UUID.randomUUID().toString() + json.decodeFromString(PrototypeString.serializer(), "\"$value\"") + .shouldBeInstanceOf().value.shouldBeEqual(value) + + // string constant reference + // discriminated inline + var (constantName, constantValue) = stringReferenceConstants.entries.first() + json.decodeFromString( + PrototypeString.serializer(), + """ + { + "type": "constant", + "name": "$constantName" + } + """.trimIndent() + ).shouldBeInstanceOf().value.shouldBeEqual(constantValue) + + // url reference (should not be allowed) + shouldThrow { + json.decodeFromString( + PrototypeString.serializer(), + """ + { + "type": "url", + "url": "https://google.se" + } + """.trimIndent() + ).shouldBeInstanceOf().value.shouldBeEqual(constantValue) + } + + // file reference (should not be allowed) + shouldThrow { + val file = tempfile("test.txt") + json.decodeFromString( + PrototypeString.serializer(), + """ + { + "type": "file", + "file": "$file" + } + """.trimIndent() + ).shouldBeInstanceOf().value.shouldBeEqual(constantValue) + } + } }) \ No newline at end of file