Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ private void loadDictionary(ZstdDictionary dictionary) {
public ZstdStreamResult compress(MemorySegment dst, MemorySegment src, ZstdEndDirective directive) {
NativeCall.requireNative(dst, "dst");
NativeCall.requireNative(src, "src");
in.set(src, src.byteSize(), 0);
out.set(dst, dst.byteSize(), 0);
in.set(src, src.byteSize());
out.set(dst, dst.byteSize());
long remaining = NativeCall.checkReturnValue(() -> (long) Bindings.COMPRESS_STREAM2.invokeExact(
ptr(), out.segment(), in.segment(), directive.value()));
return new ZstdStreamResult(in.pos(), out.pos(), remaining);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ private void loadDictionary(ZstdDictionary dictionary) {
/// @return how much was consumed and produced, and the remaining hint
/// @throws ZstdException if the frame is invalid
public ZstdStreamResult decompress(MemorySegment dst, MemorySegment src) {
in.set(src, src.byteSize(), 0);
out.set(dst, dst.byteSize(), 0);
in.set(src, src.byteSize());
out.set(dst, dst.byteSize());
long remaining = NativeCall.checkReturnValue(() -> (long) Bindings.DECOMPRESS_STREAM.invokeExact(
ptr(), out.segment(), in.segment()));
return new ZstdStreamResult(in.pos(), out.pos(), remaining);
Expand Down
3 changes: 3 additions & 0 deletions zstd/src/main/java/io/github/dfa1/zstd/ZstdException.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package io.github.dfa1.zstd;

import java.io.Serial;

/// Thrown when a zstd native call reports an error.
///
/// Unchecked: zstd errors on valid use of this API indicate either corrupt
/// input or a programming error (e.g. an undersized destination buffer), not a
/// recoverable I/O condition.
public final class ZstdException extends RuntimeException {

@Serial
private static final long serialVersionUID = 1L;

/// The zstd error category for this failure.
Expand Down
7 changes: 3 additions & 4 deletions zstd/src/main/java/io/github/dfa1/zstd/ZstdInputStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ public final class ZstdInputStream extends InputStream {
private final MemorySegment dctx;
private final MemorySegment inSeg;
private final MemorySegment outSeg;
private final long inCap;
private final long outCap;
private final ZstdStreamBuffer inBuf = new ZstdStreamBuffer(arena);
private final ZstdStreamBuffer outBufView = new ZstdStreamBuffer(arena);
Expand Down Expand Up @@ -67,7 +66,7 @@ public ZstdInputStream(InputStream in, ZstdDictionary dictionary) {
if (dictionary != null) {
loadDictionary(dictionary);
}
this.inCap = (long) Bindings.DSTREAM_IN_SIZE.invokeExact();
long inCap = (long) Bindings.DSTREAM_IN_SIZE.invokeExact();
this.outCap = (long) Bindings.DSTREAM_OUT_SIZE.invokeExact();
this.inSeg = arena.allocate(inCap);
this.outSeg = arena.allocate(outCap);
Expand Down Expand Up @@ -140,9 +139,9 @@ private boolean produce() throws IOException {
return false;
}
MemorySegment.copy(feed, 0, inSeg, JAVA_BYTE, 0, r);
inBuf.set(inSeg, r, 0);
inBuf.set(inSeg, r);
}
outBufView.set(outSeg, outCap, 0);
outBufView.set(outSeg, outCap);
lastHint = NativeCall.checkReturnValue(() -> (long) Bindings.DECOMPRESS_STREAM.invokeExact(
dctx, outBufView.segment(), inBuf.segment()));
int produced = Math.toIntExact(outBufView.pos());
Expand Down
8 changes: 4 additions & 4 deletions zstd/src/main/java/io/github/dfa1/zstd/ZstdOutputStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ public void write(byte[] b, int off, int len) throws IOException {
while (remaining > 0) {
int chunk = (int) Math.min(remaining, inCap);
MemorySegment.copy(b, pos, inSeg, JAVA_BYTE, 0, chunk);
in.set(inSeg, chunk, 0);
in.set(inSeg, chunk);
do {
drainOutput(ZSTD_E_CONTINUE);
} while (in.pos() < chunk);
Expand All @@ -165,7 +165,7 @@ public void write(byte[] b, int off, int len) throws IOException {
@Override
public void flush() throws IOException {
ensureOpen();
in.set(inSeg, 0, 0);
in.set(inSeg, 0);
long remainingHint;
do {
remainingHint = drainOutput(ZSTD_E_FLUSH);
Expand All @@ -179,7 +179,7 @@ public void close() throws IOException {
return;
}
try {
in.set(inSeg, 0, 0);
in.set(inSeg, 0);
long remainingHint;
do {
remainingHint = drainOutput(ZSTD_E_END);
Expand All @@ -196,7 +196,7 @@ public void close() throws IOException {
/// Runs one compressStream2 call and writes whatever it produced to `out`.
/// Returns the zstd "remaining" hint (0 means the directive is fully flushed).
private long drainOutput(int directive) throws IOException {
outBuf.set(outSeg, outCap, 0);
outBuf.set(outSeg, outCap);
long remainingHint = NativeCall.checkReturnValue(() -> (long) Bindings.COMPRESS_STREAM2.invokeExact(
cctx, outBuf.segment(), in.segment(), directive));
int produced = Math.toIntExact(outBuf.pos());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ public record ZstdSkippableContent(byte[] content, int magicVariant) {
/// @return `true` if `o` is a [ZstdSkippableContent] with equal content bytes and variant
@Override
public boolean equals(Object o) {
return o instanceof ZstdSkippableContent other
&& magicVariant == other.magicVariant
&& Arrays.equals(content, other.content);
return o instanceof ZstdSkippableContent(byte[] otherContent, int otherVariant)
&& magicVariant == otherVariant
&& Arrays.equals(content, otherContent);
}

/// Hash code consistent with [#equals(Object)], derived from the content bytes
Expand All @@ -35,6 +35,7 @@ public int hashCode() {
///
/// @return a string with the content length and magic variant
@Override
@SuppressWarnings("NullableProblems") // toString never returns null; we just don't pull in JB @NotNull
public String toString() {
return "ZstdSkippableContent[content=" + content.length + " bytes, magicVariant=" + magicVariant + "]";
}
Expand Down
5 changes: 3 additions & 2 deletions zstd/src/main/java/io/github/dfa1/zstd/ZstdStreamBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ MemorySegment segment() {
return struct;
}

void set(MemorySegment buffer, long size, long pos) {
/// Points the buffer at `buffer` with the given size and a fresh position of 0.
void set(MemorySegment buffer, long size) {
struct.set(ADDRESS, OFF_PTR, buffer);
struct.set(JAVA_LONG, OFF_SIZE, size);
struct.set(JAVA_LONG, OFF_POS, pos);
struct.set(JAVA_LONG, OFF_POS, 0L);
}

long size() {
Expand Down
34 changes: 34 additions & 0 deletions zstd/src/test/java/io/github/dfa1/zstd/ZstdFrameTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import org.junit.jupiter.api.Test;

import java.io.ByteArrayOutputStream;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -188,4 +190,36 @@ private ZstdDictionary trainDict() {
return ZstdDictionary.train(samples, 8 * 1024);
}
}

@Nested
class SegmentOverloads {

@Test
void mirrorTheByteArrayOverloadsForANativeFrame() {
// Given a frame as both a byte[] and the same bytes in a native segment
byte[] frame = Zstd.compress(PAYLOAD);
try (Arena arena = Arena.ofConfined()) {
MemorySegment seg = Zstd.copyIn(arena, frame);

// When inspected through the zero-copy MemorySegment overloads
// Then each agrees with its byte[] counterpart
assertThat(ZstdFrame.isZstdFrame(seg)).isEqualTo(ZstdFrame.isZstdFrame(frame));
assertThat(ZstdFrame.compressedSize(seg)).isEqualTo(ZstdFrame.compressedSize(frame));
assertThat(ZstdFrame.decompressedBound(seg)).isEqualTo(ZstdFrame.decompressedBound(frame));
assertThat(ZstdFrame.dictId(seg)).isEqualTo(ZstdFrame.dictId(frame));
}
}

@Test
void recognisesASkippableNativeFrame() {
// Given a skippable frame in a native segment
byte[] frame = ZstdFrame.writeSkippableFrame("meta".getBytes(StandardCharsets.UTF_8), 5);
try (Arena arena = Arena.ofConfined()) {
MemorySegment seg = Zstd.copyIn(arena, frame);

// When tested through the MemorySegment overload / Then it is skippable
assertThat(ZstdFrame.isSkippableFrame(seg)).isTrue();
}
}
}
}
Loading