diff --git a/agentscope-core/pom.xml b/agentscope-core/pom.xml index 54bf500a7..46a11a786 100644 --- a/agentscope-core/pom.xml +++ b/agentscope-core/pom.xml @@ -34,9 +34,11 @@ 17 17 17 + 1.37 3.2.8 4.0.0-M13 3.11.0 + 3.5.0 3.1.2 3.1.1 3.3.1 @@ -44,6 +46,10 @@ 3.5.0 UTF-8 -Xms512m -Xmx1024m + .*Benchmark.* + 2 + 3 + 1 0.7.0 @@ -140,5 +146,69 @@ com.networknt json-schema-validator + + + org.openjdk.jmh + jmh-core + ${jmh.version} + test + + + + org.openjdk.jmh + jmh-generator-annprocess + ${jmh.version} + test + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + + org.openjdk.jmh + jmh-generator-annprocess + ${jmh.version} + + + + + + org.codehaus.mojo + exec-maven-plugin + ${exec-maven-plugin.version} + + test + true + java + + -Dbenchmark.include=${benchmark.include} + -Dbenchmark.warmupIterations=${benchmark.warmupIterations} + -Dbenchmark.measurementIterations=${benchmark.measurementIterations} + -Dbenchmark.forks=${benchmark.forks} + -cp + + io.agentscope.core.benchmark.BenchmarkLauncher + + + + + run-benchmarks + + exec + + + + + + + + + + + diff --git a/agentscope-core/src/test/java/io/agentscope/core/benchmark/BenchmarkLauncher.java b/agentscope-core/src/test/java/io/agentscope/core/benchmark/BenchmarkLauncher.java new file mode 100644 index 000000000..864972407 --- /dev/null +++ b/agentscope-core/src/test/java/io/agentscope/core/benchmark/BenchmarkLauncher.java @@ -0,0 +1,47 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.agentscope.core.benchmark; + +import org.openjdk.jmh.results.format.ResultFormatType; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; + +/** Launcher for running the core JMH benchmark suite from Maven test classpath. */ +public final class BenchmarkLauncher { + + private BenchmarkLauncher() {} + + public static void main(String[] args) throws RunnerException { + String includePattern = System.getProperty("benchmark.include", ".*Benchmark.*"); + int warmupIterations = Integer.getInteger("benchmark.warmupIterations", 2); + int measurementIterations = Integer.getInteger("benchmark.measurementIterations", 3); + int forks = Integer.getInteger("benchmark.forks", 1); + + Options options = + new OptionsBuilder() + .include(includePattern) + .warmupIterations(warmupIterations) + .measurementIterations(measurementIterations) + .forks(forks) + .resultFormat(ResultFormatType.JSON) + .result("target/jmh-result.json") + .build(); + + new Runner(options).run(); + } +} diff --git a/agentscope-core/src/test/java/io/agentscope/core/benchmark/BenchmarkSupport.java b/agentscope-core/src/test/java/io/agentscope/core/benchmark/BenchmarkSupport.java new file mode 100644 index 000000000..ad6677314 --- /dev/null +++ b/agentscope-core/src/test/java/io/agentscope/core/benchmark/BenchmarkSupport.java @@ -0,0 +1,70 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.agentscope.core.benchmark; + +import io.agentscope.core.ReActAgent; +import io.agentscope.core.agent.AgentBase; +import io.agentscope.core.agent.test.MockModel; +import io.agentscope.core.agent.test.TestUtils; +import io.agentscope.core.memory.InMemoryMemory; +import io.agentscope.core.message.Msg; +import io.agentscope.core.pipeline.MsgHub; +import io.agentscope.core.pipeline.SequentialPipeline; +import io.agentscope.core.tool.Toolkit; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; + +/** Shared fixtures and helpers for JMH benchmarks. */ +public final class BenchmarkSupport { + + public static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(5); + + private BenchmarkSupport() {} + + public static ReActAgent createAgent(String name, String responseText) { + return ReActAgent.builder() + .name(name) + .sysPrompt("You are a benchmark agent.") + .model(new MockModel(responseText)) + .toolkit(new Toolkit()) + .memory(new InMemoryMemory()) + .build(); + } + + public static Msg createInputMessage(String text) { + return TestUtils.createUserMessage("benchmark-user", text); + } + + public static SequentialPipeline createSequentialPipeline(int agentCount) { + List agents = new ArrayList<>(); + for (int index = 0; index < agentCount; index++) { + agents.add(createAgent("PipelineAgent" + index, "pipeline-response-" + index)); + } + return new SequentialPipeline(agents); + } + + public static MsgHub createEnteredHub(int participantCount) { + List participants = new ArrayList<>(); + for (int index = 0; index < participantCount; index++) { + participants.add(createAgent("HubAgent" + index, "hub-response-" + index)); + } + + MsgHub hub = MsgHub.builder().name("BenchmarkHub").participants(participants).build(); + hub.enter().block(DEFAULT_TIMEOUT); + return hub; + } +} diff --git a/agentscope-core/src/test/java/io/agentscope/core/benchmark/MsgHubBenchmark.java b/agentscope-core/src/test/java/io/agentscope/core/benchmark/MsgHubBenchmark.java new file mode 100644 index 000000000..6881e66c1 --- /dev/null +++ b/agentscope-core/src/test/java/io/agentscope/core/benchmark/MsgHubBenchmark.java @@ -0,0 +1,79 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.agentscope.core.benchmark; + +import io.agentscope.core.message.Msg; +import io.agentscope.core.pipeline.MsgHub; +import java.util.concurrent.TimeUnit; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; + +/** Benchmarks for MsgHub lifecycle and single-message distribution. */ +public class MsgHubBenchmark { + + @BenchmarkMode(Mode.SampleTime) + @OutputTimeUnit(TimeUnit.MILLISECONDS) + @Warmup(iterations = 2, time = 1) + @Measurement(iterations = 3, time = 1) + @State(Scope.Thread) + public static class LifecycleState { + + @Benchmark + public MsgHub enterAndExitLifecycle() { + MsgHub hub = BenchmarkSupport.createEnteredHub(3); + hub.exit().block(BenchmarkSupport.DEFAULT_TIMEOUT); + return hub; + } + } + + @BenchmarkMode({Mode.SampleTime, Mode.Throughput}) + @OutputTimeUnit(TimeUnit.MILLISECONDS) + @Warmup(iterations = 2, time = 1) + @Measurement(iterations = 3, time = 1) + @State(Scope.Thread) + public static class BroadcastState { + + private MsgHub hub; + private Msg message; + + @Setup(Level.Iteration) + public void setUp() { + hub = BenchmarkSupport.createEnteredHub(3); + message = BenchmarkSupport.createInputMessage("Broadcast benchmark message."); + } + + @TearDown(Level.Iteration) + public void tearDown() { + if (hub != null) { + hub.close(); + } + } + + @Benchmark + public Void singleMessageBroadcast() { + return hub.broadcast(message).block(BenchmarkSupport.DEFAULT_TIMEOUT); + } + } +} diff --git a/agentscope-core/src/test/java/io/agentscope/core/benchmark/README.md b/agentscope-core/src/test/java/io/agentscope/core/benchmark/README.md new file mode 100644 index 000000000..b85c4f96e --- /dev/null +++ b/agentscope-core/src/test/java/io/agentscope/core/benchmark/README.md @@ -0,0 +1,40 @@ +# AgentScope Core Benchmarks + +This directory contains the first-phase JMH benchmark harness for `agentscope-core`. + +## Included benchmarks + +- `ReActAgentBenchmark` �single-agent calls with a mock model +- `SequentialPipelineBenchmark` �sequential pipeline orchestration over three agents +- `MsgHubBenchmark` �MsgHub lifecycle and single-message distribution + +## Compile benchmark sources + +```bash +mvn -pl agentscope-core -am -DskipTests test-compile +``` + +## Run the full benchmark suite + +```bash +mvn -pl agentscope-core -am -DskipTests test-compile exec:java \ + -Dexec.classpathScope=test \ + -Dexec.mainClass=io.agentscope.core.benchmark.BenchmarkLauncher +``` + +## Run a focused benchmark quickly + +```bash +mvn -pl agentscope-core -am -DskipTests test-compile exec:java \ + -Dexec.classpathScope=test \ + -Dexec.mainClass=io.agentscope.core.benchmark.BenchmarkLauncher \ + -Dbenchmark.include=.*ReActAgentBenchmark.* \ + -Dbenchmark.warmupIterations=1 \ + -Dbenchmark.measurementIterations=1 \ + -Dbenchmark.forks=1 +``` + +Results are written to `agentscope-core/target/jmh-result.json`. + + + diff --git a/agentscope-core/src/test/java/io/agentscope/core/benchmark/ReActAgentBenchmark.java b/agentscope-core/src/test/java/io/agentscope/core/benchmark/ReActAgentBenchmark.java new file mode 100644 index 000000000..b4d22ca98 --- /dev/null +++ b/agentscope-core/src/test/java/io/agentscope/core/benchmark/ReActAgentBenchmark.java @@ -0,0 +1,53 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.agentscope.core.benchmark; + +import io.agentscope.core.ReActAgent; +import io.agentscope.core.message.Msg; +import java.util.concurrent.TimeUnit; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +/** Benchmarks for single-agent execution with a mock model. */ +@BenchmarkMode({Mode.SampleTime, Mode.Throughput}) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Warmup(iterations = 2, time = 1) +@Measurement(iterations = 3, time = 1) +@State(Scope.Thread) +public class ReActAgentBenchmark { + + private ReActAgent agent; + private Msg input; + + @Setup(Level.Iteration) + public void setUp() { + agent = BenchmarkSupport.createAgent("BenchmarkAgent", "agent-benchmark-response"); + input = BenchmarkSupport.createInputMessage("Summarize the benchmark request."); + } + + @Benchmark + public Msg singleAgentCall() { + return agent.call(input).block(BenchmarkSupport.DEFAULT_TIMEOUT); + } +} diff --git a/agentscope-core/src/test/java/io/agentscope/core/benchmark/SequentialPipelineBenchmark.java b/agentscope-core/src/test/java/io/agentscope/core/benchmark/SequentialPipelineBenchmark.java new file mode 100644 index 000000000..bfbab9f78 --- /dev/null +++ b/agentscope-core/src/test/java/io/agentscope/core/benchmark/SequentialPipelineBenchmark.java @@ -0,0 +1,53 @@ +/* + * Copyright 2024-2026 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.agentscope.core.benchmark; + +import io.agentscope.core.message.Msg; +import io.agentscope.core.pipeline.SequentialPipeline; +import java.util.concurrent.TimeUnit; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +/** Benchmarks for sequential pipeline orchestration. */ +@BenchmarkMode({Mode.SampleTime, Mode.Throughput}) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@Warmup(iterations = 2, time = 1) +@Measurement(iterations = 3, time = 1) +@State(Scope.Thread) +public class SequentialPipelineBenchmark { + + private SequentialPipeline pipeline; + private Msg input; + + @Setup(Level.Iteration) + public void setUp() { + pipeline = BenchmarkSupport.createSequentialPipeline(3); + input = BenchmarkSupport.createInputMessage("Run the sequential pipeline benchmark."); + } + + @Benchmark + public Msg sequentialPipelineExecution() { + return pipeline.execute(input).block(BenchmarkSupport.DEFAULT_TIMEOUT); + } +} diff --git a/agentscope-extensions/agentscope-extensions-session-mysql/src/main/java/io/agentscope/core/session/mysql/MysqlSession.java b/agentscope-extensions/agentscope-extensions-session-mysql/src/main/java/io/agentscope/core/session/mysql/MysqlSession.java index f60ff3eea..5fc380d25 100644 --- a/agentscope-extensions/agentscope-extensions-session-mysql/src/main/java/io/agentscope/core/session/mysql/MysqlSession.java +++ b/agentscope-extensions/agentscope-extensions-session-mysql/src/main/java/io/agentscope/core/session/mysql/MysqlSession.java @@ -303,12 +303,16 @@ public void save(SessionKey sessionKey, String key, State value) { String json = JsonUtils.getJsonCodec().toJson(value); - stmt.setString(1, sessionId); - stmt.setString(2, key); - stmt.setInt(3, SINGLE_STATE_INDEX); - stmt.setString(4, json); - - stmt.executeUpdate(); + executeWriteTransaction( + conn, + () -> { + stmt.setString(1, sessionId); + stmt.setString(2, key); + stmt.setInt(3, SINGLE_STATE_INDEX); + stmt.setString(4, json); + stmt.executeUpdate(); + return null; + }); } catch (Exception e) { throw new RuntimeException("Failed to save state: " + key, e); @@ -359,24 +363,23 @@ public void save(SessionKey sessionKey, String key, List values currentHash, storedHash, values.size(), existingCount); if (needsFullRewrite) { - // Transaction: delete all + insert all - conn.setAutoCommit(false); - try { - deleteListItems(conn, sessionId, key); - insertAllItems(conn, sessionId, key, values); - saveHash(conn, sessionId, hashKey, currentHash); - conn.commit(); - } catch (Exception e) { - conn.rollback(); - throw e; - } finally { - conn.setAutoCommit(true); - } + executeWriteTransaction( + conn, + () -> { + deleteListItems(conn, sessionId, key); + insertAllItems(conn, sessionId, key, values); + saveHash(conn, sessionId, hashKey, currentHash); + return null; + }); } else if (values.size() > existingCount) { - // Incremental append List newItems = values.subList(existingCount, values.size()); - insertItems(conn, sessionId, key, newItems, existingCount); - saveHash(conn, sessionId, hashKey, currentHash); + executeWriteTransaction( + conn, + () -> { + insertItems(conn, sessionId, key, newItems, existingCount); + saveHash(conn, sessionId, hashKey, currentHash); + return null; + }); } // else: no change, skip @@ -629,10 +632,15 @@ public void delete(SessionKey sessionKey) { try (Connection conn = dataSource.getConnection(); PreparedStatement stmt = conn.prepareStatement(deleteSql)) { - stmt.setString(1, sessionId); - stmt.executeUpdate(); + executeWriteTransaction( + conn, + () -> { + stmt.setString(1, sessionId); + stmt.executeUpdate(); + return null; + }); - } catch (SQLException e) { + } catch (Exception e) { throw new RuntimeException("Failed to delete session: " + sessionId, e); } } @@ -708,9 +716,9 @@ public int clearAllSessions() { try (Connection conn = dataSource.getConnection(); PreparedStatement stmt = conn.prepareStatement(clearSql)) { - return stmt.executeUpdate(); + return executeWriteTransaction(conn, stmt::executeUpdate); - } catch (SQLException e) { + } catch (Exception e) { throw new RuntimeException("Failed to clear sessions", e); } } @@ -733,13 +741,43 @@ public int truncateAllSessions() { try (Connection conn = dataSource.getConnection(); PreparedStatement stmt = conn.prepareStatement(clearSql)) { - return stmt.executeUpdate(); + return executeWriteTransaction(conn, stmt::executeUpdate); - } catch (SQLException e) { + } catch (Exception e) { throw new RuntimeException("Failed to truncate sessions", e); } } + private T executeWriteTransaction(Connection conn, SqlOperation operation) + throws Exception { + boolean originalAutoCommit = conn.getAutoCommit(); + if (originalAutoCommit) { + conn.setAutoCommit(false); + } + + try { + T result = operation.execute(); + conn.commit(); + return result; + } catch (Exception e) { + try { + conn.rollback(); + } catch (SQLException rollbackException) { + e.addSuppressed(rollbackException); + } + throw e; + } finally { + if (originalAutoCommit) { + conn.setAutoCommit(true); + } + } + } + + @FunctionalInterface + private interface SqlOperation { + T execute() throws Exception; + } + /** * Validate a session ID format. * diff --git a/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/MysqlSessionTest.java b/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/MysqlSessionTest.java index b720b7376..72f2baa59 100644 --- a/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/MysqlSessionTest.java +++ b/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/MysqlSessionTest.java @@ -21,6 +21,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -265,6 +266,70 @@ void testSaveAndGetListState() throws SQLException { assertEquals("value2", loaded.get(1).value()); } + @Test + @DisplayName("Should commit single-state writes when connection auto-commit is disabled") + void testSaveSingleStateCommitsWhenAutoCommitDisabled() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + when(mockConnection.getAutoCommit()).thenReturn(false); + when(mockStatement.executeUpdate()).thenReturn(1); + + MysqlSession session = new MysqlSession(mockDataSource, true); + SessionKey sessionKey = SimpleSessionKey.of("session1"); + + session.save(sessionKey, "testModule", new TestState("test_value", 42)); + + verify(mockConnection).getAutoCommit(); + verify(mockConnection).commit(); + verify(mockConnection, never()).setAutoCommit(false); + verify(mockConnection, never()).setAutoCommit(true); + } + + @Test + @DisplayName("Should commit append writes when connection auto-commit is disabled") + void testSaveListAppendCommitsWhenAutoCommitDisabled() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + when(mockConnection.getAutoCommit()).thenReturn(false); + when(mockStatement.executeQuery()).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(false, true); + when(mockResultSet.getInt("max_index")).thenReturn(0); + when(mockResultSet.wasNull()).thenReturn(true); + when(mockStatement.executeUpdate()).thenReturn(1); + + MysqlSession session = new MysqlSession(mockDataSource, true); + SessionKey sessionKey = SimpleSessionKey.of("session1"); + List states = List.of(new TestState("value1", 1), new TestState("value2", 2)); + + session.save(sessionKey, "testList", states); + + verify(mockConnection).getAutoCommit(); + verify(mockConnection).commit(); + verify(mockConnection, never()).setAutoCommit(false); + verify(mockConnection, never()).setAutoCommit(true); + } + + @Test + @DisplayName("Should restore auto-commit after full rewrite transaction") + void testSaveListFullRewriteRestoresOriginalAutoCommit() throws SQLException { + when(mockStatement.execute()).thenReturn(true); + when(mockConnection.getAutoCommit()).thenReturn(true); + when(mockStatement.executeQuery()).thenReturn(mockResultSet); + when(mockResultSet.next()).thenReturn(true, true); + when(mockResultSet.getString("state_data")).thenReturn("stale_hash"); + when(mockResultSet.getInt("max_index")).thenReturn(0); + when(mockResultSet.wasNull()).thenReturn(false); + when(mockStatement.executeUpdate()).thenReturn(1); + + MysqlSession session = new MysqlSession(mockDataSource, true); + SessionKey sessionKey = SimpleSessionKey.of("session1"); + List states = List.of(new TestState("value1", 1), new TestState("value2", 2)); + + session.save(sessionKey, "testList", states); + + verify(mockConnection).setAutoCommit(false); + verify(mockConnection).commit(); + verify(mockConnection).setAutoCommit(true); + } + @Test @DisplayName("Should return empty for non-existent state") void testGetNonExistentState() throws SQLException { diff --git a/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/e2e/MysqlSessionE2ETest.java b/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/e2e/MysqlSessionE2ETest.java index 6cbab1ab9..7194384c8 100644 --- a/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/e2e/MysqlSessionE2ETest.java +++ b/agentscope-extensions/agentscope-extensions-session-mysql/src/test/java/io/agentscope/core/session/mysql/e2e/MysqlSessionE2ETest.java @@ -24,13 +24,16 @@ import io.agentscope.core.state.SessionKey; import io.agentscope.core.state.SimpleSessionKey; import io.agentscope.core.state.State; +import java.io.PrintWriter; import java.sql.Connection; import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; import java.sql.Statement; import java.util.List; import java.util.Optional; import java.util.Set; import java.util.UUID; +import java.util.logging.Logger; import javax.sql.DataSource; import org.h2.jdbcx.JdbcDataSource; import org.junit.jupiter.api.AfterEach; @@ -159,6 +162,45 @@ void testSaveAndLoadListState() { assertEquals("item3", allLoaded.get(2).value()); } + @Test + @DisplayName( + "Writes should persist when DataSource connections start with auto-commit disabled") + void testWritesPersistWhenAutoCommitDisabled() { + System.out.println("\n=== Test: Writes Persist With Auto-Commit Disabled ==="); + + dataSource = new AutoCommitDisabledDataSource(createH2DataSource()); + String schemaName = generateSafeIdentifier("AGENTSCOPE_E2E").toUpperCase(); + String tableName = generateSafeIdentifier("AGENTSCOPE_SESSIONS").toUpperCase(); + createdSchemaName = schemaName; + + initSchemaAndTable(dataSource, schemaName, tableName); + MysqlSession session = new MysqlSession(dataSource, schemaName, tableName, false); + SessionKey sessionKey = SimpleSessionKey.of("auto_commit_off_" + UUID.randomUUID()); + + session.save(sessionKey, "moduleA", new TestState("hello", 1)); + Optional singleState = session.get(sessionKey, "moduleA", TestState.class); + assertTrue(singleState.isPresent()); + assertEquals("hello", singleState.get().value()); + + List initialStates = + List.of(new TestState("item1", 1), new TestState("item2", 2)); + session.save(sessionKey, "stateList", initialStates); + + List appendedStates = + List.of( + new TestState("item1", 1), + new TestState("item2", 2), + new TestState("item3", 3)); + session.save(sessionKey, "stateList", appendedStates); + + List loadedStates = session.getList(sessionKey, "stateList", TestState.class); + assertEquals(3, loadedStates.size()); + assertEquals("item3", loadedStates.get(2).value()); + + session.delete(sessionKey); + assertFalse(session.exists(sessionKey)); + } + @Test @DisplayName("Session does not exist should return false") void testSessionNotExists() { @@ -208,6 +250,64 @@ void testCreateIfNotExistFalseFailsWhenMissing() { () -> new MysqlSession(dataSource, schemaName, tableName, false)); } + private static final class AutoCommitDisabledDataSource implements DataSource { + + private final DataSource delegate; + + private AutoCommitDisabledDataSource(DataSource delegate) { + this.delegate = delegate; + } + + @Override + public Connection getConnection() throws SQLException { + Connection connection = delegate.getConnection(); + connection.setAutoCommit(false); + return connection; + } + + @Override + public Connection getConnection(String username, String password) throws SQLException { + Connection connection = delegate.getConnection(username, password); + connection.setAutoCommit(false); + return connection; + } + + @Override + public T unwrap(Class iface) throws SQLException { + return delegate.unwrap(iface); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + return delegate.isWrapperFor(iface); + } + + @Override + public PrintWriter getLogWriter() throws SQLException { + return delegate.getLogWriter(); + } + + @Override + public void setLogWriter(PrintWriter out) throws SQLException { + delegate.setLogWriter(out); + } + + @Override + public void setLoginTimeout(int seconds) throws SQLException { + delegate.setLoginTimeout(seconds); + } + + @Override + public int getLoginTimeout() throws SQLException { + return delegate.getLoginTimeout(); + } + + @Override + public Logger getParentLogger() throws SQLFeatureNotSupportedException { + return delegate.getParentLogger(); + } + } + private static DataSource createH2DataSource() { String dbName = "mysql_session_e2e_" + UUID.randomUUID().toString().replace("-", ""); JdbcDataSource ds = new JdbcDataSource();