diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index b078493ef..508870150 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -4,46 +4,24 @@ package io.modelcontextprotocol.server; -import java.time.Duration; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.function.BiFunction; - import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.json.TypeRef; import io.modelcontextprotocol.json.schema.JsonSchemaValidator; -import io.modelcontextprotocol.spec.DefaultMcpStreamableServerSessionFactory; -import io.modelcontextprotocol.spec.McpClientSession; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.*; +import io.modelcontextprotocol.spec.McpSchema.*; import io.modelcontextprotocol.spec.McpSchema.CompleteResult.CompleteCompletion; -import io.modelcontextprotocol.spec.McpSchema.ErrorCodes; -import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; -import io.modelcontextprotocol.spec.McpSchema.PromptReference; -import io.modelcontextprotocol.spec.McpSchema.ResourceReference; -import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import io.modelcontextprotocol.spec.McpServerSession; -import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.spec.McpServerTransportProviderBase; -import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; -import io.modelcontextprotocol.util.Assert; -import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; -import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; -import io.modelcontextprotocol.util.Utils; +import io.modelcontextprotocol.util.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import java.time.Duration; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.function.BiFunction; + import static io.modelcontextprotocol.spec.McpError.RESOURCE_NOT_FOUND; /** @@ -98,6 +76,8 @@ public class McpAsyncServer { private final JsonSchemaValidator jsonSchemaValidator; + private final boolean validateToolInputs; + private final McpSchema.ServerCapabilities serverCapabilities; private final McpSchema.Implementation serverInfo; @@ -129,7 +109,8 @@ public class McpAsyncServer { */ McpAsyncServer(McpServerTransportProvider mcpTransportProvider, McpJsonMapper jsonMapper, McpServerFeatures.Async features, Duration requestTimeout, - McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator, + boolean validateToolInputs) { this.mcpTransportProvider = mcpTransportProvider; this.jsonMapper = jsonMapper; this.serverInfo = features.serverInfo(); @@ -142,6 +123,7 @@ public class McpAsyncServer { this.completions.putAll(features.completions()); this.uriTemplateManagerFactory = uriTemplateManagerFactory; this.jsonSchemaValidator = jsonSchemaValidator; + this.validateToolInputs = validateToolInputs; Map> requestHandlers = prepareRequestHandlers(); Map notificationHandlers = prepareNotificationHandlers(features); @@ -157,7 +139,8 @@ public class McpAsyncServer { McpAsyncServer(McpStreamableServerTransportProvider mcpTransportProvider, McpJsonMapper jsonMapper, McpServerFeatures.Async features, Duration requestTimeout, - McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator, + boolean validateToolInputs) { this.mcpTransportProvider = mcpTransportProvider; this.jsonMapper = jsonMapper; this.serverInfo = features.serverInfo(); @@ -170,6 +153,7 @@ public class McpAsyncServer { this.completions.putAll(features.completions()); this.uriTemplateManagerFactory = uriTemplateManagerFactory; this.jsonSchemaValidator = jsonSchemaValidator; + this.validateToolInputs = validateToolInputs; Map> requestHandlers = prepareRequestHandlers(); Map notificationHandlers = prepareNotificationHandlers(features); @@ -543,6 +527,13 @@ private McpRequestHandler toolsCallRequestHandler() { .build()); } + McpSchema.Tool tool = toolSpecification.get().tool(); + CallToolResult validationError = ToolInputValidator.validate(tool, callToolRequest.arguments(), + this.validateToolInputs, this.jsonMapper, this.jsonSchemaValidator); + if (validationError != null) { + return Mono.just(validationError); + } + return toolSpecification.get().callHandler().apply(exchange, callToolRequest); }; } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java index 360eb607d..04212e1de 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -25,6 +25,7 @@ import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.ToolInputValidator; import io.modelcontextprotocol.util.ToolNameValidator; import reactor.core.publisher.Mono; @@ -243,7 +244,7 @@ public McpAsyncServer build() { : McpJsonDefaults.getSchemaValidator(); return new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, - features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator); + features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator, validateToolInputs); } } @@ -269,7 +270,7 @@ public McpAsyncServer build() { var jsonSchemaValidator = this.jsonSchemaValidator != null ? this.jsonSchemaValidator : McpJsonDefaults.getSchemaValidator(); return new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, - features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator); + features, requestTimeout, uriTemplateManagerFactory, jsonSchemaValidator, validateToolInputs); } } @@ -293,6 +294,8 @@ abstract class AsyncSpecification> { boolean strictToolNameValidation = ToolNameValidator.isStrictByDefault(); + boolean validateToolInputs = ToolInputValidator.isEnabledByDefault(); + /** * The Model Context Protocol (MCP) allows servers to expose tools that can be * invoked by language models. Tools enable models to interact with external @@ -421,6 +424,18 @@ public AsyncSpecification strictToolNameValidation(boolean strict) { return this; } + /** + * Sets whether to validate tool inputs against the tool's input schema. When set, + * this takes priority over the system property + * {@code io.modelcontextprotocol.validateToolInputs}. + * @param validate true to validate inputs and return error on validation failure + * @return This builder instance for method chaining + */ + public AsyncSpecification validateToolInputs(boolean validate) { + this.validateToolInputs = validate; + return this; + } + /** * Sets the server capabilities that will be advertised to clients during * connection initialization. Capabilities define what features the server @@ -818,7 +833,8 @@ public McpSyncServer build() { var asyncServer = new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, asyncFeatures, requestTimeout, uriTemplateManagerFactory, - jsonSchemaValidator != null ? jsonSchemaValidator : McpJsonDefaults.getSchemaValidator()); + jsonSchemaValidator != null ? jsonSchemaValidator : McpJsonDefaults.getSchemaValidator(), + validateToolInputs); return new McpSyncServer(asyncServer, this.immediateExecution); } @@ -849,7 +865,7 @@ public McpSyncServer build() { : McpJsonDefaults.getSchemaValidator(); var asyncServer = new McpAsyncServer(transportProvider, jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, asyncFeatures, this.requestTimeout, - this.uriTemplateManagerFactory, jsonSchemaValidator); + this.uriTemplateManagerFactory, jsonSchemaValidator, validateToolInputs); return new McpSyncServer(asyncServer, this.immediateExecution); } @@ -872,6 +888,8 @@ abstract class SyncSpecification> { boolean strictToolNameValidation = ToolNameValidator.isStrictByDefault(); + boolean validateToolInputs = ToolInputValidator.isEnabledByDefault(); + /** * The Model Context Protocol (MCP) allows servers to expose tools that can be * invoked by language models. Tools enable models to interact with external @@ -1004,6 +1022,18 @@ public SyncSpecification strictToolNameValidation(boolean strict) { return this; } + /** + * Sets whether to validate tool inputs against the tool's input schema. When set, + * this takes priority over the system property + * {@code io.modelcontextprotocol.validateToolInputs}. + * @param validate true to validate inputs and return error on validation failure + * @return This builder instance for method chaining + */ + public SyncSpecification validateToolInputs(boolean validate) { + this.validateToolInputs = validate; + return this; + } + /** * Sets the server capabilities that will be advertised to clients during * connection initialization. Capabilities define what features the server @@ -1401,6 +1431,8 @@ class StatelessAsyncSpecification { boolean strictToolNameValidation = ToolNameValidator.isStrictByDefault(); + boolean validateToolInputs = ToolInputValidator.isEnabledByDefault(); + /** * The Model Context Protocol (MCP) allows servers to expose tools that can be * invoked by language models. Tools enable models to interact with external @@ -1530,6 +1562,18 @@ public StatelessAsyncSpecification strictToolNameValidation(boolean strict) { return this; } + /** + * Sets whether to validate tool inputs against the tool's input schema. When set, + * this takes priority over the system property + * {@code io.modelcontextprotocol.validateToolInputs}. + * @param validate true to validate inputs and return error on validation failure + * @return This builder instance for method chaining + */ + public StatelessAsyncSpecification validateToolInputs(boolean validate) { + this.validateToolInputs = validate; + return this; + } + /** * Sets the server capabilities that will be advertised to clients during * connection initialization. Capabilities define what features the server @@ -1859,7 +1903,8 @@ public McpStatelessAsyncServer build() { this.resources, this.resourceTemplates, this.prompts, this.completions, this.instructions); return new McpStatelessAsyncServer(transport, jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, features, requestTimeout, uriTemplateManagerFactory, - jsonSchemaValidator != null ? jsonSchemaValidator : McpJsonDefaults.getSchemaValidator()); + jsonSchemaValidator != null ? jsonSchemaValidator : McpJsonDefaults.getSchemaValidator(), + validateToolInputs); } } @@ -1884,6 +1929,8 @@ class StatelessSyncSpecification { boolean strictToolNameValidation = ToolNameValidator.isStrictByDefault(); + boolean validateToolInputs = ToolInputValidator.isEnabledByDefault(); + /** * The Model Context Protocol (MCP) allows servers to expose tools that can be * invoked by language models. Tools enable models to interact with external @@ -2013,6 +2060,18 @@ public StatelessSyncSpecification strictToolNameValidation(boolean strict) { return this; } + /** + * Sets whether to validate tool inputs against the tool's input schema. When set, + * this takes priority over the system property + * {@code io.modelcontextprotocol.validateToolInputs}. + * @param validate true to validate inputs and return error on validation failure + * @return This builder instance for method chaining + */ + public StatelessSyncSpecification validateToolInputs(boolean validate) { + this.validateToolInputs = validate; + return this; + } + /** * Sets the server capabilities that will be advertised to clients during * connection initialization. Capabilities define what features the server @@ -2360,7 +2419,8 @@ public McpStatelessSyncServer build() { var asyncServer = new McpStatelessAsyncServer(transport, jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, asyncFeatures, requestTimeout, uriTemplateManagerFactory, - this.jsonSchemaValidator != null ? this.jsonSchemaValidator : McpJsonDefaults.getSchemaValidator()); + this.jsonSchemaValidator != null ? this.jsonSchemaValidator : McpJsonDefaults.getSchemaValidator(), + validateToolInputs); return new McpStatelessSyncServer(asyncServer, this.immediateExecution); } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java index c7a1fd0d7..250df7ab7 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java @@ -21,6 +21,7 @@ import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.DefaultMcpUriTemplateManagerFactory; import io.modelcontextprotocol.util.McpUriTemplateManagerFactory; +import io.modelcontextprotocol.util.ToolInputValidator; import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -77,9 +78,12 @@ public class McpStatelessAsyncServer { private final JsonSchemaValidator jsonSchemaValidator; + private final boolean validateToolInputs; + McpStatelessAsyncServer(McpStatelessServerTransport mcpTransport, McpJsonMapper jsonMapper, McpStatelessServerFeatures.Async features, Duration requestTimeout, - McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator) { + McpUriTemplateManagerFactory uriTemplateManagerFactory, JsonSchemaValidator jsonSchemaValidator, + boolean validateToolInputs) { this.mcpTransportProvider = mcpTransport; this.jsonMapper = jsonMapper; this.serverInfo = features.serverInfo(); @@ -92,6 +96,7 @@ public class McpStatelessAsyncServer { this.completions.putAll(features.completions()); this.uriTemplateManagerFactory = uriTemplateManagerFactory; this.jsonSchemaValidator = jsonSchemaValidator; + this.validateToolInputs = validateToolInputs; Map> requestHandlers = new HashMap<>(); @@ -409,6 +414,13 @@ private McpStatelessRequestHandler toolsCallRequestHandler() { .build()); } + McpSchema.Tool tool = toolSpecification.get().tool(); + CallToolResult validationError = ToolInputValidator.validate(tool, callToolRequest.arguments(), + this.validateToolInputs, this.jsonMapper, this.jsonSchemaValidator); + if (validationError != null) { + return Mono.just(validationError); + } + return toolSpecification.get().callHandler().apply(ctx, callToolRequest); }; } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/util/ToolInputValidator.java b/mcp-core/src/main/java/io/modelcontextprotocol/util/ToolInputValidator.java new file mode 100644 index 000000000..141705369 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/util/ToolInputValidator.java @@ -0,0 +1,74 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Validates tool input arguments against JSON schema. + * + * @author Andrei Shakirin + */ +public final class ToolInputValidator { + + private static final Logger logger = LoggerFactory.getLogger(ToolInputValidator.class); + + /** + * System property to disable tool input validation. Set to "false" to disable. + * Default is true (validation enabled). + */ + public static final String VALIDATE_TOOL_INPUTS_PROPERTY = "io.modelcontextprotocol.validateToolInputs"; + + private ToolInputValidator() { + } + + /** + * Returns whether validation is enabled by default based on system property. + * @return true if validation is enabled (default), false if disabled + */ + public static boolean isEnabledByDefault() { + return !"false".equalsIgnoreCase(System.getProperty(VALIDATE_TOOL_INPUTS_PROPERTY)); + } + + /** + * Validates tool arguments against the tool's input schema. + * @param tool the tool definition containing the input schema + * @param arguments the arguments to validate + * @param validateToolInputs whether validation is enabled + * @param jsonMapper the JSON mapper for schema conversion + * @param validator the JSON schema validator (may be null) + * @return CallToolResult with isError=true if validation fails, null if valid or + * validation skipped + */ + public static CallToolResult validate(McpSchema.Tool tool, Map arguments, + boolean validateToolInputs, McpJsonMapper jsonMapper, JsonSchemaValidator validator) { + if (!validateToolInputs || tool.inputSchema() == null || validator == null) { + return null; + } + Map inputSchema = jsonMapper.convertValue(tool.inputSchema(), + new TypeRef>() { + }); + Map args = arguments != null ? arguments : Map.of(); + var validation = validator.validate(inputSchema, args); + if (!validation.valid()) { + logger.warn("Tool '{}' input validation failed: {}", tool.name(), validation.errorMessage()); + return CallToolResult.builder() + .content(List.of(new McpSchema.TextContent(validation.errorMessage()))) + .isError(true) + .build(); + } + return null; + } + +} diff --git a/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolInputValidatorTests.java b/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolInputValidatorTests.java new file mode 100644 index 000000000..87db708f9 --- /dev/null +++ b/mcp-core/src/test/java/io/modelcontextprotocol/util/ToolInputValidatorTests.java @@ -0,0 +1,110 @@ +/* + * Copyright 2026-2026 the original author or authors. + */ + +package io.modelcontextprotocol.util; + +import java.util.List; +import java.util.Map; + +import io.modelcontextprotocol.json.McpJsonMapper; +import io.modelcontextprotocol.json.TypeRef; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator; +import io.modelcontextprotocol.json.schema.JsonSchemaValidator.ValidationResponse; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link ToolInputValidator}. + */ +class ToolInputValidatorTests { + + private final McpJsonMapper jsonMapper = mock(McpJsonMapper.class); + + private final JsonSchemaValidator validator = mock(JsonSchemaValidator.class); + + private final McpSchema.JsonSchema inputSchema = new McpSchema.JsonSchema("object", + Map.of("name", Map.of("type", "string")), List.of("name"), null, null, null); + + private final Tool toolWithSchema = Tool.builder() + .name("test-tool") + .description("Test tool") + .inputSchema(inputSchema) + .build(); + + private final Tool toolWithoutSchema = Tool.builder().name("test-tool").description("Test tool").build(); + + @Test + void validate_whenDisabled_returnsNull() { + CallToolResult result = ToolInputValidator.validate(toolWithSchema, Map.of("name", "test"), false, jsonMapper, + validator); + + assertThat(result).isNull(); + verify(validator, never()).validate(any(), any()); + } + + @Test + void validate_whenNoSchema_returnsNull() { + CallToolResult result = ToolInputValidator.validate(toolWithoutSchema, Map.of("name", "test"), true, jsonMapper, + validator); + + assertThat(result).isNull(); + verify(validator, never()).validate(any(), any()); + } + + @Test + void validate_whenNoValidator_returnsNull() { + CallToolResult result = ToolInputValidator.validate(toolWithSchema, Map.of("name", "test"), true, jsonMapper, + null); + + assertThat(result).isNull(); + } + + @Test + @SuppressWarnings("unchecked") + void validate_withValidInput_returnsNull() { + when(jsonMapper.convertValue(any(), any(TypeRef.class))).thenReturn(Map.of("type", "object")); + when(validator.validate(any(), any())).thenReturn(ValidationResponse.asValid(null)); + + CallToolResult result = ToolInputValidator.validate(toolWithSchema, Map.of("name", "test"), true, jsonMapper, + validator); + + assertThat(result).isNull(); + } + + @Test + @SuppressWarnings("unchecked") + void validate_withInvalidInput_returnsErrorResult() { + when(jsonMapper.convertValue(any(), any(TypeRef.class))).thenReturn(Map.of("type", "object")); + when(validator.validate(any(), any())).thenReturn(ValidationResponse.asInvalid("missing required: 'name'")); + + CallToolResult result = ToolInputValidator.validate(toolWithSchema, Map.of(), true, jsonMapper, validator); + + assertThat(result).isNotNull(); + assertThat(result.isError()).isTrue(); + assertThat(((TextContent) result.content().get(0)).text()).contains("missing required: 'name'"); + } + + @Test + @SuppressWarnings("unchecked") + void validate_withNullArguments_usesEmptyMap() { + when(jsonMapper.convertValue(any(), any(TypeRef.class))).thenReturn(Map.of("type", "object")); + when(validator.validate(any(), any())).thenReturn(ValidationResponse.asValid(null)); + + CallToolResult result = ToolInputValidator.validate(toolWithSchema, null, true, jsonMapper, validator); + + assertThat(result).isNull(); + verify(validator).validate(any(), any()); + } + +} diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/server/ToolInputValidationIntegrationTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/ToolInputValidationIntegrationTests.java new file mode 100644 index 000000000..bf349405e --- /dev/null +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/ToolInputValidationIntegrationTests.java @@ -0,0 +1,270 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.common.McpTransportContext; +import io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider; +import io.modelcontextprotocol.server.transport.TomcatTestUtil; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.TextContent; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import jakarta.servlet.http.HttpServletRequest; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.publisher.Mono; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for tool input validation against JSON schema. Validates that input validation + * errors are returned as Tool Execution Errors (isError=true) rather than Protocol + * Errors, per MCP specification. + * + * @author Alireza Khoram + */ +@Timeout(15) +class ToolInputValidationIntegrationTests extends AbstractMcpClientServerIntegrationTests { + + private static final int PORT = TomcatTestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private static final String TOOL_NAME = "test-tool"; + + private static final McpSchema.JsonSchema INPUT_SCHEMA = new McpSchema.JsonSchema("object", + Map.of("name", Map.of("type", "string"), "age", Map.of("type", "integer", "minimum", 0)), + List.of("name", "age"), null, null, null); + + private static final McpTransportContextExtractor TEST_CONTEXT_EXTRACTOR = ( + r) -> McpTransportContext.create(Map.of("important", "value")); + + private HttpServletStreamableServerTransportProvider mcpServerTransportProvider; + + private Tomcat tomcat; + + static Stream clientsForTesting() { + return Stream.of(Arguments.of("httpclient")); + } + + static Stream validInputTestCases() { + return Stream.of( + // serverType, validationEnabled, inputArgs, expectedOutput + Arguments.of("sync", true, Map.of("name", "Alice", "age", 30), "Hello Alice, age 30"), + Arguments.of("async", true, Map.of("name", "Bob", "age", 25), "Hello Bob, age 25"), + Arguments.of("sync", false, Map.of("name", "Alice", "age", 30), "Hello Alice, age 30"), + Arguments.of("async", false, Map.of("name", "Bob", "age", 25), "Hello Bob, age 25")); + } + + static Stream invalidInputTestCases() { + return Stream.of( + // serverType, inputArgs, expectedErrorSubstring + Arguments.of("sync", Map.of("name", "Alice"), "age"), // missing required + Arguments.of("async", Map.of("name", "Bob", "age", -10), "minimum")); // invalid + // value + } + + @BeforeEach + public void before() { + mcpServerTransportProvider = HttpServletStreamableServerTransportProvider.builder() + .mcpEndpoint(MESSAGE_ENDPOINT) + .contextExtractor(TEST_CONTEXT_EXTRACTOR) + .build(); + + tomcat = TomcatTestUtil.createTomcatServer("", PORT, mcpServerTransportProvider); + try { + tomcat.start(); + assertThat(tomcat.getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(MESSAGE_ENDPOINT) + .build()).requestTimeout(Duration.ofSeconds(10))); + } + + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpServerTransportProvider); + } + + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpServerTransportProvider); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcat != null) { + try { + tomcat.stop(); + tomcat.destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + } + + private McpServerFeatures.SyncToolSpecification createSyncTool() { + Tool tool = Tool.builder() + .name(TOOL_NAME) + .description("Test tool with schema") + .inputSchema(INPUT_SCHEMA) + .build(); + + return McpServerFeatures.SyncToolSpecification.builder().tool(tool).callHandler((exchange, request) -> { + String name = (String) request.arguments().get("name"); + Integer age = ((Number) request.arguments().get("age")).intValue(); + return CallToolResult.builder() + .content(List.of(new TextContent("Hello " + name + ", age " + age))) + .isError(false) + .build(); + }).build(); + } + + private McpServerFeatures.AsyncToolSpecification createAsyncTool() { + Tool tool = Tool.builder() + .name(TOOL_NAME) + .description("Test tool with schema") + .inputSchema(INPUT_SCHEMA) + .build(); + + return McpServerFeatures.AsyncToolSpecification.builder().tool(tool).callHandler((exchange, request) -> { + String name = (String) request.arguments().get("name"); + Integer age = ((Number) request.arguments().get("age")).intValue(); + return Mono.just(CallToolResult.builder() + .content(List.of(new TextContent("Hello " + name + ", age " + age))) + .isError(false) + .build()); + }).build(); + } + + @ParameterizedTest(name = "{0} server, validation={1}") + @MethodSource("validInputTestCases") + void validInput_shouldSucceed(String serverType, boolean validationEnabled, Map input, + String expectedOutput) { + var clientBuilder = clientBuilders.get("httpclient"); + Object server = createServer(serverType, validationEnabled); + + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("test-client", "1.0.0")).build()) { + client.initialize(); + CallToolResult result = client.callTool(new CallToolRequest(TOOL_NAME, input)); + + assertThat(result.isError()).isFalse(); + assertThat(((TextContent) result.content().get(0)).text()).isEqualTo(expectedOutput); + } + finally { + closeServer(server, serverType); + } + } + + @ParameterizedTest(name = "{0} server, input={1}") + @MethodSource("invalidInputTestCases") + void invalidInput_withDefaultValidation_shouldReturnToolError(String serverType, Map input, + String expectedErrorSubstring) { + var clientBuilder = clientBuilders.get("httpclient"); + Object server = createServerWithDefaultValidation(serverType); + + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("test-client", "1.0.0")).build()) { + client.initialize(); + CallToolResult result = client.callTool(new CallToolRequest(TOOL_NAME, input)); + + assertThat(result.isError()).isTrue(); + String errorMessage = ((TextContent) result.content().get(0)).text(); + assertThat(errorMessage).containsIgnoringCase(expectedErrorSubstring); + } + finally { + closeServer(server, serverType); + } + } + + @ParameterizedTest(name = "{0} server, input={1}") + @MethodSource("invalidInputTestCases") + void invalidInput_withValidationDisabled_shouldSucceed(String serverType, Map input, + String ignored) { + var clientBuilder = clientBuilders.get("httpclient"); + Object server = createServer(serverType, false); + + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("test-client", "1.0.0")).build()) { + client.initialize(); + // Invalid input should pass through when validation is disabled + // The tool handler will fail, but that's expected - we're testing validation + // is skipped + try { + client.callTool(new CallToolRequest(TOOL_NAME, input)); + } + catch (Exception e) { + // Expected - tool handler fails on invalid input, but validation didn't + // block it + assertThat(e.getMessage()).doesNotContainIgnoringCase("validation"); + } + } + finally { + closeServer(server, serverType); + } + } + + private Object createServerWithDefaultValidation(String serverType) { + if ("sync".equals(serverType)) { + return prepareSyncServerBuilder().serverInfo("test-server", "1.0.0").tools(createSyncTool()).build(); + } + else { + return prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(createAsyncTool()).build(); + } + } + + private Object createServer(String serverType, boolean validationEnabled) { + if ("sync".equals(serverType)) { + return prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .validateToolInputs(validationEnabled) + .tools(createSyncTool()) + .build(); + } + else { + return prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .validateToolInputs(validationEnabled) + .tools(createAsyncTool()) + .build(); + } + } + + private void closeServer(Object server, String serverType) { + if ("async".equals(serverType)) { + ((McpAsyncServer) server).closeGracefully().block(); + } + else { + ((McpSyncServer) server).close(); + } + } + +}