diff --git a/writer/src/main/java/io/github/dfa1/vortex/writer/ChunkImpl.java b/writer/src/main/java/io/github/dfa1/vortex/writer/ChunkImpl.java index ff25b783..a740d4e4 100644 --- a/writer/src/main/java/io/github/dfa1/vortex/writer/ChunkImpl.java +++ b/writer/src/main/java/io/github/dfa1/vortex/writer/ChunkImpl.java @@ -62,38 +62,32 @@ private static Object adaptPrimitive(String column, DType.Primitive dtype, Objec return switch (ptype) { case I8, U8 -> switch (value) { case byte[] a -> a; - case Byte[] a -> nullable ? boxedToNullableBytes(a) - : rejectNullable(column, ptype); + case Byte[] a -> nullable ? boxedToNullableBytes(a) : rejectNullable(column, ptype); default -> throw typeMismatch(column, ptype, value); }; case I16, U16 -> switch (value) { case short[] a -> a; - case Short[] a -> nullable ? boxedToNullableShorts(a) - : rejectNullable(column, ptype); + case Short[] a -> nullable ? boxedToNullableShorts(a) : rejectNullable(column, ptype); default -> throw typeMismatch(column, ptype, value); }; case I32, U32 -> switch (value) { case int[] a -> a; - case Integer[] a -> nullable ? boxedToNullableInts(a) - : rejectNullable(column, ptype); + case Integer[] a -> nullable ? boxedToNullableInts(a) : rejectNullable(column, ptype); default -> throw typeMismatch(column, ptype, value); }; case I64, U64 -> switch (value) { case long[] a -> a; - case Long[] a -> nullable ? boxedToNullableLongs(a) - : rejectNullable(column, ptype); + case Long[] a -> nullable ? boxedToNullableLongs(a) : rejectNullable(column, ptype); default -> throw typeMismatch(column, ptype, value); }; case F32 -> switch (value) { case float[] a -> a; - case Float[] a -> nullable ? boxedToNullableFloats(a) - : rejectNullable(column, ptype); + case Float[] a -> nullable ? boxedToNullableFloats(a) : rejectNullable(column, ptype); default -> throw typeMismatch(column, ptype, value); }; case F64 -> switch (value) { case double[] a -> a; - case Double[] a -> nullable ? boxedToNullableDoubles(a) - : rejectNullable(column, ptype); + case Double[] a -> nullable ? boxedToNullableDoubles(a) : rejectNullable(column, ptype); default -> throw typeMismatch(column, ptype, value); }; case F16 -> switch (value) { diff --git a/writer/src/test/java/io/github/dfa1/vortex/writer/ChunkImplTest.java b/writer/src/test/java/io/github/dfa1/vortex/writer/ChunkImplTest.java new file mode 100644 index 00000000..abd0c182 --- /dev/null +++ b/writer/src/test/java/io/github/dfa1/vortex/writer/ChunkImplTest.java @@ -0,0 +1,228 @@ +package io.github.dfa1.vortex.writer; + +import io.github.dfa1.vortex.core.DType; +import io.github.dfa1.vortex.core.PType; +import io.github.dfa1.vortex.writer.encode.NullableData; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class ChunkImplTest { + + private static DType.Struct schema(DType dtype) { + return new DType.Struct(List.of("c"), List.of(dtype), false); + } + + private static DType prim(PType ptype, boolean nullable) { + return new DType.Primitive(ptype, nullable); + } + + private static Object putGet(DType dtype, Object value) { + ChunkImpl sut = new ChunkImpl(schema(dtype)); + sut.put("c", value); + return sut.finish().get("c"); + } + + @Nested + class PutAndFinish { + + @Test + void unknownColumnRejected() { + // Given + ChunkImpl sut = new ChunkImpl(schema(prim(PType.I32, false))); + + // When / Then + assertThatThrownBy(() -> sut.put("nope", new int[]{1})) + .isInstanceOf(IllegalArgumentException.class).hasMessageContaining("unknown column"); + } + + @Test + void duplicatePutRejected() { + // Given + ChunkImpl sut = new ChunkImpl(schema(prim(PType.I32, false))); + sut.put("c", new int[]{1}); + + // When / Then + assertThatThrownBy(() -> sut.put("c", new int[]{2})) + .isInstanceOf(IllegalArgumentException.class).hasMessageContaining("duplicate"); + } + + @Test + void putReturnsSelfForChaining() { + // Given + ChunkImpl sut = new ChunkImpl(schema(prim(PType.I32, false))); + + // When + Chunk result = sut.put("c", new int[]{1}); + + // Then + assertThat(result).isSameAs(sut); + } + + @Test + void finishRejectsMissingColumn() { + // Given — schema has two columns, only one put + ChunkImpl sut = new ChunkImpl(new DType.Struct( + List.of("a", "b"), List.of(prim(PType.I32, false), prim(PType.I32, false)), false)); + sut.put("a", new int[]{1}); + + // When / Then + assertThatThrownBy(sut::finish) + .isInstanceOf(IllegalStateException.class).hasMessageContaining("missing column"); + } + + @Test + void finishReturnsAllColumns() { + // Given + ChunkImpl sut = new ChunkImpl(schema(prim(PType.I32, false))); + int[] col = {1, 2, 3}; + sut.put("c", col); + + // When + Map result = sut.finish(); + + // Then + assertThat(result).containsEntry("c", col); + } + + @Test + void nullValueRejected() { + // When / Then + assertThatThrownBy(() -> putGet(prim(PType.I32, false), null)) + .isInstanceOf(IllegalArgumentException.class).hasMessageContaining("null array"); + } + + @Test + void unadaptedDtypePassesThrough() { + // Given — a List dtype is not specially adapted; its carrier passes through unchanged + DType listDtype = new DType.List(prim(PType.I32, false), false); + Object carrier = new Object(); + + // When + Object result = putGet(listDtype, carrier); + + // Then + assertThat(result).isSameAs(carrier); + } + } + + @Nested + class Primitive { + + @Test + void plainArraysAcceptedForEveryPType() { + // Given / When / Then — the matching primitive array passes through + assertThat(putGet(prim(PType.I8, false), new byte[]{1})).isInstanceOf(byte[].class); + assertThat(putGet(prim(PType.I16, false), new short[]{1})).isInstanceOf(short[].class); + assertThat(putGet(prim(PType.I32, false), new int[]{1})).isInstanceOf(int[].class); + assertThat(putGet(prim(PType.I64, false), new long[]{1})).isInstanceOf(long[].class); + assertThat(putGet(prim(PType.F32, false), new float[]{1})).isInstanceOf(float[].class); + assertThat(putGet(prim(PType.F64, false), new double[]{1})).isInstanceOf(double[].class); + assertThat(putGet(prim(PType.F16, false), new short[]{1})).isInstanceOf(short[].class); + } + + @Test + void boxedArraysConvertToNullableDataOnNullableColumns() { + // Given / When / Then — null slots become invalid in the NullableData carrier + assertValidity(putGet(prim(PType.I8, true), new Byte[]{1, null})); + assertValidity(putGet(prim(PType.I16, true), new Short[]{1, null})); + assertValidity(putGet(prim(PType.I32, true), new Integer[]{1, null})); + assertValidity(putGet(prim(PType.I64, true), new Long[]{1L, null})); + assertValidity(putGet(prim(PType.F32, true), new Float[]{1f, null})); + assertValidity(putGet(prim(PType.F64, true), new Double[]{1.0, null})); + } + + @Test + void boxedArraysRejectedOnNonNullableColumns() { + // Each must hit rejectNullable (the "rejects boxed array" message), not the + // generic typeMismatch — asserting the message keeps these on the boxed arm. + assertRejectsBoxed(prim(PType.I8, false), new Byte[]{1}); + assertRejectsBoxed(prim(PType.I16, false), new Short[]{1}); + assertRejectsBoxed(prim(PType.I32, false), new Integer[]{1}); + assertRejectsBoxed(prim(PType.I64, false), new Long[]{1L}); + assertRejectsBoxed(prim(PType.F32, false), new Float[]{1f}); + assertRejectsBoxed(prim(PType.F64, false), new Double[]{1.0}); + } + + private void assertRejectsBoxed(DType dtype, Object boxed) { + assertThatThrownBy(() -> putGet(dtype, boxed)) + .isInstanceOf(IllegalArgumentException.class).hasMessageContaining("rejects boxed array"); + } + + @Test + void wrongTypeRejectedForEveryPType() { + for (PType p : List.of(PType.I8, PType.I16, PType.I32, PType.I64, PType.F32, PType.F64, PType.F16)) { + assertThatThrownBy(() -> putGet(prim(p, false), "not an array")) + .as("ptype %s", p) + .isInstanceOf(IllegalArgumentException.class).hasMessageContaining("expects"); + } + } + + private void assertValidity(Object result) { + assertThat(result).isInstanceOf(NullableData.class); + assertThat(((NullableData) result).validity()).containsExactly(true, false); + } + } + + @Nested + class Utf8 { + + @Test + void stringArrayAccepted() { + assertThat(putGet(new DType.Utf8(false), new String[]{"a", "b"})).isInstanceOf(String[].class); + } + + @Test + void nullableAllowsNullElements() { + assertThat(putGet(new DType.Utf8(true), new String[]{"a", null})).isInstanceOf(String[].class); + } + + @Test + void nonNullableRejectsNullElement() { + assertThatThrownBy(() -> putGet(new DType.Utf8(false), new String[]{"a", null})) + .isInstanceOf(IllegalArgumentException.class).hasMessageContaining("null at row 1"); + } + + @Test + void wrongTypeRejected() { + assertThatThrownBy(() -> putGet(new DType.Utf8(false), new int[]{1})) + .isInstanceOf(IllegalArgumentException.class).hasMessageContaining("expects String[]"); + } + } + + @Nested + class Bool { + + @Test + void boolArrayAccepted() { + assertThat(putGet(new DType.Bool(false), new boolean[]{true, false})).isInstanceOf(boolean[].class); + } + + @Test + void boxedConvertsToNullableDataOnNullableColumn() { + // Given / When + Object result = putGet(new DType.Bool(true), new Boolean[]{true, null, false}); + + // Then + assertThat(result).isInstanceOf(NullableData.class); + assertThat(((NullableData) result).validity()).containsExactly(true, false, true); + } + + @Test + void boxedRejectedOnNonNullableColumn() { + assertThatThrownBy(() -> putGet(new DType.Bool(false), new Boolean[]{true})) + .isInstanceOf(IllegalArgumentException.class).hasMessageContaining("rejects Boolean[]"); + } + + @Test + void wrongTypeRejected() { + assertThatThrownBy(() -> putGet(new DType.Bool(false), new int[]{1})) + .isInstanceOf(IllegalArgumentException.class).hasMessageContaining("expects boolean[]"); + } + } +}