diff --git a/core/src/main/java/io/roastedroot/treesitter/TreeSitterPool.java b/core/src/main/java/io/roastedroot/treesitter/TreeSitterPool.java new file mode 100644 index 0000000..5f8d6cb --- /dev/null +++ b/core/src/main/java/io/roastedroot/treesitter/TreeSitterPool.java @@ -0,0 +1,116 @@ +package io.roastedroot.treesitter; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.function.Consumer; +import java.util.function.Function; + +/** + * A thread-safe pool of {@link TreeSitter} instances for concurrent usage. + * + *

Each {@link TreeSitter} wraps a Chicory {@code Instance} which is not thread-safe. + * This pool manages a fixed number of {@code TreeSitter} instances and lends them to callers + * via {@link #execute(Function)} or {@link #execute(Consumer)}. When all instances are in use, + * callers block until one becomes available — virtual threads will park without pinning a + * carrier thread since the pool uses {@link ArrayBlockingQueue} (backed by {@code ReentrantLock}). + * + *

The underlying Wasm {@code Module} is loaded once (singleton) and shared across all + * instances — only the Chicory {@code Instance} (memory, globals, stack) is duplicated. + * + *

Usage with virtual threads: + *

{@code
+ * try (var pool = new TreeSitterPool(Runtime.getRuntime().availableProcessors());
+ *      var executor = Executors.newVirtualThreadPerTaskExecutor()) {
+ *
+ *     List> futures = files.stream()
+ *         .map(file -> executor.submit(() ->
+ *             pool.execute(ts -> {
+ *                 try (var parser = ts.newParser()) {
+ *                     parser.setLanguage(Language.JAVA);
+ *                     try (var tree = parser.parseString(source)) {
+ *                         return tree.rootNode().toSexp();
+ *                     }
+ *                 }
+ *             })
+ *         ))
+ *         .toList();
+ * }
+ * }
+ */ +public class TreeSitterPool implements AutoCloseable { + + private final BlockingQueue pool; + private final List allInstances; + + /** + * Creates a pool with the given number of {@link TreeSitter} instances. + * + * @param poolSize the number of instances to pre-create (must be > 0) + * @throws IllegalArgumentException if poolSize is not positive + */ + public TreeSitterPool(int poolSize) { + if (poolSize <= 0) { + throw new IllegalArgumentException("poolSize must be positive, got: " + poolSize); + } + this.allInstances = new ArrayList<>(poolSize); + this.pool = new ArrayBlockingQueue<>(poolSize); + for (int i = 0; i < poolSize; i++) { + TreeSitter ts = TreeSitter.create(); + allInstances.add(ts); + pool.offer(ts); + } + } + + /** + * Borrows a {@link TreeSitter} instance, applies the given function, and returns the instance + * to the pool. Blocks if no instance is currently available. + * + * @param action the function to execute with a borrowed instance + * @param the return type + * @return the result of the function + * @throws InterruptedException if the calling thread is interrupted while waiting + */ + public T execute(Function action) throws InterruptedException { + TreeSitter ts = pool.take(); + try { + return action.apply(ts); + } finally { + pool.offer(ts); + } + } + + /** + * Borrows a {@link TreeSitter} instance, runs the given consumer, and returns the instance + * to the pool. Blocks if no instance is currently available. + * + * @param action the consumer to execute with a borrowed instance + * @throws InterruptedException if the calling thread is interrupted while waiting + */ + public void execute(Consumer action) throws InterruptedException { + TreeSitter ts = pool.take(); + try { + action.accept(ts); + } finally { + pool.offer(ts); + } + } + + /** + * Returns the total number of instances managed by this pool. + */ + public int size() { + return allInstances.size(); + } + + /** + * Closes all pooled {@link TreeSitter} instances. + */ + @Override + public void close() { + for (TreeSitter ts : allInstances) { + ts.close(); + } + } +} \ No newline at end of file diff --git a/core/src/test/java/io/roastedroot/treesitter/TreeSitterPoolTest.java b/core/src/test/java/io/roastedroot/treesitter/TreeSitterPoolTest.java new file mode 100644 index 0000000..972ad07 --- /dev/null +++ b/core/src/test/java/io/roastedroot/treesitter/TreeSitterPoolTest.java @@ -0,0 +1,137 @@ +package io.roastedroot.treesitter; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.jupiter.api.Test; + +class TreeSitterPoolTest { + + @Test + void poolSizeIsReported() { + try (var pool = new TreeSitterPool(4)) { + assertEquals(4, pool.size()); + } + } + + @Test + void rejectsNonPositiveSize() { + assertThrows(IllegalArgumentException.class, () -> new TreeSitterPool(0)); + assertThrows(IllegalArgumentException.class, () -> new TreeSitterPool(-1)); + } + + @Test + void singleThreadParsing() throws Exception { + long start = System.nanoTime(); + try (var pool = new TreeSitterPool(1)) { + String sexp = pool.execute(ts -> { + try (var parser = ts.newParser()) { + parser.setLanguage(Language.JAVA); + try (var tree = parser.parseString("class Foo {}")) { + return tree.rootNode().toSexp(); + } + } + }); + assertNotNull(sexp); + assertTrue(sexp.contains("class_declaration")); + } + long elapsedMs = (System.nanoTime() - start) / 1_000_000; + System.out.println("singleThreadParsing: " + elapsedMs + " ms (pool=1, tasks=1)"); + } + + @Test + void concurrentParsingWithVirtualThreads() throws Exception { + int poolSize = 4; + int taskCount = 20; + + List sources = new ArrayList<>(); + for (int i = 0; i < taskCount; i++) { + sources.add("class Task" + i + " { void run() {} }"); + } + + AtomicInteger successCount = new AtomicInteger(); + + long start = System.nanoTime(); + try (var pool = new TreeSitterPool(poolSize); + var executor = Executors.newVirtualThreadPerTaskExecutor()) { + + List> futures = new ArrayList<>(); + for (String source : sources) { + futures.add(executor.submit(() -> + pool.execute(ts -> { + try (var parser = ts.newParser()) { + parser.setLanguage(Language.JAVA); + try (var tree = parser.parseString(source)) { + successCount.incrementAndGet(); + return tree.rootNode().toSexp(); + } + } + }) + )); + } + + for (Future future : futures) { + String sexp = future.get(); + assertNotNull(sexp); + assertTrue(sexp.contains("class_declaration")); + } + } + long elapsedMs = (System.nanoTime() - start) / 1_000_000; + System.out.println("concurrentParsingWithVirtualThreads: " + elapsedMs + " ms (pool=" + poolSize + ", tasks=" + taskCount + ")"); + + assertEquals(taskCount, successCount.get()); + } + + @Test + void sequentialParsingBaseline() throws Exception { + int taskCount = 20; + + List sources = new ArrayList<>(); + for (int i = 0; i < taskCount; i++) { + sources.add("class Task" + i + " { void run() {} }"); + } + + long start = System.nanoTime(); + try (var ts = TreeSitter.create(); + var parser = ts.newParser()) { + parser.setLanguage(Language.JAVA); + for (String source : sources) { + try (var tree = parser.parseString(source)) { + assertNotNull(tree.rootNode().toSexp()); + } + } + } + long elapsedMs = (System.nanoTime() - start) / 1_000_000; + System.out.println("sequentialParsingBaseline: " + elapsedMs + " ms (no pool, tasks=" + taskCount + ")"); + } + + @Test + void voidExecuteWorks() throws Exception { + AtomicInteger counter = new AtomicInteger(); + + long start = System.nanoTime(); + try (var pool = new TreeSitterPool(2)) { + pool.execute(ts -> { + try (var parser = ts.newParser()) { + parser.setLanguage(Language.JSON); + try (var tree = parser.parseString("{\"key\": 1}")) { + assertNotNull(tree.rootNode()); + counter.incrementAndGet(); + } + } + }); + } + long elapsedMs = (System.nanoTime() - start) / 1_000_000; + System.out.println("voidExecuteWorks: " + elapsedMs + " ms (pool=2, tasks=1)"); + + assertEquals(1, counter.get()); + } +} \ No newline at end of file diff --git a/pom.xml b/pom.xml index ab9266c..eb7e131 100644 --- a/pom.xml +++ b/pom.xml @@ -17,9 +17,10 @@ UTF-8 UTF-8 - 17 + 21 1.7.5 + latest 5.11.4 2.17.2 @@ -28,6 +29,20 @@ 3.6.1 + + + + + +