diff --git a/zstd/src/main/java/io/github/dfa1/zstd/ZstdCompressStream.java b/zstd/src/main/java/io/github/dfa1/zstd/ZstdCompressStream.java index 44b028c..75fec0c 100644 --- a/zstd/src/main/java/io/github/dfa1/zstd/ZstdCompressStream.java +++ b/zstd/src/main/java/io/github/dfa1/zstd/ZstdCompressStream.java @@ -41,39 +41,50 @@ public ZstdCompressStream() { /// /// @param level the compression level public ZstdCompressStream(int level) { - super(create(level, null)); + this(level, null); } /// Creates a streaming compressor at `level` using `dictionary`. /// /// @param level the compression level - /// @param dictionary the dictionary to compress against + /// @param dictionary the dictionary to compress against, or `null` for none public ZstdCompressStream(int level, ZstdDictionary dictionary) { - super(create(level, dictionary)); + // Own the context first, so any failure setting it up is cleaned up by + // close() — one release path, no leak on a half-built stream. + super(createCctx()); + try { + Zstd.call(() -> (long) Bindings.CCTX_SET_PARAMETER.invokeExact( + ptr(), ZstdCompressParameter.COMPRESSION_LEVEL.value(), level)); + if (dictionary != null) { + loadDictionary(dictionary); + } + } catch (Throwable t) { + close(); + throw rethrow(t); + } } - private static MemorySegment create(int level, ZstdDictionary dictionary) { + private static MemorySegment createCctx() { try { MemorySegment cctx = (MemorySegment) Bindings.CREATE_CCTX.invokeExact(); if (MemorySegment.NULL.equals(cctx)) { throw new ZstdException("ZSTD_createCCtx returned NULL"); } - Zstd.call(() -> (long) Bindings.CCTX_SET_PARAMETER.invokeExact( - cctx, ZstdCompressParameter.COMPRESSION_LEVEL.value(), level)); - if (dictionary != null) { - try (Arena staging = Arena.ofConfined()) { - byte[] raw = dictionary.raw(); - MemorySegment d = Zstd.copyIn(staging, raw); - Zstd.call(() -> (long) Bindings.CCTX_LOAD_DICTIONARY.invokeExact( - cctx, d, (long) raw.length)); - } - } return cctx; } catch (Throwable t) { throw rethrow(t); } } + private void loadDictionary(ZstdDictionary dictionary) { + try (Arena staging = Arena.ofConfined()) { + byte[] raw = dictionary.raw(); + MemorySegment d = Zstd.copyIn(staging, raw); + Zstd.call(() -> (long) Bindings.CCTX_LOAD_DICTIONARY.invokeExact( + ptr(), d, (long) raw.length)); + } + } + /// Compresses as much of `src` as fits into `dst` in one step. /// /// Advance the source by [ZstdStreamResult#bytesConsumed()] and write out diff --git a/zstd/src/main/java/io/github/dfa1/zstd/ZstdDecompressStream.java b/zstd/src/main/java/io/github/dfa1/zstd/ZstdDecompressStream.java index 6f8f482..c42aa2f 100644 --- a/zstd/src/main/java/io/github/dfa1/zstd/ZstdDecompressStream.java +++ b/zstd/src/main/java/io/github/dfa1/zstd/ZstdDecompressStream.java @@ -18,36 +18,47 @@ public final class ZstdDecompressStream extends NativeObject { /// Creates a streaming decompressor. public ZstdDecompressStream() { - super(create(null)); + this(null); } /// Creates a streaming decompressor for frames built with `dictionary`. /// - /// @param dictionary the dictionary the frames were compressed against + /// @param dictionary the dictionary the frames were compressed against, or `null` for none public ZstdDecompressStream(ZstdDictionary dictionary) { - super(create(dictionary)); + // Own the context first, so any failure setting it up is cleaned up by + // close() — one release path, no leak on a half-built stream. + super(createDctx()); + try { + if (dictionary != null) { + loadDictionary(dictionary); + } + } catch (Throwable t) { + close(); + throw rethrow(t); + } } - private static MemorySegment create(ZstdDictionary dictionary) { + private static MemorySegment createDctx() { try { MemorySegment dctx = (MemorySegment) Bindings.CREATE_DCTX.invokeExact(); if (MemorySegment.NULL.equals(dctx)) { throw new ZstdException("ZSTD_createDCtx returned NULL"); } - if (dictionary != null) { - try (Arena staging = Arena.ofConfined()) { - byte[] raw = dictionary.raw(); - MemorySegment d = Zstd.copyIn(staging, raw); - Zstd.call(() -> (long) Bindings.DCTX_LOAD_DICTIONARY.invokeExact( - dctx, d, (long) raw.length)); - } - } return dctx; } catch (Throwable t) { throw rethrow(t); } } + private void loadDictionary(ZstdDictionary dictionary) { + try (Arena staging = Arena.ofConfined()) { + byte[] raw = dictionary.raw(); + MemorySegment d = Zstd.copyIn(staging, raw); + Zstd.call(() -> (long) Bindings.DCTX_LOAD_DICTIONARY.invokeExact( + ptr(), d, (long) raw.length)); + } + } + /// Decompresses as much of `src` as fits into `dst` in one step. /// /// Advance the source by [ZstdStreamResult#bytesConsumed()] and read out diff --git a/zstd/src/main/java/io/github/dfa1/zstd/ZstdInputStream.java b/zstd/src/main/java/io/github/dfa1/zstd/ZstdInputStream.java index e16ab4d..1380875 100644 --- a/zstd/src/main/java/io/github/dfa1/zstd/ZstdInputStream.java +++ b/zstd/src/main/java/io/github/dfa1/zstd/ZstdInputStream.java @@ -57,24 +57,39 @@ public ZstdInputStream(InputStream in) { /// @param dictionary the dictionary the frame was compressed against, or `null` for none public ZstdInputStream(InputStream in, ZstdDictionary dictionary) { this.in = in; + MemorySegment d = null; try { - this.dctx = (MemorySegment) Bindings.CREATE_DCTX.invokeExact(); - if (MemorySegment.NULL.equals(dctx)) { + d = (MemorySegment) Bindings.CREATE_DCTX.invokeExact(); + if (MemorySegment.NULL.equals(d)) { throw new ZstdException("ZSTD_createDCtx returned NULL"); } + this.dctx = d; if (dictionary != null) { loadDictionary(dictionary); } this.inCap = (long) Bindings.DSTREAM_IN_SIZE.invokeExact(); this.outCap = (long) Bindings.DSTREAM_OUT_SIZE.invokeExact(); + this.inSeg = arena.allocate(inCap); + this.outSeg = arena.allocate(outCap); + this.feed = new byte[Math.toIntExact(inCap)]; + this.hold = new byte[Math.toIntExact(outCap)]; } catch (Throwable t) { + // Free the context if it was created, then the arena, so a failed + // constructor leaks neither the native dctx nor the arena buffers. + if (d != null && !MemorySegment.NULL.equals(d)) { + freeDctx(d); + } arena.close(); throw rethrow(t); } - this.inSeg = arena.allocate(inCap); - this.outSeg = arena.allocate(outCap); - this.feed = new byte[Math.toIntExact(inCap)]; - this.hold = new byte[Math.toIntExact(outCap)]; + } + + private static void freeDctx(MemorySegment dctx) { + try { + var _ = (long) Bindings.FREE_DCTX.invokeExact(dctx); + } catch (Throwable _) { + // best-effort free + } } private void loadDictionary(ZstdDictionary dictionary) { @@ -158,11 +173,7 @@ public void close() throws IOException { return; } closed = true; - try { - var _ = (long) Bindings.FREE_DCTX.invokeExact(dctx); - } catch (Throwable _) { - // best-effort free - } + freeDctx(dctx); arena.close(); in.close(); } diff --git a/zstd/src/main/java/io/github/dfa1/zstd/ZstdOutputStream.java b/zstd/src/main/java/io/github/dfa1/zstd/ZstdOutputStream.java index 87d59e5..2a42e30 100644 --- a/zstd/src/main/java/io/github/dfa1/zstd/ZstdOutputStream.java +++ b/zstd/src/main/java/io/github/dfa1/zstd/ZstdOutputStream.java @@ -93,11 +93,13 @@ private void setPledgedSrcSize(long pledgedSrcSize) { /// @param dictionary the dictionary to compress against, or `null` for none public ZstdOutputStream(OutputStream out, int level, ZstdDictionary dictionary) { this.out = out; + MemorySegment c = null; try { - this.cctx = (MemorySegment) Bindings.CREATE_CCTX.invokeExact(); - if (MemorySegment.NULL.equals(cctx)) { + c = (MemorySegment) Bindings.CREATE_CCTX.invokeExact(); + if (MemorySegment.NULL.equals(c)) { throw new ZstdException("ZSTD_createCCtx returned NULL"); } + this.cctx = c; Zstd.call(() -> (long) Bindings.CCTX_SET_PARAMETER.invokeExact( cctx, ZSTD_C_COMPRESSION_LEVEL, level)); if (dictionary != null) { @@ -105,13 +107,26 @@ public ZstdOutputStream(OutputStream out, int level, ZstdDictionary dictionary) } this.inCap = (long) Bindings.CSTREAM_IN_SIZE.invokeExact(); this.outCap = (long) Bindings.CSTREAM_OUT_SIZE.invokeExact(); + this.inSeg = arena.allocate(inCap); + this.outSeg = arena.allocate(outCap); + this.drain = new byte[Math.toIntExact(outCap)]; } catch (Throwable t) { + // Free the context if it was created, then the arena, so a failed + // constructor leaks neither the native cctx nor the arena buffers. + if (c != null && !MemorySegment.NULL.equals(c)) { + freeCctx(c); + } arena.close(); throw rethrow(t); } - this.inSeg = arena.allocate(inCap); - this.outSeg = arena.allocate(outCap); - this.drain = new byte[Math.toIntExact(outCap)]; + } + + private static void freeCctx(MemorySegment cctx) { + try { + var _ = (long) Bindings.FREE_CCTX.invokeExact(cctx); + } catch (Throwable _) { + // best-effort free + } } private void loadDictionary(ZstdDictionary dictionary) { @@ -172,11 +187,7 @@ public void close() throws IOException { out.flush(); } finally { closed = true; - try { - var _ = (long) Bindings.FREE_CCTX.invokeExact(cctx); - } catch (Throwable _) { - // best-effort free - } + freeCctx(cctx); arena.close(); out.close(); }