Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 187 additions & 6 deletions writer/src/main/java/io/github/dfa1/vortex/writer/VortexWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
import io.github.dfa1.vortex.encoding.EncodingId;
import io.github.dfa1.vortex.writer.encode.EncodeContext;
import io.github.dfa1.vortex.writer.encode.EncodeNode;
import io.github.dfa1.vortex.proto.ScalarValue;
import io.github.dfa1.vortex.writer.encode.EncodeResult;
import io.github.dfa1.vortex.writer.encode.NullableData;
import io.github.dfa1.vortex.writer.encode.StructData;
import io.github.dfa1.vortex.writer.encode.StructEncodingEncoder;
import io.github.dfa1.vortex.writer.encode.EncodingEncoder;
import io.github.dfa1.vortex.writer.encode.AlpEncodingEncoder;
import io.github.dfa1.vortex.writer.encode.BitpackedEncodingEncoder;
Expand Down Expand Up @@ -75,6 +79,11 @@ public final class VortexWriter implements Closeable {
private static final int LAYOUT_CHUNKED = 1;
private static final int LAYOUT_STRUCT = 2;
private static final int LAYOUT_DICT = 3;
private static final int LAYOUT_ZONED = 4;

// Stat ordinals in the Rust `Stat` enum (see ZonedStatsSchema). v1 emits MAX + MIN only.
private static final int STAT_MAX = 3;
private static final int STAT_MIN = 4;

// Columns with global cardinality below this threshold are dict-encoded across all chunks.
// Kept low: global dict hurts high-cardinality F64 columns (ALP codes beat U16 dict codes).
Expand Down Expand Up @@ -107,6 +116,12 @@ public final class VortexWriter implements Closeable {
private final Map<String, DictColRef> dictColRefs = new LinkedHashMap<>();
private boolean firstChunkSeen = false;

// Per-column zone-maps, populated by flushZoneMaps() in close() when enableZoneMaps is set.
private final Map<String, ZoneMapRef> zoneMaps = new LinkedHashMap<>();
// Stats (ScalarValue bytes) of the most recently written segment, captured for ChunkRef.
private byte[] lastStatsMin;
private byte[] lastStatsMax;

private VortexWriter(
WritableByteChannel channel, DType.Struct schema, WriteOptions options, List<EncodingEncoder> encodings
) {
Expand Down Expand Up @@ -451,7 +466,7 @@ public void writeChunk(Map<String, Object> columns) throws IOException {
} else {
long rowCount = arrayLength(data);
int segIdx = writeSegment(colDtype, data);
colChunks.get(colName).add(new ChunkRef(segIdx, rowCount));
colChunks.get(colName).add(new ChunkRef(segIdx, rowCount, lastStatsMin, lastStatsMax));
}
}
firstChunkSeen = true;
Expand All @@ -460,6 +475,7 @@ public void writeChunk(Map<String, Object> columns) throws IOException {
@Override
public void close() throws IOException {
flushDictColumns();
flushZoneMaps();
ByteBuffer footerBuf = buildFooter();
long footerOff = bytesWritten;
write(footerBuf);
Expand Down Expand Up @@ -550,6 +566,8 @@ private int writeSegment(DType dtype, Object data, EncodingEncoder encodingOverr
bytesWritten += 4;

segs.add(new SegRef(offset, bytesWritten - offset));
lastStatsMin = result.statsMin();
lastStatsMax = result.statsMax();
return segIdx;
}
}
Expand Down Expand Up @@ -656,6 +674,159 @@ private int buildArrayNodeFlatBuffer(FlatBufferBuilder fbb, EncodeNode node, int
fbb, encIdx, metaOff, childVec, bufIdxVec, statsOff);
}

/// Emits a per-column `vortex.stats` zone-map (one zone per chunk) for every fixed-width
/// primitive column whose chunks all carry min/max stats. Must run before [#buildFooter]
/// so the stats-table segments are present in `segment_specs`.
private void flushZoneMaps() throws IOException {
if (!options.enableZoneMaps()) {
return;
}
for (Map.Entry<String, List<ChunkRef>> e : colChunks.entrySet()) {
String colName = e.getKey();
List<ChunkRef> chunks = e.getValue();
if (chunks.isEmpty()) {
continue;
}
DType colDtype = schema.fieldTypes().get(schema.fieldNames().indexOf(colName));
if (!(colDtype instanceof DType.Primitive prim) || !isZoneMappable(prim.ptype())) {
continue;
}
if (!chunks.stream().allMatch(ChunkRef::hasStats)) {
continue;
}
int nZones = chunks.size();
boolean[] allValid = new boolean[nZones];
java.util.Arrays.fill(allValid, true);
boolean[] notTruncated = new boolean[nZones];
DType nullablePrim = new DType.Primitive(prim.ptype(), true);
// Field order mirrors ZonedStatsSchema.statsTableDtype for present stats MAX(3), MIN(4):
// [max, max_is_truncated, min, min_is_truncated].
DType.Struct statsDtype = new DType.Struct(
List.of("max", "max_is_truncated", "min", "min_is_truncated"),
List.of(nullablePrim, new DType.Bool(false), nullablePrim, new DType.Bool(false)),
false);
StructData sd = new StructData(List.of(
new NullableData(statColumn(prim.ptype(), chunks, true), allValid),
notTruncated,
new NullableData(statColumn(prim.ptype(), chunks, false), allValid.clone()),
notTruncated.clone()));
int zonesSegIdx = writeSegment(statsDtype, sd, new StructEncodingEncoder());
zoneMaps.put(colName, new ZoneMapRef(zonesSegIdx, nZones, options.chunkSize()));
}
}

/// Wraps a column's data layout in a `vortex.stats` (zoned) layout when a zone-map was
/// emitted for it; otherwise returns the data layout unchanged.
private int wrapZoneMap(FlatBufferBuilder fbb, String colName, int dataLayout, long colRows) {
ZoneMapRef zm = zoneMaps.get(colName);
if (zm == null) {
return dataLayout;
}
int zonesSegV = Layout.createSegmentsVector(fbb, new long[]{zm.zonesSegIdx()});
int zonesFlat = Layout.createLayout(fbb, LAYOUT_FLAT, zm.nZones(), 0, 0, zonesSegV);
int childV = Layout.createChildrenVector(fbb, new int[]{dataLayout, zonesFlat});
int metaV = Layout.createMetadataVector(fbb, zonedMetadataBytes(zm.zoneLen()));
return Layout.createLayout(fbb, LAYOUT_ZONED, colRows, metaV, childV, 0);
}

/// `vortex.stats` metadata: `u32` zone length (LE) + a 1-byte stat bitset with the MAX and
/// MIN bits set (LSB-first), matching [io.github.dfa1.vortex.inspect] `ZonedStatsSchema`.
private static byte[] zonedMetadataBytes(long zoneLen) {
byte[] meta = new byte[5];
ByteBuffer.wrap(meta).order(ByteOrder.LITTLE_ENDIAN).putInt((int) zoneLen);
meta[4] = (byte) ((1 << STAT_MAX) | (1 << STAT_MIN));
return meta;
}

private static boolean isZoneMappable(PType ptype) {
return switch (ptype) {
case I8, I16, I32, I64, U8, U16, U32, U64, F32, F64 -> true;
case F16 -> false;
};
}

/// Builds the per-zone min (or max) values array in the storage shape the primitive encoder
/// expects, decoding each chunk's serialised [ScalarValue] stat.
private static Object statColumn(PType ptype, List<ChunkRef> chunks, boolean max) {
int n = chunks.size();
return switch (ptype) {
case I8, U8 -> {
byte[] a = new byte[n];
for (int i = 0; i < n; i++) {
a[i] = (byte) scalarLong(chunks.get(i), max);
}
yield a;
}
case I16, U16 -> {
short[] a = new short[n];
for (int i = 0; i < n; i++) {
a[i] = (short) scalarLong(chunks.get(i), max);
}
yield a;
}
case I32, U32 -> {
int[] a = new int[n];
for (int i = 0; i < n; i++) {
a[i] = (int) scalarLong(chunks.get(i), max);
}
yield a;
}
case I64, U64 -> {
long[] a = new long[n];
for (int i = 0; i < n; i++) {
a[i] = scalarLong(chunks.get(i), max);
}
yield a;
}
case F32 -> {
float[] a = new float[n];
for (int i = 0; i < n; i++) {
a[i] = (float) scalarDouble(chunks.get(i), max);
}
yield a;
}
case F64 -> {
double[] a = new double[n];
for (int i = 0; i < n; i++) {
a[i] = scalarDouble(chunks.get(i), max);
}
yield a;
}
case F16 -> throw new IllegalStateException("F16 is not zone-mappable");
};
}

private static long scalarLong(ChunkRef cr, boolean max) {
ScalarValue sv = decodeScalar(max ? cr.statsMax() : cr.statsMin());
if (sv.int64_value() != null) {
return sv.int64_value();
}
if (sv.uint64_value() != null) {
return sv.uint64_value();
}
throw new IllegalStateException("expected integer scalar stat");
}

private static double scalarDouble(ChunkRef cr, boolean max) {
ScalarValue sv = decodeScalar(max ? cr.statsMax() : cr.statsMin());
if (sv.f64_value() != null) {
return sv.f64_value();
}
if (sv.f32_value() != null) {
return sv.f32_value();
}
throw new IllegalStateException("expected float scalar stat");
}

private static ScalarValue decodeScalar(byte[] bytes) {
MemorySegment seg = MemorySegment.ofArray(bytes);
try {
return ScalarValue.decode(seg, 0, seg.byteSize());
} catch (IOException ex) {
throw new java.io.UncheckedIOException(ex);
}
}

private ByteBuffer buildFooter() {
var fbb = new FlatBufferBuilder(512);

Expand All @@ -675,7 +846,8 @@ private ByteBuffer buildFooter() {
int ls1 = LayoutSpec.createLayoutSpec(fbb, fbb.createString("vortex.chunked"));
int ls2 = LayoutSpec.createLayoutSpec(fbb, fbb.createString("vortex.struct"));
int ls3 = LayoutSpec.createLayoutSpec(fbb, fbb.createString("vortex.dict"));
int lsv = Footer.createLayoutSpecsVector(fbb, new int[]{ls0, ls1, ls2, ls3});
int ls4 = LayoutSpec.createLayoutSpec(fbb, fbb.createString("vortex.stats"));
int lsv = Footer.createLayoutSpecsVector(fbb, new int[]{ls0, ls1, ls2, ls3, ls4});

// segment_specs (inline struct vector — write in reverse order)
Footer.startSegmentSpecsVector(fbb, segs.size());
Expand Down Expand Up @@ -716,7 +888,8 @@ private ByteBuffer buildLayout() {
colRows += cr.rowCount();
}
int childV = Layout.createChildrenVector(fbb, flats);
colLayouts[c] = Layout.createLayout(fbb, LAYOUT_CHUNKED, colRows, 0, childV, 0);
int dataChunked = Layout.createLayout(fbb, LAYOUT_CHUNKED, colRows, 0, childV, 0);
colLayouts[c] = wrapZoneMap(fbb, colName, dataChunked, colRows);
if (totalRows == 0) {
totalRows = colRows;
}
Expand Down Expand Up @@ -816,7 +989,7 @@ private void writeGlobalDictColumn(String colName, DType.Primitive dtype, List<O
for (Object chunk : chunks) {
long rowCount = arrayLength(chunk);
int segIdx = writeSegment(dtype, chunk);
colChunks.get(colName).add(new ChunkRef(segIdx, rowCount));
colChunks.get(colName).add(new ChunkRef(segIdx, rowCount, lastStatsMin, lastStatsMax));
}
return;
}
Expand Down Expand Up @@ -860,7 +1033,7 @@ private void writeGlobalDictUtf8Column(String colName, DType.Utf8 dtype, List<Ob
for (Object chunk : chunks) {
long rowCount = arrayLength(chunk);
int segIdx = writeSegment(dtype, chunk);
colChunks.get(colName).add(new ChunkRef(segIdx, rowCount));
colChunks.get(colName).add(new ChunkRef(segIdx, rowCount, lastStatsMin, lastStatsMax));
}
return;
}
Expand Down Expand Up @@ -1075,7 +1248,15 @@ private static Object buildCodesArray(Object data, PType ptype, Map<Object, Inte
private record SegRef(long offset, long len) {
}

private record ChunkRef(int segIdx, long rowCount) {
private record ChunkRef(int segIdx, long rowCount, byte[] statsMin, byte[] statsMax) {
boolean hasStats() {
return statsMin != null && statsMax != null;
}
}

/// Per-column zone-map: the flat segment holding the per-zone stats table, the zone
/// count (one zone per chunk), and the logical rows per zone.
private record ZoneMapRef(int zonesSegIdx, long nZones, long zoneLen) {
}

private record DictColRef(int valuesSegIdx, long valuesLen, List<Integer> codesSegIdxes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,15 @@ public EncodeResult encode(DType dtype, Object data, EncodeContext ctx) {
EncodeNode[] children = new EncodeNode[fields.size()];
for (int i = 0; i < fields.size(); i++) {
DType fieldDtype = fieldTypes.get(i);
EncodeResult fieldResult = findEncoding(fieldDtype).encode(fieldDtype, fields.get(i), ctx);
Object fieldData = fields.get(i);
// Nullable fields carry a NullableData(values, validity) pair; route them through the
// masked encoder (values + validity bitmap), mirroring VortexWriter.writeSegment. The
// plain fallbacks only handle dense arrays.
EncodingEncoder fieldEncoder =
(fieldData instanceof NullableData && !(fieldDtype instanceof DType.Extension))
? new MaskedEncodingEncoder()
: findEncoding(fieldDtype);
EncodeResult fieldResult = fieldEncoder.encode(fieldDtype, fieldData, ctx);
int bufOffset = allBuffers.size();
children[i] = EncodeNode.remapBufferIndices(fieldResult.rootNode(), bufOffset);
allBuffers.addAll(fieldResult.buffers());
Expand Down
Loading