diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/VortexWriter.java b/writer/src/main/java/io/github/dfa1/vortex/writer/VortexWriter.java index 9a5dd840..ce72f45c 100644 --- a/writer/src/main/java/io/github/dfa1/vortex/writer/VortexWriter.java +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/VortexWriter.java @@ -11,7 +11,11 @@ import io.github.dfa1.vortex.encoding.EncodingId; import io.github.dfa1.vortex.writer.encode.EncodeContext; import io.github.dfa1.vortex.writer.encode.EncodeNode; +import io.github.dfa1.vortex.proto.ScalarValue; import io.github.dfa1.vortex.writer.encode.EncodeResult; +import io.github.dfa1.vortex.writer.encode.NullableData; +import io.github.dfa1.vortex.writer.encode.StructData; +import io.github.dfa1.vortex.writer.encode.StructEncodingEncoder; import io.github.dfa1.vortex.writer.encode.EncodingEncoder; import io.github.dfa1.vortex.writer.encode.AlpEncodingEncoder; import io.github.dfa1.vortex.writer.encode.BitpackedEncodingEncoder; @@ -75,6 +79,11 @@ public final class VortexWriter implements Closeable { private static final int LAYOUT_CHUNKED = 1; private static final int LAYOUT_STRUCT = 2; private static final int LAYOUT_DICT = 3; + private static final int LAYOUT_ZONED = 4; + + // Stat ordinals in the Rust `Stat` enum (see ZonedStatsSchema). v1 emits MAX + MIN only. + private static final int STAT_MAX = 3; + private static final int STAT_MIN = 4; // Columns with global cardinality below this threshold are dict-encoded across all chunks. // Kept low: global dict hurts high-cardinality F64 columns (ALP codes beat U16 dict codes). @@ -107,6 +116,12 @@ public final class VortexWriter implements Closeable { private final Map dictColRefs = new LinkedHashMap<>(); private boolean firstChunkSeen = false; + // Per-column zone-maps, populated by flushZoneMaps() in close() when enableZoneMaps is set. + private final Map zoneMaps = new LinkedHashMap<>(); + // Stats (ScalarValue bytes) of the most recently written segment, captured for ChunkRef. + private byte[] lastStatsMin; + private byte[] lastStatsMax; + private VortexWriter( WritableByteChannel channel, DType.Struct schema, WriteOptions options, List encodings ) { @@ -451,7 +466,7 @@ public void writeChunk(Map columns) throws IOException { } else { long rowCount = arrayLength(data); int segIdx = writeSegment(colDtype, data); - colChunks.get(colName).add(new ChunkRef(segIdx, rowCount)); + colChunks.get(colName).add(new ChunkRef(segIdx, rowCount, lastStatsMin, lastStatsMax)); } } firstChunkSeen = true; @@ -460,6 +475,7 @@ public void writeChunk(Map columns) throws IOException { @Override public void close() throws IOException { flushDictColumns(); + flushZoneMaps(); ByteBuffer footerBuf = buildFooter(); long footerOff = bytesWritten; write(footerBuf); @@ -550,6 +566,8 @@ private int writeSegment(DType dtype, Object data, EncodingEncoder encodingOverr bytesWritten += 4; segs.add(new SegRef(offset, bytesWritten - offset)); + lastStatsMin = result.statsMin(); + lastStatsMax = result.statsMax(); return segIdx; } } @@ -656,6 +674,159 @@ private int buildArrayNodeFlatBuffer(FlatBufferBuilder fbb, EncodeNode node, int fbb, encIdx, metaOff, childVec, bufIdxVec, statsOff); } + /// Emits a per-column `vortex.stats` zone-map (one zone per chunk) for every fixed-width + /// primitive column whose chunks all carry min/max stats. Must run before [#buildFooter] + /// so the stats-table segments are present in `segment_specs`. + private void flushZoneMaps() throws IOException { + if (!options.enableZoneMaps()) { + return; + } + for (Map.Entry> e : colChunks.entrySet()) { + String colName = e.getKey(); + List chunks = e.getValue(); + if (chunks.isEmpty()) { + continue; + } + DType colDtype = schema.fieldTypes().get(schema.fieldNames().indexOf(colName)); + if (!(colDtype instanceof DType.Primitive prim) || !isZoneMappable(prim.ptype())) { + continue; + } + if (!chunks.stream().allMatch(ChunkRef::hasStats)) { + continue; + } + int nZones = chunks.size(); + boolean[] allValid = new boolean[nZones]; + java.util.Arrays.fill(allValid, true); + boolean[] notTruncated = new boolean[nZones]; + DType nullablePrim = new DType.Primitive(prim.ptype(), true); + // Field order mirrors ZonedStatsSchema.statsTableDtype for present stats MAX(3), MIN(4): + // [max, max_is_truncated, min, min_is_truncated]. + DType.Struct statsDtype = new DType.Struct( + List.of("max", "max_is_truncated", "min", "min_is_truncated"), + List.of(nullablePrim, new DType.Bool(false), nullablePrim, new DType.Bool(false)), + false); + StructData sd = new StructData(List.of( + new NullableData(statColumn(prim.ptype(), chunks, true), allValid), + notTruncated, + new NullableData(statColumn(prim.ptype(), chunks, false), allValid.clone()), + notTruncated.clone())); + int zonesSegIdx = writeSegment(statsDtype, sd, new StructEncodingEncoder()); + zoneMaps.put(colName, new ZoneMapRef(zonesSegIdx, nZones, options.chunkSize())); + } + } + + /// Wraps a column's data layout in a `vortex.stats` (zoned) layout when a zone-map was + /// emitted for it; otherwise returns the data layout unchanged. + private int wrapZoneMap(FlatBufferBuilder fbb, String colName, int dataLayout, long colRows) { + ZoneMapRef zm = zoneMaps.get(colName); + if (zm == null) { + return dataLayout; + } + int zonesSegV = Layout.createSegmentsVector(fbb, new long[]{zm.zonesSegIdx()}); + int zonesFlat = Layout.createLayout(fbb, LAYOUT_FLAT, zm.nZones(), 0, 0, zonesSegV); + int childV = Layout.createChildrenVector(fbb, new int[]{dataLayout, zonesFlat}); + int metaV = Layout.createMetadataVector(fbb, zonedMetadataBytes(zm.zoneLen())); + return Layout.createLayout(fbb, LAYOUT_ZONED, colRows, metaV, childV, 0); + } + + /// `vortex.stats` metadata: `u32` zone length (LE) + a 1-byte stat bitset with the MAX and + /// MIN bits set (LSB-first), matching [io.github.dfa1.vortex.inspect] `ZonedStatsSchema`. + private static byte[] zonedMetadataBytes(long zoneLen) { + byte[] meta = new byte[5]; + ByteBuffer.wrap(meta).order(ByteOrder.LITTLE_ENDIAN).putInt((int) zoneLen); + meta[4] = (byte) ((1 << STAT_MAX) | (1 << STAT_MIN)); + return meta; + } + + private static boolean isZoneMappable(PType ptype) { + return switch (ptype) { + case I8, I16, I32, I64, U8, U16, U32, U64, F32, F64 -> true; + case F16 -> false; + }; + } + + /// Builds the per-zone min (or max) values array in the storage shape the primitive encoder + /// expects, decoding each chunk's serialised [ScalarValue] stat. + private static Object statColumn(PType ptype, List chunks, boolean max) { + int n = chunks.size(); + return switch (ptype) { + case I8, U8 -> { + byte[] a = new byte[n]; + for (int i = 0; i < n; i++) { + a[i] = (byte) scalarLong(chunks.get(i), max); + } + yield a; + } + case I16, U16 -> { + short[] a = new short[n]; + for (int i = 0; i < n; i++) { + a[i] = (short) scalarLong(chunks.get(i), max); + } + yield a; + } + case I32, U32 -> { + int[] a = new int[n]; + for (int i = 0; i < n; i++) { + a[i] = (int) scalarLong(chunks.get(i), max); + } + yield a; + } + case I64, U64 -> { + long[] a = new long[n]; + for (int i = 0; i < n; i++) { + a[i] = scalarLong(chunks.get(i), max); + } + yield a; + } + case F32 -> { + float[] a = new float[n]; + for (int i = 0; i < n; i++) { + a[i] = (float) scalarDouble(chunks.get(i), max); + } + yield a; + } + case F64 -> { + double[] a = new double[n]; + for (int i = 0; i < n; i++) { + a[i] = scalarDouble(chunks.get(i), max); + } + yield a; + } + case F16 -> throw new IllegalStateException("F16 is not zone-mappable"); + }; + } + + private static long scalarLong(ChunkRef cr, boolean max) { + ScalarValue sv = decodeScalar(max ? cr.statsMax() : cr.statsMin()); + if (sv.int64_value() != null) { + return sv.int64_value(); + } + if (sv.uint64_value() != null) { + return sv.uint64_value(); + } + throw new IllegalStateException("expected integer scalar stat"); + } + + private static double scalarDouble(ChunkRef cr, boolean max) { + ScalarValue sv = decodeScalar(max ? cr.statsMax() : cr.statsMin()); + if (sv.f64_value() != null) { + return sv.f64_value(); + } + if (sv.f32_value() != null) { + return sv.f32_value(); + } + throw new IllegalStateException("expected float scalar stat"); + } + + private static ScalarValue decodeScalar(byte[] bytes) { + MemorySegment seg = MemorySegment.ofArray(bytes); + try { + return ScalarValue.decode(seg, 0, seg.byteSize()); + } catch (IOException ex) { + throw new java.io.UncheckedIOException(ex); + } + } + private ByteBuffer buildFooter() { var fbb = new FlatBufferBuilder(512); @@ -675,7 +846,8 @@ private ByteBuffer buildFooter() { int ls1 = LayoutSpec.createLayoutSpec(fbb, fbb.createString("vortex.chunked")); int ls2 = LayoutSpec.createLayoutSpec(fbb, fbb.createString("vortex.struct")); int ls3 = LayoutSpec.createLayoutSpec(fbb, fbb.createString("vortex.dict")); - int lsv = Footer.createLayoutSpecsVector(fbb, new int[]{ls0, ls1, ls2, ls3}); + int ls4 = LayoutSpec.createLayoutSpec(fbb, fbb.createString("vortex.stats")); + int lsv = Footer.createLayoutSpecsVector(fbb, new int[]{ls0, ls1, ls2, ls3, ls4}); // segment_specs (inline struct vector — write in reverse order) Footer.startSegmentSpecsVector(fbb, segs.size()); @@ -716,7 +888,8 @@ private ByteBuffer buildLayout() { colRows += cr.rowCount(); } int childV = Layout.createChildrenVector(fbb, flats); - colLayouts[c] = Layout.createLayout(fbb, LAYOUT_CHUNKED, colRows, 0, childV, 0); + int dataChunked = Layout.createLayout(fbb, LAYOUT_CHUNKED, colRows, 0, childV, 0); + colLayouts[c] = wrapZoneMap(fbb, colName, dataChunked, colRows); if (totalRows == 0) { totalRows = colRows; } @@ -816,7 +989,7 @@ private void writeGlobalDictColumn(String colName, DType.Primitive dtype, List codesSegIdxes, diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/StructEncodingEncoder.java b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/StructEncodingEncoder.java index 828e6bf4..a13b54f5 100644 --- a/writer/src/main/java/io/github/dfa1/vortex/writer/encode/StructEncodingEncoder.java +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/encode/StructEncodingEncoder.java @@ -49,7 +49,15 @@ public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) { EncodeNode[] children = new EncodeNode[fields.size()]; for (int i = 0; i < fields.size(); i++) { DType fieldDtype = fieldTypes.get(i); - EncodeResult fieldResult = findEncoding(fieldDtype).encode(fieldDtype, fields.get(i), ctx); + Object fieldData = fields.get(i); + // Nullable fields carry a NullableData(values, validity) pair; route them through the + // masked encoder (values + validity bitmap), mirroring VortexWriter.writeSegment. The + // plain fallbacks only handle dense arrays. + EncodingEncoder fieldEncoder = + (fieldData instanceof NullableData && !(fieldDtype instanceof DType.Extension)) + ? new MaskedEncodingEncoder() + : findEncoding(fieldDtype); + EncodeResult fieldResult = fieldEncoder.encode(fieldDtype, fieldData, ctx); int bufOffset = allBuffers.size(); children[i] = EncodeNode.remapBufferIndices(fieldResult.rootNode(), bufOffset); allBuffers.addAll(fieldResult.buffers()); diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/WriterZoneMapTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/WriterZoneMapTest.java new file mode 100644 index 00000000..f8822ec2 --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/WriterZoneMapTest.java @@ -0,0 +1,123 @@ +package io.github.dfa1.vortex.writer; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.reader.Layout; +import io.github.dfa1.vortex.reader.SegmentSpec; +import io.github.dfa1.vortex.reader.VortexReader; +import io.github.dfa1.vortex.reader.array.LongArray; +import io.github.dfa1.vortex.reader.array.MaskedArray; +import io.github.dfa1.vortex.reader.array.StructArray; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.lang.foreign.Arena; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/// Round-trip coverage for writer-side `vortex.stats` (zone-map) emission. +class WriterZoneMapTest { + + private static final DType.Struct SCHEMA = new DType.Struct( + List.of("v"), List.of(new DType.Primitive(PType.I64, false)), false); + + // Three zones of four rows: [0..3], [4..7], [8..11]. + private static Path write(Path tmp, boolean zoneMaps) throws IOException { + WriteOptions opts = new WriteOptions(4, zoneMaps, 0.90, 0, true, false); + Path file = tmp.resolve("zoned-" + zoneMaps + ".vtx"); + try (var ch = FileChannel.open(file, StandardOpenOption.CREATE, StandardOpenOption.WRITE); + var sut = VortexWriter.create(ch, SCHEMA, opts)) { + for (int z = 0; z < 3; z++) { + long[] v = new long[4]; + for (int i = 0; i < 4; i++) { + v[i] = z * 4L + i; + } + sut.writeChunk(Map.of("v", v)); + } + } + return file; + } + + @Test + void enableZoneMaps_wrapsColumnInZonedLayoutWithMetadata(@TempDir Path tmp) throws IOException { + // Given a file written with zone maps on + Path file = write(tmp, true); + + // When + try (VortexReader reader = VortexReader.open(file)) { + Layout column = reader.layout().children().get(0); + + // Then the column is a vortex.stats layout: [data, zones], zone_len=4, MAX+MIN bitset + assertThat(column.isZoned()).isTrue(); + assertThat(column.children()).hasSize(2); + ByteBuffer meta = column.metadata().duplicate().order(ByteOrder.LITTLE_ENDIAN); + assertThat(meta.getInt(meta.position())).isEqualTo(4); // zone_len + assertThat(meta.get(meta.position() + 4)).isEqualTo((byte) 0x18); // bits 3(MAX)+4(MIN) + } + } + + @Test + void disableZoneMaps_leavesColumnUnwrapped(@TempDir Path tmp) throws IOException { + // Given zone maps off + Path file = write(tmp, false); + + // When / Then the column is the plain chunked data layout + try (VortexReader reader = VortexReader.open(file)) { + assertThat(reader.layout().children().get(0).isZoned()).isFalse(); + } + } + + @Test + void zoneMaps_dataStillRoundTrips(@TempDir Path tmp) throws IOException { + // Given a zone-mapped file + Path file = write(tmp, true); + + // When the data is scanned (the zoned wrapper is transparent for reads) + long[] all; + try (VortexReader reader = VortexReader.open(file)) { + all = VortexReads.readAllLongs(reader, "v"); + } + + // Then every row is intact + assertThat(all).containsExactly(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11); + } + + @Test + void zoneMaps_statsPayloadDecodesPerZoneMinMax(@TempDir Path tmp) throws IOException { + // Given a zone-mapped file (3 zones of 4 rows) + Path file = write(tmp, true); + + // When the per-zone stats table is decoded the way the inspector decodes Rust files + try (VortexReader reader = VortexReader.open(file)) { + Layout zonesFlat = reader.layout().children().get(0).children().get(1); + SegmentSpec spec = reader.footer().segmentSpecs().get(zonesFlat.segments().getFirst()); + DType nullableI64 = new DType.Primitive(PType.I64, true); + DType.Struct statsDtype = new DType.Struct( + List.of("max", "max_is_truncated", "min", "min_is_truncated"), + List.of(nullableI64, new DType.Bool(false), nullableI64, new DType.Bool(false)), + false); + + try (Arena arena = Arena.ofConfined()) { + StructArray stats = (StructArray) reader.decodeFlatSegment(spec, statsDtype, 3, arena); + LongArray max = (LongArray) ((MaskedArray) stats.field("max")).inner(); + LongArray min = (LongArray) ((MaskedArray) stats.field("min")).inner(); + + // Then min/max per zone match the source data + assertThat(min.getLong(0)).isZero(); + assertThat(max.getLong(0)).isEqualTo(3); + assertThat(min.getLong(1)).isEqualTo(4); + assertThat(max.getLong(1)).isEqualTo(7); + assertThat(min.getLong(2)).isEqualTo(8); + assertThat(max.getLong(2)).isEqualTo(11); + } + } + } +}