Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- `DType.isUnsigned()` — `true` for the unsigned integer primitives (`U8`–`U64`), `false` otherwise. ([#159](https://github.com/dfa1/vortex-java/issues/159))
- The `vortex.zstd` encoder now writes nullable columns (primitive and utf8/binary): null positions are stripped before compression and validity is emitted as a Bool child, matching the Rust reference layout. When `vortex.zstd` is the configured encoder, nullable primitive columns route to it directly instead of being wrapped in `vortex.masked`.
- `new ZstdEncodingEncoder(valuesPerFrame)` splits the payload into independently compressed frames of `valuesPerFrame` values each (one `ZstdFrameMetadata` per frame), letting a slice scan decompress only the frames overlapping its row range. The no-arg constructor still emits a single frame. ([#170](https://github.com/dfa1/vortex-java/pull/170))

### Changed

Expand Down
7 changes: 0 additions & 7 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,3 @@ Per-encoding gotchas:

See [docs/compatibility.md](docs/compatibility.md) for the full encoding support table and S3 fixture status.

### `vortex.zstd` known limitations

- [ ] **Multi-frame encode** — `ZstdEncoding.Encoder` always produces a single frame for the whole array.
Fix: accept a `valuesPerFrame` parameter (default: all values in one frame). Split the raw byte buffer at frame
boundaries (`valuesPerFrame * byteWidth`), compress each slice independently, emit one `ZstdFrameMetadata` per frame.
Enables partial decompression during slice scans.

Original file line number Diff line number Diff line change
Expand Up @@ -1209,6 +1209,25 @@ void javaWriter_rustReader_zstd_utf8(@TempDir Path tmp) throws IOException {
assertThat(decoded).containsExactly(data);
}

@Test
void javaWriter_rustReader_zstd_multiFrameI64(@TempDir Path tmp) throws IOException {
// Given — ZstdEncoding split into frames of 3 values: 7 values -> 3 frames (3, 3, 1), each
// an independently compressed zstd frame with its own ZstdFrameMetadata. Verifies the
// multi-frame wire layout against the Rust reader.
Path file = tmp.resolve("java_zstd_multiframe_i64.vtx");
long[] data = {1L, 2L, 3L, 4L, 5L, 6L, 7L};
try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE);
var sut = VortexWriter.create(ch, TS_SCHEMA, WriteOptions.defaults(),
List.of(new ZstdEncodingEncoder(3)))) {
// When
sut.writeChunk(Map.of("ts", data));
}

// Then
long[] decoded = readLongColumn(file, "ts");
assertThat(decoded).containsExactly(data);
}

@Test
void javaWriter_rustReader_zstd_nullableI64(@TempDir Path tmp) throws IOException {
// Given — nullable primitive I64 written with ZstdEncoding. A configured zstd encoder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,28 @@
import java.util.List;

/// Write-only encoder for `vortex.zstd`.
///
/// By default the whole array compresses into a single zstd frame. Construct with a positive
/// `valuesPerFrame` to split the payload into independently compressed frames of that many values
/// each (the last frame holds the remainder), emitting one `ZstdFrameMetadata` per frame. Multiple
/// frames let a slice scan decompress only the frames overlapping its row range.
public final class ZstdEncodingEncoder implements EncodingEncoder {

/// Public no-arg constructor required by [java.util.ServiceLoader].
/// Values per zstd frame; `0` (or any non-positive value) means a single frame for the whole array.
private final long valuesPerFrame;

/// Public no-arg constructor required by [java.util.ServiceLoader]; compresses each array into
/// a single frame.
public ZstdEncodingEncoder() {
this(0);
}

/// Creates an encoder that splits the payload into frames of `valuesPerFrame` values each.
///
/// @param valuesPerFrame the number of values per zstd frame; non-positive means a single frame
/// for the whole array
public ZstdEncodingEncoder(long valuesPerFrame) {
this.valuesPerFrame = valuesPerFrame;
}

@Override
Expand Down Expand Up @@ -66,78 +84,142 @@ public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) {
throw new VortexException(EncodingId.VORTEX_ZSTD, "unsupported dtype: " + dtype);
}

private static EncodeResult encodePrimitive(DType.Primitive dt, Object data, Arena arena) {
private EncodeResult encodePrimitive(DType.Primitive dt, Object data, Arena arena) {
int byteWidth = dt.ptype().byteSize();
MemorySegment raw = primitiveToLeBytes(dt.ptype(), data, arena);
long n = primitiveLength(dt.ptype(), data);
return buildResult(raw, n, arena);
return buildResult(raw, uniformLayout(n, byteWidth), arena);
}

private static EncodeResult encodeVarBin(String[] strings, Arena arena) {
private EncodeResult encodeVarBin(String[] strings, Arena arena) {
MemorySegment raw = buildLengthPrefixed(strings, arena);
return buildResult(raw, strings.length, arena);
return buildResult(raw, varBinLayout(raw, strings.length), arena);
}

private static EncodeResult buildResult(MemorySegment raw, long n, Arena arena) {
// Zero-copy: compress the arena-native raw segment straight into another arena segment,
// no heap byte[] bounce on either side. The compressed slice is owned by the caller arena.
MemorySegment compressed;
try (ZstdCompressCtx cctx = new ZstdCompressCtx()) {
compressed = cctx.compress(arena, raw);
}
byte[] meta = new ProtoZstdMetadata(
0,
List.of(new ProtoZstdFrameMetadata(raw.byteSize(), n))
).encode();
EncodeNode root = new EncodeNode(EncodingId.VORTEX_ZSTD, MemorySegment.ofArray(meta),
new EncodeNode[0], new int[]{0});
return new EncodeResult(root, List.of(compressed), null, null);
private EncodeResult buildResult(MemorySegment raw, FrameLayout layout, Arena arena) {
// Zero-copy: each frame is an arena-native slice of raw, compressed straight into another
// arena segment. A single-value-per-array config yields one frame (the prior behaviour).
Frames frames = compressFrames(raw, layout, arena);
EncodeNode root = new EncodeNode(EncodingId.VORTEX_ZSTD, MemorySegment.ofArray(frames.metadata()),
new EncodeNode[0], frameBufferIndices(frames.compressed().size(), 0));
return new EncodeResult(root, List.copyOf(frames.compressed()), null, null);
}

private static EncodeResult encodeNullablePrimitive(DType.Primitive dt, NullableData nd, EncodeContext ctx) {
private EncodeResult encodeNullablePrimitive(DType.Primitive dt, NullableData nd, EncodeContext ctx) {
Arena arena = ctx.arena();
int byteWidth = dt.ptype().byteSize();
boolean[] validity = nd.validity();
// Strip null positions: only valid values reach the compressed payload (mirrors the Rust
// reference). The decoder scatters them back over the validity mask carried by child[0].
MemorySegment full = primitiveToLeBytes(dt.ptype(), nd.values(), arena);
MemorySegment packed = packValidBytes(full, validity, byteWidth, arena);
return buildNullableResult(packed, countValid(validity), validity, ctx);
return buildNullableResult(packed, uniformLayout(countValid(validity), byteWidth), validity, ctx);
}

private static EncodeResult encodeNullableVarBin(NullableData nd, EncodeContext ctx) {
private EncodeResult encodeNullableVarBin(NullableData nd, EncodeContext ctx) {
// Strip null positions: only valid strings reach the compressed payload (mirrors the Rust
// reference). The decoder scatters them back over the validity mask carried by child[0].
String[] valid = stripNulls((String[]) nd.values());
MemorySegment packed = buildLengthPrefixed(valid, ctx.arena());
return buildNullableResult(packed, valid.length, nd.validity(), ctx);
return buildNullableResult(packed, varBinLayout(packed, valid.length), nd.validity(), ctx);
}

private static EncodeResult buildNullableResult(
MemorySegment raw, long nValues, boolean[] validity, EncodeContext ctx) {
// Zero-copy: compress the arena-native packed segment into another arena segment.
MemorySegment compressed;
try (ZstdCompressCtx cctx = new ZstdCompressCtx()) {
compressed = cctx.compress(ctx.arena(), raw);
}
byte[] meta = new ProtoZstdMetadata(
0,
List.of(new ProtoZstdFrameMetadata(raw.byteSize(), nValues))
).encode();
private EncodeResult buildNullableResult(
MemorySegment raw, FrameLayout layout, boolean[] validity, EncodeContext ctx) {
Frames frames = compressFrames(raw, layout, ctx.arena());
int frameCount = frames.compressed().size();

EncodeResult validityResult = new BoolEncodingEncoder().encode(DType.BOOL, validity, ctx);
// The frame payload owns buffer[0]; the validity child's buffers follow, so shift its
// buffer indices by one.
EncodeNode validityNode = EncodeNode.remapBufferIndices(validityResult.rootNode(), 1);
// The frame payloads own buffer[0..frameCount-1]; the validity child's buffers follow, so
// shift its buffer indices past them.
EncodeNode validityNode = EncodeNode.remapBufferIndices(validityResult.rootNode(), frameCount);

List<MemorySegment> buffers = new ArrayList<>(1 + validityResult.buffers().size());
buffers.add(compressed);
List<MemorySegment> buffers = new ArrayList<>(frameCount + validityResult.buffers().size());
buffers.addAll(frames.compressed());
buffers.addAll(validityResult.buffers());

EncodeNode root = new EncodeNode(EncodingId.VORTEX_ZSTD, MemorySegment.ofArray(meta),
new EncodeNode[]{validityNode}, new int[]{0});
EncodeNode root = new EncodeNode(EncodingId.VORTEX_ZSTD, MemorySegment.ofArray(frames.metadata()),
new EncodeNode[]{validityNode}, frameBufferIndices(frameCount, 0));
return new EncodeResult(root, buffers, null, null);
}

/// Byte spans and value counts of each frame; spans sum to the payload size.
private record FrameLayout(long[] byteLengths, long[] valueCounts) {
}

/// Compressed frame payloads paired with the encoded `ZstdMetadata` describing them.
private record Frames(List<MemorySegment> compressed, byte[] metadata) {
}

private static int[] frameBufferIndices(int frameCount, int base) {
int[] indices = new int[frameCount];
for (int i = 0; i < frameCount; i++) {
indices[i] = base + i;
}
return indices;
}

/// Frame layout for `n` fixed-width values (`byteWidth` bytes each): `valuesPerFrame` values
/// per frame, the last frame holding the remainder. One frame when framing is disabled.
private FrameLayout uniformLayout(long n, int byteWidth) {
if (valuesPerFrame <= 0 || n <= valuesPerFrame) {
return new FrameLayout(new long[]{n * byteWidth}, new long[]{n});
}
int frameCount = (int) ((n + valuesPerFrame - 1) / valuesPerFrame);
long[] byteLengths = new long[frameCount];
long[] valueCounts = new long[frameCount];
long remaining = n;
for (int f = 0; f < frameCount; f++) {
long count = Math.min(valuesPerFrame, remaining);
valueCounts[f] = count;
byteLengths[f] = count * byteWidth;
remaining -= count;
}
return new FrameLayout(byteLengths, valueCounts);
}

/// Frame layout for a length-prefixed varbin payload: `valuesPerFrame` values per frame, with
/// each frame's byte span found by walking the 4-byte length prefixes to a value boundary.
private FrameLayout varBinLayout(MemorySegment raw, long nValues) {
if (valuesPerFrame <= 0 || nValues <= valuesPerFrame) {
return new FrameLayout(new long[]{raw.byteSize()}, new long[]{nValues});
}
int frameCount = (int) ((nValues + valuesPerFrame - 1) / valuesPerFrame);
long[] byteLengths = new long[frameCount];
long[] valueCounts = new long[frameCount];
long pos = 0;
long valueIdx = 0;
for (int f = 0; f < frameCount; f++) {
long count = Math.min(valuesPerFrame, nValues - valueIdx);
long start = pos;
for (long k = 0; k < count; k++) {
int len = raw.get(PTypeIO.LE_INT, pos);
pos += 4L + len;
}
byteLengths[f] = pos - start;
valueCounts[f] = count;
valueIdx += count;
}
return new FrameLayout(byteLengths, valueCounts);
}

private static Frames compressFrames(MemorySegment raw, FrameLayout layout, Arena arena) {
int frameCount = layout.byteLengths().length;
List<MemorySegment> compressed = new ArrayList<>(frameCount);
List<ProtoZstdFrameMetadata> metas = new ArrayList<>(frameCount);
long offset = 0;
try (ZstdCompressCtx cctx = new ZstdCompressCtx()) {
for (int f = 0; f < frameCount; f++) {
long len = layout.byteLengths()[f];
compressed.add(cctx.compress(arena, raw.asSlice(offset, len)));
metas.add(new ProtoZstdFrameMetadata(len, layout.valueCounts()[f]));
offset += len;
}
}
byte[] metadata = new ProtoZstdMetadata(0, List.copyOf(metas)).encode();
return new Frames(compressed, metadata);
}

private static MemorySegment packValidBytes(
MemorySegment full, boolean[] validity, int byteWidth, Arena arena) {
long validBytes = (long) countValid(validity) * byteWidth;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -427,4 +427,86 @@ void encode_i32_metadata_framesCount_isNonZero() throws Exception {
assertThat(meta.frames()).isNotEmpty();
}
}

@Nested
class MultiFrame {

private static final ZstdEncodingEncoder FRAMED = new ZstdEncodingEncoder(4);

@Test
void encode_i32_splitsIntoFrames_andRoundTrips() throws Exception {
// Given — 10 values, 4 per frame: 3 frames (4, 4, 2), one compressed buffer each.
int[] data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};

// When
EncodeResult result = FRAMED.encode(DTypes.I32, data, EncodeTestHelper.testCtx());

// Then
var metaSeg = result.rootNode().metadata();
ProtoZstdMetadata meta = ProtoZstdMetadata.decode(metaSeg, 0, metaSeg.byteSize());
assertThat(meta.frames()).hasSize(3);
assertThat(meta.frames().get(0).n_values()).isEqualTo(4);
assertThat(meta.frames().get(2).n_values()).isEqualTo(2);
assertThat(result.buffers()).hasSize(3);

DecodeContext ctx = DecodeTestHelper.toDecodeContext(result, data.length, DTypes.I32, ReadRegistry.empty());
IntArray decoded = (IntArray) DECODER.decode(ctx);
for (int i = 0; i < data.length; i++) {
assertThat(decoded.getInt(i)).as("index %d", i).isEqualTo(data[i]);
}
}

@Test
void encode_varBin_splitsOnValueBoundaries_andRoundTrips() throws Exception {
// Given — 5 strings, 2 per frame: 3 frames (2, 2, 1). Entries vary in length, so the
// frame byte spans must be found by walking the length prefixes, not a fixed stride.
ZstdEncodingEncoder framedByTwo = new ZstdEncodingEncoder(2);
String[] data = {"a", "bb", "ccc", "d", "eeeee"};

// When
EncodeResult result = framedByTwo.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx());

// Then
var metaSeg = result.rootNode().metadata();
ProtoZstdMetadata meta = ProtoZstdMetadata.decode(metaSeg, 0, metaSeg.byteSize());
assertThat(meta.frames()).hasSize(3);

DecodeContext ctx = DecodeTestHelper.toDecodeContext(result, data.length, DTypes.UTF8, ReadRegistry.empty());
VarBinArray decoded = (VarBinArray) DECODER.decode(ctx);
for (int i = 0; i < data.length; i++) {
assertThat(decoded.getString(i)).as("index %d", i).isEqualTo(data[i]);
}
}

@Test
void encode_nullablePrimitive_framesOverValidValues_andRoundTrips() throws Exception {
// Given — 7 rows, 5 valid. Frames cover only the packed valid values (4 + 1), and the
// validity child's buffers must trail the two frame buffers.
int[] storage = {10, 0, 20, 30, 0, 40, 50};
boolean[] validity = {true, false, true, true, false, true, true};
DType i32Nullable = new DType.Primitive(PType.I32, true);
NullableData data = new NullableData(storage, validity);

// When
EncodeResult result = FRAMED.encode(i32Nullable, data, EncodeTestHelper.testCtx());

// Then
var metaSeg = result.rootNode().metadata();
ProtoZstdMetadata meta = ProtoZstdMetadata.decode(metaSeg, 0, metaSeg.byteSize());
assertThat(meta.frames()).hasSize(2);
assertThat(meta.frames().get(0).n_values()).isEqualTo(4);
assertThat(meta.frames().get(1).n_values()).isEqualTo(1);

DecodeContext ctx = DecodeTestHelper.toDecodeContext(
result, validity.length, i32Nullable, TestRegistry.ofDecoders(new BoolEncodingDecoder()));
MaskedArray decoded = (MaskedArray) DECODER.decode(ctx);
assertThat(decoded.length()).isEqualTo(7);
assertThat(decoded.isValid(1)).isFalse();
assertThat(decoded.isValid(4)).isFalse();
IntArray child = (IntArray) decoded.inner();
assertThat(child.getInt(0)).isEqualTo(10);
assertThat(child.getInt(2)).isEqualTo(20);
assertThat(child.getInt(6)).isEqualTo(50);
}
}
}
Loading