From ff6e656cd9dc450ab1c085b15c6d0edcc8f95dd5 Mon Sep 17 00:00:00 2001 From: kshackleton1 Date: Sat, 6 Jun 2026 16:10:01 +0100 Subject: [PATCH] Fix close pool race between JdbcConnection.close() and statement return Signed-off-by: kshackleton1 --- .../selekt/jdbc/connection/JdbcConnection.kt | 52 +++++++++++++------ .../jdbc/connection/JdbcConnectionTest.kt | 36 +++++++++++++ 2 files changed, 72 insertions(+), 16 deletions(-) diff --git a/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/connection/JdbcConnection.kt b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/connection/JdbcConnection.kt index 41f3edee4e..e25dc62000 100644 --- a/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/connection/JdbcConnection.kt +++ b/selekt-jdbc/src/main/kotlin/com/bloomberg/selekt/jdbc/connection/JdbcConnection.kt @@ -43,9 +43,12 @@ import java.sql.Struct import java.lang.invoke.MethodHandles import java.util.Properties import java.util.concurrent.Executor +import java.util.concurrent.locks.ReentrantLock import javax.annotation.concurrent.NotThreadSafe +import kotlin.concurrent.withLock import org.slf4j.Logger import org.slf4j.LoggerFactory +import javax.annotation.concurrent.GuardedBy private const val MAX_POOLED_STATEMENTS = 32 @@ -79,7 +82,9 @@ internal class JdbcConnection( private var networkTimeout = 0 private val holdability = ResultSet.CLOSE_CURSORS_AT_COMMIT private val warnings = mutableListOf() + @GuardedBy("poolLock") private val preparedStatementPool = LinkedHashMap() + private val poolLock = ReentrantLock() private val _metaData by lazy { JdbcDatabaseMetaData(this, database, connectionURL) } @@ -132,11 +137,15 @@ internal class JdbcConnection( checkClosed() checkResultSetType(resultSetType) checkResultSetConcurrency(resultSetConcurrency) - preparedStatementPool.remove(sql)?.let { - it.reopen() - return it + val pooled = poolLock.withLock { + if (closed) { null } else { preparedStatementPool.remove(sql) } + } + return if (pooled != null) { + pooled.reopen() + pooled + } else { + JdbcPreparedStatement(this, database, sql, resultSetType, resultSetConcurrency, resultSetHoldability) } - return JdbcPreparedStatement(this, database, sql, resultSetType, resultSetConcurrency, resultSetHoldability) } override fun prepareStatement( @@ -317,23 +326,34 @@ internal class JdbcConnection( } private fun closePreparedStatementPool() { - preparedStatementPool.run { - values.forEachCatching(JdbcPreparedStatement::closePooled) - clear() + val snapshot = poolLock.withLock { + if (preparedStatementPool.isEmpty()) { + return + } + ArrayList(preparedStatementPool.values).also { preparedStatementPool.clear() } } + snapshot.forEachCatching(JdbcPreparedStatement::closePooled) } internal fun returnPreparedStatement(statement: JdbcPreparedStatement): Boolean { - if (closed) { - return false - } - if (preparedStatementPool.size >= MAX_POOLED_STATEMENTS && !preparedStatementPool.containsKey(statement.sql)) { - val eldest = preparedStatementPool.entries.iterator().next() - preparedStatementPool.remove(eldest.key) - runCatching { eldest.value.closePooled() } + var accepted = false + var evicted: JdbcPreparedStatement? = null + if (!closed) { + poolLock.withLock { + if (!closed) { + accepted = true + val containsKey = preparedStatementPool.containsKey(statement.sql) + if (preparedStatementPool.size >= MAX_POOLED_STATEMENTS && !containsKey) { + evicted = preparedStatementPool.entries.iterator().next().also { + preparedStatementPool.remove(it.key) + }.value + } + preparedStatementPool[statement.sql] = statement + } + } } - preparedStatementPool[statement.sql] = statement - return true + evicted?.let { runCatching(it::closePooled) } + return accepted } override fun isClosed(): Boolean = closed diff --git a/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/connection/JdbcConnectionTest.kt b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/connection/JdbcConnectionTest.kt index 8b92abe7f7..853fe83801 100644 --- a/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/connection/JdbcConnectionTest.kt +++ b/selekt-jdbc/src/test/kotlin/com/bloomberg/selekt/jdbc/connection/JdbcConnectionTest.kt @@ -1369,6 +1369,42 @@ internal class JdbcConnectionTest { assertNotSame(statements[0], evictedStatement) } + @Test + fun preparedStatementPoolReturnRejectedAfterCloseDoesNotLeak() { + val sql = "SELECT 1" + val preparedStatement = connection.prepareStatement(sql) as JdbcPreparedStatement + connection.close() + assertFalse(connection.returnPreparedStatement(preparedStatement)) + assertFailsWith { connection.prepareStatement(sql) } + } + + @Test + fun preparedStatementPoolSurvivesConcurrentReturnsAndClose() { + repeat(8) { + val freshConnection = JdbcConnection(SharedDatabase(mock()), connectionURL, properties) + val statements = (0 until 64).map { i -> + freshConnection.prepareStatement("SELECT $i") as JdbcPreparedStatement + } + val startLatch = CountDownLatch(1) + val doneLatch = CountDownLatch(statements.size + 1) + statements.forEach { stmt -> + thread(start = true, name = "stmt-return") { + startLatch.await(2, TimeUnit.SECONDS) + runCatching { stmt.close() } + doneLatch.countDown() + } + } + thread(start = true, name = "conn-close") { + startLatch.await(2, TimeUnit.SECONDS) + runCatching { freshConnection.close() } + doneLatch.countDown() + } + startLatch.countDown() + assertTrue(doneLatch.await(5, TimeUnit.SECONDS)) + assertTrue(freshConnection.isClosed) + } + } + @Test fun executeQueryOnReadOnlyConnectionDoesNotBeginTransaction() { val database = mock {