Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions zstd/src/main/java/io/github/dfa1/zstd/ZstdCompressStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,39 +41,50 @@ public ZstdCompressStream() {
///
/// @param level the compression level
public ZstdCompressStream(int level) {
super(create(level, null));
this(level, null);
}

/// Creates a streaming compressor at `level` using `dictionary`.
///
/// @param level the compression level
/// @param dictionary the dictionary to compress against
/// @param dictionary the dictionary to compress against, or `null` for none
public ZstdCompressStream(int level, ZstdDictionary dictionary) {
super(create(level, dictionary));
// Own the context first, so any failure setting it up is cleaned up by
// close() — one release path, no leak on a half-built stream.
super(createCctx());
try {
Zstd.call(() -> (long) Bindings.CCTX_SET_PARAMETER.invokeExact(
ptr(), ZstdCompressParameter.COMPRESSION_LEVEL.value(), level));
if (dictionary != null) {
loadDictionary(dictionary);
}
} catch (Throwable t) {
close();
throw rethrow(t);
}
}

private static MemorySegment create(int level, ZstdDictionary dictionary) {
private static MemorySegment createCctx() {
try {
MemorySegment cctx = (MemorySegment) Bindings.CREATE_CCTX.invokeExact();
if (MemorySegment.NULL.equals(cctx)) {
throw new ZstdException("ZSTD_createCCtx returned NULL");
}
Zstd.call(() -> (long) Bindings.CCTX_SET_PARAMETER.invokeExact(
cctx, ZstdCompressParameter.COMPRESSION_LEVEL.value(), level));
if (dictionary != null) {
try (Arena staging = Arena.ofConfined()) {
byte[] raw = dictionary.raw();
MemorySegment d = Zstd.copyIn(staging, raw);
Zstd.call(() -> (long) Bindings.CCTX_LOAD_DICTIONARY.invokeExact(
cctx, d, (long) raw.length));
}
}
return cctx;
} catch (Throwable t) {
throw rethrow(t);
}
}

private void loadDictionary(ZstdDictionary dictionary) {
try (Arena staging = Arena.ofConfined()) {
byte[] raw = dictionary.raw();
MemorySegment d = Zstd.copyIn(staging, raw);
Zstd.call(() -> (long) Bindings.CCTX_LOAD_DICTIONARY.invokeExact(
ptr(), d, (long) raw.length));
}
}

/// Compresses as much of `src` as fits into `dst` in one step.
///
/// Advance the source by [ZstdStreamResult#bytesConsumed()] and write out
Expand Down
35 changes: 23 additions & 12 deletions zstd/src/main/java/io/github/dfa1/zstd/ZstdDecompressStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,47 @@ public final class ZstdDecompressStream extends NativeObject {

/// Creates a streaming decompressor.
public ZstdDecompressStream() {
super(create(null));
this(null);
}

/// Creates a streaming decompressor for frames built with `dictionary`.
///
/// @param dictionary the dictionary the frames were compressed against
/// @param dictionary the dictionary the frames were compressed against, or `null` for none
public ZstdDecompressStream(ZstdDictionary dictionary) {
super(create(dictionary));
// Own the context first, so any failure setting it up is cleaned up by
// close() — one release path, no leak on a half-built stream.
super(createDctx());
try {
if (dictionary != null) {
loadDictionary(dictionary);
}
} catch (Throwable t) {
close();
throw rethrow(t);
}
}

private static MemorySegment create(ZstdDictionary dictionary) {
private static MemorySegment createDctx() {
try {
MemorySegment dctx = (MemorySegment) Bindings.CREATE_DCTX.invokeExact();
if (MemorySegment.NULL.equals(dctx)) {
throw new ZstdException("ZSTD_createDCtx returned NULL");
}
if (dictionary != null) {
try (Arena staging = Arena.ofConfined()) {
byte[] raw = dictionary.raw();
MemorySegment d = Zstd.copyIn(staging, raw);
Zstd.call(() -> (long) Bindings.DCTX_LOAD_DICTIONARY.invokeExact(
dctx, d, (long) raw.length));
}
}
return dctx;
} catch (Throwable t) {
throw rethrow(t);
}
}

private void loadDictionary(ZstdDictionary dictionary) {
try (Arena staging = Arena.ofConfined()) {
byte[] raw = dictionary.raw();
MemorySegment d = Zstd.copyIn(staging, raw);
Zstd.call(() -> (long) Bindings.DCTX_LOAD_DICTIONARY.invokeExact(
ptr(), d, (long) raw.length));
}
}

/// Decompresses as much of `src` as fits into `dst` in one step.
///
/// Advance the source by [ZstdStreamResult#bytesConsumed()] and read out
Expand Down
33 changes: 22 additions & 11 deletions zstd/src/main/java/io/github/dfa1/zstd/ZstdInputStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,39 @@ public ZstdInputStream(InputStream in) {
/// @param dictionary the dictionary the frame was compressed against, or `null` for none
public ZstdInputStream(InputStream in, ZstdDictionary dictionary) {
this.in = in;
MemorySegment d = null;
try {
this.dctx = (MemorySegment) Bindings.CREATE_DCTX.invokeExact();
if (MemorySegment.NULL.equals(dctx)) {
d = (MemorySegment) Bindings.CREATE_DCTX.invokeExact();
if (MemorySegment.NULL.equals(d)) {
throw new ZstdException("ZSTD_createDCtx returned NULL");
}
this.dctx = d;
if (dictionary != null) {
loadDictionary(dictionary);
}
this.inCap = (long) Bindings.DSTREAM_IN_SIZE.invokeExact();
this.outCap = (long) Bindings.DSTREAM_OUT_SIZE.invokeExact();
this.inSeg = arena.allocate(inCap);
this.outSeg = arena.allocate(outCap);
this.feed = new byte[Math.toIntExact(inCap)];
this.hold = new byte[Math.toIntExact(outCap)];
} catch (Throwable t) {
// Free the context if it was created, then the arena, so a failed
// constructor leaks neither the native dctx nor the arena buffers.
if (d != null && !MemorySegment.NULL.equals(d)) {
freeDctx(d);
}
arena.close();
throw rethrow(t);
}
this.inSeg = arena.allocate(inCap);
this.outSeg = arena.allocate(outCap);
this.feed = new byte[Math.toIntExact(inCap)];
this.hold = new byte[Math.toIntExact(outCap)];
}

private static void freeDctx(MemorySegment dctx) {
try {
var _ = (long) Bindings.FREE_DCTX.invokeExact(dctx);
} catch (Throwable _) {
// best-effort free
}
}

private void loadDictionary(ZstdDictionary dictionary) {
Expand Down Expand Up @@ -158,11 +173,7 @@ public void close() throws IOException {
return;
}
closed = true;
try {
var _ = (long) Bindings.FREE_DCTX.invokeExact(dctx);
} catch (Throwable _) {
// best-effort free
}
freeDctx(dctx);
arena.close();
in.close();
}
Expand Down
31 changes: 21 additions & 10 deletions zstd/src/main/java/io/github/dfa1/zstd/ZstdOutputStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,25 +93,40 @@ private void setPledgedSrcSize(long pledgedSrcSize) {
/// @param dictionary the dictionary to compress against, or `null` for none
public ZstdOutputStream(OutputStream out, int level, ZstdDictionary dictionary) {
this.out = out;
MemorySegment c = null;
try {
this.cctx = (MemorySegment) Bindings.CREATE_CCTX.invokeExact();
if (MemorySegment.NULL.equals(cctx)) {
c = (MemorySegment) Bindings.CREATE_CCTX.invokeExact();
if (MemorySegment.NULL.equals(c)) {
throw new ZstdException("ZSTD_createCCtx returned NULL");
}
this.cctx = c;
Zstd.call(() -> (long) Bindings.CCTX_SET_PARAMETER.invokeExact(
cctx, ZSTD_C_COMPRESSION_LEVEL, level));
if (dictionary != null) {
loadDictionary(dictionary);
}
this.inCap = (long) Bindings.CSTREAM_IN_SIZE.invokeExact();
this.outCap = (long) Bindings.CSTREAM_OUT_SIZE.invokeExact();
this.inSeg = arena.allocate(inCap);
this.outSeg = arena.allocate(outCap);
this.drain = new byte[Math.toIntExact(outCap)];
} catch (Throwable t) {
// Free the context if it was created, then the arena, so a failed
// constructor leaks neither the native cctx nor the arena buffers.
if (c != null && !MemorySegment.NULL.equals(c)) {
freeCctx(c);
}
arena.close();
throw rethrow(t);
}
this.inSeg = arena.allocate(inCap);
this.outSeg = arena.allocate(outCap);
this.drain = new byte[Math.toIntExact(outCap)];
}

private static void freeCctx(MemorySegment cctx) {
try {
var _ = (long) Bindings.FREE_CCTX.invokeExact(cctx);
} catch (Throwable _) {
// best-effort free
}
}

private void loadDictionary(ZstdDictionary dictionary) {
Expand Down Expand Up @@ -172,11 +187,7 @@ public void close() throws IOException {
out.flush();
} finally {
closed = true;
try {
var _ = (long) Bindings.FREE_CCTX.invokeExact(cctx);
} catch (Throwable _) {
// best-effort free
}
freeCctx(cctx);
arena.close();
out.close();
}
Expand Down
Loading