Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import io.github.dfa1.vortex.core.DType;
import io.github.dfa1.vortex.core.PType;
import io.github.dfa1.vortex.core.VortexException;
import io.github.dfa1.vortex.encoding.PTypeIO;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;

import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.util.ArrayList;
import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
Expand Down Expand Up @@ -155,6 +157,58 @@ void foldSumsThroughCodes() {
assertThat(total).isEqualTo(6L);
}
}

@Test
void allCodeTypes_getForEachFoldMaterialize() {
try (Arena arena = Arena.ofConfined()) {
// Given — same selection [2,0,1] expressed as U8/U16/U32/U64 codes
LongArray values = longArray(arena, 100L, 200L, 300L);
long[] expected = {300L, 100L, 200L};
List<Array> codeVariants = List.of(
byteArray(arena, (byte) 2, (byte) 0, (byte) 1),
shortArray(arena, U16, (short) 2, (short) 0, (short) 1),
intArray(arena, U32, 2, 0, 1),
longArray(arena, U64, 2L, 0L, 1L));

for (Array codes : codeVariants) {
DictLongArray sut = DictLongArray.of(I64, 3, values, codes);
String label = codes.getClass().getSimpleName();

// getter
for (int i = 0; i < 3; i++) {
assertThat(sut.getLong(i)).as(label).isEqualTo(expected[i]);
}
// forEach
var seen = new ArrayList<Long>();
sut.forEachLong(seen::add);
assertThat(seen).as(label).containsExactly(300L, 100L, 200L);
// fold
assertThat(sut.fold(0L, Long::sum)).as(label).isEqualTo(600L);
// materialize
MemorySegment m = sut.materialize(arena);
for (int i = 0; i < 3; i++) {
assertThat(m.getAtIndex(PTypeIO.LE_LONG, i)).as(label).isEqualTo(expected[i]);
}
}
}
}

@Test
void bulkOps_invalidCodesType_throw() {
// Given — a record built directly (bypassing of()) with a non-int codes array
try (Arena arena = Arena.ofConfined()) {
LongArray values = longArray(arena, 1L);
DictLongArray sut = new DictLongArray(I64, 1, values, floatArray(arena, 0f));

// When / Then — every bulk path hits the defensive default arm
assertThatThrownBy(() -> sut.materialize(arena))
.isInstanceOf(VortexException.class).hasMessageContaining("invalid codes");
assertThatThrownBy(() -> sut.forEachLong(v -> { }))
.isInstanceOf(VortexException.class);
assertThatThrownBy(() -> sut.fold(0L, Long::sum))
.isInstanceOf(VortexException.class);
}
}
}

@Nested
Expand Down Expand Up @@ -198,6 +252,51 @@ void foldThroughCodes() {
assertThat(total).isEqualTo(11);
}
}

@Test
void allCodeTypes_getForEachFoldMaterialize() {
try (Arena arena = Arena.ofConfined()) {
IntArray values = intArray(arena, I32, 100, 200, 300);
int[] expected = {300, 100, 200};
List<Array> codeVariants = List.of(
byteArray(arena, (byte) 2, (byte) 0, (byte) 1),
shortArray(arena, U16, (short) 2, (short) 0, (short) 1),
intArray(arena, U32, 2, 0, 1),
longArray(arena, U64, 2L, 0L, 1L));

for (Array codes : codeVariants) {
DictIntArray sut = DictIntArray.of(I32, 3, values, codes);
String label = codes.getClass().getSimpleName();

for (int i = 0; i < 3; i++) {
assertThat(sut.getInt(i)).as(label).isEqualTo(expected[i]);
}
var seen = new ArrayList<Integer>();
sut.forEachInt(seen::add);
assertThat(seen).as(label).containsExactly(300, 100, 200);
assertThat(sut.fold(0, Integer::sum)).as(label).isEqualTo(600);
MemorySegment m = sut.materialize(arena);
for (int i = 0; i < 3; i++) {
assertThat(m.getAtIndex(PTypeIO.LE_INT, i)).as(label).isEqualTo(expected[i]);
}
}
}
}

@Test
void bulkOps_invalidCodesType_throw() {
try (Arena arena = Arena.ofConfined()) {
IntArray values = intArray(arena, I32, 1);
DictIntArray sut = new DictIntArray(I32, 1, values, floatArray(arena, 0f));

assertThatThrownBy(() -> sut.materialize(arena))
.isInstanceOf(VortexException.class).hasMessageContaining("invalid codes");
assertThatThrownBy(() -> sut.forEachInt(v -> { }))
.isInstanceOf(VortexException.class);
assertThatThrownBy(() -> sut.fold(0, Integer::sum))
.isInstanceOf(VortexException.class);
}
}
}

@Nested
Expand Down Expand Up @@ -241,6 +340,51 @@ void foldThroughCodes() {
assertThat(total).isEqualTo(5.0);
}
}

@Test
void allCodeTypes_getForEachFoldMaterialize() {
try (Arena arena = Arena.ofConfined()) {
DoubleArray values = doubleArray(arena, 1.5, 2.5, 3.5);
double[] expected = {3.5, 1.5, 2.5};
List<Array> codeVariants = List.of(
byteArray(arena, (byte) 2, (byte) 0, (byte) 1),
shortArray(arena, U16, (short) 2, (short) 0, (short) 1),
intArray(arena, U32, 2, 0, 1),
longArray(arena, U64, 2L, 0L, 1L));

for (Array codes : codeVariants) {
DictDoubleArray sut = DictDoubleArray.of(F64, 3, values, codes);
String label = codes.getClass().getSimpleName();

for (int i = 0; i < 3; i++) {
assertThat(sut.getDouble(i)).as(label).isEqualTo(expected[i]);
}
var seen = new ArrayList<Double>();
sut.forEachDouble(seen::add);
assertThat(seen).as(label).containsExactly(3.5, 1.5, 2.5);
assertThat(sut.fold(0.0, Double::sum)).as(label).isEqualTo(7.5);
MemorySegment m = sut.materialize(arena);
for (int i = 0; i < 3; i++) {
assertThat(m.getAtIndex(PTypeIO.LE_DOUBLE, i)).as(label).isEqualTo(expected[i]);
}
}
}
}

@Test
void bulkOps_invalidCodesType_throw() {
try (Arena arena = Arena.ofConfined()) {
DoubleArray values = doubleArray(arena, 1.0);
DictDoubleArray sut = new DictDoubleArray(F64, 1, values, floatArray(arena, 0f));

assertThatThrownBy(() -> sut.materialize(arena))
.isInstanceOf(VortexException.class).hasMessageContaining("invalid codes");
assertThatThrownBy(() -> sut.forEachDouble(v -> { }))
.isInstanceOf(VortexException.class);
assertThatThrownBy(() -> sut.fold(0.0, Double::sum))
.isInstanceOf(VortexException.class);
}
}
}

@Nested
Expand Down Expand Up @@ -270,6 +414,46 @@ void foldThroughCodes() {
assertThat(total).isEqualTo(3.5);
}
}

@Test
void allCodeTypes_getFoldMaterialize() {
try (Arena arena = Arena.ofConfined()) {
FloatArray values = floatArray(arena, 1.5f, 2.5f, 3.5f);
float[] expected = {3.5f, 1.5f, 2.5f};
List<Array> codeVariants = List.of(
byteArray(arena, (byte) 2, (byte) 0, (byte) 1),
shortArray(arena, U16, (short) 2, (short) 0, (short) 1),
intArray(arena, U32, 2, 0, 1),
longArray(arena, U64, 2L, 0L, 1L));

for (Array codes : codeVariants) {
DictFloatArray sut = DictFloatArray.of(F32, 3, values, codes);
String label = codes.getClass().getSimpleName();

for (int i = 0; i < 3; i++) {
assertThat(sut.getFloat(i)).as(label).isEqualTo(expected[i]);
}
assertThat(sut.fold(0.0, Double::sum)).as(label).isEqualTo(7.5);
MemorySegment m = sut.materialize(arena);
for (int i = 0; i < 3; i++) {
assertThat(m.getAtIndex(PTypeIO.LE_FLOAT, i)).as(label).isEqualTo(expected[i]);
}
}
}
}

@Test
void bulkOps_invalidCodesType_throw() {
try (Arena arena = Arena.ofConfined()) {
FloatArray values = floatArray(arena, 1.0f);
DictFloatArray sut = new DictFloatArray(F32, 1, values, doubleArray(arena, 0.0));

assertThatThrownBy(() -> sut.materialize(arena))
.isInstanceOf(VortexException.class).hasMessageContaining("invalid codes");
assertThatThrownBy(() -> sut.fold(0.0, Double::sum))
.isInstanceOf(VortexException.class);
}
}
}

@Nested
Expand Down