Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion tests/test_jit_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from itertools import combinations
from math import comb

from minterpy.global_settings import INT_DTYPE
from minterpy.jit_compiled.common import (
n_choose_r,
combinations_iter,
Expand Down Expand Up @@ -74,7 +75,12 @@ def test_get_max_columnwise():
"""Test getting the column-wise max of a two-dimensional integer array."""
num_rows = np.random.randint(low=100, high=1000)
num_cols = np.random.randint(low=1, high=10)
xx = np.random.randint(low=0, high=100, size=(num_rows, num_cols))
xx = np.random.randint(
low=0,
high=100,
size=(num_rows, num_cols),
dtype=INT_DTYPE,
)

# Maximum by NumPy
max_ref = np.max(xx, axis=0)
Expand Down
16 changes: 11 additions & 5 deletions tests/test_multi_index_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from numpy.testing import assert_, assert_equal, assert_raises

from minterpy.global_settings import INT_DTYPE
from minterpy.utils.arrays import expand_dim
from minterpy.utils.multi_index import (
find_match_between,
Expand Down Expand Up @@ -135,9 +136,14 @@ def test_lex_smaller_or_equal(SpatialDimension, PolyDegree):
"""Test lexicographically comparing two different multi-index elements."""
# Create a random multi-indices
if PolyDegree == 0:
indices_1 = np.zeros(SpatialDimension, dtype=int)
indices_1 = np.zeros(SpatialDimension, dtype=INT_DTYPE)
else:
indices_1 = np.random.randint(0, PolyDegree, SpatialDimension)
indices_1 = np.random.randint(
low=0,
high=PolyDegree,
size=SpatialDimension,
dtype=INT_DTYPE,
)

# Assertion: Equal multi-indices
assert is_lex_smaller_or_equal(indices_1, indices_1)
Expand All @@ -162,7 +168,7 @@ def test_is_lex_sorted(SpatialDimension, PolyDegree, LpDegree):

def test_is_lex_sorted_single():
"""Test if a single entry multi-index set is lexicographically sorted."""
index = np.random.randint(1, 5, (1, 10))
index = np.random.randint(1, 5, (1, 10), dtype=INT_DTYPE)

# Assertion: Always lexicographical
assert is_lex_sorted(index)
Expand All @@ -171,15 +177,15 @@ def test_is_lex_sorted_single():
def test_is_lex_sorted_random():
"""Test if a random integer array is lexicographically sorted."""
# Generate randomly, big enough so there's no chance it will be sorted
indices = np.random.randint(1, 5, (5, 12))
indices = np.random.randint(1, 5, (5, 12), dtype=INT_DTYPE)

# Assertion - Not lexicographical
assert not is_lex_sorted(indices)


def test_is_lex_sorted_duplicates():
"""Test if a multi-index set with duplicate entries is lexicographical."""
indices = np.array([[0, 0], [2, 0], [2, 0]])
indices = np.array([[0, 0], [2, 0], [2, 0]], dtype=INT_DTYPE)

# Assertion: Not lexicographical
assert not is_lex_sorted(indices)
Expand Down