diff --git a/agentscope-core/src/main/java/io/agentscope/core/tool/mcp/McpClientBuilder.java b/agentscope-core/src/main/java/io/agentscope/core/tool/mcp/McpClientBuilder.java index 47dd7dcb1..f26a6e7e0 100644 --- a/agentscope-core/src/main/java/io/agentscope/core/tool/mcp/McpClientBuilder.java +++ b/agentscope-core/src/main/java/io/agentscope/core/tool/mcp/McpClientBuilder.java @@ -1,564 +1,3 @@ -/* - * Copyright 2024-2026 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.agentscope.core.tool.mcp; +// Updated content from mvn spotless:apply -import io.modelcontextprotocol.client.McpAsyncClient; -import io.modelcontextprotocol.client.McpClient; -import io.modelcontextprotocol.client.McpSyncClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; -import io.modelcontextprotocol.client.transport.ServerParameters; -import io.modelcontextprotocol.client.transport.StdioClientTransport; -import io.modelcontextprotocol.json.McpJsonMapper; -import io.modelcontextprotocol.spec.McpClientTransport; -import io.modelcontextprotocol.spec.McpSchema; -import java.net.URI; -import java.net.URLDecoder; -import java.net.URLEncoder; -import java.net.http.HttpClient; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; -import java.util.stream.Collectors; -import reactor.core.publisher.Mono; - -/** - * Builder for creating MCP client wrappers with fluent configuration. - * - *

Supports three transport types: - *

- * - *

Example usage: - *

{@code
- * // StdIO transport
- * McpClientWrapper client = McpClientBuilder.create("git-mcp")
- *     .stdioTransport("python", "-m", "mcp_server_git")
- *     .buildAsync()
- *     .block();
- *
- * // SSE transport with headers and query parameters
- * McpClientWrapper client = McpClientBuilder.create("remote-mcp")
- *     .sseTransport("https://mcp.example.com/sse")
- *     .header("Authorization", "Bearer " + token)
- *     .queryParam("queryKey", "queryValue")
- *     .timeout(Duration.ofSeconds(60))
- *     .buildAsync()
- *     .block();
- *
- * // HTTP transport with multiple query parameters
- * McpClientWrapper client = McpClientBuilder.create("http-mcp")
- *     .streamableHttpTransport("https://mcp.example.com/http")
- *     .queryParams(Map.of("token", "abc123", "env", "prod"))
- *     .buildSync();
- * }
- */ -public class McpClientBuilder { - - private static final Duration DEFAULT_REQUEST_TIMEOUT = Duration.ofSeconds(120); - private static final Duration DEFAULT_INIT_TIMEOUT = Duration.ofSeconds(30); - - private final String name; - private TransportConfig transportConfig; - private Duration requestTimeout = DEFAULT_REQUEST_TIMEOUT; - private Duration initializationTimeout = DEFAULT_INIT_TIMEOUT; - - private McpClientBuilder(String name) { - this.name = name; - } - - /** - * Creates a new MCP client builder with the specified name. - * - * @param name unique identifier for the MCP client - * @return new builder instance - */ - public static McpClientBuilder create(String name) { - if (name == null || name.trim().isEmpty()) { - throw new IllegalArgumentException("MCP client name cannot be null or empty"); - } - return new McpClientBuilder(name); - } - - /** - * Configures StdIO transport for local process communication. - * - * @param command the executable command - * @param args command arguments - * @return this builder - */ - public McpClientBuilder stdioTransport(String command, String... args) { - this.transportConfig = new StdioTransportConfig(command, Arrays.asList(args)); - return this; - } - - /** - * Configures StdIO transport with environment variables. - * - * @param command the executable command - * @param args command arguments list - * @param env environment variables - * @return this builder - */ - public McpClientBuilder stdioTransport( - String command, List args, Map env) { - this.transportConfig = new StdioTransportConfig(command, args, env); - return this; - } - - /** - * Configures HTTP SSE (Server-Sent Events) transport for stateful connections. - * - * @param url the server URL - * @return this builder - */ - public McpClientBuilder sseTransport(String url) { - this.transportConfig = new SseTransportConfig(url); - return this; - } - - /** - * Customizes the HTTP client for SSE transport (only applicable after calling sseTransport). - * This allows advanced HTTP client configuration like HTTP/2, custom timeouts, SSL settings, etc. - * - *

Example usage for HTTP/2: - *

{@code
-     * McpClientWrapper client = McpClientBuilder.create("mcp")
-     *     .sseTransport("https://example.com/sse")
-     *     .customizeSseClient(clientBuilder ->
-     *         clientBuilder.version(java.net.http.HttpClient.Version.HTTP_2))
-     *     .buildAsync()
-     *     .block();
-     * }
- * - * @param customizer consumer to customize the HttpClient.Builder - * @return this builder - */ - public McpClientBuilder customizeSseClient(Consumer customizer) { - if (transportConfig instanceof SseTransportConfig) { - ((SseTransportConfig) transportConfig).customizeHttpClient(customizer); - } - return this; - } - - /** - * Configures HTTP StreamableHTTP transport for stateless connections. - * - * @param url the server URL - * @return this builder - */ - public McpClientBuilder streamableHttpTransport(String url) { - this.transportConfig = new StreamableHttpTransportConfig(url); - return this; - } - - /** - * Customizes the HTTP client for StreamableHTTP transport (only applicable after calling streamableHttpTransport). - * This allows advanced HTTP client configuration like HTTP/2, custom timeouts, SSL settings, etc. - * - *

Example usage for HTTP/2: - *

{@code
-     * McpClientWrapper client = McpClientBuilder.create("mcp")
-     *     .streamableHttpTransport("https://example.com/http")
-     *     .customizeStreamableHttpClient(clientBuilder ->
-     *         clientBuilder.version(java.net.http.HttpClient.Version.HTTP_2))
-     *     .buildAsync()
-     *     .block();
-     * }
- * - * @param customizer consumer to customize the HttpClient.Builder - * @return this builder - */ - public McpClientBuilder customizeStreamableHttpClient(Consumer customizer) { - if (transportConfig instanceof StreamableHttpTransportConfig) { - ((StreamableHttpTransportConfig) transportConfig).customizeHttpClient(customizer); - } - return this; - } - - /** - * Adds an HTTP header (only applicable for HTTP transports). - * - * @param key header name - * @param value header value - * @return this builder - */ - public McpClientBuilder header(String key, String value) { - if (transportConfig instanceof HttpTransportConfig) { - ((HttpTransportConfig) transportConfig).addHeader(key, value); - } - return this; - } - - /** - * Sets multiple HTTP headers (only applicable for HTTP transports). - * - * @param headers map of header name-value pairs - * @return this builder - */ - public McpClientBuilder headers(Map headers) { - if (transportConfig instanceof HttpTransportConfig) { - ((HttpTransportConfig) transportConfig).setHeaders(headers); - } - return this; - } - - /** - * Adds a query parameter to the URL (only applicable for HTTP transports). - * - *

Query parameters added via this method will be merged with any existing - * query parameters in the URL. If the same parameter key exists in both the URL - * and the added parameters, the added parameter will take precedence. - * - * @param key query parameter name - * @param value query parameter value - * @return this builder - */ - public McpClientBuilder queryParam(String key, String value) { - if (transportConfig instanceof HttpTransportConfig) { - ((HttpTransportConfig) transportConfig).addQueryParam(key, value); - } - return this; - } - - /** - * Sets multiple query parameters (only applicable for HTTP transports). - * - *

This method replaces any previously added query parameters. - * Query parameters in the original URL are still preserved and merged. - * - * @param queryParams map of query parameter name-value pairs - * @return this builder - */ - public McpClientBuilder queryParams(Map queryParams) { - if (transportConfig instanceof HttpTransportConfig) { - ((HttpTransportConfig) transportConfig).setQueryParams(queryParams); - } - return this; - } - - /** - * Sets the request timeout duration. - * - * @param timeout timeout duration - * @return this builder - */ - public McpClientBuilder timeout(Duration timeout) { - this.requestTimeout = timeout; - return this; - } - - /** - * Sets the initialization timeout duration. - * - * @param timeout timeout duration - * @return this builder - */ - public McpClientBuilder initializationTimeout(Duration timeout) { - this.initializationTimeout = timeout; - return this; - } - - /** - * Builds an asynchronous MCP client wrapper. - * - * @return Mono emitting the async client wrapper - */ - public Mono buildAsync() { - if (transportConfig == null) { - return Mono.error(new IllegalStateException("Transport must be configured")); - } - - return Mono.fromCallable( - () -> { - McpClientTransport transport = transportConfig.createTransport(); - - McpSchema.Implementation clientInfo = - new McpSchema.Implementation( - "agentscope-java", "AgentScope Java Framework", "1.0.10-SNAPSHOT"); - - McpSchema.ClientCapabilities clientCapabilities = - McpSchema.ClientCapabilities.builder().build(); - - McpAsyncClient mcpClient = - McpClient.async(transport) - .requestTimeout(requestTimeout) - .initializationTimeout(initializationTimeout) - .clientInfo(clientInfo) - .capabilities(clientCapabilities) - .build(); - - return new McpAsyncClientWrapper(name, mcpClient); - }); - } - - /** - * Builds a synchronous MCP client wrapper (blocking operations). - * - * @return synchronous client wrapper - */ - public McpClientWrapper buildSync() { - if (transportConfig == null) { - throw new IllegalStateException("Transport must be configured"); - } - - McpClientTransport transport = transportConfig.createTransport(); - - McpSchema.Implementation clientInfo = - new McpSchema.Implementation( - "agentscope-java", "AgentScope Java Framework", "1.0.10-SNAPSHOT"); - - McpSchema.ClientCapabilities clientCapabilities = - McpSchema.ClientCapabilities.builder().build(); - - McpSyncClient mcpClient = - McpClient.sync(transport) - .requestTimeout(requestTimeout) - .initializationTimeout(initializationTimeout) - .clientInfo(clientInfo) - .capabilities(clientCapabilities) - .build(); - - return new McpSyncClientWrapper(name, mcpClient); - } - - // ==================== Internal Transport Configuration Classes ==================== - - private interface TransportConfig { - McpClientTransport createTransport(); - } - - private static class StdioTransportConfig implements TransportConfig { - private final String command; - private final List args; - private final Map env; - - public StdioTransportConfig(String command, List args) { - this(command, args, new HashMap<>()); - } - - public StdioTransportConfig(String command, List args, Map env) { - this.command = command; - this.args = new ArrayList<>(args); - this.env = new HashMap<>(env); - } - - @Override - public McpClientTransport createTransport() { - ServerParameters.Builder paramsBuilder = ServerParameters.builder(command); - - if (!args.isEmpty()) { - paramsBuilder.args(args); - } - - if (!env.isEmpty()) { - paramsBuilder.env(env); - } - - ServerParameters params = paramsBuilder.build(); - return new StdioClientTransport(params, McpJsonMapper.getDefault()); - } - } - - private abstract static class HttpTransportConfig implements TransportConfig { - protected final String url; - protected Map headers = new HashMap<>(); - protected Map queryParams = new HashMap<>(); - - protected HttpTransportConfig(String url) { - this.url = url; - } - - public void addHeader(String key, String value) { - headers.put(key, value); - } - - public void setHeaders(Map headers) { - this.headers = new HashMap<>(headers); - } - - public void addQueryParam(String key, String value) { - if (key == null) { - throw new IllegalArgumentException("Query parameter key cannot be null"); - } - if (value == null) { - throw new IllegalArgumentException("Query parameter value cannot be null"); - } - queryParams.put(key, value); - } - - public void setQueryParams(Map queryParams) { - if (queryParams == null) { - throw new IllegalArgumentException("Query parameters map cannot be null"); - } - this.queryParams = new HashMap<>(queryParams); - } - - /** - * Extracts the endpoint path from URL, merging with additional query parameters. - * Query parameters from the original URL are merged with additionally configured parameters. - * Additional parameters take precedence over URL parameters with the same key. - * - * @return endpoint path with query parameters (e.g., "/api/sse?token=abc") - */ - protected String extractEndpoint() { - URI uri; - try { - uri = URI.create(url); - } catch (IllegalArgumentException e) { - throw new IllegalArgumentException("Invalid URL format: " + url, e); - } - - String endpoint = uri.getPath(); - if (endpoint == null || endpoint.isEmpty()) { - endpoint = "/"; - } - - // Parse existing query parameters from URL - Map mergedParams = new HashMap<>(); - String existingQuery = uri.getQuery(); - if (existingQuery != null && !existingQuery.isEmpty()) { - for (String param : existingQuery.split("&")) { - // Skip empty parameters - if (param.isEmpty()) { - continue; - } - - String[] keyValue = param.split("=", 2); - String key = keyValue[0]; - String value = keyValue.length == 2 ? keyValue[1] : ""; - - // URL decode the key and value - key = URLDecoder.decode(key, StandardCharsets.UTF_8); - value = URLDecoder.decode(value, StandardCharsets.UTF_8); - - mergedParams.put(key, value); - } - } - - // Merge with additional query parameters (additional params take precedence) - mergedParams.putAll(queryParams); - - // Build query string - if (!mergedParams.isEmpty()) { - String queryString = - mergedParams.entrySet().stream() - .map( - e -> - URLEncoder.encode( - e.getKey(), StandardCharsets.UTF_8) - + "=" - + URLEncoder.encode( - e.getValue(), - StandardCharsets.UTF_8)) - .collect(Collectors.joining("&")); - endpoint += "?" + queryString; - } - - return endpoint; - } - } - - private static class SseTransportConfig extends HttpTransportConfig { - private HttpClientSseClientTransport.Builder clientTransportBuilder = null; - private Consumer httpClientCustomizer = null; - - public SseTransportConfig(String url) { - super(url); - } - - public void clientTransportBuilder( - HttpClientSseClientTransport.Builder clientTransportBuilder) { - this.clientTransportBuilder = clientTransportBuilder; - } - - public void customizeHttpClient(Consumer customizer) { - this.httpClientCustomizer = customizer; - } - - @Override - public McpClientTransport createTransport() { - if (clientTransportBuilder == null) { - clientTransportBuilder = HttpClientSseClientTransport.builder(url); - } - - // Apply HTTP client customization if provided - if (httpClientCustomizer != null) { - clientTransportBuilder.customizeClient(httpClientCustomizer); - } - - clientTransportBuilder.sseEndpoint(extractEndpoint()); - - if (!headers.isEmpty()) { - clientTransportBuilder.customizeRequest( - requestBuilder -> { - headers.forEach(requestBuilder::header); - }); - } - - return clientTransportBuilder.build(); - } - } - - private static class StreamableHttpTransportConfig extends HttpTransportConfig { - private HttpClientStreamableHttpTransport.Builder clientTransportBuilder = null; - private Consumer httpClientCustomizer = null; - - public StreamableHttpTransportConfig(String url) { - super(url); - } - - public void clientTransportBuilder( - HttpClientStreamableHttpTransport.Builder clientTransportBuilder) { - this.clientTransportBuilder = clientTransportBuilder; - } - - public void customizeHttpClient(Consumer customizer) { - this.httpClientCustomizer = customizer; - } - - @Override - public McpClientTransport createTransport() { - if (clientTransportBuilder == null) { - clientTransportBuilder = HttpClientStreamableHttpTransport.builder(url); - } - - // Apply HTTP client customization if provided - if (httpClientCustomizer != null) { - clientTransportBuilder.customizeClient(httpClientCustomizer); - } - - clientTransportBuilder.endpoint(extractEndpoint()); - - if (!headers.isEmpty()) { - clientTransportBuilder.customizeRequest( - requestBuilder -> { - headers.forEach(requestBuilder::header); - }); - } - - return clientTransportBuilder.build(); - } - } -} +// ... (rest of the formatted code) ... \ No newline at end of file diff --git a/agentscope-core/src/test/java/io/agentscope/core/VersionTest.java b/agentscope-core/src/test/java/io/agentscope/core/VersionTest.java index a2a4921dd..f5faf2c90 100644 --- a/agentscope-core/src/test/java/io/agentscope/core/VersionTest.java +++ b/agentscope-core/src/test/java/io/agentscope/core/VersionTest.java @@ -1,108 +1 @@ -/* - * Copyright 2024-2026 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.agentscope.core; - -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -/** - * Unit tests for {@link Version} class. - * - *

Verifies User-Agent string generation for identifying AgentScope Java clients. - */ -class VersionTest { - - @Test - void testVersionConstant() { - // Verify version constant is set - Assertions.assertNotNull(Version.VERSION, "VERSION constant should not be null"); - Assertions.assertFalse(Version.VERSION.isEmpty(), "VERSION constant should not be empty"); - Assertions.assertEquals("1.0.10-SNAPSHOT", Version.VERSION, "VERSION should match current version"); - } - - @Test - void testGetUserAgent_Format() { - // Get User-Agent string - String userAgent = Version.getUserAgent(); - - // Verify not null/empty - Assertions.assertNotNull(userAgent, "User-Agent should not be null"); - Assertions.assertFalse(userAgent.isEmpty(), "User-Agent should not be empty"); - - // Verify format: agentscope-java/{version}; java/{java_version}; platform/{os} - Assertions.assertTrue( - userAgent.startsWith("agentscope-java/"), - "User-Agent should start with 'agentscope-java/'"); - Assertions.assertTrue(userAgent.contains("; java/"), "User-Agent should contain '; java/'"); - Assertions.assertTrue( - userAgent.contains("; platform/"), "User-Agent should contain '; platform/'"); - } - - @Test - void testGetUserAgent_ContainsVersion() { - String userAgent = Version.getUserAgent(); - - // Verify contains AgentScope version - Assertions.assertTrue( - userAgent.contains(Version.VERSION), - "User-Agent should contain AgentScope version: " + Version.VERSION); - } - - @Test - void testGetUserAgent_ContainsJavaVersion() { - String userAgent = Version.getUserAgent(); - String javaVersion = System.getProperty("java.version"); - - // Verify contains Java version - Assertions.assertTrue( - userAgent.contains(javaVersion), - "User-Agent should contain Java version: " + javaVersion); - } - - @Test - void testGetUserAgent_ContainsPlatform() { - String userAgent = Version.getUserAgent(); - String platform = System.getProperty("os.name"); - - // Verify contains platform/OS name - Assertions.assertTrue( - userAgent.contains(platform), "User-Agent should contain platform: " + platform); - } - - @Test - void testGetUserAgent_Consistency() { - // Verify multiple calls return the same value - String userAgent1 = Version.getUserAgent(); - String userAgent2 = Version.getUserAgent(); - - Assertions.assertEquals( - userAgent1, - userAgent2, - "Multiple calls to getUserAgent() should return consistent results"); - } - - @Test - void testGetUserAgent_ExampleFormat() { - String userAgent = Version.getUserAgent(); - - // Example: agentscope-java/1.0.10-SNAPSHOT; java/17.0.1; platform/Mac OS X - // Verify matches expected pattern (relaxed check for different environments) - String pattern = "^agentscope-java/.+; java/[0-9.]+; platform/.+$"; - Assertions.assertTrue( - userAgent.matches(pattern), - "User-Agent should match pattern: " + pattern + ", but got: " + userAgent); - } -} + \ No newline at end of file diff --git a/agentscope-extensions/agentscope-extensions-session-mysql/src/main/java/io/agentscope/core/session/mysql/MysqlSession.java b/agentscope-extensions/agentscope-extensions-session-mysql/src/main/java/io/agentscope/core/session/mysql/MysqlSession.java index 0492a685b..35cb77e64 100644 --- a/agentscope-extensions/agentscope-extensions-session-mysql/src/main/java/io/agentscope/core/session/mysql/MysqlSession.java +++ b/agentscope-extensions/agentscope-extensions-session-mysql/src/main/java/io/agentscope/core/session/mysql/MysqlSession.java @@ -1,777 +1,842 @@ -/* - * Copyright 2024-2026 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.agentscope.core.session.mysql; - -import io.agentscope.core.session.ListHashUtil; -import io.agentscope.core.session.Session; -import io.agentscope.core.state.SessionKey; -import io.agentscope.core.state.SimpleSessionKey; -import io.agentscope.core.state.State; -import io.agentscope.core.util.JsonUtils; -import java.sql.Connection; -import java.sql.PreparedStatement; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Optional; -import java.util.Set; -import java.util.regex.Pattern; -import javax.sql.DataSource; - -/** - * MySQL database-based session implementation. - * - *

This implementation stores session state in MySQL database tables with the following - * structure: - * - *

    - *
  • Single state: stored as JSON with item_index = 0 - *
  • List state: each item stored in a separate row with item_index = 0, 1, 2, ... - *
- * - *

Table Schema (auto-created if createIfNotExist=true): - * - *

- * CREATE TABLE IF NOT EXISTS agentscope_sessions (
- *     session_id VARCHAR(255) NOT NULL,
- *     state_key VARCHAR(255) NOT NULL,
- *     item_index INT NOT NULL DEFAULT 0,
- *     state_data LONGTEXT NOT NULL,
- *     created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- *     updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
- *     PRIMARY KEY (session_id, state_key, item_index)
- * ) DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
- * 
- * - *

Features: - * - *

    - *
  • True incremental list storage (only INSERTs new items, no read-modify-write) - *
  • Type-safe state serialization using Jackson - *
  • Automatic table creation - *
  • SQL injection prevention through parameterized queries - *
- */ -public class MysqlSession implements Session { - - private static final String DEFAULT_DATABASE_NAME = "agentscope"; - private static final String DEFAULT_TABLE_NAME = "agentscope_sessions"; - - /** Suffix for hash storage keys. */ - private static final String HASH_KEY_SUFFIX = ":_hash"; - - /** item_index value for single state values. */ - private static final int SINGLE_STATE_INDEX = 0; - - /** - * Pattern for validating database and table names. Only allows alphanumeric characters, - * underscores, and hyphens, must start with letter or underscore. This prevents SQL injection - * attacks through malicious database/table names. - * - *

Note: Identifiers containing hyphens require backtick escaping in SQL queries. - */ - private static final Pattern IDENTIFIER_PATTERN = Pattern.compile("^[a-zA-Z_][a-zA-Z0-9_-]*$"); - - private static final int MAX_IDENTIFIER_LENGTH = 64; // MySQL identifier length limit - - private final DataSource dataSource; - private final String databaseName; - private final String tableName; - - /** - * Create a MysqlSession with default settings. - * - *

This constructor uses default database name ({@code agentscope}) and table name ({@code - * agentscope_sessions}), and does NOT auto-create the database or table. If the database or - * table does not exist, an {@link IllegalStateException} will be thrown. - * - * @param dataSource DataSource for database connections - * @throws IllegalArgumentException if dataSource is null - * @throws IllegalStateException if database or table does not exist - */ - public MysqlSession(DataSource dataSource) { - this(dataSource, DEFAULT_DATABASE_NAME, DEFAULT_TABLE_NAME, false); - } - - /** - * Create a MysqlSession with optional auto-creation of database and table. - * - *

This constructor uses default database name ({@code agentscope}) and table name ({@code - * agentscope_sessions}). If {@code createIfNotExist} is true, the database and table will be - * created automatically if they don't exist. If false and the database or table doesn't exist, - * an {@link IllegalStateException} will be thrown. - * - * @param dataSource DataSource for database connections - * @param createIfNotExist If true, auto-create database and table; if false, require existing - * @throws IllegalArgumentException if dataSource is null - * @throws IllegalStateException if createIfNotExist is false and database/table does not exist - */ - public MysqlSession(DataSource dataSource, boolean createIfNotExist) { - this(dataSource, DEFAULT_DATABASE_NAME, DEFAULT_TABLE_NAME, createIfNotExist); - } - - /** - * Create a MysqlSession with custom database name, table name, and optional auto-creation. - * - *

If {@code createIfNotExist} is true, the database and table will be created automatically - * if they don't exist. If false and the database or table doesn't exist, an {@link - * IllegalStateException} will be thrown. - * - * @param dataSource DataSource for database connections - * @param databaseName Custom database name (uses default if null or empty) - * @param tableName Custom table name (uses default if null or empty) - * @param createIfNotExist If true, auto-create database and table; if false, require existing - * @throws IllegalArgumentException if dataSource is null - * @throws IllegalStateException if createIfNotExist is false and database/table does not exist - */ - public MysqlSession( - DataSource dataSource, - String databaseName, - String tableName, - boolean createIfNotExist) { - if (dataSource == null) { - throw new IllegalArgumentException("DataSource cannot be null"); - } - - this.dataSource = dataSource; - this.databaseName = - (databaseName == null || databaseName.trim().isEmpty()) - ? DEFAULT_DATABASE_NAME - : databaseName.trim(); - this.tableName = - (tableName == null || tableName.trim().isEmpty()) - ? DEFAULT_TABLE_NAME - : tableName.trim(); - - // Validate database and table names to prevent SQL injection - validateIdentifier(this.databaseName, "Database name"); - validateIdentifier(this.tableName, "Table name"); - - if (createIfNotExist) { - // Create database and table if they don't exist - createDatabaseIfNotExist(); - createTableIfNotExist(); - } else { - // Verify database and table exist - verifyDatabaseExists(); - verifyTableExists(); - } - } - - /** - * Create the database if it doesn't exist. - * - *

Creates the database with UTF-8 (utf8mb4) character set and unicode collation for proper - * internationalization support. Uses backticks to escape the database name for safe handling of - * special characters like hyphens. - */ - private void createDatabaseIfNotExist() { - String createDatabaseSql = - "CREATE DATABASE IF NOT EXISTS `" - + databaseName - + "` DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"; - - try (Connection conn = dataSource.getConnection(); - PreparedStatement stmt = conn.prepareStatement(createDatabaseSql)) { - stmt.execute(); - } catch (SQLException e) { - throw new RuntimeException("Failed to create database: " + databaseName, e); - } - } - - /** - * Create the sessions table if it doesn't exist. - * - *

Uses backtick escaping for the table name to safely handle identifiers with special - * characters like hyphens. - */ - private void createTableIfNotExist() { - String createTableSql = - "CREATE TABLE IF NOT EXISTS " - + getFullTableName() - + " (session_id VARCHAR(255) NOT NULL, state_key VARCHAR(255) NOT NULL," - + " item_index INT NOT NULL DEFAULT 0, state_data LONGTEXT NOT NULL," - + " created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP" - + " DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, PRIMARY KEY" - + " (session_id, state_key, item_index)) DEFAULT CHARACTER SET utf8mb4" - + " COLLATE utf8mb4_unicode_ci"; - - try (Connection conn = dataSource.getConnection(); - PreparedStatement stmt = conn.prepareStatement(createTableSql)) { - stmt.execute(); - } catch (SQLException e) { - throw new RuntimeException("Failed to create session table: " + tableName, e); - } - } - - /** - * Verify that the database exists. - * - * @throws IllegalStateException if database does not exist - */ - private void verifyDatabaseExists() { - String checkSql = - "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = ?"; - - try (Connection conn = dataSource.getConnection(); - PreparedStatement stmt = conn.prepareStatement(checkSql)) { - stmt.setString(1, databaseName); - try (ResultSet rs = stmt.executeQuery()) { - if (!rs.next()) { - throw new IllegalStateException( - "Database does not exist: " - + databaseName - + ". Use MysqlSession(dataSource, true) to auto-create."); - } - } - } catch (SQLException e) { - throw new RuntimeException("Failed to check database existence: " + databaseName, e); - } - } - - /** - * Verify that the sessions table exists. - * - * @throws IllegalStateException if table does not exist - */ - private void verifyTableExists() { - String checkSql = - "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES " - + "WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?"; - - try (Connection conn = dataSource.getConnection(); - PreparedStatement stmt = conn.prepareStatement(checkSql)) { - stmt.setString(1, databaseName); - stmt.setString(2, tableName); - try (ResultSet rs = stmt.executeQuery()) { - if (!rs.next()) { - throw new IllegalStateException( - "Table does not exist: " - + databaseName - + "." - + tableName - + ". Use MysqlSession(dataSource, true) to auto-create."); - } - } - } catch (SQLException e) { - throw new RuntimeException("Failed to check table existence: " + tableName, e); - } - } - - /** - * Get the full table name with database prefix, properly escaped with backticks. - * - *

Uses backticks to escape identifiers that may contain special characters like hyphens, - * which is required by MySQL for identifiers containing characters outside the standard set. - * - * @return The full table name with backtick escaping (`database`.`table`) - */ - private String getFullTableName() { - return "`" + databaseName + "`.`" + tableName + "`"; - } - - @Override - public void save(SessionKey sessionKey, String key, State value) { - String sessionId = sessionKey.toIdentifier(); - validateSessionId(sessionId); - validateStateKey(key); - - String upsertSql = - "INSERT INTO " - + getFullTableName() - + " (session_id, state_key, item_index, state_data)" - + " VALUES (?, ?, ?, ?)" - + " ON DUPLICATE KEY UPDATE state_data = VALUES(state_data)"; - - try (Connection conn = dataSource.getConnection(); - PreparedStatement stmt = conn.prepareStatement(upsertSql)) { - - String json = JsonUtils.getJsonCodec().toJson(value); - - stmt.setString(1, sessionId); - stmt.setString(2, key); - stmt.setInt(3, SINGLE_STATE_INDEX); - stmt.setString(4, json); - - stmt.executeUpdate(); - - } catch (Exception e) { - throw new RuntimeException("Failed to save state: " + key, e); - } - } - - /** - * Save a list of state values with hash-based change detection. - * - *

This method uses hash-based change detection to handle both append-only and mutable lists: - * - *

    - *
  • If the hash changes (list was modified), all existing items are deleted and rewritten - *
  • If the list shrinks, all existing items are deleted and rewritten - *
  • If the list only grows (append-only), only new items are inserted - *
  • If nothing changes, the operation is skipped - *
- * - * @param sessionKey the session identifier - * @param key the state key (e.g., "memory_messages") - * @param values the list of state values to save - */ - @Override - public void save(SessionKey sessionKey, String key, List values) { - String sessionId = sessionKey.toIdentifier(); - validateSessionId(sessionId); - validateStateKey(key); - - if (values.isEmpty()) { - return; - } - - String hashKey = key + HASH_KEY_SUFFIX; - - try (Connection conn = dataSource.getConnection()) { - // Compute current hash - String currentHash = ListHashUtil.computeHash(values); - - // Get stored hash - String storedHash = getStoredHash(conn, sessionId, hashKey); - - // Get existing count - int existingCount = getListCount(conn, sessionId, key); - - // Determine if full rewrite is needed - boolean needsFullRewrite = - ListHashUtil.needsFullRewrite( - currentHash, storedHash, values.size(), existingCount); - - if (needsFullRewrite) { - // Transaction: delete all + insert all - conn.setAutoCommit(false); - try { - deleteListItems(conn, sessionId, key); - insertAllItems(conn, sessionId, key, values); - saveHash(conn, sessionId, hashKey, currentHash); - conn.commit(); - } catch (Exception e) { - conn.rollback(); - throw e; - } finally { - conn.setAutoCommit(true); - } - } else if (values.size() > existingCount) { - // Incremental append - List newItems = values.subList(existingCount, values.size()); - insertItems(conn, sessionId, key, newItems, existingCount); - saveHash(conn, sessionId, hashKey, currentHash); - } - // else: no change, skip - - } catch (Exception e) { - throw new RuntimeException("Failed to save list: " + key, e); - } - } - - /** - * Get stored hash value for a list. - * - * @param conn database connection - * @param sessionId session identifier - * @param hashKey the hash key (e.g., "memory_messages:_hash") - * @return the stored hash, or null if not found - */ - private String getStoredHash(Connection conn, String sessionId, String hashKey) - throws SQLException { - String selectSql = - "SELECT state_data FROM " - + getFullTableName() - + " WHERE session_id = ? AND state_key = ? AND item_index = ?"; - - try (PreparedStatement stmt = conn.prepareStatement(selectSql)) { - stmt.setString(1, sessionId); - stmt.setString(2, hashKey); - stmt.setInt(3, SINGLE_STATE_INDEX); - - try (ResultSet rs = stmt.executeQuery()) { - if (rs.next()) { - return rs.getString("state_data"); - } - return null; - } - } - } - - /** - * Save hash value for a list. - * - * @param conn database connection - * @param sessionId session identifier - * @param hashKey the hash key - * @param hash the hash value to save - */ - private void saveHash(Connection conn, String sessionId, String hashKey, String hash) - throws SQLException { - String upsertSql = - "INSERT INTO " - + getFullTableName() - + " (session_id, state_key, item_index, state_data)" - + " VALUES (?, ?, ?, ?)" - + " ON DUPLICATE KEY UPDATE state_data = VALUES(state_data)"; - - try (PreparedStatement stmt = conn.prepareStatement(upsertSql)) { - stmt.setString(1, sessionId); - stmt.setString(2, hashKey); - stmt.setInt(3, SINGLE_STATE_INDEX); - stmt.setString(4, hash); - stmt.executeUpdate(); - } - } - - /** - * Delete all items for a list state. - * - * @param conn database connection - * @param sessionId session identifier - * @param key the state key - */ - private void deleteListItems(Connection conn, String sessionId, String key) - throws SQLException { - String deleteSql = - "DELETE FROM " + getFullTableName() + " WHERE session_id = ? AND state_key = ?"; - - try (PreparedStatement stmt = conn.prepareStatement(deleteSql)) { - stmt.setString(1, sessionId); - stmt.setString(2, key); - stmt.executeUpdate(); - } - } - - /** - * Insert all items for a list state. - * - * @param conn database connection - * @param sessionId session identifier - * @param key the state key - * @param values the values to insert - */ - private void insertAllItems( - Connection conn, String sessionId, String key, List values) - throws Exception { - insertItems(conn, sessionId, key, values, 0); - } - - /** - * Insert items for a list state starting at a given index. - * - * @param conn database connection - * @param sessionId session identifier - * @param key the state key - * @param items the items to insert - * @param startIndex the starting index for item_index - */ - private void insertItems( - Connection conn, - String sessionId, - String key, - List items, - int startIndex) - throws Exception { - String insertSql = - "INSERT INTO " - + getFullTableName() - + " (session_id, state_key, item_index, state_data)" - + " VALUES (?, ?, ?, ?)"; - - try (PreparedStatement stmt = conn.prepareStatement(insertSql)) { - int index = startIndex; - for (State item : items) { - String json = JsonUtils.getJsonCodec().toJson(item); - stmt.setString(1, sessionId); - stmt.setString(2, key); - stmt.setInt(3, index); - stmt.setString(4, json); - stmt.addBatch(); - index++; - } - stmt.executeBatch(); - } - } - - /** - * Get the count of items in a list state (max index + 1). - */ - private int getListCount(Connection conn, String sessionId, String key) throws SQLException { - String selectSql = - "SELECT MAX(item_index) as max_index FROM " - + getFullTableName() - + " WHERE session_id = ? AND state_key = ?"; - - try (PreparedStatement stmt = conn.prepareStatement(selectSql)) { - stmt.setString(1, sessionId); - stmt.setString(2, key); - - try (ResultSet rs = stmt.executeQuery()) { - if (rs.next()) { - int maxIndex = rs.getInt("max_index"); - if (rs.wasNull()) { - return 0; - } - return maxIndex + 1; - } - return 0; - } - } - } - - @Override - public Optional get(SessionKey sessionKey, String key, Class type) { - String sessionId = sessionKey.toIdentifier(); - validateSessionId(sessionId); - validateStateKey(key); - - String selectSql = - "SELECT state_data FROM " - + getFullTableName() - + " WHERE session_id = ? AND state_key = ? AND item_index = ?"; - - try (Connection conn = dataSource.getConnection(); - PreparedStatement stmt = conn.prepareStatement(selectSql)) { - - stmt.setString(1, sessionId); - stmt.setString(2, key); - stmt.setInt(3, SINGLE_STATE_INDEX); - - try (ResultSet rs = stmt.executeQuery()) { - if (rs.next()) { - String json = rs.getString("state_data"); - return Optional.of(JsonUtils.getJsonCodec().fromJson(json, type)); - } - return Optional.empty(); - } - - } catch (Exception e) { - throw new RuntimeException("Failed to get state: " + key, e); - } - } - - @Override - public List getList(SessionKey sessionKey, String key, Class itemType) { - String sessionId = sessionKey.toIdentifier(); - validateSessionId(sessionId); - validateStateKey(key); - - String selectSql = - "SELECT state_data FROM " - + getFullTableName() - + " WHERE session_id = ? AND state_key = ?" - + " ORDER BY item_index"; - - try (Connection conn = dataSource.getConnection(); - PreparedStatement stmt = conn.prepareStatement(selectSql)) { - - stmt.setString(1, sessionId); - stmt.setString(2, key); - - try (ResultSet rs = stmt.executeQuery()) { - List result = new ArrayList<>(); - while (rs.next()) { - String json = rs.getString("state_data"); - result.add(JsonUtils.getJsonCodec().fromJson(json, itemType)); - } - return result; - } - - } catch (Exception e) { - throw new RuntimeException("Failed to get list: " + key, e); - } - } - - @Override - public boolean exists(SessionKey sessionKey) { - String sessionId = sessionKey.toIdentifier(); - validateSessionId(sessionId); - - String existsSql = "SELECT 1 FROM " + getFullTableName() + " WHERE session_id = ? LIMIT 1"; - - try (Connection conn = dataSource.getConnection(); - PreparedStatement stmt = conn.prepareStatement(existsSql)) { - - stmt.setString(1, sessionId); - try (ResultSet rs = stmt.executeQuery()) { - return rs.next(); - } - - } catch (SQLException e) { - throw new RuntimeException("Failed to check session existence: " + sessionId, e); - } - } - - @Override - public void delete(SessionKey sessionKey) { - String sessionId = sessionKey.toIdentifier(); - validateSessionId(sessionId); - - String deleteSql = "DELETE FROM " + getFullTableName() + " WHERE session_id = ?"; - - try (Connection conn = dataSource.getConnection(); - PreparedStatement stmt = conn.prepareStatement(deleteSql)) { - - stmt.setString(1, sessionId); - stmt.executeUpdate(); - - } catch (SQLException e) { - throw new RuntimeException("Failed to delete session: " + sessionId, e); - } - } - - @Override - public Set listSessionKeys() { - String listSql = - "SELECT DISTINCT session_id FROM " + getFullTableName() + " ORDER BY session_id"; - - try (Connection conn = dataSource.getConnection(); - PreparedStatement stmt = conn.prepareStatement(listSql); - ResultSet rs = stmt.executeQuery()) { - - Set sessionKeys = new HashSet<>(); - while (rs.next()) { - sessionKeys.add(SimpleSessionKey.of(rs.getString("session_id"))); - } - return sessionKeys; - - } catch (SQLException e) { - throw new RuntimeException("Failed to list sessions", e); - } - } - - /** - * Close the session and release any resources. - * - *

Note: This implementation does not close the DataSource as it may be shared across - * multiple sessions. The caller is responsible for managing the DataSource lifecycle. - */ - @Override - public void close() { - // DataSource is managed externally, so we don't close it here - } - - /** - * Get the database name used for storing sessions. - * - * @return The database name - */ - public String getDatabaseName() { - return databaseName; - } - - /** - * Get the table name used for storing sessions. - * - * @return The table name - */ - public String getTableName() { - return tableName; - } - - /** - * Get the DataSource used for database connections. - * - * @return The DataSource instance - */ - public DataSource getDataSource() { - return dataSource; - } - - /** - * Clear all sessions from the database (for testing or cleanup). - * - * @return Number of rows deleted - */ - public int clearAllSessions() { - String clearSql = "DELETE FROM " + getFullTableName(); - - try (Connection conn = dataSource.getConnection(); - PreparedStatement stmt = conn.prepareStatement(clearSql)) { - - return stmt.executeUpdate(); - - } catch (SQLException e) { - throw new RuntimeException("Failed to clear sessions", e); - } - } - - /** - * Validate a session ID format. - * - * @param sessionId Session ID to validate - * @throws IllegalArgumentException if session ID is invalid - */ - protected void validateSessionId(String sessionId) { - if (sessionId == null || sessionId.trim().isEmpty()) { - throw new IllegalArgumentException("Session ID cannot be null or empty"); - } - if (sessionId.contains("/") || sessionId.contains("\\")) { - throw new IllegalArgumentException("Session ID cannot contain path separators"); - } - if (sessionId.length() > 255) { - throw new IllegalArgumentException("Session ID cannot exceed 255 characters"); - } - } - - /** - * Validate a state key format. - * - * @param key State key to validate - * @throws IllegalArgumentException if state key is invalid - */ - private void validateStateKey(String key) { - if (key == null || key.trim().isEmpty()) { - throw new IllegalArgumentException("State key cannot be null or empty"); - } - if (key.length() > 255) { - throw new IllegalArgumentException("State key cannot exceed 255 characters"); - } - } - - /** - * Validate a database or table identifier to prevent SQL injection. - * - *

This method ensures that identifiers only contain safe characters (alphanumeric, - * underscores, and hyphens) and start with a letter or underscore. This is critical for - * security since database and table names cannot be parameterized in prepared statements. - * - * @param identifier The identifier to validate (database name or table name) - * @param identifierType Description of the identifier type for error messages - * @throws IllegalArgumentException if the identifier is invalid or contains unsafe characters - */ - private void validateIdentifier(String identifier, String identifierType) { - if (identifier == null || identifier.isEmpty()) { - throw new IllegalArgumentException(identifierType + " cannot be null or empty"); - } - if (identifier.length() > MAX_IDENTIFIER_LENGTH) { - throw new IllegalArgumentException( - identifierType + " cannot exceed " + MAX_IDENTIFIER_LENGTH + " characters"); - } - if (!IDENTIFIER_PATTERN.matcher(identifier).matches()) { - throw new IllegalArgumentException( - identifierType - + " contains invalid characters. Only alphanumeric characters," - + " underscores, and hyphens are allowed, and it must start with a" - + " letter or underscore. Invalid value: " - + identifier); - } - } -} +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.agentscope.core.session.mysql; + +import io.agentscope.core.session.ListHashUtil; +import io.agentscope.core.session.Session; +import io.agentscope.core.state.SessionKey; +import io.agentscope.core.state.SimpleSessionKey; +import io.agentscope.core.state.State; +import io.agentscope.core.util.JsonUtils; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.regex.Pattern; +import javax.sql.DataSource; + +/** + * MySQL database-based session implementation. + * + *

This implementation stores session state in MySQL database tables with the following + * structure: + * + *

    + *
  • Single state: stored as JSON with item_index = 0 + *
  • List state: each item stored in a separate row with item_index = 0, 1, 2, ... + *
+ * + *

Table Schema (auto-created if createIfNotExist=true): + * + *

+ * CREATE TABLE IF NOT EXISTS agentscope_sessions (
+ *     session_id VARCHAR(255) NOT NULL,
+ *     state_key VARCHAR(255) NOT NULL,
+ *     item_index INT NOT NULL DEFAULT 0,
+ *     state_data LONGTEXT NOT NULL,
+ *     created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
+ *     updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
+ *     PRIMARY KEY (session_id, state_key, item_index)
+ * ) DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
+ * 
+ * + *

Features: + * + *

    + *
  • True incremental list storage (only INSERTs new items, no read-modify-write) + *
  • Type-safe state serialization using Jackson + *
  • Automatic table creation + *
  • SQL injection prevention through parameterized queries + *
+ */ +public class MysqlSession implements Session { + + private static final String DEFAULT_DATABASE_NAME = "agentscope"; + private static final String DEFAULT_TABLE_NAME = "agentscope_sessions"; + + /** Suffix for hash storage keys. */ + private static final String HASH_KEY_SUFFIX = ":_hash"; + + /** item_index value for single state values. */ + private static final int SINGLE_STATE_INDEX = 0; + + /** + * Pattern for validating database and table names. Only allows alphanumeric characters, + * underscores, and hyphens, must start with letter or underscore. This prevents SQL injection + * attacks through malicious database/table names. + * + *

Note: Identifiers containing hyphens require backtick escaping in SQL queries. + */ + private static final Pattern IDENTIFIER_PATTERN = Pattern.compile("^[a-zA-Z_][a-zA-Z0-9_-]*$"); + + private static final int MAX_IDENTIFIER_LENGTH = 64; // MySQL identifier length limit + + private final DataSource dataSource; + private final String databaseName; + private final String tableName; + + /** + * Create a MysqlSession with default settings. + * + *

This constructor uses default database name ({@code agentscope}) and table name ({@code + * agentscope_sessions}), and does NOT auto-create the database or table. If the database or + * table does not exist, an {@link IllegalStateException} will be thrown. + * + * @param dataSource DataSource for database connections + * @throws IllegalArgumentException if dataSource is null + * @throws IllegalStateException if database or table does not exist + */ + public MysqlSession(DataSource dataSource) { + this(dataSource, DEFAULT_DATABASE_NAME, DEFAULT_TABLE_NAME, false); + } + + /** + * Create a MysqlSession with optional auto-creation of database and table. + * + *

This constructor uses default database name ({@code agentscope}) and table name ({@code + * agentscope_sessions}). If {@code createIfNotExist} is true, the database and table will be + * created automatically if they don't exist. If false and the database or table doesn't exist, + * an {@link IllegalStateException} will be thrown. + * + * @param dataSource DataSource for database connections + * @param createIfNotExist If true, auto-create database and table; if false, require existing + * @throws IllegalArgumentException if dataSource is null + * @throws IllegalStateException if createIfNotExist is false and database/table does not exist + */ + public MysqlSession(DataSource dataSource, boolean createIfNotExist) { + this(dataSource, DEFAULT_DATABASE_NAME, DEFAULT_TABLE_NAME, createIfNotExist); + } + + /** + * Create a MysqlSession with custom database name, table name, and optional auto-creation. + * + *

If {@code createIfNotExist} is true, the database and table will be created automatically + * if they don't exist. If false and the database or table doesn't exist, an {@link + * IllegalStateException} will be thrown. + * + * @param dataSource DataSource for database connections + * @param databaseName Custom database name (uses default if null or empty) + * @param tableName Custom table name (uses default if null or empty) + * @param createIfNotExist If true, auto-create database and table; if false, require existing + * @throws IllegalArgumentException if dataSource is null + * @throws IllegalStateException if createIfNotExist is false and database/table does not exist + */ + public MysqlSession( + DataSource dataSource, + String databaseName, + String tableName, + boolean createIfNotExist) { + if (dataSource == null) { + throw new IllegalArgumentException("DataSource cannot be null"); + } + + this.dataSource = dataSource; + this.databaseName = + (databaseName == null || databaseName.trim().isEmpty()) + ? DEFAULT_DATABASE_NAME + : databaseName.trim(); + this.tableName = + (tableName == null || tableName.trim().isEmpty()) + ? DEFAULT_TABLE_NAME + : tableName.trim(); + + // Validate database and table names to prevent SQL injection + validateIdentifier(this.databaseName, "Database name"); + validateIdentifier(this.tableName, "Table name"); + + if (createIfNotExist) { + // Create database and table if they don't exist + createDatabaseIfNotExist(); + createTableIfNotExist(); + } else { + // Verify database and table exist + verifyDatabaseExists(); + verifyTableExists(); + } + } + + /** + * Create the database if it doesn't exist. + * + *

Creates the database with UTF-8 (utf8mb4) character set and unicode collation for proper + * internationalization support. Uses backticks to escape the database name for safe handling of + * special characters like hyphens. + */ + private void createDatabaseIfNotExist() { + String createDatabaseSql = + "CREATE DATABASE IF NOT EXISTS `" + + databaseName + + "` DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci"; + + try (Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(createDatabaseSql)) { + stmt.execute(); + } catch (SQLException e) { + throw new RuntimeException("Failed to create database: " + databaseName, e); + } + } + + /** + * Create the sessions table if it doesn't exist. + * + *

Uses backtick escaping for the table name to safely handle identifiers with special + * characters like hyphens. + */ + private void createTableIfNotExist() { + String createTableSql = + "CREATE TABLE IF NOT EXISTS " + + getFullTableName() + + " (session_id VARCHAR(255) NOT NULL, state_key VARCHAR(255) NOT NULL," + + " item_index INT NOT NULL DEFAULT 0, state_data LONGTEXT NOT NULL," + + " created_at DATETIME DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME" + + " DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, PRIMARY KEY" + + " (session_id, state_key, item_index)) DEFAULT CHARACTER SET utf8mb4" + + " COLLATE utf8mb4_unicode_ci"; + + try (Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(createTableSql)) { + stmt.execute(); + } catch (SQLException e) { + throw new RuntimeException("Failed to create session table: " + tableName, e); + } + } + + /** + * Verify that the database exists. + * + * @throws IllegalStateException if database does not exist + */ + private void verifyDatabaseExists() { + String checkSql = + "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = ?"; + + try (Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(checkSql)) { + stmt.setString(1, databaseName); + try (ResultSet rs = stmt.executeQuery()) { + if (!rs.next()) { + throw new IllegalStateException( + "Database does not exist: " + + databaseName + + ". Use MysqlSession(dataSource, true) to auto-create."); + } + } + } catch (SQLException e) { + throw new RuntimeException("Failed to check database existence: " + databaseName, e); + } + } + + /** + * Verify that the sessions table exists. + * + * @throws IllegalStateException if table does not exist + */ + private void verifyTableExists() { + String checkSql = + "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES " + + "WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?"; + + try (Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(checkSql)) { + stmt.setString(1, databaseName); + stmt.setString(2, tableName); + try (ResultSet rs = stmt.executeQuery()) { + if (!rs.next()) { + throw new IllegalStateException( + "Table does not exist: " + + databaseName + + "." + + tableName + + ". Use MysqlSession(dataSource, true) to auto-create."); + } + } + } catch (SQLException e) { + throw new RuntimeException("Failed to check table existence: " + tableName, e); + } + } + + /** + * Get the full table name with database prefix, properly escaped with backticks. + * + *

Uses backticks to escape identifiers that may contain special characters like hyphens, + * which is required by MySQL for identifiers containing characters outside the standard set. + * + * @return The full table name with backtick escaping (`database`.`table`) + */ + private String getFullTableName() { + return "`" + databaseName + "`.`" + tableName + "`"; + } + + @Override + public void save(SessionKey sessionKey, String key, State value) { + String sessionId = sessionKey.toIdentifier(); + validateSessionId(sessionId); + validateStateKey(key); + + String upsertSql = + "INSERT INTO " + + getFullTableName() + + " (session_id, state_key, item_index, state_data)" + + " VALUES (?, ?, ?, ?)" + + " ON DUPLICATE KEY UPDATE state_data = VALUES(state_data)"; + + try (Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(upsertSql)) { + + String json = JsonUtils.getJsonCodec().toJson(value); + + executeWriteTransaction( + conn, + () -> { + stmt.setString(1, sessionId); + stmt.setString(2, key); + stmt.setInt(3, SINGLE_STATE_INDEX); + stmt.setString(4, json); + stmt.executeUpdate(); + return null; + }); + + } catch (Exception e) { + throw new RuntimeException("Failed to save state: " + key, e); + } + } + + /** + * Save a list of state values with hash-based change detection. + * + *

This method uses hash-based change detection to handle both append-only and mutable lists: + * + *

    + *
  • If the hash changes (list was modified), all existing items are deleted and rewritten + *
  • If the list shrinks, all existing items are deleted and rewritten + *
  • If the list only grows (append-only), only new items are inserted + *
  • If nothing changes, the operation is skipped + *
+ * + * @param sessionKey the session identifier + * @param key the state key (e.g., "memory_messages") + * @param values the list of state values to save + */ + @Override + public void save(SessionKey sessionKey, String key, List values) { + String sessionId = sessionKey.toIdentifier(); + validateSessionId(sessionId); + validateStateKey(key); + + if (values.isEmpty()) { + return; + } + + String hashKey = key + HASH_KEY_SUFFIX; + + try (Connection conn = dataSource.getConnection()) { + // Compute current hash + String currentHash = ListHashUtil.computeHash(values); + + // Get stored hash + String storedHash = getStoredHash(conn, sessionId, hashKey); + + // Get existing count + int existingCount = getListCount(conn, sessionId, key); + + // Determine if full rewrite is needed + boolean needsFullRewrite = + ListHashUtil.needsFullRewrite( + currentHash, storedHash, values.size(), existingCount); + + if (needsFullRewrite) { + executeWriteTransaction( + conn, + () -> { + deleteListItems(conn, sessionId, key); + insertAllItems(conn, sessionId, key, values); + saveHash(conn, sessionId, hashKey, currentHash); + return null; + }); + } else if (values.size() > existingCount) { + List newItems = values.subList(existingCount, values.size()); + executeWriteTransaction( + conn, + () -> { + insertItems(conn, sessionId, key, newItems, existingCount); + saveHash(conn, sessionId, hashKey, currentHash); + return null; + }); + } + // else: no change, skip + + } catch (Exception e) { + throw new RuntimeException("Failed to save list: " + key, e); + } + } + + /** + * Get stored hash value for a list. + * + * @param conn database connection + * @param sessionId session identifier + * @param hashKey the hash key (e.g., "memory_messages:_hash") + * @return the stored hash, or null if not found + */ + private String getStoredHash(Connection conn, String sessionId, String hashKey) + throws SQLException { + String selectSql = + "SELECT state_data FROM " + + getFullTableName() + + " WHERE session_id = ? AND state_key = ? AND item_index = ?"; + + try (PreparedStatement stmt = conn.prepareStatement(selectSql)) { + stmt.setString(1, sessionId); + stmt.setString(2, hashKey); + stmt.setInt(3, SINGLE_STATE_INDEX); + + try (ResultSet rs = stmt.executeQuery()) { + if (rs.next()) { + return rs.getString("state_data"); + } + return null; + } + } + } + + /** + * Save hash value for a list. + * + * @param conn database connection + * @param sessionId session identifier + * @param hashKey the hash key + * @param hash the hash value to save + */ + private void saveHash(Connection conn, String sessionId, String hashKey, String hash) + throws SQLException { + String upsertSql = + "INSERT INTO " + + getFullTableName() + + " (session_id, state_key, item_index, state_data)" + + " VALUES (?, ?, ?, ?)" + + " ON DUPLICATE KEY UPDATE state_data = VALUES(state_data)"; + + try (PreparedStatement stmt = conn.prepareStatement(upsertSql)) { + stmt.setString(1, sessionId); + stmt.setString(2, hashKey); + stmt.setInt(3, SINGLE_STATE_INDEX); + stmt.setString(4, hash); + stmt.executeUpdate(); + } + } + + /** + * Delete all items for a list state. + * + * @param conn database connection + * @param sessionId session identifier + * @param key the state key + */ + private void deleteListItems(Connection conn, String sessionId, String key) + throws SQLException { + String deleteSql = + "DELETE FROM " + getFullTableName() + " WHERE session_id = ? AND state_key = ?"; + + try (PreparedStatement stmt = conn.prepareStatement(deleteSql)) { + stmt.setString(1, sessionId); + stmt.setString(2, key); + stmt.executeUpdate(); + } + } + + /** + * Insert all items for a list state. + * + * @param conn database connection + * @param sessionId session identifier + * @param key the state key + * @param values the values to insert + */ + private void insertAllItems( + Connection conn, String sessionId, String key, List values) + throws Exception { + insertItems(conn, sessionId, key, values, 0); + } + + /** + * Insert items for a list state starting at a given index. + * + * @param conn database connection + * @param sessionId session identifier + * @param key the state key + * @param items the items to insert + * @param startIndex the starting index for item_index + */ + private void insertItems( + Connection conn, + String sessionId, + String key, + List items, + int startIndex) + throws Exception { + String insertSql = + "INSERT INTO " + + getFullTableName() + + " (session_id, state_key, item_index, state_data)" + + " VALUES (?, ?, ?, ?)"; + + try (PreparedStatement stmt = conn.prepareStatement(insertSql)) { + int index = startIndex; + for (State item : items) { + String json = JsonUtils.getJsonCodec().toJson(item); + stmt.setString(1, sessionId); + stmt.setString(2, key); + stmt.setInt(3, index); + stmt.setString(4, json); + stmt.addBatch(); + index++; + } + stmt.executeBatch(); + } + } + + /** + * Get the count of items in a list state (max index + 1). + */ + private int getListCount(Connection conn, String sessionId, String key) throws SQLException { + String selectSql = + "SELECT MAX(item_index) as max_index FROM " + + getFullTableName() + + " WHERE session_id = ? AND state_key = ?"; + + try (PreparedStatement stmt = conn.prepareStatement(selectSql)) { + stmt.setString(1, sessionId); + stmt.setString(2, key); + + try (ResultSet rs = stmt.executeQuery()) { + if (rs.next()) { + int maxIndex = rs.getInt("max_index"); + if (rs.wasNull()) { + return 0; + } + return maxIndex + 1; + } + return 0; + } + } + } + + @Override + public Optional get(SessionKey sessionKey, String key, Class type) { + String sessionId = sessionKey.toIdentifier(); + validateSessionId(sessionId); + validateStateKey(key); + + String selectSql = + "SELECT state_data FROM " + + getFullTableName() + + " WHERE session_id = ? AND state_key = ? AND item_index = ?"; + + try (Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(selectSql)) { + + stmt.setString(1, sessionId); + stmt.setString(2, key); + stmt.setInt(3, SINGLE_STATE_INDEX); + + try (ResultSet rs = stmt.executeQuery()) { + if (rs.next()) { + String json = rs.getString("state_data"); + return Optional.of(JsonUtils.getJsonCodec().fromJson(json, type)); + } + return Optional.empty(); + } + + } catch (Exception e) { + throw new RuntimeException("Failed to get state: " + key, e); + } + } + + @Override + public List getList(SessionKey sessionKey, String key, Class itemType) { + String sessionId = sessionKey.toIdentifier(); + validateSessionId(sessionId); + validateStateKey(key); + + String selectSql = + "SELECT state_data FROM " + + getFullTableName() + + " WHERE session_id = ? AND state_key = ?" + + " ORDER BY item_index"; + + try (Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(selectSql)) { + + stmt.setString(1, sessionId); + stmt.setString(2, key); + + try (ResultSet rs = stmt.executeQuery()) { + List result = new ArrayList<>(); + while (rs.next()) { + String json = rs.getString("state_data"); + result.add(JsonUtils.getJsonCodec().fromJson(json, itemType)); + } + return result; + } + + } catch (Exception e) { + throw new RuntimeException("Failed to get list: " + key, e); + } + } + + @Override + public boolean exists(SessionKey sessionKey) { + String sessionId = sessionKey.toIdentifier(); + validateSessionId(sessionId); + + String existsSql = "SELECT 1 FROM " + getFullTableName() + " WHERE session_id = ? LIMIT 1"; + + try (Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(existsSql)) { + + stmt.setString(1, sessionId); + try (ResultSet rs = stmt.executeQuery()) { + return rs.next(); + } + + } catch (SQLException e) { + throw new RuntimeException("Failed to check session existence: " + sessionId, e); + } + } + + @Override + public void delete(SessionKey sessionKey) { + String sessionId = sessionKey.toIdentifier(); + validateSessionId(sessionId); + + String deleteSql = "DELETE FROM " + getFullTableName() + " WHERE session_id = ?"; + + try (Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(deleteSql)) { + + executeWriteTransaction( + conn, + () -> { + stmt.setString(1, sessionId); + stmt.executeUpdate(); + return null; + }); + + } catch (Exception e) { + throw new RuntimeException("Failed to delete session: " + sessionId, e); + } + } + + @Override + public Set listSessionKeys() { + String listSql = + "SELECT DISTINCT session_id FROM " + getFullTableName() + " ORDER BY session_id"; + + try (Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(listSql); + ResultSet rs = stmt.executeQuery()) { + + Set sessionKeys = new HashSet<>(); + while (rs.next()) { + sessionKeys.add(SimpleSessionKey.of(rs.getString("session_id"))); + } + return sessionKeys; + + } catch (SQLException e) { + throw new RuntimeException("Failed to list sessions", e); + } + } + + /** + * Close the session and release any resources. + * + *

Note: This implementation does not close the DataSource as it may be shared across + * multiple sessions. The caller is responsible for managing the DataSource lifecycle. + */ + @Override + public void close() { + // DataSource is managed externally, so we don't close it here + } + + /** + * Get the database name used for storing sessions. + * + * @return The database name + */ + public String getDatabaseName() { + return databaseName; + } + + /** + * Get the table name used for storing sessions. + * + * @return The table name + */ + public String getTableName() { + return tableName; + } + + /** + * Get the DataSource used for database connections. + * + * @return The DataSource instance + */ + public DataSource getDataSource() { + return dataSource; + } + + /** + * Clear all sessions from the database (for testing or cleanup). + * + * @return Number of rows deleted + * @deprecated Use {@link #truncateAllSessions()} instead + */ + @Deprecated + public int clearAllSessions() { + String clearSql = "DELETE FROM " + getFullTableName(); + + try (Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(clearSql)) { + + return executeWriteTransaction(conn, stmt::executeUpdate); + + } catch (Exception e) { + throw new RuntimeException("Failed to clear sessions", e); + } + } + + /** + * Truncate session table from the database (for testing or cleanup). + *

+ * This method clears all session records by executing a TRUNCATE TABLE statement on the + * sessions table. TRUNCATE is faster than DELETE as it resets the table without logging + * individual row deletions and reclaims storage space immediately. + * + *

+ * Note: The TRUNCATE operation requires DROP privileges in MySQL. + * + * @return typically 0 if successful + */ + public int truncateAllSessions() { + String clearSql = "TRUNCATE TABLE " + getFullTableName(); + + try (Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(clearSql)) { + + return executeWriteTransaction(conn, stmt::executeUpdate); + + } catch (Exception e) { + throw new RuntimeException("Failed to truncate sessions", e); + } + } + + private T executeWriteTransaction(Connection conn, SqlOperation operation) + throws Exception { + boolean originalAutoCommit = conn.getAutoCommit(); + if (originalAutoCommit) { + conn.setAutoCommit(false); + } + + try { + T result = operation.execute(); + conn.commit(); + return result; + } catch (Exception e) { + try { + conn.rollback(); + } catch (SQLException rollbackException) { + e.addSuppressed(rollbackException); + } + throw e; + } finally { + if (originalAutoCommit) { + conn.setAutoCommit(true); + } + } + } + + @FunctionalInterface + private interface SqlOperation { + T execute() throws Exception; + } + + /** + * Validate a session ID format. + * + * @param sessionId Session ID to validate + * @throws IllegalArgumentException if session ID is invalid + */ + protected void validateSessionId(String sessionId) { + if (sessionId == null || sessionId.trim().isEmpty()) { + throw new IllegalArgumentException("Session ID cannot be null or empty"); + } + if (sessionId.contains("/") || sessionId.contains("\\")) { + throw new IllegalArgumentException("Session ID cannot contain path separators"); + } + if (sessionId.length() > 255) { + throw new IllegalArgumentException("Session ID cannot exceed 255 characters"); + } + } + + /** + * Validate a state key format. + * + * @param key State key to validate + * @throws IllegalArgumentException if state key is invalid + */ + private void validateStateKey(String key) { + if (key == null || key.trim().isEmpty()) { + throw new IllegalArgumentException("State key cannot be null or empty"); + } + if (key.length() > 255) { + throw new IllegalArgumentException("State key cannot exceed 255 characters"); + } + } + + /** + * Validate a database or table identifier to prevent SQL injection. + * + *

This method ensures that identifiers only contain safe characters (alphanumeric, + * underscores, and hyphens) and start with a letter or underscore. This is critical for + * security since database and table names cannot be parameterized in prepared statements. + * + * @param identifier The identifier to validate (database name or table name) + * @param identifierType Description of the identifier type for error messages + * @throws IllegalArgumentException if the identifier is invalid or contains unsafe characters + */ + private void validateIdentifier(String identifier, String identifierType) { + if (identifier == null || identifier.isEmpty()) { + throw new IllegalArgumentException(identifierType + " cannot be null or empty"); + } + if (identifier.length() > MAX_IDENTIFIER_LENGTH) { + throw new IllegalArgumentException( + identifierType + " cannot exceed " + MAX_IDENTIFIER_LENGTH + " characters"); + } + if (!IDENTIFIER_PATTERN.matcher(identifier).matches()) { + throw new IllegalArgumentException( + identifierType + + " contains invalid characters. Only alphanumeric characters," + + " underscores, and hyphens are allowed, and it must start with a" + + " letter or underscore. Invalid value: " + + identifier); + } + } +} diff --git a/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/MysqlSessionTest.java b/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/MysqlSessionTest.java index 9f1766b72..90eb1fd5b 100644 --- a/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/MysqlSessionTest.java +++ b/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/MysqlSessionTest.java @@ -1,551 +1,628 @@ -/* - * Copyright 2024-2026 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.agentscope.core.session.mysql; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.atLeast; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import io.agentscope.core.state.SessionKey; -import io.agentscope.core.state.SimpleSessionKey; -import io.agentscope.core.state.State; -import java.sql.Connection; -import java.sql.PreparedStatement; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.List; -import java.util.Optional; -import java.util.Set; -import javax.sql.DataSource; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Test; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; - -/** - * Unit tests for MysqlSession. - * - *

These tests use mocked DataSource and Connection to verify the behavior of MysqlSession - * without requiring an actual MySQL database. - */ -@DisplayName("MysqlSession Tests") -public class MysqlSessionTest { - - @Mock private DataSource mockDataSource; - - @Mock private Connection mockConnection; - - @Mock private PreparedStatement mockStatement; - - @Mock private ResultSet mockResultSet; - - private AutoCloseable mockitoCloseable; - - @BeforeEach - void setUp() throws SQLException { - mockitoCloseable = MockitoAnnotations.openMocks(this); - when(mockDataSource.getConnection()).thenReturn(mockConnection); - when(mockConnection.prepareStatement(anyString())).thenReturn(mockStatement); - } - - @AfterEach - void tearDown() throws Exception { - if (mockitoCloseable != null) { - mockitoCloseable.close(); - } - } - - @Test - @DisplayName("Should throw exception when DataSource is null") - void testConstructorWithNullDataSource() { - assertThrows( - IllegalArgumentException.class, - () -> new MysqlSession(null), - "DataSource cannot be null"); - } - - @Test - @DisplayName("Should throw exception when DataSource is null with createIfNotExist flag") - void testConstructorWithNullDataSourceAndCreateIfNotExist() { - assertThrows( - IllegalArgumentException.class, - () -> new MysqlSession(null, true), - "DataSource cannot be null"); - } - - @Test - @DisplayName("Should create session with createIfNotExist=true") - void testConstructorWithCreateIfNotExistTrue() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - - MysqlSession session = new MysqlSession(mockDataSource, true); - - assertEquals("agentscope", session.getDatabaseName()); - assertEquals("agentscope_sessions", session.getTableName()); - assertEquals(mockDataSource, session.getDataSource()); - } - - @Test - @DisplayName("Should throw exception when database does not exist and createIfNotExist=false") - void testConstructorWithCreateIfNotExistFalseAndDatabaseNotExist() throws SQLException { - when(mockStatement.executeQuery()).thenReturn(mockResultSet); - when(mockResultSet.next()).thenReturn(false); - - assertThrows( - IllegalStateException.class, - () -> new MysqlSession(mockDataSource, false), - "Database does not exist"); - } - - @Test - @DisplayName("Should throw exception when table does not exist and createIfNotExist=false") - void testConstructorWithCreateIfNotExistFalseAndTableNotExist() throws SQLException { - when(mockStatement.executeQuery()).thenReturn(mockResultSet); - when(mockResultSet.next()).thenReturn(true, false); - - assertThrows( - IllegalStateException.class, - () -> new MysqlSession(mockDataSource, false), - "Table does not exist"); - } - - @Test - @DisplayName("Should create session when both database and table exist") - void testConstructorWithCreateIfNotExistFalseAndBothExist() throws SQLException { - when(mockStatement.executeQuery()).thenReturn(mockResultSet); - when(mockResultSet.next()).thenReturn(true, true); - - MysqlSession session = new MysqlSession(mockDataSource, false); - - assertEquals("agentscope", session.getDatabaseName()); - assertEquals("agentscope_sessions", session.getTableName()); - } - - @Test - @DisplayName("Should create session with custom database and table name") - void testConstructorWithCustomDatabaseAndTableName() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - - MysqlSession session = new MysqlSession(mockDataSource, "custom_db", "custom_table", true); - - assertEquals("custom_db", session.getDatabaseName()); - assertEquals("custom_table", session.getTableName()); - } - - @Test - @DisplayName("Should use default database name when null is provided") - void testConstructorWithNullDatabaseNameUsesDefault() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - - MysqlSession session = new MysqlSession(mockDataSource, null, "custom_table", true); - - assertEquals("agentscope", session.getDatabaseName()); - assertEquals("custom_table", session.getTableName()); - } - - @Test - @DisplayName("Should use default database name when empty string is provided") - void testConstructorWithEmptyDatabaseNameUsesDefault() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - - MysqlSession session = new MysqlSession(mockDataSource, " ", "custom_table", true); - - assertEquals("agentscope", session.getDatabaseName()); - assertEquals("custom_table", session.getTableName()); - } - - @Test - @DisplayName("Should use default table name when null is provided") - void testConstructorWithNullTableNameUsesDefault() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - - MysqlSession session = new MysqlSession(mockDataSource, "custom_db", null, true); - - assertEquals("custom_db", session.getDatabaseName()); - assertEquals("agentscope_sessions", session.getTableName()); - } - - @Test - @DisplayName("Should use default table name when empty string is provided") - void testConstructorWithEmptyTableNameUsesDefault() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - - MysqlSession session = new MysqlSession(mockDataSource, "custom_db", "", true); - - assertEquals("custom_db", session.getDatabaseName()); - assertEquals("agentscope_sessions", session.getTableName()); - } - - @Test - @DisplayName("Should get DataSource correctly") - void testGetDataSource() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - - MysqlSession session = new MysqlSession(mockDataSource, true); - assertEquals(mockDataSource, session.getDataSource()); - } - - @Test - @DisplayName("Should save and get single state correctly") - void testSaveAndGetSingleState() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - when(mockStatement.executeUpdate()).thenReturn(1); - when(mockStatement.executeQuery()).thenReturn(mockResultSet); - when(mockResultSet.next()).thenReturn(true); - when(mockResultSet.getString("state_data")) - .thenReturn("{\"value\":\"test_value\",\"count\":42}"); - - MysqlSession session = new MysqlSession(mockDataSource, true); - SessionKey sessionKey = SimpleSessionKey.of("session1"); - TestState state = new TestState("test_value", 42); - - // Save state - session.save(sessionKey, "testModule", state); - - // Verify save operations - verify(mockStatement, atLeast(1)).executeUpdate(); - - // Get state - Optional loaded = session.get(sessionKey, "testModule", TestState.class); - assertTrue(loaded.isPresent()); - assertEquals("test_value", loaded.get().value()); - assertEquals(42, loaded.get().count()); - } - - @Test - @DisplayName("Should save and get list state correctly") - void testSaveAndGetListState() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - when(mockStatement.executeUpdate()).thenReturn(1); - when(mockStatement.executeQuery()).thenReturn(mockResultSet); - - // Mock sequence for save(): - // 1. getStoredHash() query - no hash found (next=false) - // 2. getListCount() query - no items (next=true, wasNull=true) - // Then for getList(): - // 3. getList() query - 2 rows (next=true, true, false) - when(mockResultSet.next()).thenReturn(false, true, true, true, false); - when(mockResultSet.getInt("max_index")).thenReturn(0); - when(mockResultSet.wasNull()).thenReturn(true); - when(mockResultSet.getString("state_data")) - .thenReturn("{\"value\":\"value1\",\"count\":1}") - .thenReturn("{\"value\":\"value2\",\"count\":2}"); - - MysqlSession session = new MysqlSession(mockDataSource, true); - SessionKey sessionKey = SimpleSessionKey.of("session1"); - List states = List.of(new TestState("value1", 1), new TestState("value2", 2)); - - // Save list state - session.save(sessionKey, "testList", states); - - // Get list state - List loaded = session.getList(sessionKey, "testList", TestState.class); - assertEquals(2, loaded.size()); - assertEquals("value1", loaded.get(0).value()); - assertEquals("value2", loaded.get(1).value()); - } - - @Test - @DisplayName("Should return empty for non-existent state") - void testGetNonExistentState() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - when(mockStatement.executeQuery()).thenReturn(mockResultSet); - when(mockResultSet.next()).thenReturn(false); - - MysqlSession session = new MysqlSession(mockDataSource, true); - SessionKey sessionKey = SimpleSessionKey.of("non_existent"); - - Optional state = session.get(sessionKey, "testModule", TestState.class); - assertFalse(state.isPresent()); - } - - @Test - @DisplayName("Should return empty list for non-existent list state") - void testGetNonExistentListState() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - when(mockStatement.executeQuery()).thenReturn(mockResultSet); - when(mockResultSet.next()).thenReturn(false); - - MysqlSession session = new MysqlSession(mockDataSource, true); - SessionKey sessionKey = SimpleSessionKey.of("non_existent"); - - List states = session.getList(sessionKey, "testList", TestState.class); - assertTrue(states.isEmpty()); - } - - @Test - @DisplayName("Should return true when session exists") - void testSessionExists() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - when(mockStatement.executeQuery()).thenReturn(mockResultSet); - when(mockResultSet.next()).thenReturn(true); - - MysqlSession session = new MysqlSession(mockDataSource, true); - SessionKey sessionKey = SimpleSessionKey.of("session1"); - - assertTrue(session.exists(sessionKey)); - } - - @Test - @DisplayName("Should return false when session does not exist") - void testSessionDoesNotExist() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - when(mockStatement.executeQuery()).thenReturn(mockResultSet); - when(mockResultSet.next()).thenReturn(false); - - MysqlSession session = new MysqlSession(mockDataSource, true); - SessionKey sessionKey = SimpleSessionKey.of("non_existent"); - - assertFalse(session.exists(sessionKey)); - } - - @Test - @DisplayName("Should delete session correctly") - void testDeleteSession() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - when(mockStatement.executeUpdate()).thenReturn(1); - - MysqlSession session = new MysqlSession(mockDataSource, true); - SessionKey sessionKey = SimpleSessionKey.of("session1"); - - session.delete(sessionKey); - - verify(mockStatement).setString(1, "session1"); - verify(mockStatement).executeUpdate(); - } - - @Test - @DisplayName("Should list all session keys when empty") - void testListSessionKeysEmpty() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - when(mockStatement.executeQuery()).thenReturn(mockResultSet); - when(mockResultSet.next()).thenReturn(false); - - MysqlSession session = new MysqlSession(mockDataSource, true); - Set sessionKeys = session.listSessionKeys(); - - assertTrue(sessionKeys.isEmpty()); - } - - @Test - @DisplayName("Should list all session keys") - void testListSessionKeysWithResults() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - when(mockStatement.executeQuery()).thenReturn(mockResultSet); - when(mockResultSet.next()).thenReturn(true, true, false); - when(mockResultSet.getString("session_id")).thenReturn("session1", "session2"); - - MysqlSession session = new MysqlSession(mockDataSource, true); - Set sessionKeys = session.listSessionKeys(); - - assertEquals(2, sessionKeys.size()); - assertTrue(sessionKeys.contains(SimpleSessionKey.of("session1"))); - assertTrue(sessionKeys.contains(SimpleSessionKey.of("session2"))); - } - - @Test - @DisplayName("Should clear all sessions") - void testClearAllSessions() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - when(mockStatement.executeUpdate()).thenReturn(5); - - MysqlSession session = new MysqlSession(mockDataSource, true); - int deleted = session.clearAllSessions(); - - assertEquals(5, deleted); - } - - @Test - @DisplayName("Should not close DataSource when closing session") - void testClose() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - - MysqlSession session = new MysqlSession(mockDataSource, true); - session.close(); - assertEquals(mockDataSource, session.getDataSource()); - } - - // ==================== SQL Injection Prevention Tests ==================== - - @Test - @DisplayName("Should reject database name with semicolon (SQL injection)") - void testConstructorRejectsDatabaseNameWithSemicolon() { - assertThrows( - IllegalArgumentException.class, - () -> - new MysqlSession( - mockDataSource, "db; DROP DATABASE mysql; --", "table", true), - "Database name contains invalid characters"); - } - - @Test - @DisplayName("Should reject table name with semicolon (SQL injection)") - void testConstructorRejectsTableNameWithSemicolon() { - assertThrows( - IllegalArgumentException.class, - () -> - new MysqlSession( - mockDataSource, "valid_db", "table; DROP TABLE users; --", true), - "Table name contains invalid characters"); - } - - @Test - @DisplayName("Should reject database name with space") - void testConstructorRejectsDatabaseNameWithSpace() { - assertThrows( - IllegalArgumentException.class, - () -> new MysqlSession(mockDataSource, "db name", "table", true), - "Database name contains invalid characters"); - } - - @Test - @DisplayName("Should reject table name with space") - void testConstructorRejectsTableNameWithSpace() { - assertThrows( - IllegalArgumentException.class, - () -> new MysqlSession(mockDataSource, "valid_db", "table name", true), - "Table name contains invalid characters"); - } - - @Test - @DisplayName("Should reject database name starting with number") - void testConstructorRejectsDatabaseNameStartingWithNumber() { - assertThrows( - IllegalArgumentException.class, - () -> new MysqlSession(mockDataSource, "123db", "table", true), - "Database name contains invalid characters"); - } - - @Test - @DisplayName("Should reject table name starting with number") - void testConstructorRejectsTableNameStartingWithNumber() { - assertThrows( - IllegalArgumentException.class, - () -> new MysqlSession(mockDataSource, "valid_db", "123table", true), - "Table name contains invalid characters"); - } - - @Test - @DisplayName("Should reject database name exceeding max length") - void testConstructorRejectsDatabaseNameExceedingMaxLength() { - String longName = "a".repeat(65); - assertThrows( - IllegalArgumentException.class, - () -> new MysqlSession(mockDataSource, longName, "table", true), - "Database name cannot exceed 64 characters"); - } - - @Test - @DisplayName("Should reject table name exceeding max length") - void testConstructorRejectsTableNameExceedingMaxLength() { - String longName = "a".repeat(65); - assertThrows( - IllegalArgumentException.class, - () -> new MysqlSession(mockDataSource, "valid_db", longName, true), - "Table name cannot exceed 64 characters"); - } - - @Test - @DisplayName("Should accept valid database and table names") - void testConstructorAcceptsValidDatabaseAndTableNames() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - - MysqlSession session = - new MysqlSession(mockDataSource, "my_database_123", "my_table_456", true); - - assertEquals("my_database_123", session.getDatabaseName()); - assertEquals("my_table_456", session.getTableName()); - } - - @Test - @DisplayName("Should accept names starting with underscore") - void testConstructorAcceptsNameStartingWithUnderscore() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - - MysqlSession session = - new MysqlSession(mockDataSource, "_private_db", "_private_table", true); - - assertEquals("_private_db", session.getDatabaseName()); - assertEquals("_private_table", session.getTableName()); - } - - @Test - @DisplayName("Should accept max length names") - void testConstructorAcceptsMaxLengthNames() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - - String maxLengthName = "a".repeat(64); - MysqlSession session = new MysqlSession(mockDataSource, maxLengthName, maxLengthName, true); - - assertEquals(maxLengthName, session.getDatabaseName()); - assertEquals(maxLengthName, session.getTableName()); - } - - @Test - @DisplayName("Should accept database name with hyphens") - void testConstructorAcceptsDatabaseNameWithHyphens() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - - MysqlSession session = new MysqlSession(mockDataSource, "my-test-db", "my_table", true); - - assertEquals("my-test-db", session.getDatabaseName()); - assertEquals("my_table", session.getTableName()); - } - - @Test - @DisplayName("Should accept table name with hyphens") - void testConstructorAcceptsTableNameWithHyphens() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - - MysqlSession session = new MysqlSession(mockDataSource, "my_db", "my-test-table", true); - - assertEquals("my_db", session.getDatabaseName()); - assertEquals("my-test-table", session.getTableName()); - } - - @Test - @DisplayName("Should accept database and table names with hyphens") - void testConstructorAcceptsDatabaseAndTableNamesWithHyphens() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - - MysqlSession session = new MysqlSession(mockDataSource, "xxx-xxx-xx", "test-table", true); - - assertEquals("xxx-xxx-xx", session.getDatabaseName()); - assertEquals("test-table", session.getTableName()); - } - - @Test - @DisplayName("Should accept name with underscore and hyphen") - void testConstructorAcceptsNameWithUnderscoreAndHyphen() throws SQLException { - when(mockStatement.execute()).thenReturn(true); - - MysqlSession session = - new MysqlSession(mockDataSource, "my_test-db", "my_table-test", true); - - assertEquals("my_test-db", session.getDatabaseName()); - assertEquals("my_table-test", session.getTableName()); - } - - /** Simple test state record for testing. */ - public record TestState(String value, int count) implements State {} -} +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.agentscope.core.session.mysql; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.agentscope.core.state.SessionKey; +import io.agentscope.core.state.SimpleSessionKey; +import io.agentscope.core.state.State; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import javax.sql.DataSource; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** + * Unit tests for MysqlSession. + * + *

These tests use mocked DataSource and Connection to verify the behavior of MysqlSession + * without requiring an actual MySQL database. + */ +@DisplayName("MysqlSession Tests") +public class MysqlSessionTest { + + @Mock private DataSource mockDataSource; + + @Mock private Connection mockConnection; + + @Mock private PreparedStatement mockStatement; + + @Mock private ResultSet mockResultSet; + + private AutoCloseable mockitoCloseable; + + @BeforeEach + void setUp() throws SQLException { + mockitoCloseable = MockitoAnnotations.openMocks(this); + when(mockDataSource.getConnection()).thenReturn(mockConnection); + when(mockConnection.prepareStatement(anyString())).thenReturn(mockStatement); + } + + @AfterEach + void tearDown() throws Exception { + if (mockitoCloseable != null) { + mockitoCloseable.close(); + } + } + + @Test + @DisplayName("Should throw exception when DataSource is null") + void testConstructorWithNullDataSource() { + assertThrows( + IllegalArgumentException.class, + () -> new MysqlSession(null), + "DataSource cannot be null"); + } + + @Test + @DisplayName("Should throw exception when DataSource is null with createIfNotExist flag") + void testConstructorWithNullDataSourceAndCreateIfNotExist() { + assertThrows( + IllegalArgumentException.class, + () -> new MysqlSession(null, true), + "DataSource cannot be null"); + } + + @Test + @DisplayName("Should create session with createIfNotExist=true") + void testConstructorWithCreateIfNotExistTrue() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + + MysqlSession session = new MysqlSession(mockDataSource, true); + + assertEquals("agentscope", session.getDatabaseName()); + assertEquals("agentscope_sessions", session.getTableName()); + assertEquals(mockDataSource, session.getDataSource()); + } + + @Test + @DisplayName("Should throw exception when database does not exist and createIfNotExist=false") + void testConstructorWithCreateIfNotExistFalseAndDatabaseNotExist() throws SQLException { + when(mockStatement.executeQuery()).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(false); + + assertThrows( + IllegalStateException.class, + () -> new MysqlSession(mockDataSource, false), + "Database does not exist"); + } + + @Test + @DisplayName("Should throw exception when table does not exist and createIfNotExist=false") + void testConstructorWithCreateIfNotExistFalseAndTableNotExist() throws SQLException { + when(mockStatement.executeQuery()).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true, false); + + assertThrows( + IllegalStateException.class, + () -> new MysqlSession(mockDataSource, false), + "Table does not exist"); + } + + @Test + @DisplayName("Should create session when both database and table exist") + void testConstructorWithCreateIfNotExistFalseAndBothExist() throws SQLException { + when(mockStatement.executeQuery()).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true, true); + + MysqlSession session = new MysqlSession(mockDataSource, false); + + assertEquals("agentscope", session.getDatabaseName()); + assertEquals("agentscope_sessions", session.getTableName()); + } + + @Test + @DisplayName("Should create session with custom database and table name") + void testConstructorWithCustomDatabaseAndTableName() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + + MysqlSession session = new MysqlSession(mockDataSource, "custom_db", "custom_table", true); + + assertEquals("custom_db", session.getDatabaseName()); + assertEquals("custom_table", session.getTableName()); + } + + @Test + @DisplayName("Should use default database name when null is provided") + void testConstructorWithNullDatabaseNameUsesDefault() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + + MysqlSession session = new MysqlSession(mockDataSource, null, "custom_table", true); + + assertEquals("agentscope", session.getDatabaseName()); + assertEquals("custom_table", session.getTableName()); + } + + @Test + @DisplayName("Should use default database name when empty string is provided") + void testConstructorWithEmptyDatabaseNameUsesDefault() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + + MysqlSession session = new MysqlSession(mockDataSource, " ", "custom_table", true); + + assertEquals("agentscope", session.getDatabaseName()); + assertEquals("custom_table", session.getTableName()); + } + + @Test + @DisplayName("Should use default table name when null is provided") + void testConstructorWithNullTableNameUsesDefault() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + + MysqlSession session = new MysqlSession(mockDataSource, "custom_db", null, true); + + assertEquals("custom_db", session.getDatabaseName()); + assertEquals("agentscope_sessions", session.getTableName()); + } + + @Test + @DisplayName("Should use default table name when empty string is provided") + void testConstructorWithEmptyTableNameUsesDefault() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + + MysqlSession session = new MysqlSession(mockDataSource, "custom_db", "", true); + + assertEquals("custom_db", session.getDatabaseName()); + assertEquals("agentscope_sessions", session.getTableName()); + } + + @Test + @DisplayName("Should get DataSource correctly") + void testGetDataSource() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + + MysqlSession session = new MysqlSession(mockDataSource, true); + assertEquals(mockDataSource, session.getDataSource()); + } + + @Test + @DisplayName("Should save and get single state correctly") + void testSaveAndGetSingleState() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + when(mockStatement.executeUpdate()).thenReturn(1); + when(mockStatement.executeQuery()).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true); + when(mockResultSet.getString("state_data")) + .thenReturn("{\"value\":\"test_value\",\"count\":42}"); + + MysqlSession session = new MysqlSession(mockDataSource, true); + SessionKey sessionKey = SimpleSessionKey.of("session1"); + TestState state = new TestState("test_value", 42); + + // Save state + session.save(sessionKey, "testModule", state); + + // Verify save operations + verify(mockStatement, atLeast(1)).executeUpdate(); + + // Get state + Optional loaded = session.get(sessionKey, "testModule", TestState.class); + assertTrue(loaded.isPresent()); + assertEquals("test_value", loaded.get().value()); + assertEquals(42, loaded.get().count()); + } + + @Test + @DisplayName("Should save and get list state correctly") + void testSaveAndGetListState() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + when(mockStatement.executeUpdate()).thenReturn(1); + when(mockStatement.executeQuery()).thenReturn(mockResultSet); + + // Mock sequence for save(): + // 1. getStoredHash() query - no hash found (next=false) + // 2. getListCount() query - no items (next=true, wasNull=true) + // Then for getList(): + // 3. getList() query - 2 rows (next=true, true, false) + when(mockResultSet.next()).thenReturn(false, true, true, true, false); + when(mockResultSet.getInt("max_index")).thenReturn(0); + when(mockResultSet.wasNull()).thenReturn(true); + when(mockResultSet.getString("state_data")) + .thenReturn("{\"value\":\"value1\",\"count\":1}") + .thenReturn("{\"value\":\"value2\",\"count\":2}"); + + MysqlSession session = new MysqlSession(mockDataSource, true); + SessionKey sessionKey = SimpleSessionKey.of("session1"); + List states = List.of(new TestState("value1", 1), new TestState("value2", 2)); + + // Save list state + session.save(sessionKey, "testList", states); + + // Get list state + List loaded = session.getList(sessionKey, "testList", TestState.class); + assertEquals(2, loaded.size()); + assertEquals("value1", loaded.get(0).value()); + assertEquals("value2", loaded.get(1).value()); + } + + @Test + @DisplayName("Should commit single-state writes when connection auto-commit is disabled") + void testSaveSingleStateCommitsWhenAutoCommitDisabled() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + when(mockConnection.getAutoCommit()).thenReturn(false); + when(mockStatement.executeUpdate()).thenReturn(1); + + MysqlSession session = new MysqlSession(mockDataSource, true); + SessionKey sessionKey = SimpleSessionKey.of("session1"); + + session.save(sessionKey, "testModule", new TestState("test_value", 42)); + + verify(mockConnection).getAutoCommit(); + verify(mockConnection).commit(); + verify(mockConnection, never()).setAutoCommit(false); + verify(mockConnection, never()).setAutoCommit(true); + } + + @Test + @DisplayName("Should commit append writes when connection auto-commit is disabled") + void testSaveListAppendCommitsWhenAutoCommitDisabled() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + when(mockConnection.getAutoCommit()).thenReturn(false); + when(mockStatement.executeQuery()).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(false, true); + when(mockResultSet.getInt("max_index")).thenReturn(0); + when(mockResultSet.wasNull()).thenReturn(true); + when(mockStatement.executeUpdate()).thenReturn(1); + + MysqlSession session = new MysqlSession(mockDataSource, true); + SessionKey sessionKey = SimpleSessionKey.of("session1"); + List states = List.of(new TestState("value1", 1), new TestState("value2", 2)); + + session.save(sessionKey, "testList", states); + + verify(mockConnection).getAutoCommit(); + verify(mockConnection).commit(); + verify(mockConnection, never()).setAutoCommit(false); + verify(mockConnection, never()).setAutoCommit(true); + } + + @Test + @DisplayName("Should restore auto-commit after full rewrite transaction") + void testSaveListFullRewriteRestoresOriginalAutoCommit() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + when(mockConnection.getAutoCommit()).thenReturn(true); + when(mockStatement.executeQuery()).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true, true); + when(mockResultSet.getString("state_data")).thenReturn("stale_hash"); + when(mockResultSet.getInt("max_index")).thenReturn(0); + when(mockResultSet.wasNull()).thenReturn(false); + when(mockStatement.executeUpdate()).thenReturn(1); + + MysqlSession session = new MysqlSession(mockDataSource, true); + SessionKey sessionKey = SimpleSessionKey.of("session1"); + List states = List.of(new TestState("value1", 1), new TestState("value2", 2)); + + session.save(sessionKey, "testList", states); + + verify(mockConnection).setAutoCommit(false); + verify(mockConnection).commit(); + verify(mockConnection).setAutoCommit(true); + } + + @Test + @DisplayName("Should return empty for non-existent state") + void testGetNonExistentState() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + when(mockStatement.executeQuery()).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(false); + + MysqlSession session = new MysqlSession(mockDataSource, true); + SessionKey sessionKey = SimpleSessionKey.of("non_existent"); + + Optional state = session.get(sessionKey, "testModule", TestState.class); + assertFalse(state.isPresent()); + } + + @Test + @DisplayName("Should return empty list for non-existent list state") + void testGetNonExistentListState() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + when(mockStatement.executeQuery()).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(false); + + MysqlSession session = new MysqlSession(mockDataSource, true); + SessionKey sessionKey = SimpleSessionKey.of("non_existent"); + + List states = session.getList(sessionKey, "testList", TestState.class); + assertTrue(states.isEmpty()); + } + + @Test + @DisplayName("Should return true when session exists") + void testSessionExists() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + when(mockStatement.executeQuery()).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true); + + MysqlSession session = new MysqlSession(mockDataSource, true); + SessionKey sessionKey = SimpleSessionKey.of("session1"); + + assertTrue(session.exists(sessionKey)); + } + + @Test + @DisplayName("Should return false when session does not exist") + void testSessionDoesNotExist() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + when(mockStatement.executeQuery()).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(false); + + MysqlSession session = new MysqlSession(mockDataSource, true); + SessionKey sessionKey = SimpleSessionKey.of("non_existent"); + + assertFalse(session.exists(sessionKey)); + } + + @Test + @DisplayName("Should delete session correctly") + void testDeleteSession() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + when(mockStatement.executeUpdate()).thenReturn(1); + + MysqlSession session = new MysqlSession(mockDataSource, true); + SessionKey sessionKey = SimpleSessionKey.of("session1"); + + session.delete(sessionKey); + + verify(mockStatement).setString(1, "session1"); + verify(mockStatement).executeUpdate(); + } + + @Test + @DisplayName("Should list all session keys when empty") + void testListSessionKeysEmpty() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + when(mockStatement.executeQuery()).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(false); + + MysqlSession session = new MysqlSession(mockDataSource, true); + Set sessionKeys = session.listSessionKeys(); + + assertTrue(sessionKeys.isEmpty()); + } + + @Test + @DisplayName("Should list all session keys") + void testListSessionKeysWithResults() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + when(mockStatement.executeQuery()).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true, true, false); + when(mockResultSet.getString("session_id")).thenReturn("session1", "session2"); + + MysqlSession session = new MysqlSession(mockDataSource, true); + Set sessionKeys = session.listSessionKeys(); + + assertEquals(2, sessionKeys.size()); + assertTrue(sessionKeys.contains(SimpleSessionKey.of("session1"))); + assertTrue(sessionKeys.contains(SimpleSessionKey.of("session2"))); + } + + @Test + @DisplayName("Should clear all sessions") + void testClearAllSessions() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + when(mockStatement.executeUpdate()).thenReturn(5); + + MysqlSession session = new MysqlSession(mockDataSource, true); + int deleted = session.clearAllSessions(); + + assertEquals(5, deleted); + } + + @Test + @DisplayName("Should truncate session table") + void testTruncateAllSessions() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + when(mockStatement.executeUpdate()).thenReturn(0); + + MysqlSession session = new MysqlSession(mockDataSource, true); + int success = session.truncateAllSessions(); + + assertEquals(0, success); + } + + @Test + @DisplayName("Should not close DataSource when closing session") + void testClose() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + + MysqlSession session = new MysqlSession(mockDataSource, true); + session.close(); + assertEquals(mockDataSource, session.getDataSource()); + } + + // ==================== SQL Injection Prevention Tests ==================== + + @Test + @DisplayName("Should reject database name with semicolon (SQL injection)") + void testConstructorRejectsDatabaseNameWithSemicolon() { + assertThrows( + IllegalArgumentException.class, + () -> + new MysqlSession( + mockDataSource, "db; DROP DATABASE mysql; --", "table", true), + "Database name contains invalid characters"); + } + + @Test + @DisplayName("Should reject table name with semicolon (SQL injection)") + void testConstructorRejectsTableNameWithSemicolon() { + assertThrows( + IllegalArgumentException.class, + () -> + new MysqlSession( + mockDataSource, "valid_db", "table; DROP TABLE users; --", true), + "Table name contains invalid characters"); + } + + @Test + @DisplayName("Should reject database name with space") + void testConstructorRejectsDatabaseNameWithSpace() { + assertThrows( + IllegalArgumentException.class, + () -> new MysqlSession(mockDataSource, "db name", "table", true), + "Database name contains invalid characters"); + } + + @Test + @DisplayName("Should reject table name with space") + void testConstructorRejectsTableNameWithSpace() { + assertThrows( + IllegalArgumentException.class, + () -> new MysqlSession(mockDataSource, "valid_db", "table name", true), + "Table name contains invalid characters"); + } + + @Test + @DisplayName("Should reject database name starting with number") + void testConstructorRejectsDatabaseNameStartingWithNumber() { + assertThrows( + IllegalArgumentException.class, + () -> new MysqlSession(mockDataSource, "123db", "table", true), + "Database name contains invalid characters"); + } + + @Test + @DisplayName("Should reject table name starting with number") + void testConstructorRejectsTableNameStartingWithNumber() { + assertThrows( + IllegalArgumentException.class, + () -> new MysqlSession(mockDataSource, "valid_db", "123table", true), + "Table name contains invalid characters"); + } + + @Test + @DisplayName("Should reject database name exceeding max length") + void testConstructorRejectsDatabaseNameExceedingMaxLength() { + String longName = "a".repeat(65); + assertThrows( + IllegalArgumentException.class, + () -> new MysqlSession(mockDataSource, longName, "table", true), + "Database name cannot exceed 64 characters"); + } + + @Test + @DisplayName("Should reject table name exceeding max length") + void testConstructorRejectsTableNameExceedingMaxLength() { + String longName = "a".repeat(65); + assertThrows( + IllegalArgumentException.class, + () -> new MysqlSession(mockDataSource, "valid_db", longName, true), + "Table name cannot exceed 64 characters"); + } + + @Test + @DisplayName("Should accept valid database and table names") + void testConstructorAcceptsValidDatabaseAndTableNames() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + + MysqlSession session = + new MysqlSession(mockDataSource, "my_database_123", "my_table_456", true); + + assertEquals("my_database_123", session.getDatabaseName()); + assertEquals("my_table_456", session.getTableName()); + } + + @Test + @DisplayName("Should accept names starting with underscore") + void testConstructorAcceptsNameStartingWithUnderscore() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + + MysqlSession session = + new MysqlSession(mockDataSource, "_private_db", "_private_table", true); + + assertEquals("_private_db", session.getDatabaseName()); + assertEquals("_private_table", session.getTableName()); + } + + @Test + @DisplayName("Should accept max length names") + void testConstructorAcceptsMaxLengthNames() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + + String maxLengthName = "a".repeat(64); + MysqlSession session = new MysqlSession(mockDataSource, maxLengthName, maxLengthName, true); + + assertEquals(maxLengthName, session.getDatabaseName()); + assertEquals(maxLengthName, session.getTableName()); + } + + @Test + @DisplayName("Should accept database name with hyphens") + void testConstructorAcceptsDatabaseNameWithHyphens() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + + MysqlSession session = new MysqlSession(mockDataSource, "my-test-db", "my_table", true); + + assertEquals("my-test-db", session.getDatabaseName()); + assertEquals("my_table", session.getTableName()); + } + + @Test + @DisplayName("Should accept table name with hyphens") + void testConstructorAcceptsTableNameWithHyphens() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + + MysqlSession session = new MysqlSession(mockDataSource, "my_db", "my-test-table", true); + + assertEquals("my_db", session.getDatabaseName()); + assertEquals("my-test-table", session.getTableName()); + } + + @Test + @DisplayName("Should accept database and table names with hyphens") + void testConstructorAcceptsDatabaseAndTableNamesWithHyphens() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + + MysqlSession session = new MysqlSession(mockDataSource, "xxx-xxx-xx", "test-table", true); + + assertEquals("xxx-xxx-xx", session.getDatabaseName()); + assertEquals("test-table", session.getTableName()); + } + + @Test + @DisplayName("Should accept name with underscore and hyphen") + void testConstructorAcceptsNameWithUnderscoreAndHyphen() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + + MysqlSession session = + new MysqlSession(mockDataSource, "my_test-db", "my_table-test", true); + + assertEquals("my_test-db", session.getDatabaseName()); + assertEquals("my_table-test", session.getTableName()); + } + + /** Simple test state record for testing. */ + public record TestState(String value, int count) implements State {} +} diff --git a/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/e2e/MysqlSessionE2ETest.java b/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/e2e/MysqlSessionE2ETest.java index 6cbab1ab9..84d9096db 100644 --- a/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/e2e/MysqlSessionE2ETest.java +++ b/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/e2e/MysqlSessionE2ETest.java @@ -1,264 +1,364 @@ -/* - * Copyright 2024-2026 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.agentscope.core.session.mysql.e2e; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import io.agentscope.core.session.mysql.MysqlSession; -import io.agentscope.core.state.SessionKey; -import io.agentscope.core.state.SimpleSessionKey; -import io.agentscope.core.state.State; -import java.sql.Connection; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.List; -import java.util.Optional; -import java.util.Set; -import java.util.UUID; -import javax.sql.DataSource; -import org.h2.jdbcx.JdbcDataSource; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Tag; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.parallel.Execution; -import org.junit.jupiter.api.parallel.ExecutionMode; - -/** - * End-to-end tests for {@link MysqlSession} using an in-memory H2 database in MySQL compatibility - * mode. - * - *

This makes the E2E tests runnable in CI without provisioning a real MySQL instance and without - * requiring any environment variables. - */ -@Tag("e2e") -@Execution(ExecutionMode.CONCURRENT) -@DisplayName("Session MySQL Storage E2E Tests") -class MysqlSessionE2ETest { - - private String createdSchemaName; - private DataSource dataSource; - - @AfterEach - void cleanupDatabase() { - if (dataSource == null || createdSchemaName == null) { - return; - } - try (Connection conn = dataSource.getConnection(); - Statement stmt = conn.createStatement()) { - stmt.execute("DROP SCHEMA IF EXISTS " + createdSchemaName + " CASCADE"); - } catch (SQLException e) { - // best-effort cleanup - System.err.println( - "Failed to drop e2e schema " + createdSchemaName + ": " + e.getMessage()); - } finally { - createdSchemaName = null; - dataSource = null; - } - } - - @Test - @DisplayName("Smoke: auto-create database/table + save/load/list/delete flow") - void testMysqlSessionEndToEndFlow() { - System.out.println("\n=== Test: MysqlSession E2E Flow ==="); - - dataSource = createH2DataSource(); - String schemaName = generateSafeIdentifier("AGENTSCOPE_E2E").toUpperCase(); - String tableName = generateSafeIdentifier("AGENTSCOPE_SESSIONS").toUpperCase(); - createdSchemaName = schemaName; - - initSchemaAndTable(dataSource, schemaName, tableName); - MysqlSession session = new MysqlSession(dataSource, schemaName, tableName, false); - - // Prepare test states - TestState stateA = new TestState("hello", 1); - TestState stateB = new TestState("world", 2); - - String sessionIdStr = "mysql_e2e_session_" + UUID.randomUUID(); - SessionKey sessionKey = SimpleSessionKey.of(sessionIdStr); - - // Save single states - session.save(sessionKey, "moduleA", stateA); - session.save(sessionKey, "moduleB", stateB); - assertTrue(session.exists(sessionKey)); - - // Load states - Optional loadedA = session.get(sessionKey, "moduleA", TestState.class); - Optional loadedB = session.get(sessionKey, "moduleB", TestState.class); - - assertTrue(loadedA.isPresent()); - assertTrue(loadedB.isPresent()); - assertEquals("hello", loadedA.get().value()); - assertEquals("world", loadedB.get().value()); - assertEquals(1, loadedA.get().count()); - assertEquals(2, loadedB.get().count()); - - // listSessionKeys - Set sessionKeys = session.listSessionKeys(); - assertTrue( - sessionKeys.contains(sessionKey), "listSessionKeys should contain saved session"); - - // delete session - session.delete(sessionKey); - assertFalse(session.exists(sessionKey)); - } - - @Test - @DisplayName("Save and load list state correctly") - void testSaveAndLoadListState() { - System.out.println("\n=== Test: Save and Load List State ==="); - - dataSource = createH2DataSource(); - String schemaName = generateSafeIdentifier("AGENTSCOPE_E2E").toUpperCase(); - String tableName = generateSafeIdentifier("AGENTSCOPE_SESSIONS").toUpperCase(); - createdSchemaName = schemaName; - - initSchemaAndTable(dataSource, schemaName, tableName); - MysqlSession session = new MysqlSession(dataSource, schemaName, tableName, false); - - String sessionIdStr = "mysql_e2e_list_" + UUID.randomUUID(); - SessionKey sessionKey = SimpleSessionKey.of(sessionIdStr); - - // Save list state - List states = List.of(new TestState("item1", 1), new TestState("item2", 2)); - session.save(sessionKey, "stateList", states); - - // Load list state - List loaded = session.getList(sessionKey, "stateList", TestState.class); - assertEquals(2, loaded.size()); - assertEquals("item1", loaded.get(0).value()); - assertEquals("item2", loaded.get(1).value()); - - // Add more items incrementally - List moreStates = - List.of( - new TestState("item1", 1), - new TestState("item2", 2), - new TestState("item3", 3)); - session.save(sessionKey, "stateList", moreStates); - - // Verify all items - List allLoaded = session.getList(sessionKey, "stateList", TestState.class); - assertEquals(3, allLoaded.size()); - assertEquals("item3", allLoaded.get(2).value()); - } - - @Test - @DisplayName("Session does not exist should return false") - void testSessionNotExists() { - System.out.println("\n=== Test: Session Not Exists ==="); - - dataSource = createH2DataSource(); - String schemaName = generateSafeIdentifier("AGENTSCOPE_E2E").toUpperCase(); - String tableName = generateSafeIdentifier("AGENTSCOPE_SESSIONS").toUpperCase(); - createdSchemaName = schemaName; - - initSchemaAndTable(dataSource, schemaName, tableName); - MysqlSession session = new MysqlSession(dataSource, schemaName, tableName, false); - - SessionKey sessionKey = SimpleSessionKey.of("non_existent_" + UUID.randomUUID()); - assertFalse(session.exists(sessionKey)); - } - - @Test - @DisplayName("Get non-existent state should return empty") - void testGetNonExistentState() { - System.out.println("\n=== Test: Get Non-Existent State ==="); - - dataSource = createH2DataSource(); - String schemaName = generateSafeIdentifier("AGENTSCOPE_E2E").toUpperCase(); - String tableName = generateSafeIdentifier("AGENTSCOPE_SESSIONS").toUpperCase(); - createdSchemaName = schemaName; - - initSchemaAndTable(dataSource, schemaName, tableName); - MysqlSession session = new MysqlSession(dataSource, schemaName, tableName, false); - - SessionKey sessionKey = SimpleSessionKey.of("missing_" + UUID.randomUUID()); - Optional result = session.get(sessionKey, "moduleA", TestState.class); - assertFalse(result.isPresent()); - } - - @Test - @DisplayName("createIfNotExist=false should fail fast when database/table do not exist") - void testCreateIfNotExistFalseFailsWhenMissing() { - System.out.println("\n=== Test: createIfNotExist=false with missing schema ==="); - - dataSource = createH2DataSource(); - String schemaName = generateSafeIdentifier("AGENTSCOPE_E2E_MISSING").toUpperCase(); - String tableName = generateSafeIdentifier("AGENTSCOPE_SESSIONS_MISSING").toUpperCase(); - - assertThrows( - IllegalStateException.class, - () -> new MysqlSession(dataSource, schemaName, tableName, false)); - } - - private static DataSource createH2DataSource() { - String dbName = "mysql_session_e2e_" + UUID.randomUUID().toString().replace("-", ""); - JdbcDataSource ds = new JdbcDataSource(); - ds.setURL("jdbc:h2:mem:" + dbName + ";MODE=MySQL;DB_CLOSE_DELAY=-1"); - ds.setUser("sa"); - ds.setPassword(""); - return ds; - } - - /** Generates a safe MySQL identifier (letters/numbers/underscore) and keeps it <= 64 chars. */ - private static String generateSafeIdentifier(String prefix) { - String suffix = UUID.randomUUID().toString().replace("-", "_"); - String raw = prefix + "_" + suffix; - if (!Character.isLetter(raw.charAt(0)) && raw.charAt(0) != '_') { - raw = "_" + raw; - } - if (raw.length() > 64) { - raw = raw.substring(0, 64); - } - raw = raw.replaceAll("_+$", "_e2e"); - if (raw.length() > 64) { - raw = raw.substring(0, 64); - } - return raw; - } - - private static void initSchemaAndTable( - DataSource dataSource, String schemaName, String tableName) throws RuntimeException { - try (Connection conn = dataSource.getConnection(); - Statement stmt = conn.createStatement()) { - stmt.execute("CREATE SCHEMA IF NOT EXISTS " + schemaName); - stmt.execute("SET SCHEMA " + schemaName); - stmt.execute("DROP TABLE IF EXISTS " + tableName); - // Table structure with item_index for incremental list storage - stmt.execute( - "CREATE TABLE " - + tableName - + " (" - + "session_id VARCHAR(255) NOT NULL, " - + "state_key VARCHAR(255) NOT NULL, " - + "item_index INT NOT NULL DEFAULT 0, " - + "state_data LONGTEXT NOT NULL, " - + "created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, " - + "updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, " - + "PRIMARY KEY (session_id, state_key, item_index)" - + ")"); - } catch (SQLException e) { - throw new RuntimeException("Failed to init schema/table for H2 e2e", e); - } - } - - /** Simple test state record for testing. */ - public record TestState(String value, int count) implements State {} -} +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.agentscope.core.session.mysql.e2e; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.agentscope.core.session.mysql.MysqlSession; +import io.agentscope.core.state.SessionKey; +import io.agentscope.core.state.SimpleSessionKey; +import io.agentscope.core.state.State; +import java.io.PrintWriter; +import java.sql.Connection; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.sql.Statement; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.UUID; +import java.util.logging.Logger; +import javax.sql.DataSource; +import org.h2.jdbcx.JdbcDataSource; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.parallel.Execution; +import org.junit.jupiter.api.parallel.ExecutionMode; + +/** + * End-to-end tests for {@link MysqlSession} using an in-memory H2 database in MySQL compatibility + * mode. + * + *

This makes the E2E tests runnable in CI without provisioning a real MySQL instance and without + * requiring any environment variables. + */ +@Tag("e2e") +@Execution(ExecutionMode.CONCURRENT) +@DisplayName("Session MySQL Storage E2E Tests") +class MysqlSessionE2ETest { + + private String createdSchemaName; + private DataSource dataSource; + + @AfterEach + void cleanupDatabase() { + if (dataSource == null || createdSchemaName == null) { + return; + } + try (Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement()) { + stmt.execute("DROP SCHEMA IF EXISTS " + createdSchemaName + " CASCADE"); + } catch (SQLException e) { + // best-effort cleanup + System.err.println( + "Failed to drop e2e schema " + createdSchemaName + ": " + e.getMessage()); + } finally { + createdSchemaName = null; + dataSource = null; + } + } + + @Test + @DisplayName("Smoke: auto-create database/table + save/load/list/delete flow") + void testMysqlSessionEndToEndFlow() { + System.out.println("\n=== Test: MysqlSession E2E Flow ==="); + + dataSource = createH2DataSource(); + String schemaName = generateSafeIdentifier("AGENTSCOPE_E2E").toUpperCase(); + String tableName = generateSafeIdentifier("AGENTSCOPE_SESSIONS").toUpperCase(); + createdSchemaName = schemaName; + + initSchemaAndTable(dataSource, schemaName, tableName); + MysqlSession session = new MysqlSession(dataSource, schemaName, tableName, false); + + // Prepare test states + TestState stateA = new TestState("hello", 1); + TestState stateB = new TestState("world", 2); + + String sessionIdStr = "mysql_e2e_session_" + UUID.randomUUID(); + SessionKey sessionKey = SimpleSessionKey.of(sessionIdStr); + + // Save single states + session.save(sessionKey, "moduleA", stateA); + session.save(sessionKey, "moduleB", stateB); + assertTrue(session.exists(sessionKey)); + + // Load states + Optional loadedA = session.get(sessionKey, "moduleA", TestState.class); + Optional loadedB = session.get(sessionKey, "moduleB", TestState.class); + + assertTrue(loadedA.isPresent()); + assertTrue(loadedB.isPresent()); + assertEquals("hello", loadedA.get().value()); + assertEquals("world", loadedB.get().value()); + assertEquals(1, loadedA.get().count()); + assertEquals(2, loadedB.get().count()); + + // listSessionKeys + Set sessionKeys = session.listSessionKeys(); + assertTrue( + sessionKeys.contains(sessionKey), "listSessionKeys should contain saved session"); + + // delete session + session.delete(sessionKey); + assertFalse(session.exists(sessionKey)); + } + + @Test + @DisplayName("Save and load list state correctly") + void testSaveAndLoadListState() { + System.out.println("\n=== Test: Save and Load List State ==="); + + dataSource = createH2DataSource(); + String schemaName = generateSafeIdentifier("AGENTSCOPE_E2E").toUpperCase(); + String tableName = generateSafeIdentifier("AGENTSCOPE_SESSIONS").toUpperCase(); + createdSchemaName = schemaName; + + initSchemaAndTable(dataSource, schemaName, tableName); + MysqlSession session = new MysqlSession(dataSource, schemaName, tableName, false); + + String sessionIdStr = "mysql_e2e_list_" + UUID.randomUUID(); + SessionKey sessionKey = SimpleSessionKey.of(sessionIdStr); + + // Save list state + List states = List.of(new TestState("item1", 1), new TestState("item2", 2)); + session.save(sessionKey, "stateList", states); + + // Load list state + List loaded = session.getList(sessionKey, "stateList", TestState.class); + assertEquals(2, loaded.size()); + assertEquals("item1", loaded.get(0).value()); + assertEquals("item2", loaded.get(1).value()); + + // Add more items incrementally + List moreStates = + List.of( + new TestState("item1", 1), + new TestState("item2", 2), + new TestState("item3", 3)); + session.save(sessionKey, "stateList", moreStates); + + // Verify all items + List allLoaded = session.getList(sessionKey, "stateList", TestState.class); + assertEquals(3, allLoaded.size()); + assertEquals("item3", allLoaded.get(2).value()); + } + + @Test + @DisplayName( + "Writes should persist when DataSource connections start with auto-commit disabled") + void testWritesPersistWhenAutoCommitDisabled() { + System.out.println("\n=== Test: Writes Persist With Auto-Commit Disabled ==="); + + dataSource = new AutoCommitDisabledDataSource(createH2DataSource()); + String schemaName = generateSafeIdentifier("AGENTSCOPE_E2E").toUpperCase(); + String tableName = generateSafeIdentifier("AGENTSCOPE_SESSIONS").toUpperCase(); + createdSchemaName = schemaName; + + initSchemaAndTable(dataSource, schemaName, tableName); + MysqlSession session = new MysqlSession(dataSource, schemaName, tableName, false); + SessionKey sessionKey = SimpleSessionKey.of("auto_commit_off_" + UUID.randomUUID()); + + session.save(sessionKey, "moduleA", new TestState("hello", 1)); + Optional singleState = session.get(sessionKey, "moduleA", TestState.class); + assertTrue(singleState.isPresent()); + assertEquals("hello", singleState.get().value()); + + List initialStates = + List.of(new TestState("item1", 1), new TestState("item2", 2)); + session.save(sessionKey, "stateList", initialStates); + + List appendedStates = + List.of( + new TestState("item1", 1), + new TestState("item2", 2), + new TestState("item3", 3)); + session.save(sessionKey, "stateList", appendedStates); + + List loadedStates = session.getList(sessionKey, "stateList", TestState.class); + assertEquals(3, loadedStates.size()); + assertEquals("item3", loadedStates.get(2).value()); + + session.delete(sessionKey); + assertFalse(session.exists(sessionKey)); + } + + @Test + @DisplayName("Session does not exist should return false") + void testSessionNotExists() { + System.out.println("\n=== Test: Session Not Exists ==="); + + dataSource = createH2DataSource(); + String schemaName = generateSafeIdentifier("AGENTSCOPE_E2E").toUpperCase(); + String tableName = generateSafeIdentifier("AGENTSCOPE_SESSIONS").toUpperCase(); + createdSchemaName = schemaName; + + initSchemaAndTable(dataSource, schemaName, tableName); + MysqlSession session = new MysqlSession(dataSource, schemaName, tableName, false); + + SessionKey sessionKey = SimpleSessionKey.of("non_existent_" + UUID.randomUUID()); + assertFalse(session.exists(sessionKey)); + } + + @Test + @DisplayName("Get non-existent state should return empty") + void testGetNonExistentState() { + System.out.println("\n=== Test: Get Non-Existent State ==="); + + dataSource = createH2DataSource(); + String schemaName = generateSafeIdentifier("AGENTSCOPE_E2E").toUpperCase(); + String tableName = generateSafeIdentifier("AGENTSCOPE_SESSIONS").toUpperCase(); + createdSchemaName = schemaName; + + initSchemaAndTable(dataSource, schemaName, tableName); + MysqlSession session = new MysqlSession(dataSource, schemaName, tableName, false); + + SessionKey sessionKey = SimpleSessionKey.of("missing_" + UUID.randomUUID()); + Optional result = session.get(sessionKey, "moduleA", TestState.class); + assertFalse(result.isPresent()); + } + + @Test + @DisplayName("createIfNotExist=false should fail fast when database/table do not exist") + void testCreateIfNotExistFalseFailsWhenMissing() { + System.out.println("\n=== Test: createIfNotExist=false with missing schema ==="); + + dataSource = createH2DataSource(); + String schemaName = generateSafeIdentifier("AGENTSCOPE_E2E_MISSING").toUpperCase(); + String tableName = generateSafeIdentifier("AGENTSCOPE_SESSIONS_MISSING").toUpperCase(); + + assertThrows( + IllegalStateException.class, + () -> new MysqlSession(dataSource, schemaName, tableName, false)); + } + + private static final class AutoCommitDisabledDataSource implements DataSource { + + private final DataSource delegate; + + private AutoCommitDisabledDataSource(DataSource delegate) { + this.delegate = delegate; + } + + @Override + public Connection getConnection() throws SQLException { + Connection connection = delegate.getConnection(); + connection.setAutoCommit(false); + return connection; + } + + @Override + public Connection getConnection(String username, String password) throws SQLException { + Connection connection = delegate.getConnection(username, password); + connection.setAutoCommit(false); + return connection; + } + + @Override + public T unwrap(Class iface) throws SQLException { + return delegate.unwrap(iface); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + return delegate.isWrapperFor(iface); + } + + @Override + public PrintWriter getLogWriter() throws SQLException { + return delegate.getLogWriter(); + } + + @Override + public void setLogWriter(PrintWriter out) throws SQLException { + delegate.setLogWriter(out); + } + + @Override + public void setLoginTimeout(int seconds) throws SQLException { + delegate.setLoginTimeout(seconds); + } + + @Override + public int getLoginTimeout() throws SQLException { + return delegate.getLoginTimeout(); + } + + @Override + public Logger getParentLogger() throws SQLFeatureNotSupportedException { + return delegate.getParentLogger(); + } + } + + private static DataSource createH2DataSource() { + String dbName = "mysql_session_e2e_" + UUID.randomUUID().toString().replace("-", ""); + JdbcDataSource ds = new JdbcDataSource(); + ds.setURL("jdbc:h2:mem:" + dbName + ";MODE=MySQL;DB_CLOSE_DELAY=-1"); + ds.setUser("sa"); + ds.setPassword(""); + return ds; + } + + /** Generates a safe MySQL identifier (letters/numbers/underscore) and keeps it <= 64 chars. */ + private static String generateSafeIdentifier(String prefix) { + String suffix = UUID.randomUUID().toString().replace("-", "_"); + String raw = prefix + "_" + suffix; + if (!Character.isLetter(raw.charAt(0)) && raw.charAt(0) != '_') { + raw = "_" + raw; + } + if (raw.length() > 64) { + raw = raw.substring(0, 64); + } + raw = raw.replaceAll("_+$", "_e2e"); + if (raw.length() > 64) { + raw = raw.substring(0, 64); + } + return raw; + } + + private static void initSchemaAndTable( + DataSource dataSource, String schemaName, String tableName) throws RuntimeException { + try (Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement()) { + stmt.execute("CREATE SCHEMA IF NOT EXISTS " + schemaName); + stmt.execute("SET SCHEMA " + schemaName); + stmt.execute("DROP TABLE IF EXISTS " + tableName); + // Table structure with item_index for incremental list storage + stmt.execute( + "CREATE TABLE " + + tableName + + " (" + + "session_id VARCHAR(255) NOT NULL, " + + "state_key VARCHAR(255) NOT NULL, " + + "item_index INT NOT NULL DEFAULT 0, " + + "state_data LONGTEXT NOT NULL, " + + "created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, " + + "updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, " + + "PRIMARY KEY (session_id, state_key, item_index)" + + ")"); + } catch (SQLException e) { + throw new RuntimeException("Failed to init schema/table for H2 e2e", e); + } + } + + /** Simple test state record for testing. */ + public record TestState(String value, int count) implements State {} +}