diff --git a/zstd/src/main/java/io/github/dfa1/zstd/ZstdCompressStream.java b/zstd/src/main/java/io/github/dfa1/zstd/ZstdCompressStream.java index cbb19da..3676d74 100644 --- a/zstd/src/main/java/io/github/dfa1/zstd/ZstdCompressStream.java +++ b/zstd/src/main/java/io/github/dfa1/zstd/ZstdCompressStream.java @@ -99,8 +99,8 @@ private void loadDictionary(ZstdDictionary dictionary) { public ZstdStreamResult compress(MemorySegment dst, MemorySegment src, ZstdEndDirective directive) { NativeCall.requireNative(dst, "dst"); NativeCall.requireNative(src, "src"); - in.set(src, src.byteSize(), 0); - out.set(dst, dst.byteSize(), 0); + in.set(src, src.byteSize()); + out.set(dst, dst.byteSize()); long remaining = NativeCall.checkReturnValue(() -> (long) Bindings.COMPRESS_STREAM2.invokeExact( ptr(), out.segment(), in.segment(), directive.value())); return new ZstdStreamResult(in.pos(), out.pos(), remaining); diff --git a/zstd/src/main/java/io/github/dfa1/zstd/ZstdDecompressStream.java b/zstd/src/main/java/io/github/dfa1/zstd/ZstdDecompressStream.java index cd44173..e0bedd7 100644 --- a/zstd/src/main/java/io/github/dfa1/zstd/ZstdDecompressStream.java +++ b/zstd/src/main/java/io/github/dfa1/zstd/ZstdDecompressStream.java @@ -70,8 +70,8 @@ private void loadDictionary(ZstdDictionary dictionary) { /// @return how much was consumed and produced, and the remaining hint /// @throws ZstdException if the frame is invalid public ZstdStreamResult decompress(MemorySegment dst, MemorySegment src) { - in.set(src, src.byteSize(), 0); - out.set(dst, dst.byteSize(), 0); + in.set(src, src.byteSize()); + out.set(dst, dst.byteSize()); long remaining = NativeCall.checkReturnValue(() -> (long) Bindings.DECOMPRESS_STREAM.invokeExact( ptr(), out.segment(), in.segment())); return new ZstdStreamResult(in.pos(), out.pos(), remaining); diff --git a/zstd/src/main/java/io/github/dfa1/zstd/ZstdException.java b/zstd/src/main/java/io/github/dfa1/zstd/ZstdException.java index 6802298..718514c 100644 --- a/zstd/src/main/java/io/github/dfa1/zstd/ZstdException.java +++ b/zstd/src/main/java/io/github/dfa1/zstd/ZstdException.java @@ -1,5 +1,7 @@ package io.github.dfa1.zstd; +import java.io.Serial; + /// Thrown when a zstd native call reports an error. /// /// Unchecked: zstd errors on valid use of this API indicate either corrupt @@ -7,6 +9,7 @@ /// recoverable I/O condition. public final class ZstdException extends RuntimeException { + @Serial private static final long serialVersionUID = 1L; /// The zstd error category for this failure. diff --git a/zstd/src/main/java/io/github/dfa1/zstd/ZstdInputStream.java b/zstd/src/main/java/io/github/dfa1/zstd/ZstdInputStream.java index 99a3eff..de63fd2 100644 --- a/zstd/src/main/java/io/github/dfa1/zstd/ZstdInputStream.java +++ b/zstd/src/main/java/io/github/dfa1/zstd/ZstdInputStream.java @@ -29,7 +29,6 @@ public final class ZstdInputStream extends InputStream { private final MemorySegment dctx; private final MemorySegment inSeg; private final MemorySegment outSeg; - private final long inCap; private final long outCap; private final ZstdStreamBuffer inBuf = new ZstdStreamBuffer(arena); private final ZstdStreamBuffer outBufView = new ZstdStreamBuffer(arena); @@ -67,7 +66,7 @@ public ZstdInputStream(InputStream in, ZstdDictionary dictionary) { if (dictionary != null) { loadDictionary(dictionary); } - this.inCap = (long) Bindings.DSTREAM_IN_SIZE.invokeExact(); + long inCap = (long) Bindings.DSTREAM_IN_SIZE.invokeExact(); this.outCap = (long) Bindings.DSTREAM_OUT_SIZE.invokeExact(); this.inSeg = arena.allocate(inCap); this.outSeg = arena.allocate(outCap); @@ -140,9 +139,9 @@ private boolean produce() throws IOException { return false; } MemorySegment.copy(feed, 0, inSeg, JAVA_BYTE, 0, r); - inBuf.set(inSeg, r, 0); + inBuf.set(inSeg, r); } - outBufView.set(outSeg, outCap, 0); + outBufView.set(outSeg, outCap); lastHint = NativeCall.checkReturnValue(() -> (long) Bindings.DECOMPRESS_STREAM.invokeExact( dctx, outBufView.segment(), inBuf.segment())); int produced = Math.toIntExact(outBufView.pos()); diff --git a/zstd/src/main/java/io/github/dfa1/zstd/ZstdOutputStream.java b/zstd/src/main/java/io/github/dfa1/zstd/ZstdOutputStream.java index 2f1b89c..f1e03f0 100644 --- a/zstd/src/main/java/io/github/dfa1/zstd/ZstdOutputStream.java +++ b/zstd/src/main/java/io/github/dfa1/zstd/ZstdOutputStream.java @@ -153,7 +153,7 @@ public void write(byte[] b, int off, int len) throws IOException { while (remaining > 0) { int chunk = (int) Math.min(remaining, inCap); MemorySegment.copy(b, pos, inSeg, JAVA_BYTE, 0, chunk); - in.set(inSeg, chunk, 0); + in.set(inSeg, chunk); do { drainOutput(ZSTD_E_CONTINUE); } while (in.pos() < chunk); @@ -165,7 +165,7 @@ public void write(byte[] b, int off, int len) throws IOException { @Override public void flush() throws IOException { ensureOpen(); - in.set(inSeg, 0, 0); + in.set(inSeg, 0); long remainingHint; do { remainingHint = drainOutput(ZSTD_E_FLUSH); @@ -179,7 +179,7 @@ public void close() throws IOException { return; } try { - in.set(inSeg, 0, 0); + in.set(inSeg, 0); long remainingHint; do { remainingHint = drainOutput(ZSTD_E_END); @@ -196,7 +196,7 @@ public void close() throws IOException { /// Runs one compressStream2 call and writes whatever it produced to `out`. /// Returns the zstd "remaining" hint (0 means the directive is fully flushed). private long drainOutput(int directive) throws IOException { - outBuf.set(outSeg, outCap, 0); + outBuf.set(outSeg, outCap); long remainingHint = NativeCall.checkReturnValue(() -> (long) Bindings.COMPRESS_STREAM2.invokeExact( cctx, outBuf.segment(), in.segment(), directive)); int produced = Math.toIntExact(outBuf.pos()); diff --git a/zstd/src/main/java/io/github/dfa1/zstd/ZstdSkippableContent.java b/zstd/src/main/java/io/github/dfa1/zstd/ZstdSkippableContent.java index 80d22f4..9b892a6 100644 --- a/zstd/src/main/java/io/github/dfa1/zstd/ZstdSkippableContent.java +++ b/zstd/src/main/java/io/github/dfa1/zstd/ZstdSkippableContent.java @@ -16,9 +16,9 @@ public record ZstdSkippableContent(byte[] content, int magicVariant) { /// @return `true` if `o` is a [ZstdSkippableContent] with equal content bytes and variant @Override public boolean equals(Object o) { - return o instanceof ZstdSkippableContent other - && magicVariant == other.magicVariant - && Arrays.equals(content, other.content); + return o instanceof ZstdSkippableContent(byte[] otherContent, int otherVariant) + && magicVariant == otherVariant + && Arrays.equals(content, otherContent); } /// Hash code consistent with [#equals(Object)], derived from the content bytes @@ -35,6 +35,7 @@ public int hashCode() { /// /// @return a string with the content length and magic variant @Override + @SuppressWarnings("NullableProblems") // toString never returns null; we just don't pull in JB @NotNull public String toString() { return "ZstdSkippableContent[content=" + content.length + " bytes, magicVariant=" + magicVariant + "]"; } diff --git a/zstd/src/main/java/io/github/dfa1/zstd/ZstdStreamBuffer.java b/zstd/src/main/java/io/github/dfa1/zstd/ZstdStreamBuffer.java index 9e2bb01..c231e91 100644 --- a/zstd/src/main/java/io/github/dfa1/zstd/ZstdStreamBuffer.java +++ b/zstd/src/main/java/io/github/dfa1/zstd/ZstdStreamBuffer.java @@ -26,10 +26,11 @@ MemorySegment segment() { return struct; } - void set(MemorySegment buffer, long size, long pos) { + /// Points the buffer at `buffer` with the given size and a fresh position of 0. + void set(MemorySegment buffer, long size) { struct.set(ADDRESS, OFF_PTR, buffer); struct.set(JAVA_LONG, OFF_SIZE, size); - struct.set(JAVA_LONG, OFF_POS, pos); + struct.set(JAVA_LONG, OFF_POS, 0L); } long size() { diff --git a/zstd/src/test/java/io/github/dfa1/zstd/ZstdFrameTest.java b/zstd/src/test/java/io/github/dfa1/zstd/ZstdFrameTest.java index d3758be..0add02f 100644 --- a/zstd/src/test/java/io/github/dfa1/zstd/ZstdFrameTest.java +++ b/zstd/src/test/java/io/github/dfa1/zstd/ZstdFrameTest.java @@ -5,6 +5,8 @@ import org.junit.jupiter.api.Test; import java.io.ByteArrayOutputStream; +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; @@ -188,4 +190,36 @@ private ZstdDictionary trainDict() { return ZstdDictionary.train(samples, 8 * 1024); } } + + @Nested + class SegmentOverloads { + + @Test + void mirrorTheByteArrayOverloadsForANativeFrame() { + // Given a frame as both a byte[] and the same bytes in a native segment + byte[] frame = Zstd.compress(PAYLOAD); + try (Arena arena = Arena.ofConfined()) { + MemorySegment seg = Zstd.copyIn(arena, frame); + + // When inspected through the zero-copy MemorySegment overloads + // Then each agrees with its byte[] counterpart + assertThat(ZstdFrame.isZstdFrame(seg)).isEqualTo(ZstdFrame.isZstdFrame(frame)); + assertThat(ZstdFrame.compressedSize(seg)).isEqualTo(ZstdFrame.compressedSize(frame)); + assertThat(ZstdFrame.decompressedBound(seg)).isEqualTo(ZstdFrame.decompressedBound(frame)); + assertThat(ZstdFrame.dictId(seg)).isEqualTo(ZstdFrame.dictId(frame)); + } + } + + @Test + void recognisesASkippableNativeFrame() { + // Given a skippable frame in a native segment + byte[] frame = ZstdFrame.writeSkippableFrame("meta".getBytes(StandardCharsets.UTF_8), 5); + try (Arena arena = Arena.ofConfined()) { + MemorySegment seg = Zstd.copyIn(arena, frame); + + // When tested through the MemorySegment overload / Then it is skippable + assertThat(ZstdFrame.isSkippableFrame(seg)).isTrue(); + } + } + } }