Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import static io.github.dfa1.vortex.core.io.PTypeIO.LE_INT;

import io.github.dfa1.vortex.core.error.VortexException;
import io.github.dfa1.vortex.core.model.DType;
import io.github.dfa1.vortex.core.io.IoBounds;
import io.github.dfa1.vortex.reader.array.Array;
Expand All @@ -26,6 +27,13 @@
/// FlatBuffer parsing, buffer-offset arithmetic, and encoding-spec lookup.
public final class FlatSegmentDecoder {

/// Hard cap on array-node recursion depth. The encoded array tree nests through child nodes
/// (validity, patches, run-ends, dictionary codes/values, …); a crafted or self-referential
/// FlatBuffer can drive [#convertArrayNode] into unbounded recursion and a [StackOverflowError]
/// — an `Error`, so it would bypass the [VortexException]
/// contract. 64 is well past any real encoding's nesting.
static final int MAX_ARRAY_TREE_DEPTH = 64;

private final ReadRegistry registry;

/// Creates a decoder backed by the given registry.
Expand Down Expand Up @@ -64,20 +72,25 @@ public Array decode(MemorySegment seg, List<String> encodingSpecs,
dataOffset += bufDesc.length();
}

ArrayNode rootNode = convertArrayNode(fbArray.root(), encodingSpecs);
ArrayNode rootNode = convertArrayNode(fbArray.root(), encodingSpecs, 0);
var ctx = new DecodeContext(rootNode, dtype, rowCount, bufs, registry, arena);
return registry.decode(ctx);
}

private static ArrayNode convertArrayNode(
io.github.dfa1.vortex.core.fbs.FbsArrayNode fbs,
List<String> encodingSpecs
List<String> encodingSpecs,
int depth
) {
if (depth > MAX_ARRAY_TREE_DEPTH) {
throw new VortexException(
"array tree depth exceeds limit (" + MAX_ARRAY_TREE_DEPTH + ")");
}
String rawEncodingId = encodingSpecs.get(fbs.encoding());

ArrayNode[] children = new ArrayNode[fbs.childrenLength()];
for (int i = 0; i < children.length; i++) {
children[i] = convertArrayNode(fbs.children(i), encodingSpecs);
children[i] = convertArrayNode(fbs.children(i), encodingSpecs, depth + 1);
}

int[] bufferIndices = new int[fbs.buffersLength()];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ final class PostscriptParser {
/// real encoding's metadata footprint (the largest is FSST's symbol table at ~32 KiB).
static final int MAX_LAYOUT_METADATA_BYTES = 4 * 1024 * 1024;

/// Hard cap on DType-tree recursion depth. A `DType` nests through Struct fields, List/
/// FixedSizeList element types, and Extension storage types; like the layout tree, a crafted
/// or self-referential FlatBuffer can drive [#convertDType(io.github.dfa1.vortex.core.fbs.FbsDType, int)]
/// into unbounded recursion and a [StackOverflowError] — which, being an `Error`, would escape
/// the [VortexException] sanitization and leak the reader's memory-mapped Arena. 64 is well past
/// any real schema's nesting.
static final int MAX_DTYPE_DEPTH = 64;

private PostscriptParser() {
}

Expand Down Expand Up @@ -110,7 +118,7 @@ static ParsedFile parseBlobs(MemorySegment footerBuf, MemorySegment layoutBuf, M

DType dtype = null;
if (dtypeBuf != null && dtypeBuf.byteSize() > 0) {
dtype = convertDType(io.github.dfa1.vortex.core.fbs.FbsDType.getRootAsFbsDType(dtypeBuf));
dtype = convertDType(io.github.dfa1.vortex.core.fbs.FbsDType.getRootAsFbsDType(dtypeBuf), 0);
}

return new ParsedFile(footer, dtype, layout);
Expand Down Expand Up @@ -188,7 +196,11 @@ private static Layout convertLayout(io.github.dfa1.vortex.core.fbs.FbsLayout l,
return new Layout(encodingId, l.rowCount(), metadata, List.copyOf(children), List.copyOf(segments));
}

private static DType convertDType(io.github.dfa1.vortex.core.fbs.FbsDType fbs) {
private static DType convertDType(io.github.dfa1.vortex.core.fbs.FbsDType fbs, int depth) {
if (depth > MAX_DTYPE_DEPTH) {
throw new VortexException(
"DType tree depth exceeds limit (" + MAX_DTYPE_DEPTH + ")");
}
int typeType = fbs.typeType();
return switch (typeType) {
case FbsType.FbsNull -> new DType.Null(true);
Expand Down Expand Up @@ -224,21 +236,21 @@ private static DType convertDType(io.github.dfa1.vortex.core.fbs.FbsDType fbs) {
names.add(s.names(i));
}
for (int i = 0; i < s.dtypesLength(); i++) {
types.add(convertDType(s.dtypes(i)));
types.add(convertDType(s.dtypes(i), depth + 1));
}
yield new DType.Struct(List.copyOf(names), List.copyOf(types), s.nullable());
}
case FbsType.FbsList -> {
var l = fbs.type(new io.github.dfa1.vortex.core.fbs.FbsList());
yield new DType.List(convertDType(l.elementType()), l.nullable());
yield new DType.List(convertDType(l.elementType(), depth + 1), l.nullable());
}
case FbsType.FbsFixedSizeList -> {
var fsl = fbs.type(new FbsFixedSizeList());
yield new DType.FixedSizeList(convertDType(fsl.elementType()), (int) fsl.size(), fsl.nullable());
yield new DType.FixedSizeList(convertDType(fsl.elementType(), depth + 1), (int) fsl.size(), fsl.nullable());
}
case FbsType.FbsExtension -> {
var e = fbs.type(new FbsExtension());
DType storage = convertDType(e.storageDtype());
DType storage = convertDType(e.storageDtype(), depth + 1);
yield new DType.Extension(
e.id(),
storage,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ public static VortexHttpReader open(URI uri, ReadRegistry registry, HttpClient c
: null;

var parsed = PostscriptParser.parseBlobs(footerBuf, layoutBuf, dtypeBuf);
// Reject footer segmentSpecs that fall outside the file before any rawSegment() builds a
// Range request from them. The local-file path runs this inside PostscriptParser.parse;
// the HTTP path calls parseBlobs directly and must validate here too.
PostscriptParser.validateSegmentSpecs(parsed.footer().segmentSpecs(), fileSize);

return new VortexHttpReader(
uri, client, fileSize, trailer.version(),
Expand Down Expand Up @@ -155,26 +159,46 @@ private static TailFetch fetchTail(URI uri, HttpClient client) throws IOExceptio
// Content-Range: bytes <start>-<end>/<total>
String cr = resp.headers().firstValue("Content-Range")
.orElseThrow(() -> new VortexException("206 response missing Content-Range from " + uri));
String spec = cr.substring("bytes ".length()); // "<start>-<end>/<total>"
return parseContentRange(cr, body, uri);
}

if (status == 200) {
// Server returned full file (no Range support)
return new TailFetch(body, 0L, body.length);
}

throw new VortexException("HTTP " + status + " fetching tail of " + uri);
}

/// Parses a `bytes <start>-<end>/<total>` Content-Range header from an untrusted server.
/// Any structural defect (missing `bytes ` prefix, missing `-`/`/`, non-numeric fields)
/// surfaces as a [VortexException] rather than a raw [NumberFormatException] or
/// [StringIndexOutOfBoundsException].
private static TailFetch parseContentRange(String contentRange, byte[] body, URI uri) {
try {
String prefix = "bytes ";
if (!contentRange.startsWith(prefix)) {
throw new VortexException("malformed Content-Range '" + contentRange + "' from " + uri);
}
String spec = contentRange.substring(prefix.length()); // "<start>-<end>/<total>"
int dash = spec.indexOf('-');
int slash = spec.indexOf('/');
if (dash < 0 || slash < 0 || dash > slash) {
throw new VortexException("malformed Content-Range '" + contentRange + "' from " + uri);
}
long start = Long.parseLong(spec.substring(0, dash));
long end = Long.parseLong(spec.substring(dash + 1, slash));
long total = Long.parseLong(spec.substring(slash + 1));
long start = Long.parseLong(spec.substring(0, spec.indexOf('-')));
long end = Long.parseLong(spec.substring(spec.indexOf('-') + 1, slash));
long expected = end - start + 1;
if (body.length != expected) {
throw new VortexException(
"HTTP tail from %s: Content-Range declares %d bytes but body has %d"
.formatted(uri, expected, body.length));
}
return new TailFetch(body, start, total);
} catch (NumberFormatException e) {
throw new VortexException("malformed Content-Range '" + contentRange + "' from " + uri, e);
}

if (status == 200) {
// Server returned full file (no Range support)
return new TailFetch(body, 0L, body.length);
}

throw new VortexException("HTTP " + status + " fetching tail of " + uri);
}

private static byte[] fetchRange(URI uri, long from, long to, HttpClient client) throws IOException {
Expand Down Expand Up @@ -211,8 +235,8 @@ private static MemorySegment fetchBlob(
URI uri, HttpClient client
) throws IOException {
if (offset >= tailStart) {
int relOffset = (int) (offset - tailStart);
return MemorySegment.ofArray(tail).asSlice(relOffset, length);
long relOffset = offset - tailStart;
return IoBounds.slice(MemorySegment.ofArray(tail), relOffset, length);
}
byte[] bytes = fetchRange(uri, offset, offset + length - 1, client);
return MemorySegment.ofArray(bytes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import io.github.dfa1.vortex.core.model.PType;
import io.github.dfa1.vortex.core.error.VortexException;
import io.github.dfa1.vortex.core.model.EncodingId;
import io.github.dfa1.vortex.core.io.IoBounds;
import io.github.dfa1.vortex.core.io.PTypeIO;
import io.github.dfa1.vortex.core.proto.ProtoZstdMetadata;
import io.github.dfa1.vortex.reader.array.Array;
Expand Down Expand Up @@ -63,7 +64,17 @@ public Array decode(DecodeContext ctx) {
int frameCount = meta.frames().size();
long totalUncompressed = 0;
for (int i = 0; i < frameCount; i++) {
totalUncompressed += meta.frames().get(i).uncompressed_size();
// Validate each frame's declared size (rejects negative / >2 GB) and accumulate
// overflow-safely, so a crafted metadata cannot wrap the total to a small positive
// value and under-allocate, nor drive arena.allocate negative. The per-frame cap also
// guards the (int) narrowing at the asSlice call site in decompressFrames.
int frameSize = IoBounds.toIntSize(meta.frames().get(i).uncompressed_size());
try {
totalUncompressed = Math.addExact(totalUncompressed, frameSize);
} catch (ArithmeticException e) {
throw new VortexException(EncodingId.VORTEX_ZSTD,
"total uncompressed size overflows", e);
}
}

MemorySegment decompressed = decompressFrames(ctx, meta, frameCount, totalUncompressed);
Expand Down Expand Up @@ -112,7 +123,7 @@ private static VarBinArray buildScatteredVarBin(
long scanPos = 0;
for (long i = 0; i < rowCount; i++) {
if (validity.getBoolean(i)) {
int len = validValues.get(PTypeIO.LE_INT, scanPos);
int len = readVarBinLen(validValues, scanPos);
scanPos += 4L + len;
totalDataBytes += len;
}
Expand All @@ -122,6 +133,9 @@ private static VarBinArray buildScatteredVarBin(
MemorySegment offsets = ctx.arena().allocate((rowCount + 1) * 4L, 4);
offsets.setAtIndex(PTypeIO.LE_INT, 0, 0);

// Second pass reads the same positions the first pass already bounds-checked via
// readVarBinLen, so a raw get/copy here cannot overrun; values is sized to the validated
// total. Keep both passes in lockstep — any edit to the cursor advance must stay identical.
long readPos = 0;
long dataPos = 0;
for (long i = 0; i < rowCount; i++) {
Expand Down Expand Up @@ -163,7 +177,7 @@ private static MemorySegment decompressFrames(
long outOffset = 0;
for (int i = 0; i < frameCount; i++) {
MemorySegment src = asNative(ctx.buffer(frameBufferBase + i), scratch);
int uncompSize = (int) meta.frames().get(i).uncompressed_size();
int uncompSize = IoBounds.toIntSize(meta.frames().get(i).uncompressed_size());
MemorySegment dst = out.asSlice(outOffset, uncompSize);
long written = dictionary == null
? dctx.decompress(dst, src)
Expand Down Expand Up @@ -236,11 +250,28 @@ private static Array buildPrimitive(DType.Primitive dt, long n, MemorySegment de
};
}

/// Reads a 4-byte little-endian length prefix at `pos` from a decompressed VarBin payload and
/// validates that both the prefix and the `len` bytes that follow lie within `src`. Without this,
/// a crafted payload with a negative or oversized length would advance the cursor out of bounds
/// and surface as a raw [IndexOutOfBoundsException] instead of a
/// [io.github.dfa1.vortex.core.error.VortexException].
///
/// @param src the decompressed VarBin payload segment
/// @param pos byte offset of the length prefix within `src`
/// @return the validated element length in bytes
private static int readVarBinLen(MemorySegment src, long pos) {
IoBounds.checkRange(pos, 4, src.byteSize());
int len = src.get(PTypeIO.LE_INT, pos);
// checkRange rejects len < 0 and a [pos+4, pos+4+len) range that overruns src.
IoBounds.checkRange(pos + 4L, len, src.byteSize());
return len;
}

private static VarBinArray buildVarBin(DType dtype, long n, MemorySegment decompressed, DecodeContext ctx) {
long totalDataBytes = 0;
long pos = 0;
for (long i = 0; i < n; i++) {
int len = decompressed.get(PTypeIO.LE_INT, pos);
int len = readVarBinLen(decompressed, pos);
pos += 4 + len;
totalDataBytes += len;
}
Expand All @@ -249,6 +280,9 @@ private static VarBinArray buildVarBin(DType dtype, long n, MemorySegment decomp
MemorySegment offsets = ctx.arena().allocate((n + 1) * 4L, 4);
offsets.setAtIndex(PTypeIO.LE_INT, 0, 0);

// Second pass reads the same positions the first pass already bounds-checked via
// readVarBinLen, so a raw get/copy here cannot overrun; values is sized to the validated
// total. Keep both passes in lockstep — any edit to the cursor advance must stay identical.
pos = 0;
long dataPos = 0;
for (long i = 0; i < n; i++) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package io.github.dfa1.vortex.reader;

import io.github.dfa1.vortex.core.error.VortexException;
import io.github.dfa1.vortex.core.fbs.FbsArray;
import io.github.dfa1.vortex.core.fbs.FbsArrayNode;
import io.github.dfa1.vortex.core.fbs.FbsBuilder;
import io.github.dfa1.vortex.core.io.PTypeIO;
import io.github.dfa1.vortex.core.model.DType;
import org.junit.jupiter.api.Test;

import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.util.List;

import static org.assertj.core.api.Assertions.assertThatThrownBy;

/// Adversarial test for the encoded-array-tree recursion in
/// [FlatSegmentDecoder]'s `convertArrayNode`.
///
/// The decoder walks the array node tree recursively (validity, patches, run-ends, dictionary
/// codes/values, …). Without the [FlatSegmentDecoder#MAX_ARRAY_TREE_DEPTH] cap a crafted segment
/// with thousands of nested children produces a [StackOverflowError] — an `Error` that escapes the
/// "malformed input must surface as [VortexException]" contract (ADR 0003). This pins that contract.
class ArrayNodeDepthBombSecurityTest {

private static final DType DTYPE = DType.I32;

private final FlatSegmentDecoder sut = new FlatSegmentDecoder(ReadRegistry.empty());

@Test
void deeplyNestedArrayTree_throwsVortexException() {
try (Arena arena = Arena.ofConfined()) {
// Given — a flat segment whose FbsArray root nests 65536 levels of single-child nodes.
// Real encodings nest only a handful of levels; 65536 reliably blows the JVM stack on
// the recursive convertArrayNode walk if the depth cap is removed.
byte[] fb = deeplyNestedArrayFlatBuffer(65536);
MemorySegment seg = arena.allocate(fb.length + 4L);
MemorySegment.copy(MemorySegment.ofArray(fb), 0, seg, 0, fb.length);
seg.set(PTypeIO.LE_INT, fb.length, fb.length);

// When / Then — must surface as VortexException, not StackOverflowError
assertThatThrownBy(() -> sut.decode(seg, List.of("vortex.flat"), DTYPE, 1, arena))
.isInstanceOf(VortexException.class);
}
}

/// Builds a minimal `FbsArray` whose root node has `depth` levels of single-child nesting,
/// each level a buffer-less node referencing encoding index 0.
private static byte[] deeplyNestedArrayFlatBuffer(int depth) {
FbsBuilder b = new FbsBuilder(depth * 32);
// Leaf first; FlatBuffer requires children be finished before parents.
int emptyChildren = FbsArrayNode.createChildrenVector(b, new int[0]);
int emptyBuffers = FbsArrayNode.createBuffersVector(b, new int[0]);
int current = FbsArrayNode.createFbsArrayNode(b, 0, 0, emptyChildren, emptyBuffers, 0);
for (int i = 0; i < depth; i++) {
int childV = FbsArrayNode.createChildrenVector(b, new int[]{current});
int bufV = FbsArrayNode.createBuffersVector(b, new int[0]);
current = FbsArrayNode.createFbsArrayNode(b, 0, 0, childV, bufV, 0);
}
FbsArray.startBuffersVector(b, 0);
int buffers = b.endVector();
int array = FbsArray.createFbsArray(b, current, buffers);
FbsArray.finishFbsArrayBuffer(b, array);
return b.sizedByteArray();
}
}
Loading
Loading