5353import  java .util .concurrent .Callable ;
5454import  java .util .stream .Collectors ;
5555
56+ import  static  org .opensearch .knn .profile .StopWatchUtils .startStopWatch ;
57+ import  static  org .opensearch .knn .profile .StopWatchUtils .stopStopWatchAndLog ;
58+ 
5659/** 
5760 * {@link KNNQuery} executes approximate nearest neighbor search (ANN) on a segment level. 
5861 * {@link NativeEngineKnnVectorQuery} executes approximate nearest neighbor search but gives 
@@ -274,13 +277,19 @@ private List<PerLeafResult> doSearch(
274277
275278        // For memory optimized search, it should kick off 2nd search if optimistic 
276279        if  (knnQuery .isMemoryOptimizedSearch () && perLeafResults .size () > 1 ) {
277-             run2ndOptimisticSearch (perLeafResults , knnWeight , leafReaderContexts , k , indexSearcher );
280+             log .debug (
281+                 "Running second deep dive search in optimistic while memory optimized search is enabled. perLeafResults.size()={}" ,
282+                 perLeafResults .size ()
283+             );
284+             final  StopWatch  stopWatch  = startStopWatch (log );
285+             reentrantSearch (perLeafResults , knnWeight , leafReaderContexts , k , indexSearcher );
286+             stopStopWatchAndLog (log , stopWatch , "2ndOptimisticSearch" , knnQuery .getShardId (), "All Shards" , knnQuery .getField ());
278287        }
279288
280289        return  perLeafResults ;
281290    }
282291
283-     private  void  run2ndOptimisticSearch (
292+     private  void  reentrantSearch (
284293        final  List <PerLeafResult > perLeafResults ,
285294        final  KNNWeight  knnWeight ,
286295        final  List <LeafReaderContext > leafReaderContexts ,
@@ -299,7 +308,7 @@ private void run2ndOptimisticSearch(
299308
300309        assert  (perLeafResults .size () == leafReaderContexts .size ());
301310
302-         // Get collector manager  first 
311+         // Get memory optimized knn weight  first, it's safe get it, we checked it already.  
303312        final  MemoryOptimizedKNNWeight  memoryOptKNNWeight  = (MemoryOptimizedKNNWeight ) knnWeight ;
304313
305314        // How many results have we collected? 
@@ -313,25 +322,24 @@ private void run2ndOptimisticSearch(
313322            return ;
314323        }
315324
316-         // Build segment to results table 
317-         final  Map <Integer , TopDocs > segmentOrdToResults  = new  HashMap <>(leafReaderContexts .size ());
318-         for  (int  i  = 0 ; i  < perLeafResults .size (); i ++) {
319-             segmentOrdToResults .put (leafReaderContexts .get (i ).ord , perLeafResults .get (i ).getResult ());
320-         }
321- 
322325        // Start 2nd deep dive, and get the minimum bar. 
323326        final  float  minTopKScore  = OptimisticSearchStrategyUtils .findKthLargestScore (perLeafResults , knnQuery .getK (), totalResults );
324327
325328        // Select candidate segments for 2nd search. Pick whatever segment returned all vectors whose score values are greater than `kth` 
326329        // value in the merged results. 
327330        final  List <Callable <TopDocs >> secondDeepDiveTasks  = new  ArrayList <>();
328331        final  List <Integer > contextIndices  = new  ArrayList <>();
332+         final  Map <Integer , TopDocs > segmentOrdToResults  = new  HashMap <>();
333+ 
329334        for  (int  i  = 0 ; i  < leafReaderContexts .size (); ++i ) {
330335            final  LeafReaderContext  leafReaderContext  = leafReaderContexts .get (i );
331336            final  PerLeafResult  perLeafResult  = perLeafResults .get (i );
332-             final  TopDocs  perLeaf  = segmentOrdToResults .get (leafReaderContext . ord );
337+             final  TopDocs  perLeaf  = perLeafResults .get (i ). getResult ( );
333338            if  (perLeaf .scoreDocs .length  > 0  && perLeafResult .getSearchMode () == PerLeafResult .SearchMode .APPROXIMATE_SEARCH ) {
334339                if  (FORCE_REENTER_TESTING  || perLeaf .scoreDocs [perLeaf .scoreDocs .length  - 1 ].score  >= minTopKScore ) {
340+                     // For the target segment, save top results. Which will be used as seeds. 
341+                     segmentOrdToResults .put (leafReaderContext .ord , perLeaf );
342+ 
335343                    // All this leaf's hits are at or above the global topK min score; explore it further 
336344                    secondDeepDiveTasks .add (
337345                        () -> knnWeight .approximateSearch (
@@ -348,24 +356,24 @@ private void run2ndOptimisticSearch(
348356
349357        // Kick off 2nd search tasks 
350358        if  (secondDeepDiveTasks .isEmpty () == false ) {
351-             final  ReentrantKnnCollectorManager  knnCollectorManagerPhase2  = new  ReentrantKnnCollectorManager (
359+             final  ReentrantKnnCollectorManager  reentrantCollectorManager  = new  ReentrantKnnCollectorManager (
352360                new  TopKnnCollectorManager (k , indexSearcher ),
353361                segmentOrdToResults ,
354362                knnQuery .getQueryVector (),
355363                knnQuery .getField ()
356364            );
357365
358366            // Make weight use reentrant collector manager 
359-             memoryOptKNNWeight .setOptimistic2ndKnnCollectorManager ( knnCollectorManagerPhase2 );
367+             memoryOptKNNWeight .setReentrantKNNCollectorManager ( reentrantCollectorManager );
360368
361369            final  List <TopDocs > deepDiveTopDocs  = indexSearcher .getTaskExecutor ().invokeAll (secondDeepDiveTasks );
362370
363371            // Override results for target context 
364372            for  (int  i  = 0 ; i  < deepDiveTopDocs .size (); ++i ) {
365373                // Override with the new results 
366-                 final  TopDocs  resultsFrom2ncDeepDive  = deepDiveTopDocs .get (i );
374+                 final  TopDocs  resultsFromDeepDive  = deepDiveTopDocs .get (i );
367375                final  PerLeafResult  perLeafResult  = perLeafResults .get (contextIndices .get (i ));
368-                 perLeafResult .setResult (resultsFrom2ncDeepDive );
376+                 perLeafResult .setResult (resultsFromDeepDive );
369377            }
370378        }
371379    }
0 commit comments