diff --git a/include/simsimd/probability.h b/include/simsimd/probability.h index 2865aa32..ca10b782 100644 --- a/include/simsimd/probability.h +++ b/include/simsimd/probability.h @@ -108,7 +108,7 @@ SIMSIMD_PUBLIC void simsimd_js_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_ d += ai * SIMSIMD_LOG((ai + epsilon) / (mi + epsilon)); \ d += bi * SIMSIMD_LOG((bi + epsilon) / (mi + epsilon)); \ } \ - *result = (simsimd_distance_t)d / 2; \ + *result = SIMSIMD_SQRT(((simsimd_distance_t)d / 2)); \ } SIMSIMD_MAKE_KL(serial, f64, f64, SIMSIMD_DEREFERENCE, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_kl_f64_serial @@ -219,12 +219,13 @@ SIMSIMD_PUBLIC void simsimd_js_f32_neon(simsimd_f32_t const *a, simsimd_f32_t co float32x4_t log_ratio_b_vec = _simsimd_log2_f32_neon(ratio_b_vec); float32x4_t prod_a_vec = vmulq_f32(a_vec, log_ratio_a_vec); float32x4_t prod_b_vec = vmulq_f32(b_vec, log_ratio_b_vec); + sum_vec = vaddq_f32(sum_vec, vaddq_f32(prod_a_vec, prod_b_vec)); if (n != 0) goto simsimd_js_f32_neon_cycle; simsimd_f32_t log2_normalizer = 0.693147181f; - simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer; - *result = sum / 2; + simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer / 2; + *result = SIMSIMD_SQRT(sum); } #pragma clang attribute pop @@ -296,8 +297,8 @@ SIMSIMD_PUBLIC void simsimd_js_f16_neon(simsimd_f16_t const *a, simsimd_f16_t co if (n) goto simsimd_js_f16_neon_cycle; simsimd_f32_t log2_normalizer = 0.693147181f; - simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer; - *result = sum / 2; + simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer / 2; + *result = SIMSIMD_SQRT(sum); } #pragma clang attribute pop diff --git a/scripts/test.mjs b/scripts/test.mjs index 24d9baea..8bd0842b 100644 --- a/scripts/test.mjs +++ b/scripts/test.mjs @@ -172,8 +172,15 @@ test("Kullback-Leibler C vs JS", () => { }); test("Jensen-Shannon C vs JS", () => { - const f32sDistribution = new Float32Array([1.0 / 6, 2.0 / 6, 3.0 / 6]); - const result = simsimd.jensenshannon(f32sDistribution, f32sDistribution); - const resultjs = fallback.jensenshannon(f32sDistribution, f32sDistribution); - assertAlmostEqual(resultjs, result, 0.01); + const f32sDistribution = new Float32Array([1.0, 0.0]); + const f32sDistribution2 = new Float32Array([0.5, 0.5]); + const result = simsimd.jensenshannon(f32sDistribution, f32sDistribution2); + const resultjs = fallback.jensenshannon(f32sDistribution, f32sDistribution2); + assertAlmostEqual(result, resultjs, 0.01); + + const orthogonalVec1 = new Float32Array([1.0, 0.0, 0.0]); + const orthogonalVec2 = new Float32Array([0.0, 1.0, 0.0]); + const orthoResult = simsimd.jensenshannon(orthogonalVec1, orthogonalVec2); + const orthoResultJs = fallback.jensenshannon(orthogonalVec1, orthogonalVec2); + assertAlmostEqual(orthoResult, orthoResultJs, 0.01); }); diff --git a/scripts/test.py b/scripts/test.py index 7408d71e..4e8781aa 100644 --- a/scripts/test.py +++ b/scripts/test.py @@ -38,7 +38,6 @@ python test.py """ - import os import math import time @@ -124,7 +123,7 @@ def baseline_wsum(x, y, alpha, beta): baseline_euclidean = lambda x, y: np.array(spd.euclidean(x, y)) #! SciPy returns a scalar baseline_sqeuclidean = spd.sqeuclidean baseline_cosine = spd.cosine - baseline_jensenshannon = lambda x, y: spd.jensenshannon(x, y) ** 2 + baseline_jensenshannon = lambda x, y: spd.jensenshannon(x, y) baseline_hamming = lambda x, y: spd.hamming(x, y) * len(x) baseline_jaccard = spd.jaccard @@ -453,6 +452,8 @@ def name_to_kernels(name: str): return baseline_fma, simd.fma elif name == "wsum": return baseline_wsum, simd.wsum + elif name == "jensenshannon": + return baseline_jensenshannon, simd.jensenshannon else: raise ValueError(f"Unknown kernel name: {name}") @@ -839,12 +840,13 @@ def test_dense_bits(ndim, metric, capability, stats_fixture): collect_errors(metric, ndim, "bin8", accurate, accurate_dt, expected, expected_dt, result, result_dt, stats_fixture) -@pytest.mark.skip(reason="Problems inferring the tolerance bounds for numerical errors") +@pytest.mark.skipif(not numpy_available, reason="NumPy is not installed") +@pytest.mark.skipif(not scipy_available, reason="SciPy is not installed") @pytest.mark.repeat(50) @pytest.mark.parametrize("ndim", [11, 97, 1536]) @pytest.mark.parametrize("dtype", ["float32", "float16"]) @pytest.mark.parametrize("capability", possible_capabilities) -def test_jensen_shannon(ndim, dtype, capability): +def test_jensen_shannon(ndim, dtype, capability, stats_fixture): """Compares the simd.jensenshannon() function with scipy.spatial.distance.jensenshannon(), measuring the accuracy error for f16, and f32 types.""" np.random.seed() a = np.abs(np.random.randn(ndim)).astype(dtype)