Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions zstd/src/main/java/io/github/dfa1/zstd/Bindings.java
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,11 @@ final class Bindings {
NativeLibrary.lookup("ZSTD_getDictID_fromDDict", FunctionDescriptor.of(JAVA_INT, ADDRESS));

// ZSTD_bounds { size_t error; int lowerBound; int upperBound; } — returned by value
private static final MemoryLayout BOUNDS_LAYOUT =
MemoryLayout.structLayout(JAVA_LONG, JAVA_INT, JAVA_INT);
static final MemoryLayout BOUNDS_LAYOUT =
MemoryLayout.structLayout(
JAVA_LONG.withName("error"),
JAVA_INT.withName("lowerBound"),
JAVA_INT.withName("upperBound"));
// ZSTD_bounds ZSTD_cParam_getBounds(ZSTD_cParameter) / ZSTD_dParam_getBounds(ZSTD_dParameter)
static final MethodHandle CPARAM_GET_BOUNDS =
NativeLibrary.lookup("ZSTD_cParam_getBounds", FunctionDescriptor.of(BOUNDS_LAYOUT, JAVA_INT));
Expand Down
81 changes: 81 additions & 0 deletions zstd/src/main/java/io/github/dfa1/zstd/NativeCall.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package io.github.dfa1.zstd;

import java.lang.foreign.MemorySegment;
import java.nio.charset.StandardCharsets;

/// Package-private helpers that adapt raw FFM downcalls to the zstd error
/// convention: run a native call, decode a zstd `size_t` error code into a
/// {@link ZstdException}, and guard native-segment arguments. Shared by the
/// binding classes so the conventions live in one place.
final class NativeCall {

/// A native call returning a zstd `size_t` status that may encode an error.
@FunctionalInterface
interface ZstdCall {
long run() throws Throwable;
}

/// Invokes a size-returning zstd call and converts a zstd error code into a
/// {@link ZstdException}.
static long checkReturnValue(ZstdCall c) {
long code;
try {
code = c.run();
} catch (Throwable t) {
throw rethrow(t);
}
if (isError(code)) {
throw new ZstdException(errorName(code), ZstdErrorCode.of(errorCode(code)));
}
return code;
}

static boolean isError(long code) {
try {
return ((int) Bindings.IS_ERROR.invokeExact(code)) != 0;
} catch (Throwable t) {
throw rethrow(t);
}
}

private static int errorCode(long code) {
try {
return (int) Bindings.GET_ERROR_CODE.invokeExact(code);
} catch (Throwable t) {
throw rethrow(t);
}
}

@SuppressWarnings("restricted") // reinterpret needed to read a C string of unknown length
private static String errorName(long code) {
try {
MemorySegment p = (MemorySegment) Bindings.GET_ERROR_NAME.invokeExact(code);
return p.reinterpret(Long.MAX_VALUE).getString(0, StandardCharsets.US_ASCII);
} catch (Throwable t) {
throw rethrow(t);
}
}

/// Guards a zero-copy entry point: the segment handed to zstd must be backed
/// by native (off-heap) memory, since its address is dereferenced in C. Fails
/// fast with a clear message instead of the FFM linker's cryptic error.
static MemorySegment requireNative(MemorySegment seg, String name) {
if (!seg.isNative()) {
throw new IllegalArgumentException(
name + " must be a native (off-heap) MemorySegment; got a heap segment");
}
return seg;
}

/// Rethrows any `Throwable` as if unchecked, laundering the checked
/// `Throwable` that {@link java.lang.invoke.MethodHandle#invokeExact} declares.
/// The shared sink for every binding class's native-call catch blocks.
@SuppressWarnings("unchecked")
static <E extends Throwable> RuntimeException rethrow(Throwable t) throws E {
throw (E) t;
}

private NativeCall() {
// no instances
}
}
92 changes: 15 additions & 77 deletions zstd/src/main/java/io/github/dfa1/zstd/Zstd.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public static byte[] compress(byte[] src, int level) {
MemorySegment in = copyIn(arena, src);
long bound = compressBound(src.length);
MemorySegment out = arena.allocate(bound);
long written = call(() -> (long) Bindings.COMPRESS.invokeExact(
long written = NativeCall.checkReturnValue(() -> (long) Bindings.COMPRESS.invokeExact(
out, bound, in, (long) src.length, level));
return copyOut(out, written);
}
Expand Down Expand Up @@ -77,7 +77,7 @@ public static byte[] decompress(byte[] compressed, int maxSize) {
try (Arena arena = Arena.ofConfined()) {
MemorySegment in = copyIn(arena, compressed);
MemorySegment out = arena.allocate(Math.max(maxSize, 1));
long written = call(() -> (long) Bindings.DECOMPRESS.invokeExact(
long written = NativeCall.checkReturnValue(() -> (long) Bindings.DECOMPRESS.invokeExact(
out, (long) maxSize, in, (long) compressed.length));
return copyOut(out, written);
}
Expand All @@ -91,12 +91,12 @@ public static byte[] decompress(byte[] compressed, int maxSize) {
/// @return the decompressed length in bytes
/// @throws ZstdException if the frame is invalid or does not store its size
public static long decompressedSize(MemorySegment frame) {
requireNative(frame, "frame");
NativeCall.requireNative(frame, "frame");
long size;
try {
size = (long) Bindings.GET_FRAME_CONTENT_SIZE.invokeExact(frame, frame.byteSize());
} catch (Throwable t) {
throw sneaky(t);
throw NativeCall.rethrow(t);
}
if (size == CONTENTSIZE_UNKNOWN) {
throw new ZstdException("decompressed size not stored in frame");
Expand All @@ -116,7 +116,7 @@ public static long compressBound(long srcSize) {
try {
return (long) Bindings.COMPRESS_BOUND.invokeExact(srcSize);
} catch (Throwable t) {
throw sneaky(t);
throw NativeCall.rethrow(t);
}
}

Expand All @@ -126,7 +126,7 @@ private static long frameContentSize(byte[] compressed) {
MemorySegment in = copyIn(arena, compressed);
return (long) Bindings.GET_FRAME_CONTENT_SIZE.invokeExact(in, (long) compressed.length);
} catch (Throwable t) {
throw sneaky(t);
throw NativeCall.rethrow(t);
}
}

Expand All @@ -137,7 +137,7 @@ public static int maxCompressionLevel() {
try {
return (int) Bindings.MAX_C_LEVEL.invokeExact();
} catch (Throwable t) {
throw sneaky(t);
throw NativeCall.rethrow(t);
}
}

Expand All @@ -148,7 +148,7 @@ public static int minCompressionLevel() {
try {
return (int) Bindings.MIN_C_LEVEL.invokeExact();
} catch (Throwable t) {
throw sneaky(t);
throw NativeCall.rethrow(t);
}
}

Expand All @@ -159,7 +159,7 @@ public static int defaultCompressionLevel() {
try {
return (int) Bindings.DEFAULT_C_LEVEL.invokeExact();
} catch (Throwable t) {
throw sneaky(t);
throw NativeCall.rethrow(t);
}
}

Expand All @@ -172,7 +172,7 @@ public static long estimateCompressContextSize(int level) {
try {
return (long) Bindings.ESTIMATE_CCTX_SIZE.invokeExact(level);
} catch (Throwable t) {
throw sneaky(t);
throw NativeCall.rethrow(t);
}
}

Expand All @@ -183,7 +183,7 @@ public static long estimateDecompressContextSize() {
try {
return (long) Bindings.ESTIMATE_DCTX_SIZE.invokeExact();
} catch (Throwable t) {
throw sneaky(t);
throw NativeCall.rethrow(t);
}
}

Expand All @@ -197,7 +197,7 @@ public static long estimateCompressDictSize(long dictSize, int level) {
try {
return (long) Bindings.ESTIMATE_CDICT_SIZE.invokeExact(dictSize, level);
} catch (Throwable t) {
throw sneaky(t);
throw NativeCall.rethrow(t);
}
}

Expand All @@ -210,7 +210,7 @@ public static long estimateDecompressDictSize(long dictSize) {
try {
return (long) Bindings.ESTIMATE_DDICT_SIZE.invokeExact(dictSize, 0); // ZSTD_dlm_byCopy
} catch (Throwable t) {
throw sneaky(t);
throw NativeCall.rethrow(t);
}
}

Expand All @@ -223,69 +223,12 @@ public static String version() {
MemorySegment p = (MemorySegment) Bindings.VERSION_STRING.invokeExact();
return p.reinterpret(Long.MAX_VALUE).getString(0, StandardCharsets.US_ASCII);
} catch (Throwable t) {
throw sneaky(t);
throw NativeCall.rethrow(t);
}
}

// --- package-private helpers shared with the context classes ---

/// A native call returning a zstd `size_t` status that may encode an error.
@FunctionalInterface
interface SizeCall {
long run() throws Throwable;
}

/// Invokes a size-returning zstd call and converts a zstd error code into a
/// {@link ZstdException}.
static long call(SizeCall c) {
long code;
try {
code = c.run();
} catch (Throwable t) {
throw sneaky(t);
}
if (isError(code)) {
throw new ZstdException(errorName(code), ZstdErrorCode.of(errorCode(code)));
}
return code;
}

static boolean isError(long code) {
try {
return ((int) Bindings.IS_ERROR.invokeExact(code)) != 0;
} catch (Throwable t) {
throw sneaky(t);
}
}

private static int errorCode(long code) {
try {
return (int) Bindings.GET_ERROR_CODE.invokeExact(code);
} catch (Throwable t) {
throw sneaky(t);
}
}

@SuppressWarnings("restricted") // reinterpret needed to read a C string of unknown length
private static String errorName(long code) {
try {
MemorySegment p = (MemorySegment) Bindings.GET_ERROR_NAME.invokeExact(code);
return p.reinterpret(Long.MAX_VALUE).getString(0, StandardCharsets.US_ASCII);
} catch (Throwable t) {
throw sneaky(t);
}
}

/// Guards a zero-copy entry point: the segment handed to zstd must be backed
/// by native (off-heap) memory, since its address is dereferenced in C. Fails
/// fast with a clear message instead of the FFM linker's cryptic error.
static MemorySegment requireNative(MemorySegment seg, String name) {
if (!seg.isNative()) {
throw new IllegalArgumentException(
name + " must be a native (off-heap) MemorySegment; got a heap segment");
}
return seg;
}
// Native-call status checking and segment guards live in NativeCall.

static MemorySegment copyIn(Arena arena, byte[] src) {
MemorySegment seg = arena.allocate(Math.max(src.length, 1));
Expand All @@ -299,11 +242,6 @@ static byte[] copyOut(MemorySegment seg, long len) {
return out;
}

@SuppressWarnings("unchecked")
private static <E extends Throwable> RuntimeException sneaky(Throwable t) throws E {
throw (E) t;
}

private Zstd() {
// no instances
}
Expand Down
26 changes: 17 additions & 9 deletions zstd/src/main/java/io/github/dfa1/zstd/ZstdBounds.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.github.dfa1.zstd;

import java.lang.foreign.Arena;
import java.lang.foreign.MemoryLayout.PathElement;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SegmentAllocator;
import java.lang.invoke.MethodHandle;
Expand All @@ -15,23 +16,30 @@
/// @param upperBound the largest accepted value, inclusive
public record ZstdBounds(int lowerBound, int upperBound) {

// Offsets into the returned ZSTD_bounds struct, derived from the named layout
// rather than hand-counted, so they track the struct definition.
private static final long ERROR_OFFSET = Bindings.BOUNDS_LAYOUT.byteOffset(PathElement.groupElement("error"));
private static final long LOWER_OFFSET = Bindings.BOUNDS_LAYOUT.byteOffset(PathElement.groupElement("lowerBound"));
private static final long UPPER_OFFSET = Bindings.BOUNDS_LAYOUT.byteOffset(PathElement.groupElement("upperBound"));

/// Calls a `*_getBounds` function (which returns a `ZSTD_bounds` struct by
/// value: `{ size_t error; int lowerBound; int upperBound; }`).
static ZstdBounds query(MethodHandle getBounds, int parameter) {
try (Arena arena = Arena.ofConfined()) {
// getBounds returns a ZSTD_bounds struct by value. For a struct return,
// the FFM linker prepends a SegmentAllocator parameter to the handle:
// it allocates BOUNDS_LAYOUT.byteSize() bytes from that allocator, the
// native call writes the struct there, and the handle returns a segment
// viewing it. Passing the arena makes the struct arena-owned (freed on
// close); the cast satisfies invokeExact's exact-type requirement.
MemorySegment bounds = (MemorySegment) getBounds.invokeExact((SegmentAllocator) arena, parameter);
long error = bounds.get(JAVA_LONG, 0);
if (Zstd.isError(error)) {
long error = bounds.get(JAVA_LONG, ERROR_OFFSET);
if (NativeCall.isError(error)) {
throw new ZstdException("parameter has no queryable bounds");
}
return new ZstdBounds(bounds.get(JAVA_INT, 8), bounds.get(JAVA_INT, 12));
return new ZstdBounds(bounds.get(JAVA_INT, LOWER_OFFSET), bounds.get(JAVA_INT, UPPER_OFFSET));
} catch (Throwable t) {
throw rethrow(t);
throw NativeCall.rethrow(t);
}
}

@SuppressWarnings("unchecked")
private static <E extends Throwable> RuntimeException rethrow(Throwable t) throws E {
throw (E) t;
}
}
Loading
Loading