Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,12 @@ import java.sql.Struct
import java.lang.invoke.MethodHandles
import java.util.Properties
import java.util.concurrent.Executor
import java.util.concurrent.locks.ReentrantLock
import javax.annotation.concurrent.NotThreadSafe
import kotlin.concurrent.withLock
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import javax.annotation.concurrent.GuardedBy

private const val MAX_POOLED_STATEMENTS = 32

Expand Down Expand Up @@ -79,7 +82,9 @@ internal class JdbcConnection(
private var networkTimeout = 0
private val holdability = ResultSet.CLOSE_CURSORS_AT_COMMIT
private val warnings = mutableListOf<SQLWarning>()
@GuardedBy("poolLock")
private val preparedStatementPool = LinkedHashMap<String, JdbcPreparedStatement>()
private val poolLock = ReentrantLock()

private val _metaData by lazy { JdbcDatabaseMetaData(this, database, connectionURL) }

Expand Down Expand Up @@ -132,11 +137,15 @@ internal class JdbcConnection(
checkClosed()
checkResultSetType(resultSetType)
checkResultSetConcurrency(resultSetConcurrency)
preparedStatementPool.remove(sql)?.let {
it.reopen()
return it
val pooled = poolLock.withLock {
if (closed) { null } else { preparedStatementPool.remove(sql) }
}
return if (pooled != null) {
pooled.reopen()
pooled
} else {
JdbcPreparedStatement(this, database, sql, resultSetType, resultSetConcurrency, resultSetHoldability)
}
return JdbcPreparedStatement(this, database, sql, resultSetType, resultSetConcurrency, resultSetHoldability)
}

override fun prepareStatement(
Expand Down Expand Up @@ -317,23 +326,34 @@ internal class JdbcConnection(
}

private fun closePreparedStatementPool() {
preparedStatementPool.run {
values.forEachCatching(JdbcPreparedStatement::closePooled)
clear()
val snapshot = poolLock.withLock {
if (preparedStatementPool.isEmpty()) {
return
}
ArrayList(preparedStatementPool.values).also { preparedStatementPool.clear() }
}
snapshot.forEachCatching(JdbcPreparedStatement::closePooled)
}

internal fun returnPreparedStatement(statement: JdbcPreparedStatement): Boolean {
if (closed) {
return false
}
if (preparedStatementPool.size >= MAX_POOLED_STATEMENTS && !preparedStatementPool.containsKey(statement.sql)) {
val eldest = preparedStatementPool.entries.iterator().next()
preparedStatementPool.remove(eldest.key)
runCatching { eldest.value.closePooled() }
var accepted = false
var evicted: JdbcPreparedStatement? = null
if (!closed) {
poolLock.withLock {
if (!closed) {
accepted = true
val containsKey = preparedStatementPool.containsKey(statement.sql)
if (preparedStatementPool.size >= MAX_POOLED_STATEMENTS && !containsKey) {
evicted = preparedStatementPool.entries.iterator().next().also {
preparedStatementPool.remove(it.key)
}.value
}
preparedStatementPool[statement.sql] = statement
}
}
}
preparedStatementPool[statement.sql] = statement
return true
evicted?.let { runCatching(it::closePooled) }
return accepted
}

override fun isClosed(): Boolean = closed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,42 @@ internal class JdbcConnectionTest {
assertNotSame(statements[0], evictedStatement)
}

@Test
fun preparedStatementPoolReturnRejectedAfterCloseDoesNotLeak() {
val sql = "SELECT 1"
val preparedStatement = connection.prepareStatement(sql) as JdbcPreparedStatement
connection.close()
assertFalse(connection.returnPreparedStatement(preparedStatement))
assertFailsWith<SQLException> { connection.prepareStatement(sql) }
}

@Test
fun preparedStatementPoolSurvivesConcurrentReturnsAndClose() {
repeat(8) {
val freshConnection = JdbcConnection(SharedDatabase(mock()), connectionURL, properties)
val statements = (0 until 64).map { i ->
freshConnection.prepareStatement("SELECT $i") as JdbcPreparedStatement
}
val startLatch = CountDownLatch(1)
val doneLatch = CountDownLatch(statements.size + 1)
statements.forEach { stmt ->
thread(start = true, name = "stmt-return") {
startLatch.await(2, TimeUnit.SECONDS)
runCatching { stmt.close() }
doneLatch.countDown()
}
}
thread(start = true, name = "conn-close") {
startLatch.await(2, TimeUnit.SECONDS)
runCatching { freshConnection.close() }
doneLatch.countDown()
}
startLatch.countDown()
assertTrue(doneLatch.await(5, TimeUnit.SECONDS))
assertTrue(freshConnection.isClosed)
}
}

@Test
fun executeQueryOnReadOnlyConnectionDoesNotBeginTransaction() {
val database = mock<SQLDatabase> {
Expand Down
Loading