Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,27 @@

package org.coralprotocol.coralserver.agent.registry

import dev.eav.tomlkt.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.util.*
import kotlinx.coroutines.runBlocking
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.KSerializer
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.*
import kotlinx.serialization.builtins.ListSerializer
import kotlinx.serialization.descriptors.PrimitiveKind
import kotlinx.serialization.descriptors.PrimitiveSerialDescriptor
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.json.JsonClassDiscriminator
import kotlinx.serialization.json.*
import org.coralprotocol.coralserver.agent.runtime.prototype.DEFAULT_LOOP_FOLLOWUP_PROMPT
import org.coralprotocol.coralserver.agent.runtime.prototype.DEFAULT_LOOP_INITIAL_BASE_PROMPT
import org.coralprotocol.coralserver.agent.runtime.prototype.DEFAULT_SYSTEM_PROMPT
import org.coralprotocol.coralserver.mcp.McpResourceName
import org.koin.core.component.KoinComponent
import java.io.File
import java.nio.charset.Charset
import kotlin.io.encoding.Base64
import kotlin.reflect.full.findAnnotation

/*
NOTE: This list is used in tests, resources/constants/coral-agent.toml must be updated to include any new constants
Expand All @@ -39,6 +38,7 @@ val stringReferenceConstants = buildMap {

@Serializable
@JsonClassDiscriminator("type")
@TomlClassDiscriminator("type")
sealed interface PotentialStringReference {
val base64: Boolean?

Expand Down Expand Up @@ -76,6 +76,24 @@ sealed interface PotentialStringReference {
open class RegistryAgentStringSerializer : KSerializer<String>, KoinComponent {
open val base64Default: Boolean = false

private val stringSerializer = PotentialStringReference.String.serializer()
private val fileSerializer = PotentialStringReference.File.serializer()
private val urlSerializer = PotentialStringReference.Url.serializer()
private val constantSerializer = PotentialStringReference.Constant.serializer()

private val potentialStringSerializerDiscriminator = run {
val tomlDiscriminator = PotentialStringReference::class
.findAnnotation<TomlClassDiscriminator>()?.discriminator
?: "type"

val jsonDiscriminator = PotentialStringReference::class
.findAnnotation<JsonClassDiscriminator>()?.discriminator
?: "type"

require(tomlDiscriminator == jsonDiscriminator)
tomlDiscriminator
}

override val descriptor: SerialDescriptor =
PrimitiveSerialDescriptor("String", PrimitiveKind.STRING)

Expand All @@ -84,47 +102,139 @@ open class RegistryAgentStringSerializer : KSerializer<String>, KoinComponent {
}

override fun deserialize(decoder: Decoder): String {
val context = registryAgentSerializationContext.get()
?: return decoder.decodeString()

return try {
val reference = decoder.decodeSerializableValue(PotentialStringReference.serializer())
val text = when (reference) {
is PotentialStringReference.File -> {
if (!context.enableFileReferences)
throw IllegalStateException("File references are not enabled")

val file = File(reference.path)
if (file.isAbsolute || context.agentFilePath == null) {
file.readText(Charset.forName(reference.encoding))
} else {
context.agentFilePath.toFile().resolve(file).readText(Charset.forName(reference.encoding))
val reference = when (decoder) {
is TomlDecoder -> {
when (val element = decoder.decodeTomlElement()) {
is TomlLiteral if element.type == TomlLiteral.Type.String -> {
PotentialStringReference.String(element.content)
}

is TomlTable -> {
val type = element[potentialStringSerializerDiscriminator]?.asTomlLiteral()?.content
?: throw SerializationException("Missing discriminator \"$potentialStringSerializerDiscriminator\" in string reference")

val element = TomlTable(element.filterKeys { it != potentialStringSerializerDiscriminator })
when (type) {
stringSerializer.descriptor.serialName -> decoder.toml.decodeFromTomlElement(
stringSerializer,
element
)

fileSerializer.descriptor.serialName -> decoder.toml.decodeFromTomlElement(
fileSerializer,
element
)

urlSerializer.descriptor.serialName -> decoder.toml.decodeFromTomlElement(
urlSerializer,
element
)

constantSerializer.descriptor.serialName -> decoder.toml.decodeFromTomlElement(
constantSerializer,
element
)

else -> {
throw SerializationException("Unknown string reference type: $type")
}
}

}

else -> {
throw SerializationException("Unsupported string type: ${element::class.simpleName}")
}
}
}

is JsonDecoder -> {
when (val element = decoder.decodeJsonElement()) {
is JsonPrimitive if element.isString -> {
PotentialStringReference.String(element.content)
}

is PotentialStringReference.String -> reference.value
is PotentialStringReference.Url -> {
if (!context.enableUrlReferences)
throw IllegalStateException("Url references are not enabled")
is JsonObject -> {
val type = element[potentialStringSerializerDiscriminator]?.jsonPrimitive?.content
?: throw SerializationException("Missing discriminator \"$potentialStringSerializerDiscriminator\" in string reference")

val element = JsonObject(element.filterKeys { it != potentialStringSerializerDiscriminator })
when (type) {
stringSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement(
stringSerializer,
element
)

fileSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement(
fileSerializer,
element
)

urlSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement(
urlSerializer,
element
)

constantSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement(
constantSerializer,
element
)

else -> {
throw SerializationException("Unknown string reference type: $type")
}
}
}

runBlocking {
context.httpClient.get(reference.url).bodyAsText(Charset.forName(reference.encoding))
else -> {
throw SerializationException("Unsupported string type: ${element::class.simpleName}")
}
}
}

else -> throw SerializationException("Unsupported decoder type: ${decoder::class.simpleName}")
}

is PotentialStringReference.Constant -> {
stringReferenceConstants[reference.name] ?: throw IllegalStateException("Constant ${reference.name} not found")
val text = when (reference) {
is PotentialStringReference.File -> {
val context = registryAgentSerializationContext.get()
?: throw SerializationException("File references require a serialization context")

if (!context.enableFileReferences)
throw SerializationException("File references are not enabled")

val file = File(reference.path)
if (file.isAbsolute || context.agentFilePath == null) {
file.readText(Charset.forName(reference.encoding))
} else {
context.agentFilePath.toFile().resolve(file).readText(Charset.forName(reference.encoding))
}
}

val base64 = reference.base64 ?: base64Default
if (base64) {
text.encodeBase64()
} else {
text
is PotentialStringReference.String -> reference.value
is PotentialStringReference.Url -> {
val context = registryAgentSerializationContext.get()
?: throw SerializationException("URL references require a serialization context")

if (!context.enableUrlReferences)
throw SerializationException("Url references are not enabled")

runBlocking {
context.httpClient.get(reference.url).bodyAsText(Charset.forName(reference.encoding))
}
}

is PotentialStringReference.Constant -> {
stringReferenceConstants[reference.name]
?: throw SerializationException("Constant ${reference.name} not found")
}
} catch (_: IllegalArgumentException) {
decoder.decodeString()
}

val base64 = reference.base64 ?: base64Default
return if (base64) {
Base64.encode(text.encodeToByteArray())
} else {
text
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ sealed class PrototypeString {
for (part in parts) {
when (part) {
is PrototypeUrlPart.Path -> builder.appendPathSegments(part.value.resolve(agentOptions))
is PrototypeUrlPart.QueryParameter -> builder.parameters.append(part.name, part.value.resolve(agentOptions))
is PrototypeUrlPart.QueryParameter -> builder.parameters.append(
part.name,
part.value.resolve(agentOptions)
)
}
}

Expand Down Expand Up @@ -190,40 +193,60 @@ object PrototypeStringSerializer : KSerializer<PrototypeString> {
}
}

override fun deserialize(decoder: Decoder): PrototypeString {
override fun deserialize(decoder: Decoder
): PrototypeString {
// deserialization should allow inline strings to represent as string literals and should also
// support PotentialStringReference deserialization
return when (decoder) {

// json should only support plain deserialization of discriminated option/inline subtypes
is JsonDecoder -> {
val jsonObject = decoder.decodeJsonElement() as JsonObject

when (val type = jsonObject[prototypeStringDiscriminator]?.jsonPrimitive?.content) {
inlineSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement(
inlineSerializer,
JsonObject(jsonObject.filterKeys { it != prototypeStringDiscriminator })
)

optionSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement(
optionSerializer,
JsonObject(jsonObject.filterKeys { it != prototypeStringDiscriminator })
)

composedStringSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement(
composedStringSerializer,
JsonObject(jsonObject.filterKeys { it != prototypeStringDiscriminator })
)
when (val element = decoder.decodeJsonElement()) {
is JsonPrimitive if element.isString -> {
try {
PrototypeString.Inline(element.jsonPrimitive.content)
} catch (_: IllegalArgumentException) {
throw SerializationException("Unsupported json literal: $element")
}
}

composedUrlSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement(
composedUrlSerializer,
JsonObject(jsonObject.filterKeys { it != prototypeStringDiscriminator })
)
is JsonObject -> {
when (val type = element[prototypeStringDiscriminator]?.jsonPrimitive?.content) {
inlineSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement(
inlineSerializer,
JsonObject(element.filterKeys { it != prototypeStringDiscriminator })
)

optionSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement(
optionSerializer,
JsonObject(element.filterKeys { it != prototypeStringDiscriminator })
)

composedStringSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement(
composedStringSerializer,
JsonObject(element.filterKeys { it != prototypeStringDiscriminator })
)

composedUrlSerializer.descriptor.serialName -> decoder.json.decodeFromJsonElement(
composedUrlSerializer,
JsonObject(element.filterKeys { it != prototypeStringDiscriminator })
)

else -> {
PrototypeString.Inline(
decoder.json.decodeFromJsonElement(
RegistryAgentStringSerializer(),
element
)
)
}
}
}

else -> throw SerializationException("Unknown type: $type")
else -> {
throw SerializationException("Unsupported json element: $element")
}
}
}

// TOML deserialization should allow inline strings to represent as string literals and should also
// support PotentialStringReference deserialization
is TomlDecoder -> {
val tomlElement = decoder.decodeTomlElement()
try {
Expand Down
Loading
Loading