Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- `DType.isUnsigned()` — `true` for the unsigned integer primitives (`U8`–`U64`), `false` otherwise. ([#159](https://github.com/dfa1/vortex-java/issues/159))
- The `vortex.zstd` encoder now writes nullable columns (primitive and utf8/binary): null positions are stripped before compression and validity is emitted as a Bool child, matching the Rust reference layout. When `vortex.zstd` is the configured encoder, nullable primitive columns route to it directly instead of being wrapped in `vortex.masked`.

### Changed

Expand Down
5 changes: 0 additions & 5 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,3 @@ See [docs/compatibility.md](docs/compatibility.md) for the full encoding support
boundaries (`valuesPerFrame * byteWidth`), compress each slice independently, emit one `ZstdFrameMetadata` per frame.
Enables partial decompression during slice scans.

- [ ] **Nullable arrays (encode)** — `ZstdEncoding.Encoder` has no null handling.
Fix: accept nullable input (e.g. `Integer[]` or a validity mask alongside the data array). Strip null positions before
compression. Encode the validity bitmap as a Bool child (child[0]) in the `EncodeNode`. Mirrors what Rust does: only
valid values go into the compressed payload.

Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,78 @@ void javaWriter_rustReader_zstd_utf8(@TempDir Path tmp) throws IOException {
assertThat(decoded).containsExactly(data);
}

@Test
void javaWriter_rustReader_zstd_nullableI64(@TempDir Path tmp) throws IOException {
// Given — nullable primitive I64 written with ZstdEncoding. A configured zstd encoder
// declares acceptsNullable, so the writer routes the NullableData straight to it instead
// of masked-wrapping: zstd strips nulls before compression and emits validity as child[0].
// Verifies that nullable-zstd layout against the Rust reader.
Path file = tmp.resolve("java_zstd_nullable_i64.vtx");
DType.Struct schema = new DType.Struct(List.of("v"), List.of(new DType.Primitive(PType.I64, true)), false);
Long[] data = {10L, null, 30L, null, 50L, 60L};
try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE);
var sut = VortexWriter.create(ch, schema, WriteOptions.defaults(),
List.of(new ZstdEncodingEncoder()))) {
// When
sut.writeChunk(Map.of("v", data));
}

// Then — Rust reads a nullable BigInt vector; null positions survive, values round-trip
String uri = file.toAbsolutePath().toUri().toString();
DataSource ds = DataSource.open(SESSION, uri);
Scan scan = ds.scan(ScanOptions.of());
var values = new ArrayList<Long>();
while (scan.hasNext()) {
Partition partition = scan.next();
try (ArrowReader reader = partition.scanArrow(ALLOCATOR)) {
while (reader.loadNextBatch()) {
VectorSchemaRoot root = reader.getVectorSchemaRoot();
BigIntVector vec = (BigIntVector) root.getVector("v");
for (int i = 0; i < root.getRowCount(); i++) {
values.add(vec.isNull(i) ? null : vec.get(i));
}
}
}
}
assertThat(values).containsExactly(10L, null, 30L, null, 50L, 60L);
}

@Test
void javaWriter_rustReader_zstd_nullableUtf8(@TempDir Path tmp) throws IOException {
// Given — nullable Utf8 written with ZstdEncoding. Nullable string columns reach the zstd
// encoder directly (a String[] carrying nulls is not the NullableData shape that
// writeSegment masked-wraps), so this exercises the encoder's nullable varbin path against
// the Rust reader: nulls stripped before compression, validity emitted as child[0].
Path file = tmp.resolve("java_zstd_nullable_utf8.vtx");
DType.Struct schema = new DType.Struct(List.of("s"), List.of(new DType.Utf8(true)), false);
String[] data = {"hello", null, "world", null};
try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE);
var sut = VortexWriter.create(ch, schema, WriteOptions.defaults(),
List.of(new ZstdEncodingEncoder()))) {
// When
sut.writeChunk(Map.of("s", data));
}

// Then — Rust reads the nullable Utf8 vector; null positions survive, values round-trip
String uri = file.toAbsolutePath().toUri().toString();
DataSource ds = DataSource.open(SESSION, uri);
Scan scan = ds.scan(ScanOptions.of());
var values = new ArrayList<String>();
while (scan.hasNext()) {
Partition partition = scan.next();
try (ArrowReader reader = partition.scanArrow(ALLOCATOR)) {
while (reader.loadNextBatch()) {
VectorSchemaRoot root = reader.getVectorSchemaRoot();
VarCharVector vec = (VarCharVector) root.getVector("s");
for (int i = 0; i < root.getRowCount(); i++) {
values.add(vec.isNull(i) ? null : vec.getObject(i).toString());
}
}
}
}
assertThat(values).containsExactly("hello", null, "world", null);
}

@Test
void javaWriter_rustReader_list_i64(@TempDir Path tmp) throws IOException {
// Given — ListEncoding: exercises elements_len + offset_ptype proto fields (byte-order risk)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -520,10 +520,13 @@ private int writeSegment(DType dtype, Object data, EncodingEncoder encodingOverr
// Non-extension nullable columns (Primitive, Utf8) wrap with MaskedEncodingEncoder here.
// FbsExtension columns route through ExtEncodingEncoder.encode which itself delegates to
// MaskedEncodingEncoder when its storage data is NullableData — handled inside ExtEncoding.
// Exception: a configured encoder that embeds validity itself (acceptsNullable, e.g.
// vortex.zstd) takes the NullableData straight, so no masked wrapper is inserted.
if (encodingOverride == null
&& data instanceof io.github.dfa1.vortex.writer.encode.NullableData
&& !(dtype instanceof DType.Extension)) {
encodingOverride = new MaskedEncodingEncoder();
EncodingEncoder nullableCapable = nullableCapableEncoder(dtype);
encodingOverride = nullableCapable != null ? nullableCapable : new MaskedEncodingEncoder();
}
// Variant columns bypass the cascade: the container encoding is structural, not a
// compressible primitive codec, so route straight to the dedicated encoder.
Expand Down Expand Up @@ -595,6 +598,22 @@ private EncodingEncoder findEncoder(DType dtype) {
throw new UnsupportedOperationException("no encoder for dtype: " + dtype);
}

/// Returns the configured encoder for `dtype` that consumes a [NullableData] carrier directly
/// (embedding its own validity), or `null` to fall back to `vortex.masked` wrapping. Only the
/// first-match flat path is considered; with cascading enabled the compressor owns selection,
/// so nullable columns keep the masked layout.
private EncodingEncoder nullableCapableEncoder(DType dtype) {
if (options.allowedCascading() > 0) {
return null;
}
for (EncodingEncoder c : encodings) {
if (c.accepts(dtype) && c.acceptsNullable(dtype)) {
return c;
}
}
return null;
}

private void write(MemorySegment seg) throws IOException {
ByteBuffer buf = seg.asByteBuffer();
while (buf.hasRemaining()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@ public interface EncodingEncoder {
/// @return `true` if this encoding can encode arrays of `dtype`
boolean accepts(DType dtype);

/// Whether this encoder consumes a [NullableData] carrier directly for `dtype`, embedding
/// the validity bitmap itself (as a Bool child) rather than relying on an enclosing
/// `vortex.masked` wrapper.
///
/// The writer uses this to route nullable columns: an encoder returning `true` receives the
/// [NullableData] straight, otherwise the column is masked-wrapped over a non-nullable child.
/// Default `false` — most encoders only handle dense data.
///
/// @param dtype the dtype to test
/// @return `true` if this encoding can encode nullable `dtype` from a [NullableData] carrier
default boolean acceptsNullable(DType dtype) {
return false;
}

/// Encodes `data` to bytes using the provided arena for output buffer allocation.
///
/// @param dtype logical type of the data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;

/// Write-only encoder for `vortex.zstd`.
Expand All @@ -32,13 +33,35 @@ public boolean accepts(DType dtype) {
return dtype instanceof DType.Primitive || dtype instanceof DType.Utf8 || dtype instanceof DType.Binary;
}

@Override
public boolean acceptsNullable(DType dtype) {
// Nullable utf8/binary arrive as a String[] carrying nulls (handled in encode), not a
// NullableData carrier; only primitive nullable columns are routed here as NullableData.
return dtype instanceof DType.Primitive;
}

@Override
public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) {
if (data instanceof NullableData nd) {
if (!(dtype instanceof DType.Primitive dt)) {
throw new VortexException(EncodingId.VORTEX_ZSTD,
"NullableData is only supported for primitive dtypes, got " + dtype);
}
return encodeNullablePrimitive(dt, nd, ctx);
}
if (dtype instanceof DType.Primitive dt) {
return encodePrimitive(dt, data, ctx.arena());
}
if (dtype instanceof DType.Utf8 || dtype instanceof DType.Binary) {
return encodeVarBin((String[]) data, ctx.arena());
String[] strings = (String[]) data;
if (containsNull(strings)) {
if (!dtype.nullable()) {
throw new VortexException(EncodingId.VORTEX_ZSTD,
"non-nullable " + dtype + " contains null");
}
return encodeNullableVarBin(strings, ctx);
}
return encodeVarBin(strings, ctx.arena());
}
throw new VortexException(EncodingId.VORTEX_ZSTD, "unsupported dtype: " + dtype);
}
Expand Down Expand Up @@ -70,6 +93,112 @@ private static EncodeResult buildResult(MemorySegment raw, long n, Arena arena)
return new EncodeResult(root, List.of(compressed), null, null);
}

private static EncodeResult encodeNullablePrimitive(DType.Primitive dt, NullableData nd, EncodeContext ctx) {
Arena arena = ctx.arena();
int byteWidth = dt.ptype().byteSize();
boolean[] validity = nd.validity();
// Strip null positions: only valid values reach the compressed payload (mirrors the Rust
// reference). The decoder scatters them back over the validity mask carried by child[0].
MemorySegment full = primitiveToLeBytes(dt.ptype(), nd.values(), arena);
MemorySegment packed = packValidBytes(full, validity, byteWidth, arena);
return buildNullableResult(packed, countValid(validity), validity, ctx);
}

private static EncodeResult encodeNullableVarBin(String[] strings, EncodeContext ctx) {
boolean[] validity = validityOf(strings);
String[] valid = stripNulls(strings);
MemorySegment packed = buildLengthPrefixed(valid, ctx.arena());
return buildNullableResult(packed, valid.length, validity, ctx);
}

private static EncodeResult buildNullableResult(
MemorySegment raw, long nValues, boolean[] validity, EncodeContext ctx) {
// Zero-copy: compress the arena-native packed segment into another arena segment.
MemorySegment compressed;
try (ZstdCompressCtx cctx = new ZstdCompressCtx()) {
compressed = cctx.compress(ctx.arena(), raw);
}
byte[] meta = new ProtoZstdMetadata(
0,
List.of(new ProtoZstdFrameMetadata(raw.byteSize(), nValues))
).encode();

EncodeResult validityResult = new BoolEncodingEncoder().encode(DType.BOOL, validity, ctx);
// The frame payload owns buffer[0]; the validity child's buffers follow, so shift its
// buffer indices by one.
EncodeNode validityNode = EncodeNode.remapBufferIndices(validityResult.rootNode(), 1);

List<MemorySegment> buffers = new ArrayList<>(1 + validityResult.buffers().size());
buffers.add(compressed);
buffers.addAll(validityResult.buffers());

EncodeNode root = new EncodeNode(EncodingId.VORTEX_ZSTD, MemorySegment.ofArray(meta),
new EncodeNode[]{validityNode}, new int[]{0});
return new EncodeResult(root, buffers, null, null);
}

private static MemorySegment packValidBytes(
MemorySegment full, boolean[] validity, int byteWidth, Arena arena) {
long validBytes = (long) countValid(validity) * byteWidth;
MemorySegment packed = arena.allocate(Math.max(validBytes, 1), byteWidth);
long pos = 0;
for (int i = 0; i < validity.length; i++) {
if (validity[i]) {
MemorySegment.copy(full, (long) i * byteWidth, packed, pos, byteWidth);
pos += byteWidth;
}
}
return packed.asSlice(0, validBytes);
}

private static int countValid(boolean[] validity) {
int count = 0;
for (boolean valid : validity) {
if (valid) {
count++;
}
}
return count;
}

private static boolean containsNull(String[] strings) {
for (String s : strings) {
if (s == null) {
return true;
}
}
return false;
}

private static boolean[] validityOf(String[] strings) {
boolean[] validity = new boolean[strings.length];
for (int i = 0; i < strings.length; i++) {
validity[i] = strings[i] != null;
}
return validity;
}

private static String[] stripNulls(String[] strings) {
String[] valid = new String[countNonNull(strings)];
int j = 0;
for (String s : strings) {
if (s != null) {
valid[j++] = s;
}
}
return valid;
}

private static int countNonNull(String[] strings) {
int count = 0;
for (String s : strings) {
if (s != null) {
count++;
}
}
return count;
}

private static MemorySegment primitiveToLeBytes(PType ptype, Object data, Arena arena) {
return switch (ptype) {
case I8, U8 -> {
Expand Down
Loading
Loading