2424import io .github .jbellis .jvector .graph .disk .NeighborsScoreCache ;
2525import io .github .jbellis .jvector .graph .disk .OnDiskGraphIndex ;
2626import io .github .jbellis .jvector .graph .similarity .BuildScoreProvider ;
27+ import io .github .jbellis .jvector .graph .similarity .SearchScoreProvider ;
2728import io .github .jbellis .jvector .util .Bits ;
2829import io .github .jbellis .jvector .vector .VectorSimilarityFunction ;
2930import io .github .jbellis .jvector .vector .VectorizationProvider ;
3031import io .github .jbellis .jvector .vector .types .VectorFloat ;
3132import io .github .jbellis .jvector .vector .types .VectorTypeSupport ;
3233import org .apache .logging .log4j .Logger ;
3334import org .junit .After ;
35+ import org .junit .Assert ;
3436import org .junit .Before ;
3537import org .junit .Test ;
3638
5052public class OnHeapGraphIndexTest extends RandomizedTest {
5153 private final static Logger log = org .apache .logging .log4j .LogManager .getLogger (OnHeapGraphIndexTest .class );
5254 private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider .getInstance ().getVectorTypeSupport ();
53- private static final int numBaseVectors = 100 ;
54- private static final int numNewVectors = 100 ;
55- private static final int numAllVectors = numBaseVectors + numNewVectors ;
56- private static final int dimension = 16 ;
55+ private static final int NUM_BASE_VECTORS = 100 ;
56+ private static final int NUM_NEW_VECTORS = 100 ;
57+ private static final int NUM_ALL_VECTORS = NUM_BASE_VECTORS + NUM_NEW_VECTORS ;
58+ private static final int DIMENSION = 16 ;
5759 private static final int M = 8 ;
58- private static final int beamWidth = 100 ;
59- private static final float alpha = 1.2f ;
60- private static final float neighborOverflow = 1.2f ;
61- private static final boolean addHierarchy = false ;
60+ private static final int BEAM_WIDTH = 100 ;
61+ private static final float ALPHA = 1.2f ;
62+ private static final float NEIGHBOR_OVERFLOW = 1.2f ;
63+ private static final boolean ADD_HIERARCHY = false ;
64+ private static final int TOP_K = 10 ;
6265
6366 private Path testDirectory ;
6467
@@ -68,6 +71,9 @@ public class OnHeapGraphIndexTest extends RandomizedTest {
6871 private RandomAccessVectorValues baseVectorsRavv ;
6972 private RandomAccessVectorValues newVectorsRavv ;
7073 private RandomAccessVectorValues allVectorsRavv ;
74+ private VectorFloat <?> queryVector ;
75+ private int [] groundTruthBaseVectors ;
76+ private int [] groundTruthAllVectors ;
7177 private BuildScoreProvider baseBuildScoreProvider ;
7278 private BuildScoreProvider newBuildScoreProvider ;
7379 private BuildScoreProvider allBuildScoreProvider ;
@@ -78,42 +84,47 @@ public class OnHeapGraphIndexTest extends RandomizedTest {
7884 @ Before
7985 public void setup () throws IOException {
8086 testDirectory = Files .createTempDirectory (this .getClass ().getSimpleName ());
81- baseVectors = new ArrayList <>(numBaseVectors );
82- newVectors = new ArrayList <>(numNewVectors );
83- allVectors = new ArrayList <>(numAllVectors );
84- for (int i = 0 ; i < numBaseVectors ; i ++) {
85- VectorFloat <?> vector = createRandomVector (dimension );
87+ baseVectors = new ArrayList <>(NUM_BASE_VECTORS );
88+ newVectors = new ArrayList <>(NUM_NEW_VECTORS );
89+ allVectors = new ArrayList <>(NUM_ALL_VECTORS );
90+ for (int i = 0 ; i < NUM_BASE_VECTORS ; i ++) {
91+ VectorFloat <?> vector = createRandomVector (DIMENSION );
8692 baseVectors .add (vector );
8793 allVectors .add (vector );
8894 }
89- for (int i = 0 ; i < numNewVectors ; i ++) {
90- VectorFloat <?> vector = createRandomVector (dimension );
95+ for (int i = 0 ; i < NUM_NEW_VECTORS ; i ++) {
96+ VectorFloat <?> vector = createRandomVector (DIMENSION );
9197 newVectors .add (vector );
9298 allVectors .add (vector );
9399 }
94100
95101 // wrap the raw vectors in a RandomAccessVectorValues
96- baseVectorsRavv = new ListRandomAccessVectorValues (baseVectors , dimension );
97- newVectorsRavv = new ListRandomAccessVectorValues (newVectors , dimension );
98- allVectorsRavv = new ListRandomAccessVectorValues (allVectors , dimension );
102+ baseVectorsRavv = new ListRandomAccessVectorValues (baseVectors , DIMENSION );
103+ newVectorsRavv = new ListRandomAccessVectorValues (newVectors , DIMENSION );
104+ allVectorsRavv = new ListRandomAccessVectorValues (allVectors , DIMENSION );
99105
106+ queryVector = createRandomVector (DIMENSION );
107+ groundTruthBaseVectors = getGroundTruth (baseVectorsRavv , queryVector , TOP_K , VectorSimilarityFunction .EUCLIDEAN );
108+ groundTruthAllVectors = getGroundTruth (allVectorsRavv , queryVector , TOP_K , VectorSimilarityFunction .EUCLIDEAN );
109+
110+ // score provider using the raw, in-memory vectors
100111 baseBuildScoreProvider = BuildScoreProvider .randomAccessScoreProvider (baseVectorsRavv , VectorSimilarityFunction .EUCLIDEAN );
101112 newBuildScoreProvider = BuildScoreProvider .randomAccessScoreProvider (newVectorsRavv , VectorSimilarityFunction .EUCLIDEAN );
102113 allBuildScoreProvider = BuildScoreProvider .randomAccessScoreProvider (allVectorsRavv , VectorSimilarityFunction .EUCLIDEAN );
103114 var baseGraphIndexBuilder = new GraphIndexBuilder (baseBuildScoreProvider ,
104115 baseVectorsRavv .dimension (),
105116 M , // graph degree
106- beamWidth , // construction search depth
107- neighborOverflow , // allow degree overflow during construction by this factor
108- alpha , // relax neighbor diversity requirement by this factor
109- addHierarchy ); // add the hierarchy
117+ BEAM_WIDTH , // construction search depth
118+ NEIGHBOR_OVERFLOW , // allow degree overflow during construction by this factor
119+ ALPHA , // relax neighbor diversity requirement by this factor
120+ ADD_HIERARCHY ); // add the hierarchy
110121 var allGraphIndexBuilder = new GraphIndexBuilder (allBuildScoreProvider ,
111122 allVectorsRavv .dimension (),
112123 M , // graph degree
113- beamWidth , // construction search depth
114- neighborOverflow , // allow degree overflow during construction by this factor
115- alpha , // relax neighbor diversity requirement by this factor
116- addHierarchy ); // add the hierarchy
124+ BEAM_WIDTH , // construction search depth
125+ NEIGHBOR_OVERFLOW , // allow degree overflow during construction by this factor
126+ ALPHA , // relax neighbor diversity requirement by this factor
127+ ADD_HIERARCHY ); // add the hierarchy
117128
118129 baseGraphIndex = baseGraphIndexBuilder .build (baseVectorsRavv );
119130 allGraphIndex = allGraphIndexBuilder .build (allVectorsRavv );
@@ -156,7 +167,7 @@ public void testReconstructionOfOnHeapGraphIndex() throws IOException {
156167 validateVectors (onDiskView , baseVectorsRavv );
157168 }
158169
159- OnHeapGraphIndex reconstructedOnHeapGraphIndex = OnHeapGraphIndex .convertToHeap (onDiskGraph , neighborsScoreCacheRead , baseBuildScoreProvider , neighborOverflow , alpha );
170+ OnHeapGraphIndex reconstructedOnHeapGraphIndex = OnHeapGraphIndex .convertToHeap (onDiskGraph , neighborsScoreCacheRead , baseBuildScoreProvider , NEIGHBOR_OVERFLOW , ALPHA );
160171 TestUtil .assertGraphEquals (baseGraphIndex , reconstructedOnHeapGraphIndex );
161172 TestUtil .assertGraphEquals (onDiskGraph , reconstructedOnHeapGraphIndex );
162173
@@ -178,12 +189,18 @@ public void testIncrementalInsertionFromOnDiskIndex() throws IOException {
178189 TestUtil .assertGraphEquals (baseGraphIndex , onDiskGraph );
179190 // We will create a trivial 1:1 mapping between the new graph and the ravv
180191 final int [] graphToRavvOrdMap = IntStream .range (0 , allVectorsRavv .size ()).toArray ();
181- OnHeapGraphIndex reconstructedAllNodeOnHeapGraphIndex = GraphIndexBuilder .buildAndMergeNewNodes (onDiskGraph , neighborsScoreCache , allVectorsRavv , allBuildScoreProvider , numBaseVectors , graphToRavvOrdMap , beamWidth , neighborOverflow , alpha , addHierarchy );
192+ OnHeapGraphIndex reconstructedAllNodeOnHeapGraphIndex = GraphIndexBuilder .buildAndMergeNewNodes (onDiskGraph , neighborsScoreCache , allVectorsRavv , allBuildScoreProvider , NUM_BASE_VECTORS , graphToRavvOrdMap , BEAM_WIDTH , NEIGHBOR_OVERFLOW , ALPHA , ADD_HIERARCHY );
193+
194+ // Verify that the recall is similar
195+ float recallFromReconstructedAllNodeOnHeapGraphIndex = calculateRecall (reconstructedAllNodeOnHeapGraphIndex , allBuildScoreProvider , queryVector , groundTruthAllVectors , TOP_K );
196+ float recallFromAllGraphIndex = calculateRecall (allGraphIndex , allBuildScoreProvider , queryVector , groundTruthAllVectors , TOP_K );
197+ Assert .assertEquals (recallFromReconstructedAllNodeOnHeapGraphIndex , recallFromAllGraphIndex , 0.01f );
182198
199+ // Verify that the result sets overlap
183200 try (GraphSearcher reconstructedAllGraphSearcher = new GraphSearcher (reconstructedAllNodeOnHeapGraphIndex );
184201 GraphSearcher allGraphSearcher = new GraphSearcher (allGraphIndex )) {
185- final int topK = 10 ;
186- VectorFloat <?> queryVector = createRandomVector (dimension );
202+ final int topK = TOP_K ;
203+ VectorFloat <?> queryVector = createRandomVector (DIMENSION );
187204 var resultFromReconstructed = reconstructedAllGraphSearcher .search (allBuildScoreProvider .searchProviderFor (queryVector ), topK , Bits .ALL );
188205 var resultFromAll = allGraphSearcher .search (allBuildScoreProvider .searchProviderFor (queryVector ), topK , Bits .ALL );
189206 log .info ("Reconstructed result: {}, all result: {}" , resultFromReconstructed , resultFromAll );
@@ -210,4 +227,54 @@ private VectorFloat<?> createRandomVector(int dimension) {
210227 }
211228 return vector ;
212229 }
230+
231+ /**
232+ * Get the ground truth for a query vector
233+ * @param ravv the vectors to search
234+ * @param queryVector the query vector
235+ * @param topK the number of results to return
236+ * @param similarityFunction the similarity function to use
237+
238+ * @return the ground truth
239+ */
240+ private static int [] getGroundTruth (RandomAccessVectorValues ravv , VectorFloat <?> queryVector , int topK , VectorSimilarityFunction similarityFunction ) {
241+ var exactResults = new ArrayList <SearchResult .NodeScore >();
242+ for (int i = 0 ; i < ravv .size (); i ++) {
243+ float similarityScore = similarityFunction .compare (queryVector , ravv .getVector (i ));
244+ exactResults .add (new SearchResult .NodeScore (i , similarityScore ));
245+ }
246+ exactResults .sort ((a , b ) -> Float .compare (b .score , a .score ));
247+ return exactResults .stream ().limit (topK ).mapToInt (nodeScore -> nodeScore .node ).toArray ();
248+ }
249+
250+ private static float calculateRecall (OnHeapGraphIndex graphIndex , BuildScoreProvider buildScoreProvider , VectorFloat <?> queryVector , int [] groundTruth , int k ) throws IOException {
251+ try (GraphSearcher graphSearcher = new GraphSearcher (graphIndex )){
252+ SearchScoreProvider ssp = buildScoreProvider .searchProviderFor (queryVector );
253+ var searchResults = graphSearcher .search (ssp , k , Bits .ALL );
254+ var predicted = Arrays .stream (searchResults .getNodes ()).mapToInt (nodeScore -> nodeScore .node ).boxed ().collect (Collectors .toSet ());
255+ return calculateRecall (predicted , groundTruth , k );
256+ }
257+ }
258+ /**
259+ * Calculate the recall for a set of predicted results
260+ * @param predicted the predicted results
261+ * @param groundTruth the ground truth
262+ * @param k the number of results to consider
263+ * @return the recall
264+ */
265+ private static float calculateRecall (Set <Integer > predicted , int [] groundTruth , int k ) {
266+ int hits = 0 ;
267+ int actualK = Math .min (k , Math .min (predicted .size (), groundTruth .length ));
268+
269+ for (int i = 0 ; i < actualK ; i ++) {
270+ for (int j = 0 ; j < actualK ; j ++) {
271+ if (predicted .contains (groundTruth [j ])) {
272+ hits ++;
273+ break ;
274+ }
275+ }
276+ }
277+
278+ return ((float ) hits ) / (float ) actualK ;
279+ }
213280}
0 commit comments