From f3bace20a90061baa9391f15e3271ffd45397e9a Mon Sep 17 00:00:00 2001 From: Davide Angelocola Date: Fri, 26 Jun 2026 20:11:30 +0200 Subject: [PATCH] feat(zstd): multi-frame encode Add ZstdEncodingEncoder(valuesPerFrame): split the payload into independently compressed zstd frames of valuesPerFrame values each (the last frame holds the remainder), emitting one ZstdFrameMetadata per frame so a slice scan can decompress only the frames overlapping its row range. The no-arg constructor keeps the single-frame behaviour. Frame boundaries fall on value boundaries: fixed stride for primitives, length-prefix walk for varbin. Works for the non-nullable and nullable (frames over packed valid values, validity child trailing the frame buffers) paths alike. The decoder already iterates frames, so this is encode-only. Co-Authored-By: Claude Opus 4.8 --- CHANGELOG.md | 1 + TODO.md | 7 - .../JavaWritesRustReadsIntegrationTest.java | 19 ++ .../writer/encode/ZstdEncodingEncoder.java | 164 +++++++++++++----- .../encode/ZstdEncodingEncoderTest.java | 82 +++++++++ 5 files changed, 225 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d5247bdd..5aa8b734 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `DType.isUnsigned()` — `true` for the unsigned integer primitives (`U8`–`U64`), `false` otherwise. ([#159](https://github.com/dfa1/vortex-java/issues/159)) - The `vortex.zstd` encoder now writes nullable columns (primitive and utf8/binary): null positions are stripped before compression and validity is emitted as a Bool child, matching the Rust reference layout. When `vortex.zstd` is the configured encoder, nullable primitive columns route to it directly instead of being wrapped in `vortex.masked`. +- `new ZstdEncodingEncoder(valuesPerFrame)` splits the payload into independently compressed frames of `valuesPerFrame` values each (one `ZstdFrameMetadata` per frame), letting a slice scan decompress only the frames overlapping its row range. The no-arg constructor still emits a single frame. ([#170](https://github.com/dfa1/vortex-java/pull/170)) ### Changed diff --git a/TODO.md b/TODO.md index edde250a..8895b0dc 100644 --- a/TODO.md +++ b/TODO.md @@ -101,10 +101,3 @@ Per-encoding gotchas: See [docs/compatibility.md](docs/compatibility.md) for the full encoding support table and S3 fixture status. -### `vortex.zstd` known limitations - -- [ ] **Multi-frame encode** — `ZstdEncoding.Encoder` always produces a single frame for the whole array. - Fix: accept a `valuesPerFrame` parameter (default: all values in one frame). Split the raw byte buffer at frame - boundaries (`valuesPerFrame * byteWidth`), compress each slice independently, emit one `ZstdFrameMetadata` per frame. - Enables partial decompression during slice scans. - diff --git a/integration/src/test/java/io/github/dfa1/vortex/integration/JavaWritesRustReadsIntegrationTest.java b/integration/src/test/java/io/github/dfa1/vortex/integration/JavaWritesRustReadsIntegrationTest.java index 44323a8e..88111461 100644 --- a/integration/src/test/java/io/github/dfa1/vortex/integration/JavaWritesRustReadsIntegrationTest.java +++ b/integration/src/test/java/io/github/dfa1/vortex/integration/JavaWritesRustReadsIntegrationTest.java @@ -1209,6 +1209,25 @@ void javaWriter_rustReader_zstd_utf8(@TempDir Path tmp) throws IOException { assertThat(decoded).containsExactly(data); } + @Test + void javaWriter_rustReader_zstd_multiFrameI64(@TempDir Path tmp) throws IOException { + // Given — ZstdEncoding split into frames of 3 values: 7 values -> 3 frames (3, 3, 1), each + // an independently compressed zstd frame with its own ZstdFrameMetadata. Verifies the + // multi-frame wire layout against the Rust reader. + Path file = tmp.resolve("java_zstd_multiframe_i64.vtx"); + long[] data = {1L, 2L, 3L, 4L, 5L, 6L, 7L}; + try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); + var sut = VortexWriter.create(ch, TS_SCHEMA, WriteOptions.defaults(), + List.of(new ZstdEncodingEncoder(3)))) { + // When + sut.writeChunk(Map.of("ts", data)); + } + + // Then + long[] decoded = readLongColumn(file, "ts"); + assertThat(decoded).containsExactly(data); + } + @Test void javaWriter_rustReader_zstd_nullableI64(@TempDir Path tmp) throws IOException { // Given — nullable primitive I64 written with ZstdEncoding. A configured zstd encoder diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoder.java index a2076f9a..b022ea8e 100644 --- a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoder.java +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoder.java @@ -17,10 +17,28 @@ import java.util.List; /// Write-only encoder for `vortex.zstd`. +/// +/// By default the whole array compresses into a single zstd frame. Construct with a positive +/// `valuesPerFrame` to split the payload into independently compressed frames of that many values +/// each (the last frame holds the remainder), emitting one `ZstdFrameMetadata` per frame. Multiple +/// frames let a slice scan decompress only the frames overlapping its row range. public final class ZstdEncodingEncoder implements EncodingEncoder { - /// Public no-arg constructor required by [java.util.ServiceLoader]. + /// Values per zstd frame; `0` (or any non-positive value) means a single frame for the whole array. + private final long valuesPerFrame; + + /// Public no-arg constructor required by [java.util.ServiceLoader]; compresses each array into + /// a single frame. public ZstdEncodingEncoder() { + this(0); + } + + /// Creates an encoder that splits the payload into frames of `valuesPerFrame` values each. + /// + /// @param valuesPerFrame the number of values per zstd frame; non-positive means a single frame + /// for the whole array + public ZstdEncodingEncoder(long valuesPerFrame) { + this.valuesPerFrame = valuesPerFrame; } @Override @@ -66,34 +84,28 @@ public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { throw new VortexException(EncodingId.VORTEX_ZSTD, "unsupported dtype: " + dtype); } - private static EncodeResult encodePrimitive(DType.Primitive dt, Object data, Arena arena) { + private EncodeResult encodePrimitive(DType.Primitive dt, Object data, Arena arena) { + int byteWidth = dt.ptype().byteSize(); MemorySegment raw = primitiveToLeBytes(dt.ptype(), data, arena); long n = primitiveLength(dt.ptype(), data); - return buildResult(raw, n, arena); + return buildResult(raw, uniformLayout(n, byteWidth), arena); } - private static EncodeResult encodeVarBin(String[] strings, Arena arena) { + private EncodeResult encodeVarBin(String[] strings, Arena arena) { MemorySegment raw = buildLengthPrefixed(strings, arena); - return buildResult(raw, strings.length, arena); + return buildResult(raw, varBinLayout(raw, strings.length), arena); } - private static EncodeResult buildResult(MemorySegment raw, long n, Arena arena) { - // Zero-copy: compress the arena-native raw segment straight into another arena segment, - // no heap byte[] bounce on either side. The compressed slice is owned by the caller arena. - MemorySegment compressed; - try (ZstdCompressCtx cctx = new ZstdCompressCtx()) { - compressed = cctx.compress(arena, raw); - } - byte[] meta = new ProtoZstdMetadata( - 0, - List.of(new ProtoZstdFrameMetadata(raw.byteSize(), n)) - ).encode(); - EncodeNode root = new EncodeNode(EncodingId.VORTEX_ZSTD, MemorySegment.ofArray(meta), - new EncodeNode[0], new int[]{0}); - return new EncodeResult(root, List.of(compressed), null, null); + private EncodeResult buildResult(MemorySegment raw, FrameLayout layout, Arena arena) { + // Zero-copy: each frame is an arena-native slice of raw, compressed straight into another + // arena segment. A single-value-per-array config yields one frame (the prior behaviour). + Frames frames = compressFrames(raw, layout, arena); + EncodeNode root = new EncodeNode(EncodingId.VORTEX_ZSTD, MemorySegment.ofArray(frames.metadata()), + new EncodeNode[0], frameBufferIndices(frames.compressed().size(), 0)); + return new EncodeResult(root, List.copyOf(frames.compressed()), null, null); } - private static EncodeResult encodeNullablePrimitive(DType.Primitive dt, NullableData nd, EncodeContext ctx) { + private EncodeResult encodeNullablePrimitive(DType.Primitive dt, NullableData nd, EncodeContext ctx) { Arena arena = ctx.arena(); int byteWidth = dt.ptype().byteSize(); boolean[] validity = nd.validity(); @@ -101,43 +113,113 @@ private static EncodeResult encodeNullablePrimitive(DType.Primitive dt, Nullable // reference). The decoder scatters them back over the validity mask carried by child[0]. MemorySegment full = primitiveToLeBytes(dt.ptype(), nd.values(), arena); MemorySegment packed = packValidBytes(full, validity, byteWidth, arena); - return buildNullableResult(packed, countValid(validity), validity, ctx); + return buildNullableResult(packed, uniformLayout(countValid(validity), byteWidth), validity, ctx); } - private static EncodeResult encodeNullableVarBin(NullableData nd, EncodeContext ctx) { + private EncodeResult encodeNullableVarBin(NullableData nd, EncodeContext ctx) { // Strip null positions: only valid strings reach the compressed payload (mirrors the Rust // reference). The decoder scatters them back over the validity mask carried by child[0]. String[] valid = stripNulls((String[]) nd.values()); MemorySegment packed = buildLengthPrefixed(valid, ctx.arena()); - return buildNullableResult(packed, valid.length, nd.validity(), ctx); + return buildNullableResult(packed, varBinLayout(packed, valid.length), nd.validity(), ctx); } - private static EncodeResult buildNullableResult( - MemorySegment raw, long nValues, boolean[] validity, EncodeContext ctx) { - // Zero-copy: compress the arena-native packed segment into another arena segment. - MemorySegment compressed; - try (ZstdCompressCtx cctx = new ZstdCompressCtx()) { - compressed = cctx.compress(ctx.arena(), raw); - } - byte[] meta = new ProtoZstdMetadata( - 0, - List.of(new ProtoZstdFrameMetadata(raw.byteSize(), nValues)) - ).encode(); + private EncodeResult buildNullableResult( + MemorySegment raw, FrameLayout layout, boolean[] validity, EncodeContext ctx) { + Frames frames = compressFrames(raw, layout, ctx.arena()); + int frameCount = frames.compressed().size(); EncodeResult validityResult = new BoolEncodingEncoder().encode(DType.BOOL, validity, ctx); - // The frame payload owns buffer[0]; the validity child's buffers follow, so shift its - // buffer indices by one. - EncodeNode validityNode = EncodeNode.remapBufferIndices(validityResult.rootNode(), 1); + // The frame payloads own buffer[0..frameCount-1]; the validity child's buffers follow, so + // shift its buffer indices past them. + EncodeNode validityNode = EncodeNode.remapBufferIndices(validityResult.rootNode(), frameCount); - List buffers = new ArrayList<>(1 + validityResult.buffers().size()); - buffers.add(compressed); + List buffers = new ArrayList<>(frameCount + validityResult.buffers().size()); + buffers.addAll(frames.compressed()); buffers.addAll(validityResult.buffers()); - EncodeNode root = new EncodeNode(EncodingId.VORTEX_ZSTD, MemorySegment.ofArray(meta), - new EncodeNode[]{validityNode}, new int[]{0}); + EncodeNode root = new EncodeNode(EncodingId.VORTEX_ZSTD, MemorySegment.ofArray(frames.metadata()), + new EncodeNode[]{validityNode}, frameBufferIndices(frameCount, 0)); return new EncodeResult(root, buffers, null, null); } + /// Byte spans and value counts of each frame; spans sum to the payload size. + private record FrameLayout(long[] byteLengths, long[] valueCounts) { + } + + /// Compressed frame payloads paired with the encoded `ZstdMetadata` describing them. + private record Frames(List compressed, byte[] metadata) { + } + + private static int[] frameBufferIndices(int frameCount, int base) { + int[] indices = new int[frameCount]; + for (int i = 0; i < frameCount; i++) { + indices[i] = base + i; + } + return indices; + } + + /// Frame layout for `n` fixed-width values (`byteWidth` bytes each): `valuesPerFrame` values + /// per frame, the last frame holding the remainder. One frame when framing is disabled. + private FrameLayout uniformLayout(long n, int byteWidth) { + if (valuesPerFrame <= 0 || n <= valuesPerFrame) { + return new FrameLayout(new long[]{n * byteWidth}, new long[]{n}); + } + int frameCount = (int) ((n + valuesPerFrame - 1) / valuesPerFrame); + long[] byteLengths = new long[frameCount]; + long[] valueCounts = new long[frameCount]; + long remaining = n; + for (int f = 0; f < frameCount; f++) { + long count = Math.min(valuesPerFrame, remaining); + valueCounts[f] = count; + byteLengths[f] = count * byteWidth; + remaining -= count; + } + return new FrameLayout(byteLengths, valueCounts); + } + + /// Frame layout for a length-prefixed varbin payload: `valuesPerFrame` values per frame, with + /// each frame's byte span found by walking the 4-byte length prefixes to a value boundary. + private FrameLayout varBinLayout(MemorySegment raw, long nValues) { + if (valuesPerFrame <= 0 || nValues <= valuesPerFrame) { + return new FrameLayout(new long[]{raw.byteSize()}, new long[]{nValues}); + } + int frameCount = (int) ((nValues + valuesPerFrame - 1) / valuesPerFrame); + long[] byteLengths = new long[frameCount]; + long[] valueCounts = new long[frameCount]; + long pos = 0; + long valueIdx = 0; + for (int f = 0; f < frameCount; f++) { + long count = Math.min(valuesPerFrame, nValues - valueIdx); + long start = pos; + for (long k = 0; k < count; k++) { + int len = raw.get(PTypeIO.LE_INT, pos); + pos += 4L + len; + } + byteLengths[f] = pos - start; + valueCounts[f] = count; + valueIdx += count; + } + return new FrameLayout(byteLengths, valueCounts); + } + + private static Frames compressFrames(MemorySegment raw, FrameLayout layout, Arena arena) { + int frameCount = layout.byteLengths().length; + List compressed = new ArrayList<>(frameCount); + List metas = new ArrayList<>(frameCount); + long offset = 0; + try (ZstdCompressCtx cctx = new ZstdCompressCtx()) { + for (int f = 0; f < frameCount; f++) { + long len = layout.byteLengths()[f]; + compressed.add(cctx.compress(arena, raw.asSlice(offset, len))); + metas.add(new ProtoZstdFrameMetadata(len, layout.valueCounts()[f])); + offset += len; + } + } + byte[] metadata = new ProtoZstdMetadata(0, List.copyOf(metas)).encode(); + return new Frames(compressed, metadata); + } + private static MemorySegment packValidBytes( MemorySegment full, boolean[] validity, int byteWidth, Arena arena) { long validBytes = (long) countValid(validity) * byteWidth; diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoderTest.java index c89d8d9a..f57feeae 100644 --- a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoderTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoderTest.java @@ -427,4 +427,86 @@ void encode_i32_metadata_framesCount_isNonZero() throws Exception { assertThat(meta.frames()).isNotEmpty(); } } + + @Nested + class MultiFrame { + + private static final ZstdEncodingEncoder FRAMED = new ZstdEncodingEncoder(4); + + @Test + void encode_i32_splitsIntoFrames_andRoundTrips() throws Exception { + // Given — 10 values, 4 per frame: 3 frames (4, 4, 2), one compressed buffer each. + int[] data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + + // When + EncodeResult result = FRAMED.encode(DTypes.I32, data, EncodeTestHelper.testCtx()); + + // Then + var metaSeg = result.rootNode().metadata(); + ProtoZstdMetadata meta = ProtoZstdMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + assertThat(meta.frames()).hasSize(3); + assertThat(meta.frames().get(0).n_values()).isEqualTo(4); + assertThat(meta.frames().get(2).n_values()).isEqualTo(2); + assertThat(result.buffers()).hasSize(3); + + DecodeContext ctx = DecodeTestHelper.toDecodeContext(result, data.length, DTypes.I32, ReadRegistry.empty()); + IntArray decoded = (IntArray) DECODER.decode(ctx); + for (int i = 0; i < data.length; i++) { + assertThat(decoded.getInt(i)).as("index %d", i).isEqualTo(data[i]); + } + } + + @Test + void encode_varBin_splitsOnValueBoundaries_andRoundTrips() throws Exception { + // Given — 5 strings, 2 per frame: 3 frames (2, 2, 1). Entries vary in length, so the + // frame byte spans must be found by walking the length prefixes, not a fixed stride. + ZstdEncodingEncoder framedByTwo = new ZstdEncodingEncoder(2); + String[] data = {"a", "bb", "ccc", "d", "eeeee"}; + + // When + EncodeResult result = framedByTwo.encode(DTypes.UTF8, data, EncodeTestHelper.testCtx()); + + // Then + var metaSeg = result.rootNode().metadata(); + ProtoZstdMetadata meta = ProtoZstdMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + assertThat(meta.frames()).hasSize(3); + + DecodeContext ctx = DecodeTestHelper.toDecodeContext(result, data.length, DTypes.UTF8, ReadRegistry.empty()); + VarBinArray decoded = (VarBinArray) DECODER.decode(ctx); + for (int i = 0; i < data.length; i++) { + assertThat(decoded.getString(i)).as("index %d", i).isEqualTo(data[i]); + } + } + + @Test + void encode_nullablePrimitive_framesOverValidValues_andRoundTrips() throws Exception { + // Given — 7 rows, 5 valid. Frames cover only the packed valid values (4 + 1), and the + // validity child's buffers must trail the two frame buffers. + int[] storage = {10, 0, 20, 30, 0, 40, 50}; + boolean[] validity = {true, false, true, true, false, true, true}; + DType i32Nullable = new DType.Primitive(PType.I32, true); + NullableData data = new NullableData(storage, validity); + + // When + EncodeResult result = FRAMED.encode(i32Nullable, data, EncodeTestHelper.testCtx()); + + // Then + var metaSeg = result.rootNode().metadata(); + ProtoZstdMetadata meta = ProtoZstdMetadata.decode(metaSeg, 0, metaSeg.byteSize()); + assertThat(meta.frames()).hasSize(2); + assertThat(meta.frames().get(0).n_values()).isEqualTo(4); + assertThat(meta.frames().get(1).n_values()).isEqualTo(1); + + DecodeContext ctx = DecodeTestHelper.toDecodeContext( + result, validity.length, i32Nullable, TestRegistry.ofDecoders(new BoolEncodingDecoder())); + MaskedArray decoded = (MaskedArray) DECODER.decode(ctx); + assertThat(decoded.length()).isEqualTo(7); + assertThat(decoded.isValid(1)).isFalse(); + assertThat(decoded.isValid(4)).isFalse(); + IntArray child = (IntArray) decoded.inner(); + assertThat(child.getInt(0)).isEqualTo(10); + assertThat(child.getInt(2)).isEqualTo(20); + assertThat(child.getInt(6)).isEqualTo(50); + } + } }