diff --git a/CHANGELOG.md b/CHANGELOG.md index 93f38d4a..6d35463c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `DType.isUnsigned()` — `true` for the unsigned integer primitives (`U8`–`U64`), `false` otherwise. ([#159](https://github.com/dfa1/vortex-java/issues/159)) +- The `vortex.zstd` encoder now writes nullable columns (primitive and utf8/binary): null positions are stripped before compression and validity is emitted as a Bool child, matching the Rust reference layout. When `vortex.zstd` is the configured encoder, nullable primitive columns route to it directly instead of being wrapped in `vortex.masked`. ### Changed diff --git a/TODO.md b/TODO.md index 5cf4c394..edde250a 100644 --- a/TODO.md +++ b/TODO.md @@ -108,8 +108,3 @@ See [docs/compatibility.md](docs/compatibility.md) for the full encoding support boundaries (`valuesPerFrame * byteWidth`), compress each slice independently, emit one `ZstdFrameMetadata` per frame. Enables partial decompression during slice scans. -- [ ] **Nullable arrays (encode)** — `ZstdEncoding.Encoder` has no null handling. - Fix: accept nullable input (e.g. `Integer[]` or a validity mask alongside the data array). Strip null positions before - compression. Encode the validity bitmap as a Bool child (child[0]) in the `EncodeNode`. Mirrors what Rust does: only - valid values go into the compressed payload. - diff --git a/integration/src/test/java/io/github/dfa1/vortex/integration/JavaWritesRustReadsIntegrationTest.java b/integration/src/test/java/io/github/dfa1/vortex/integration/JavaWritesRustReadsIntegrationTest.java index 3a1a5e40..c2c19bf6 100644 --- a/integration/src/test/java/io/github/dfa1/vortex/integration/JavaWritesRustReadsIntegrationTest.java +++ b/integration/src/test/java/io/github/dfa1/vortex/integration/JavaWritesRustReadsIntegrationTest.java @@ -1174,6 +1174,78 @@ void javaWriter_rustReader_zstd_utf8(@TempDir Path tmp) throws IOException { assertThat(decoded).containsExactly(data); } + @Test + void javaWriter_rustReader_zstd_nullableI64(@TempDir Path tmp) throws IOException { + // Given — nullable primitive I64 written with ZstdEncoding. A configured zstd encoder + // declares acceptsNullable, so the writer routes the NullableData straight to it instead + // of masked-wrapping: zstd strips nulls before compression and emits validity as child[0]. + // Verifies that nullable-zstd layout against the Rust reader. + Path file = tmp.resolve("java_zstd_nullable_i64.vtx"); + DType.Struct schema = new DType.Struct(List.of("v"), List.of(new DType.Primitive(PType.I64, true)), false); + Long[] data = {10L, null, 30L, null, 50L, 60L}; + try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); + var sut = VortexWriter.create(ch, schema, WriteOptions.defaults(), + List.of(new ZstdEncodingEncoder()))) { + // When + sut.writeChunk(Map.of("v", data)); + } + + // Then — Rust reads a nullable BigInt vector; null positions survive, values round-trip + String uri = file.toAbsolutePath().toUri().toString(); + DataSource ds = DataSource.open(SESSION, uri); + Scan scan = ds.scan(ScanOptions.of()); + var values = new ArrayList(); + while (scan.hasNext()) { + Partition partition = scan.next(); + try (ArrowReader reader = partition.scanArrow(ALLOCATOR)) { + while (reader.loadNextBatch()) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + BigIntVector vec = (BigIntVector) root.getVector("v"); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vec.isNull(i) ? null : vec.get(i)); + } + } + } + } + assertThat(values).containsExactly(10L, null, 30L, null, 50L, 60L); + } + + @Test + void javaWriter_rustReader_zstd_nullableUtf8(@TempDir Path tmp) throws IOException { + // Given — nullable Utf8 written with ZstdEncoding. Nullable string columns reach the zstd + // encoder directly (a String[] carrying nulls is not the NullableData shape that + // writeSegment masked-wraps), so this exercises the encoder's nullable varbin path against + // the Rust reader: nulls stripped before compression, validity emitted as child[0]. + Path file = tmp.resolve("java_zstd_nullable_utf8.vtx"); + DType.Struct schema = new DType.Struct(List.of("s"), List.of(new DType.Utf8(true)), false); + String[] data = {"hello", null, "world", null}; + try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); + var sut = VortexWriter.create(ch, schema, WriteOptions.defaults(), + List.of(new ZstdEncodingEncoder()))) { + // When + sut.writeChunk(Map.of("s", data)); + } + + // Then — Rust reads the nullable Utf8 vector; null positions survive, values round-trip + String uri = file.toAbsolutePath().toUri().toString(); + DataSource ds = DataSource.open(SESSION, uri); + Scan scan = ds.scan(ScanOptions.of()); + var values = new ArrayList(); + while (scan.hasNext()) { + Partition partition = scan.next(); + try (ArrowReader reader = partition.scanArrow(ALLOCATOR)) { + while (reader.loadNextBatch()) { + VectorSchemaRoot root = reader.getVectorSchemaRoot(); + VarCharVector vec = (VarCharVector) root.getVector("s"); + for (int i = 0; i < root.getRowCount(); i++) { + values.add(vec.isNull(i) ? null : vec.getObject(i).toString()); + } + } + } + } + assertThat(values).containsExactly("hello", null, "world", null); + } + @Test void javaWriter_rustReader_list_i64(@TempDir Path tmp) throws IOException { // Given — ListEncoding: exercises elements_len + offset_ptype proto fields (byte-order risk) diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/VortexWriter.java b/writer/src/main/java/io/github/dfa1/vortex/writer/VortexWriter.java index f116c112..ab86aab9 100644 --- a/writer/src/main/java/io/github/dfa1/vortex/writer/VortexWriter.java +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/VortexWriter.java @@ -520,10 +520,13 @@ private int writeSegment(DType dtype, Object data, EncodingEncoder encodingOverr // Non-extension nullable columns (Primitive, Utf8) wrap with MaskedEncodingEncoder here. // FbsExtension columns route through ExtEncodingEncoder.encode which itself delegates to // MaskedEncodingEncoder when its storage data is NullableData — handled inside ExtEncoding. + // Exception: a configured encoder that embeds validity itself (acceptsNullable, e.g. + // vortex.zstd) takes the NullableData straight, so no masked wrapper is inserted. if (encodingOverride == null && data instanceof io.github.dfa1.vortex.writer.encode.NullableData && !(dtype instanceof DType.Extension)) { - encodingOverride = new MaskedEncodingEncoder(); + EncodingEncoder nullableCapable = nullableCapableEncoder(dtype); + encodingOverride = nullableCapable != null ? nullableCapable : new MaskedEncodingEncoder(); } // Variant columns bypass the cascade: the container encoding is structural, not a // compressible primitive codec, so route straight to the dedicated encoder. @@ -595,6 +598,22 @@ private EncodingEncoder findEncoder(DType dtype) { throw new UnsupportedOperationException("no encoder for dtype: " + dtype); } + /// Returns the configured encoder for `dtype` that consumes a [NullableData] carrier directly + /// (embedding its own validity), or `null` to fall back to `vortex.masked` wrapping. Only the + /// first-match flat path is considered; with cascading enabled the compressor owns selection, + /// so nullable columns keep the masked layout. + private EncodingEncoder nullableCapableEncoder(DType dtype) { + if (options.allowedCascading() > 0) { + return null; + } + for (EncodingEncoder c : encodings) { + if (c.accepts(dtype) && c.acceptsNullable(dtype)) { + return c; + } + } + return null; + } + private void write(MemorySegment seg) throws IOException { ByteBuffer buf = seg.asByteBuffer(); while (buf.hasRemaining()) { diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/EncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/EncodingEncoder.java index 434707ef..ae6119a0 100644 --- a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/EncodingEncoder.java +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/EncodingEncoder.java @@ -19,6 +19,20 @@ public interface EncodingEncoder { /// @return `true` if this encoding can encode arrays of `dtype` boolean accepts(DType dtype); + /// Whether this encoder consumes a [NullableData] carrier directly for `dtype`, embedding + /// the validity bitmap itself (as a Bool child) rather than relying on an enclosing + /// `vortex.masked` wrapper. + /// + /// The writer uses this to route nullable columns: an encoder returning `true` receives the + /// [NullableData] straight, otherwise the column is masked-wrapped over a non-nullable child. + /// Default `false` — most encoders only handle dense data. + /// + /// @param dtype the dtype to test + /// @return `true` if this encoding can encode nullable `dtype` from a [NullableData] carrier + default boolean acceptsNullable(DType dtype) { + return false; + } + /// Encodes `data` to bytes using the provided arena for output buffer allocation. /// /// @param dtype logical type of the data diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoder.java index 5126db40..4335669e 100644 --- a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoder.java +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoder.java @@ -13,6 +13,7 @@ import java.lang.foreign.MemorySegment; import java.lang.foreign.ValueLayout; import java.nio.charset.StandardCharsets; +import java.util.ArrayList; import java.util.List; /// Write-only encoder for `vortex.zstd`. @@ -32,13 +33,35 @@ public boolean accepts(DType dtype) { return dtype instanceof DType.Primitive || dtype instanceof DType.Utf8 || dtype instanceof DType.Binary; } + @Override + public boolean acceptsNullable(DType dtype) { + // Nullable utf8/binary arrive as a String[] carrying nulls (handled in encode), not a + // NullableData carrier; only primitive nullable columns are routed here as NullableData. + return dtype instanceof DType.Primitive; + } + @Override public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { + if (data instanceof NullableData nd) { + if (!(dtype instanceof DType.Primitive dt)) { + throw new VortexException(EncodingId.VORTEX_ZSTD, + "NullableData is only supported for primitive dtypes, got " + dtype); + } + return encodeNullablePrimitive(dt, nd, ctx); + } if (dtype instanceof DType.Primitive dt) { return encodePrimitive(dt, data, ctx.arena()); } if (dtype instanceof DType.Utf8 || dtype instanceof DType.Binary) { - return encodeVarBin((String[]) data, ctx.arena()); + String[] strings = (String[]) data; + if (containsNull(strings)) { + if (!dtype.nullable()) { + throw new VortexException(EncodingId.VORTEX_ZSTD, + "non-nullable " + dtype + " contains null"); + } + return encodeNullableVarBin(strings, ctx); + } + return encodeVarBin(strings, ctx.arena()); } throw new VortexException(EncodingId.VORTEX_ZSTD, "unsupported dtype: " + dtype); } @@ -70,6 +93,112 @@ private static EncodeResult buildResult(MemorySegment raw, long n, Arena arena) return new EncodeResult(root, List.of(compressed), null, null); } + private static EncodeResult encodeNullablePrimitive(DType.Primitive dt, NullableData nd, EncodeContext ctx) { + Arena arena = ctx.arena(); + int byteWidth = dt.ptype().byteSize(); + boolean[] validity = nd.validity(); + // Strip null positions: only valid values reach the compressed payload (mirrors the Rust + // reference). The decoder scatters them back over the validity mask carried by child[0]. + MemorySegment full = primitiveToLeBytes(dt.ptype(), nd.values(), arena); + MemorySegment packed = packValidBytes(full, validity, byteWidth, arena); + return buildNullableResult(packed, countValid(validity), validity, ctx); + } + + private static EncodeResult encodeNullableVarBin(String[] strings, EncodeContext ctx) { + boolean[] validity = validityOf(strings); + String[] valid = stripNulls(strings); + MemorySegment packed = buildLengthPrefixed(valid, ctx.arena()); + return buildNullableResult(packed, valid.length, validity, ctx); + } + + private static EncodeResult buildNullableResult( + MemorySegment raw, long nValues, boolean[] validity, EncodeContext ctx) { + // Zero-copy: compress the arena-native packed segment into another arena segment. + MemorySegment compressed; + try (ZstdCompressCtx cctx = new ZstdCompressCtx()) { + compressed = cctx.compress(ctx.arena(), raw); + } + byte[] meta = new ProtoZstdMetadata( + 0, + List.of(new ProtoZstdFrameMetadata(raw.byteSize(), nValues)) + ).encode(); + + EncodeResult validityResult = new BoolEncodingEncoder().encode(DType.BOOL, validity, ctx); + // The frame payload owns buffer[0]; the validity child's buffers follow, so shift its + // buffer indices by one. + EncodeNode validityNode = EncodeNode.remapBufferIndices(validityResult.rootNode(), 1); + + List buffers = new ArrayList<>(1 + validityResult.buffers().size()); + buffers.add(compressed); + buffers.addAll(validityResult.buffers()); + + EncodeNode root = new EncodeNode(EncodingId.VORTEX_ZSTD, MemorySegment.ofArray(meta), + new EncodeNode[]{validityNode}, new int[]{0}); + return new EncodeResult(root, buffers, null, null); + } + + private static MemorySegment packValidBytes( + MemorySegment full, boolean[] validity, int byteWidth, Arena arena) { + long validBytes = (long) countValid(validity) * byteWidth; + MemorySegment packed = arena.allocate(Math.max(validBytes, 1), byteWidth); + long pos = 0; + for (int i = 0; i < validity.length; i++) { + if (validity[i]) { + MemorySegment.copy(full, (long) i * byteWidth, packed, pos, byteWidth); + pos += byteWidth; + } + } + return packed.asSlice(0, validBytes); + } + + private static int countValid(boolean[] validity) { + int count = 0; + for (boolean valid : validity) { + if (valid) { + count++; + } + } + return count; + } + + private static boolean containsNull(String[] strings) { + for (String s : strings) { + if (s == null) { + return true; + } + } + return false; + } + + private static boolean[] validityOf(String[] strings) { + boolean[] validity = new boolean[strings.length]; + for (int i = 0; i < strings.length; i++) { + validity[i] = strings[i] != null; + } + return validity; + } + + private static String[] stripNulls(String[] strings) { + String[] valid = new String[countNonNull(strings)]; + int j = 0; + for (String s : strings) { + if (s != null) { + valid[j++] = s; + } + } + return valid; + } + + private static int countNonNull(String[] strings) { + int count = 0; + for (String s : strings) { + if (s != null) { + count++; + } + } + return count; + } + private static MemorySegment primitiveToLeBytes(PType ptype, Object data, Arena arena) { return switch (ptype) { case I8, U8 -> { diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoderTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoderTest.java index 591fb8ef..fbfc741d 100644 --- a/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoderTest.java +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/encode/ZstdEncodingEncoderTest.java @@ -86,6 +86,112 @@ void encode_emptyArray_roundTrips() { assertThat(decoded.length()).isZero(); } + @Test + void encode_nullablePrimitive_roundTrips() { + // Given — nulls at positions 1 and 3; the storage array holds zero placeholders there, + // validity marks the real rows. Only valid values must reach the compressed payload, + // so the decoder can scatter them back over the validity mask carried by child[0]. + int[] storage = {10, 0, 30, 0}; + boolean[] validity = {true, false, true, false}; + DType i32Nullable = new DType.Primitive(PType.I32, true); + NullableData data = new NullableData(storage, validity); + + // When + EncodeResult result = ENCODER.encode(i32Nullable, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext( + result, validity.length, i32Nullable, TestRegistry.ofDecoders(new BoolEncodingDecoder())); + MaskedArray decoded = (MaskedArray) DECODER.decode(ctx); + + // Then + assertThat(decoded.length()).isEqualTo(4); + assertThat(decoded.isValid(0)).isTrue(); + assertThat(decoded.isValid(1)).isFalse(); + assertThat(decoded.isValid(2)).isTrue(); + assertThat(decoded.isValid(3)).isFalse(); + IntArray child = (IntArray) decoded.inner(); + assertThat(child.getInt(0)).isEqualTo(10); + assertThat(child.getInt(2)).isEqualTo(30); + } + + @Test + void encode_nullableUtf8_roundTrips() { + // Given — a String[] carrying nulls; the encoder must strip them, compress only the + // valid strings, and emit the validity bitmap as child[0]. + String[] data = {"hello", null, "world", null}; + DType utf8Nullable = new DType.Utf8(true); + + // When + EncodeResult result = ENCODER.encode(utf8Nullable, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext( + result, data.length, utf8Nullable, TestRegistry.ofDecoders(new BoolEncodingDecoder())); + MaskedArray decoded = (MaskedArray) DECODER.decode(ctx); + + // Then + assertThat(decoded.length()).isEqualTo(4); + assertThat(decoded.isValid(0)).isTrue(); + assertThat(decoded.isValid(1)).isFalse(); + assertThat(decoded.isValid(2)).isTrue(); + assertThat(decoded.isValid(3)).isFalse(); + VarBinArray child = (VarBinArray) decoded.inner(); + assertThat(child.getString(0)).isEqualTo("hello"); + assertThat(child.getString(2)).isEqualTo("world"); + } + + @Test + void encode_allNullPrimitive_roundTrips() { + // Given — every row null: zero valid values reach the payload, so the compressed frame + // is built from a 0-byte slice. Guards the empty-payload corner of packValidBytes / + // zstd compress-empty, and the all-false validity bitmap. + int[] storage = {0, 0, 0}; + boolean[] validity = {false, false, false}; + DType i32Nullable = new DType.Primitive(PType.I32, true); + NullableData data = new NullableData(storage, validity); + + // When + EncodeResult result = ENCODER.encode(i32Nullable, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext( + result, validity.length, i32Nullable, TestRegistry.ofDecoders(new BoolEncodingDecoder())); + MaskedArray decoded = (MaskedArray) DECODER.decode(ctx); + + // Then + assertThat(decoded.length()).isEqualTo(3); + assertThat(decoded.isValid(0)).isFalse(); + assertThat(decoded.isValid(1)).isFalse(); + assertThat(decoded.isValid(2)).isFalse(); + } + + @Test + void encode_allNullUtf8_roundTrips() { + // Given — every string null: stripNulls yields an empty array, so the length-prefixed + // payload is 0 bytes. Guards the empty-payload corner of the nullable varbin path. + String[] data = {null, null, null}; + DType utf8Nullable = new DType.Utf8(true); + + // When + EncodeResult result = ENCODER.encode(utf8Nullable, data, EncodeTestHelper.testCtx()); + DecodeContext ctx = DecodeTestHelper.toDecodeContext( + result, data.length, utf8Nullable, TestRegistry.ofDecoders(new BoolEncodingDecoder())); + MaskedArray decoded = (MaskedArray) DECODER.decode(ctx); + + // Then + assertThat(decoded.length()).isEqualTo(3); + assertThat(decoded.isValid(0)).isFalse(); + assertThat(decoded.isValid(1)).isFalse(); + assertThat(decoded.isValid(2)).isFalse(); + } + + @Test + void encode_nonNullableUtf8WithNull_throwsVortexException() { + // Given — a non-nullable Utf8 dtype whose data carries a stray null. The encoder must + // reject it rather than silently emit a nullable layout the dtype does not declare. + String[] data = {"a", null, "c"}; + DType utf8 = new DType.Utf8(false); + + // When / Then + assertThatThrownBy(() -> ENCODER.encode(utf8, data, EncodeTestHelper.testCtx())) + .isInstanceOf(VortexException.class); + } + @Test void encode_unsupportedDtype_throwsVortexException() { assertThatThrownBy(() -> ENCODER.encode(DType.NULL, null, EncodeTestHelper.testCtx()))