Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions core/src/main/java/io/roastedroot/treesitter/TreeSitterPool.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package io.roastedroot.treesitter;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.function.Consumer;
import java.util.function.Function;

/**
* A thread-safe pool of {@link TreeSitter} instances for concurrent usage.
*
* <p>Each {@link TreeSitter} wraps a Chicory {@code Instance} which is <b>not thread-safe</b>.
* This pool manages a fixed number of {@code TreeSitter} instances and lends them to callers
* via {@link #execute(Function)} or {@link #execute(Consumer)}. When all instances are in use,
* callers block until one becomes available — virtual threads will park without pinning a
* carrier thread since the pool uses {@link ArrayBlockingQueue} (backed by {@code ReentrantLock}).
*
* <p>The underlying Wasm {@code Module} is loaded once (singleton) and shared across all
* instances — only the Chicory {@code Instance} (memory, globals, stack) is duplicated.
*
* <p>Usage with virtual threads:
* <pre>{@code
* try (var pool = new TreeSitterPool(Runtime.getRuntime().availableProcessors());
* var executor = Executors.newVirtualThreadPerTaskExecutor()) {
*
* List<Future<String>> futures = files.stream()
* .map(file -> executor.submit(() ->
* pool.execute(ts -> {
* try (var parser = ts.newParser()) {
* parser.setLanguage(Language.JAVA);
* try (var tree = parser.parseString(source)) {
* return tree.rootNode().toSexp();
* }
* }
* })
* ))
* .toList();
* }
* }</pre>
*/
public class TreeSitterPool implements AutoCloseable {

private final BlockingQueue<TreeSitter> pool;
private final List<TreeSitter> allInstances;

/**
* Creates a pool with the given number of {@link TreeSitter} instances.
*
* @param poolSize the number of instances to pre-create (must be &gt; 0)
* @throws IllegalArgumentException if poolSize is not positive
*/
public TreeSitterPool(int poolSize) {
if (poolSize <= 0) {
throw new IllegalArgumentException("poolSize must be positive, got: " + poolSize);
}
this.allInstances = new ArrayList<>(poolSize);
this.pool = new ArrayBlockingQueue<>(poolSize);
for (int i = 0; i < poolSize; i++) {
TreeSitter ts = TreeSitter.create();
allInstances.add(ts);
pool.offer(ts);
}
}

/**
* Borrows a {@link TreeSitter} instance, applies the given function, and returns the instance
* to the pool. Blocks if no instance is currently available.
*
* @param action the function to execute with a borrowed instance
* @param <T> the return type
* @return the result of the function
* @throws InterruptedException if the calling thread is interrupted while waiting
*/
public <T> T execute(Function<TreeSitter, T> action) throws InterruptedException {
TreeSitter ts = pool.take();
try {
return action.apply(ts);
} finally {
pool.offer(ts);
}
}

/**
* Borrows a {@link TreeSitter} instance, runs the given consumer, and returns the instance
* to the pool. Blocks if no instance is currently available.
*
* @param action the consumer to execute with a borrowed instance
* @throws InterruptedException if the calling thread is interrupted while waiting
*/
public void execute(Consumer<TreeSitter> action) throws InterruptedException {
TreeSitter ts = pool.take();
try {
action.accept(ts);
} finally {
pool.offer(ts);
}
}

/**
* Returns the total number of instances managed by this pool.
*/
public int size() {
return allInstances.size();
}

/**
* Closes all pooled {@link TreeSitter} instances.
*/
@Override
public void close() {
for (TreeSitter ts : allInstances) {
ts.close();
}
}
}
137 changes: 137 additions & 0 deletions core/src/test/java/io/roastedroot/treesitter/TreeSitterPoolTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package io.roastedroot.treesitter;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;

import org.junit.jupiter.api.Test;

class TreeSitterPoolTest {

@Test
void poolSizeIsReported() {
try (var pool = new TreeSitterPool(4)) {
assertEquals(4, pool.size());
}
}

@Test
void rejectsNonPositiveSize() {
assertThrows(IllegalArgumentException.class, () -> new TreeSitterPool(0));
assertThrows(IllegalArgumentException.class, () -> new TreeSitterPool(-1));
}

@Test
void singleThreadParsing() throws Exception {
long start = System.nanoTime();
try (var pool = new TreeSitterPool(1)) {
String sexp = pool.execute(ts -> {
try (var parser = ts.newParser()) {
parser.setLanguage(Language.JAVA);
try (var tree = parser.parseString("class Foo {}")) {
return tree.rootNode().toSexp();
}
}
});
assertNotNull(sexp);
assertTrue(sexp.contains("class_declaration"));
}
long elapsedMs = (System.nanoTime() - start) / 1_000_000;
System.out.println("singleThreadParsing: " + elapsedMs + " ms (pool=1, tasks=1)");
}

@Test
void concurrentParsingWithVirtualThreads() throws Exception {
int poolSize = 4;
int taskCount = 20;

List<String> sources = new ArrayList<>();
for (int i = 0; i < taskCount; i++) {
sources.add("class Task" + i + " { void run() {} }");
}

AtomicInteger successCount = new AtomicInteger();

long start = System.nanoTime();
try (var pool = new TreeSitterPool(poolSize);
var executor = Executors.newVirtualThreadPerTaskExecutor()) {

List<Future<String>> futures = new ArrayList<>();
for (String source : sources) {
futures.add(executor.submit(() ->
pool.execute(ts -> {
try (var parser = ts.newParser()) {
parser.setLanguage(Language.JAVA);
try (var tree = parser.parseString(source)) {
successCount.incrementAndGet();
return tree.rootNode().toSexp();
}
}
})
));
}

for (Future<String> future : futures) {
String sexp = future.get();
assertNotNull(sexp);
assertTrue(sexp.contains("class_declaration"));
}
}
long elapsedMs = (System.nanoTime() - start) / 1_000_000;
System.out.println("concurrentParsingWithVirtualThreads: " + elapsedMs + " ms (pool=" + poolSize + ", tasks=" + taskCount + ")");

assertEquals(taskCount, successCount.get());
}

@Test
void sequentialParsingBaseline() throws Exception {
int taskCount = 20;

List<String> sources = new ArrayList<>();
for (int i = 0; i < taskCount; i++) {
sources.add("class Task" + i + " { void run() {} }");
}

long start = System.nanoTime();
try (var ts = TreeSitter.create();
var parser = ts.newParser()) {
parser.setLanguage(Language.JAVA);
for (String source : sources) {
try (var tree = parser.parseString(source)) {
assertNotNull(tree.rootNode().toSexp());
}
}
}
long elapsedMs = (System.nanoTime() - start) / 1_000_000;
System.out.println("sequentialParsingBaseline: " + elapsedMs + " ms (no pool, tasks=" + taskCount + ")");
}

@Test
void voidExecuteWorks() throws Exception {
AtomicInteger counter = new AtomicInteger();

long start = System.nanoTime();
try (var pool = new TreeSitterPool(2)) {
pool.execute(ts -> {
try (var parser = ts.newParser()) {
parser.setLanguage(Language.JSON);
try (var tree = parser.parseString("{\"key\": 1}")) {
assertNotNull(tree.rootNode());
counter.incrementAndGet();
}
}
});
}
long elapsedMs = (System.nanoTime() - start) / 1_000_000;
System.out.println("voidExecuteWorks: " + elapsedMs + " ms (pool=2, tasks=1)");

assertEquals(1, counter.get());
}
}
17 changes: 16 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<maven.compiler.release>17</maven.compiler.release>
<maven.compiler.release>21</maven.compiler.release>

<chicory.version>1.7.5</chicory.version>
<endive.version>latest</endive.version>
<junit.version>5.11.4</junit.version>
<jackson.version>2.17.2</jackson.version>

Expand All @@ -28,6 +29,20 @@
<build-helper-maven-plugin.version>3.6.1</build-helper-maven-plugin.version>
</properties>

<dependencyManagement>
<dependencies>
<!--
<dependency>
<groupId>run.endive</groupId>
<artifactId>bom</artifactId>
<version>${endive.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
-->
</dependencies>
</dependencyManagement>

<build>
<pluginManagement>
<plugins>
Expand Down