diff --git a/CHANGELOG.md b/CHANGELOG.md index 324bf120d3..de33b173d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,3 +13,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Asymmetric Distance Computation for binary quantized faiss indices [#2733](https://github.com/opensearch-project/k-NN/pull/2733) * [BUGFIX] [Remote Vector Index Build] Don't fall back to CPU on terminal failures [#2773](https://github.com/opensearch-project/k-NN/pull/2773) * Add KNN timing info to core profiler [#2785](https://github.com/opensearch-project/k-NN/pull/2785) +* Add KNN timing info for lucene queries [#2802](https://github.com/opensearch-project/k-NN/pull/2802) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index 1b7e0f9e5f..b63cc6795c 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -7,14 +7,14 @@ import lombok.NonNull; import lombok.extern.log4j.Log4j2; -import org.apache.lucene.search.KnnByteVectorQuery; -import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.query.common.QueryUtils; +import org.opensearch.knn.index.query.lucene.ProfileKnnByteVectorQuery; +import org.opensearch.knn.index.query.lucene.ProfileKnnFloatVectorQuery; import org.opensearch.knn.index.query.lucenelib.NestedKnnVectorQueryFactory; import org.opensearch.knn.index.query.lucene.LuceneEngineKnnVectorQuery; import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery; @@ -179,8 +179,8 @@ private static Query getKnnVectorQuery( if (parentFilter == null) { assert expandNested == false : "expandNested is allowed to be true only for nested fields."; return vectorDataType == VectorDataType.FLOAT - ? new KnnFloatVectorQuery(fieldName, floatQueryVector, k, filterQuery) - : new KnnByteVectorQuery(fieldName, byteQueryVector, k, filterQuery); + ? new ProfileKnnFloatVectorQuery(fieldName, floatQueryVector, k, filterQuery) + : new ProfileKnnByteVectorQuery(fieldName, byteQueryVector, k, filterQuery); } // If parentFilter is not null, it is a nested query. Therefore, we delegate creation of query to {@link // NestedKnnVectorQueryFactory} diff --git a/src/main/java/org/opensearch/knn/index/query/common/QueryUtils.java b/src/main/java/org/opensearch/knn/index/query/common/QueryUtils.java index 931b1e334f..7b02927058 100644 --- a/src/main/java/org/opensearch/knn/index/query/common/QueryUtils.java +++ b/src/main/java/org/opensearch/knn/index/query/common/QueryUtils.java @@ -20,6 +20,9 @@ import org.apache.lucene.util.Bits; import org.opensearch.knn.index.query.KNNWeight; import org.opensearch.knn.index.query.iterators.GroupedNestedDocIdSetIterator; +import org.opensearch.knn.profile.KNNProfileUtil; +import org.opensearch.knn.profile.query.KNNQueryTimingType; +import org.opensearch.search.profile.ContextualProfileBreakdown; import java.io.IOException; import java.util.ArrayList; @@ -89,6 +92,35 @@ private int[] findSegmentStarts(final IndexReader reader, final int[] docs) { return starts; } + /** + * Performs the search in parallel with profiling. + * + * @param indexSearcher the index searcher + * @param leafReaderContexts the leaf reader contexts + * @param weight the search weight + * @return a list of maps, each mapping document IDs to their scores + * @throws IOException + */ + public List> doSearch( + final IndexSearcher indexSearcher, + final List leafReaderContexts, + final Weight weight, + ContextualProfileBreakdown profile + ) throws IOException { + List>> tasks = new ArrayList<>(leafReaderContexts.size()); + for (LeafReaderContext leafReaderContext : leafReaderContexts) { + tasks.add( + () -> (Map) KNNProfileUtil.profileBreakdown( + profile, + leafReaderContext, + KNNQueryTimingType.ANN_SEARCH, + () -> searchLeaf(leafReaderContext, weight) + ) + ); + } + return indexSearcher.getTaskExecutor().invokeAll(tasks); + } + /** * Performs the search in parallel. * diff --git a/src/main/java/org/opensearch/knn/index/query/lucene/ProfileDiversifyingChildrenByteKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/lucene/ProfileDiversifyingChildrenByteKnnVectorQuery.java new file mode 100644 index 0000000000..046243ab1a --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/lucene/ProfileDiversifyingChildrenByteKnnVectorQuery.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.lucene; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.util.Bits; +import org.opensearch.knn.profile.KNNProfileUtil; +import org.opensearch.knn.profile.query.KNNQueryTimingType; +import org.opensearch.search.profile.query.QueryProfiler; + +import java.io.IOException; + +/** + * Wrapper class used for profiling {@link DiversifyingChildrenByteKnnVectorQuery} + */ +public class ProfileDiversifyingChildrenByteKnnVectorQuery extends DiversifyingChildrenByteKnnVectorQuery { + + private QueryProfiler profiler; + + public ProfileDiversifyingChildrenByteKnnVectorQuery( + String field, + byte[] target, + Query childFilter, + int k, + BitSetProducer parentsFilter + ) { + super(field, target, childFilter, k, parentsFilter); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + profiler = KNNProfileUtil.getProfiler(indexSearcher); + return super.rewrite(indexSearcher); + } + + @Override + protected TopDocs approximateSearch( + LeafReaderContext context, + Bits acceptDocs, + int visitedLimit, + KnnCollectorManager knnCollectorManager + ) throws IOException { + return (TopDocs) KNNProfileUtil.profile(profiler, this, context, KNNQueryTimingType.ANN_SEARCH, () -> { + try { + return super.approximateSearch(context, acceptDocs, visitedLimit, knnCollectorManager); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Override + protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) + throws IOException { + return (TopDocs) KNNProfileUtil.profile(profiler, this, context, KNNQueryTimingType.EXACT_SEARCH, () -> { + try { + return super.exactSearch(context, acceptIterator, queryTimeout); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/lucene/ProfileDiversifyingChildrenFloatKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/lucene/ProfileDiversifyingChildrenFloatKnnVectorQuery.java new file mode 100644 index 0000000000..a67c9fdea1 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/lucene/ProfileDiversifyingChildrenFloatKnnVectorQuery.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.lucene; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.search.*; +import org.apache.lucene.search.join.BitSetProducer; +import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.util.Bits; +import org.opensearch.knn.profile.KNNProfileUtil; +import org.opensearch.knn.profile.query.KNNQueryTimingType; +import org.opensearch.search.profile.query.QueryProfiler; + +import java.io.IOException; + +/** + * Wrapper class used for profiling {@link DiversifyingChildrenFloatKnnVectorQuery} + */ +public class ProfileDiversifyingChildrenFloatKnnVectorQuery extends DiversifyingChildrenFloatKnnVectorQuery { + + private QueryProfiler profiler; + + public ProfileDiversifyingChildrenFloatKnnVectorQuery( + String field, + float[] target, + Query childFilter, + int k, + BitSetProducer parentsFilter + ) { + super(field, target, childFilter, k, parentsFilter); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + profiler = KNNProfileUtil.getProfiler(indexSearcher); + return super.rewrite(indexSearcher); + } + + @Override + protected TopDocs approximateSearch( + LeafReaderContext context, + Bits acceptDocs, + int visitedLimit, + KnnCollectorManager knnCollectorManager + ) throws IOException { + return (TopDocs) KNNProfileUtil.profile(profiler, this, context, KNNQueryTimingType.ANN_SEARCH, () -> { + try { + return super.approximateSearch(context, acceptDocs, visitedLimit, knnCollectorManager); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Override + protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) + throws IOException { + return (TopDocs) KNNProfileUtil.profile(profiler, this, context, KNNQueryTimingType.EXACT_SEARCH, () -> { + try { + return super.exactSearch(context, acceptIterator, queryTimeout); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/lucene/ProfileKnnByteVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/lucene/ProfileKnnByteVectorQuery.java new file mode 100644 index 0000000000..d191ec832c --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/lucene/ProfileKnnByteVectorQuery.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.lucene; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.search.*; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.util.Bits; +import org.opensearch.knn.profile.KNNProfileUtil; +import org.opensearch.knn.profile.query.KNNQueryTimingType; +import org.opensearch.search.profile.query.QueryProfiler; + +import java.io.IOException; + +/** + * Wrapper class used for profiling {@link KnnByteVectorQuery} + */ +public class ProfileKnnByteVectorQuery extends KnnByteVectorQuery { + + private QueryProfiler profiler; + + public ProfileKnnByteVectorQuery(String field, byte[] target, int k, Query filter) { + super(field, target, k, filter); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + profiler = KNNProfileUtil.getProfiler(indexSearcher); + return super.rewrite(indexSearcher); + } + + @Override + protected TopDocs approximateSearch( + LeafReaderContext context, + Bits acceptDocs, + int visitedLimit, + KnnCollectorManager knnCollectorManager + ) throws IOException { + return (TopDocs) KNNProfileUtil.profile(profiler, this, context, KNNQueryTimingType.ANN_SEARCH, () -> { + try { + return super.approximateSearch(context, acceptDocs, visitedLimit, knnCollectorManager); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Override + protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) + throws IOException { + return (TopDocs) KNNProfileUtil.profile(profiler, this, context, KNNQueryTimingType.EXACT_SEARCH, () -> { + try { + return super.exactSearch(context, acceptIterator, queryTimeout); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + +} diff --git a/src/main/java/org/opensearch/knn/index/query/lucene/ProfileKnnFloatVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/lucene/ProfileKnnFloatVectorQuery.java new file mode 100644 index 0000000000..6ef5b26e3c --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/lucene/ProfileKnnFloatVectorQuery.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query.lucene; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.search.*; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.util.Bits; +import org.opensearch.knn.profile.KNNProfileUtil; +import org.opensearch.knn.profile.query.KNNQueryTimingType; +import org.opensearch.search.profile.query.QueryProfiler; + +import java.io.IOException; + +/** + * Wrapper class used for profiling {@link KnnFloatVectorQuery} + */ +public class ProfileKnnFloatVectorQuery extends KnnFloatVectorQuery { + + private QueryProfiler profiler; + + public ProfileKnnFloatVectorQuery(String field, float[] target, int k, Query filter) { + super(field, target, k, filter); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + profiler = KNNProfileUtil.getProfiler(indexSearcher); + return super.rewrite(indexSearcher); + } + + @Override + protected TopDocs approximateSearch( + LeafReaderContext context, + Bits acceptDocs, + int visitedLimit, + KnnCollectorManager knnCollectorManager + ) throws IOException { + return (TopDocs) KNNProfileUtil.profile(profiler, this, context, KNNQueryTimingType.ANN_SEARCH, () -> { + try { + return super.approximateSearch(context, acceptDocs, visitedLimit, knnCollectorManager); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + + @Override + protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) + throws IOException { + return (TopDocs) KNNProfileUtil.profile(profiler, this, context, KNNQueryTimingType.EXACT_SEARCH, () -> { + try { + return super.exactSearch(context, acceptIterator, queryTimeout); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + } + +} diff --git a/src/main/java/org/opensearch/knn/index/query/lucenelib/ExpandNestedDocsQuery.java b/src/main/java/org/opensearch/knn/index/query/lucenelib/ExpandNestedDocsQuery.java index 863fd39ed6..dbab62e678 100644 --- a/src/main/java/org/opensearch/knn/index/query/lucenelib/ExpandNestedDocsQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/lucenelib/ExpandNestedDocsQuery.java @@ -22,6 +22,10 @@ import org.apache.lucene.search.Weight; import org.apache.lucene.util.Bits; import org.opensearch.knn.index.query.common.QueryUtils; +import org.opensearch.knn.profile.KNNProfileUtil; +import org.opensearch.knn.profile.query.KNNQueryTimingType; +import org.opensearch.search.profile.ContextualProfileBreakdown; +import org.opensearch.search.profile.query.QueryProfiler; import java.io.IOException; import java.util.ArrayList; @@ -45,12 +49,29 @@ public class ExpandNestedDocsQuery extends Query { @Override public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + QueryProfiler profiler = KNNProfileUtil.getProfiler(searcher); + if (profiler != null) { + profiler.getQueryBreakdown((Query) internalNestedKnnVectorQuery); + } Query docAndScoreQuery = internalNestedKnnVectorQuery.knnRewrite(searcher); + if (profiler != null) { + profiler.pollLastElement(); // removes internalNested from the stack + profiler.getQueryBreakdown(docAndScoreQuery); + } Weight weight = docAndScoreQuery.createWeight(searcher, scoreMode, boost); + if (profiler != null) { + profiler.pollLastElement(); // removes docAndScoreQuery from stack + } IndexReader reader = searcher.getIndexReader(); List leafReaderContexts = reader.leaves(); List> perLeafResults; - perLeafResults = queryUtils.doSearch(searcher, leafReaderContexts, weight); + ContextualProfileBreakdown profile = null; + if (profiler != null) { + profile = (ContextualProfileBreakdown) profiler.getProfileBreakdown(this); + perLeafResults = queryUtils.doSearch(searcher, leafReaderContexts, weight, profile); + } else { + perLeafResults = queryUtils.doSearch(searcher, leafReaderContexts, weight); + } TopDocs[] topDocs = retrieveAll(searcher, leafReaderContexts, perLeafResults); int sum = 0; for (TopDocs topDoc : topDocs) { @@ -71,11 +92,24 @@ private TopDocs[] retrieveAll( // Construct query List> nestedQueryTasks = new ArrayList<>(leafReaderContexts.size()); Weight filterWeight = getFilterWeight(indexSearcher); + QueryProfiler profiler = KNNProfileUtil.getProfiler(indexSearcher); for (int i = 0; i < perLeafResults.size(); i++) { LeafReaderContext leafReaderContext = leafReaderContexts.get(i); int finalI = i; nestedQueryTasks.add(() -> { - Bits queryFilter = queryUtils.createBits(leafReaderContext, filterWeight); + Bits queryFilter = (Bits) KNNProfileUtil.profile( + profiler, + this, + leafReaderContext, + KNNQueryTimingType.BITSET_CREATION, + () -> { + try { + return queryUtils.createBits(leafReaderContext, filterWeight); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + ); DocIdSetIterator allSiblings = queryUtils.getAllSiblings( leafReaderContext, perLeafResults.get(finalI).keySet(), diff --git a/src/main/java/org/opensearch/knn/index/query/lucenelib/InternalNestedKnnByteVectoryQuery.java b/src/main/java/org/opensearch/knn/index/query/lucenelib/InternalNestedKnnByteVectoryQuery.java index e9d0222323..f8d080c8da 100644 --- a/src/main/java/org/opensearch/knn/index/query/lucenelib/InternalNestedKnnByteVectoryQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/lucenelib/InternalNestedKnnByteVectoryQuery.java @@ -14,6 +14,9 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; +import org.opensearch.knn.profile.KNNProfileUtil; +import org.opensearch.knn.profile.query.KNNQueryTimingType; +import org.opensearch.search.profile.query.QueryProfiler; import java.io.IOException; @@ -29,6 +32,8 @@ public class InternalNestedKnnByteVectoryQuery extends KnnByteVectorQuery implem private final BitSetProducer parentFilter; private final DiversifyingChildrenByteKnnVectorQuery diversifyingChildrenByteKnnVectorQuery; + private QueryProfiler profiler; + public InternalNestedKnnByteVectoryQuery( final String field, final byte[] target, @@ -47,11 +52,18 @@ public InternalNestedKnnByteVectoryQuery( @Override public Query knnRewrite(final IndexSearcher searcher) throws IOException { + profiler = KNNProfileUtil.getProfiler(searcher); return diversifyingChildrenByteKnnVectorQuery.rewrite(searcher); } @Override public TopDocs knnExactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) throws IOException { - return super.exactSearch(context, acceptIterator, null); + return (TopDocs) KNNProfileUtil.profile(profiler, this, context, KNNQueryTimingType.EXACT_SEARCH, () -> { + try { + return super.exactSearch(context, acceptIterator, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); } } diff --git a/src/main/java/org/opensearch/knn/index/query/lucenelib/InternalNestedKnnFloatVectoryQuery.java b/src/main/java/org/opensearch/knn/index/query/lucenelib/InternalNestedKnnFloatVectoryQuery.java index 6e5408bb54..636a234a3b 100644 --- a/src/main/java/org/opensearch/knn/index/query/lucenelib/InternalNestedKnnFloatVectoryQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/lucenelib/InternalNestedKnnFloatVectoryQuery.java @@ -14,6 +14,9 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; +import org.opensearch.knn.profile.KNNProfileUtil; +import org.opensearch.knn.profile.query.KNNQueryTimingType; +import org.opensearch.search.profile.query.QueryProfiler; import java.io.IOException; @@ -29,6 +32,8 @@ public class InternalNestedKnnFloatVectoryQuery extends KnnFloatVectorQuery impl private final BitSetProducer parentFilter; private final DiversifyingChildrenFloatKnnVectorQuery diversifyingChildrenFloatKnnVectorQuery; + private QueryProfiler profiler; + public InternalNestedKnnFloatVectoryQuery( final String field, final float[] target, @@ -47,11 +52,18 @@ public InternalNestedKnnFloatVectoryQuery( @Override public Query knnRewrite(final IndexSearcher searcher) throws IOException { + profiler = KNNProfileUtil.getProfiler(searcher); return diversifyingChildrenFloatKnnVectorQuery.rewrite(searcher); } @Override public TopDocs knnExactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator) throws IOException { - return super.exactSearch(context, acceptIterator, null); + return (TopDocs) KNNProfileUtil.profile(profiler, this, context, KNNQueryTimingType.EXACT_SEARCH, () -> { + try { + return super.exactSearch(context, acceptIterator, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); } } diff --git a/src/main/java/org/opensearch/knn/index/query/lucenelib/NestedKnnVectorQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/lucenelib/NestedKnnVectorQueryFactory.java index e73f689b3d..892c495751 100644 --- a/src/main/java/org/opensearch/knn/index/query/lucenelib/NestedKnnVectorQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/lucenelib/NestedKnnVectorQueryFactory.java @@ -7,9 +7,9 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; -import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; -import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.opensearch.knn.index.query.common.QueryUtils; +import org.opensearch.knn.index.query.lucene.ProfileDiversifyingChildrenByteKnnVectorQuery; +import org.opensearch.knn.index.query.lucene.ProfileDiversifyingChildrenFloatKnnVectorQuery; /** * A class to create a nested knn vector query for lucene @@ -42,7 +42,7 @@ public static Query createNestedKnnVectorQuery( new InternalNestedKnnByteVectoryQuery(fieldName, vector, filterQuery, k, parentFilter) ).queryUtils(QueryUtils.getInstance()).build(); } - return new DiversifyingChildrenByteKnnVectorQuery(fieldName, vector, filterQuery, k, parentFilter); + return new ProfileDiversifyingChildrenByteKnnVectorQuery(fieldName, vector, filterQuery, k, parentFilter); } /** @@ -72,6 +72,6 @@ public static Query createNestedKnnVectorQuery( new InternalNestedKnnFloatVectoryQuery(fieldName, vector, filterQuery, k, parentFilter) ).queryUtils(QueryUtils.getInstance()).build(); } - return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, vector, filterQuery, k, parentFilter); + return new ProfileDiversifyingChildrenFloatKnnVectorQuery(fieldName, vector, filterQuery, k, parentFilter); } } diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index 0e05860fb4..96754d184e 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -6,6 +6,8 @@ package org.opensearch.knn.plugin; import com.google.common.collect.ImmutableList; +import org.apache.lucene.search.KnnByteVectorQuery; +import org.apache.lucene.search.KnnFloatVectorQuery; import org.opensearch.action.ActionRequest; import org.opensearch.cluster.NamedDiff; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; @@ -44,6 +46,7 @@ import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.knn.index.query.KNNWeight; import org.opensearch.knn.index.query.RescoreKNNVectorQuery; +import org.opensearch.knn.index.query.lucenelib.ExpandNestedDocsQuery; import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery; import org.opensearch.knn.index.query.parser.KNNQueryBuilderParser; import org.opensearch.knn.index.util.KNNClusterUtil; @@ -192,7 +195,11 @@ public class KNNPlugin extends Plugin @Override public Optional getQueryProfileMetricsProvider() { return Optional.of((searchContext, query) -> { - if (query instanceof KNNQuery || query instanceof RescoreKNNVectorQuery) { + if (query instanceof KnnByteVectorQuery + || query instanceof KnnFloatVectorQuery + || query instanceof ExpandNestedDocsQuery + || query instanceof RescoreKNNVectorQuery + || query instanceof KNNQuery) { return KNNMetrics.getKNNQueryMetrics(); } else if (query instanceof NativeEngineKnnVectorQuery) { return KNNMetrics.getNativeMetrics(); diff --git a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java index 15bb35e94b..b24c699972 100644 --- a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java @@ -1310,6 +1310,71 @@ public void testKNNSearchWithProfilerEnabled() throws Exception { deleteKNNIndex(INDEX_NAME); } + public void testKNNSearchWithProfilerEnabled_LuceneNested() throws Exception { + int dimension = 3; + String nestedFieldPath = "nested_field.my_vector"; + String mapping = createKnnIndexNestedMapping(dimension, nestedFieldPath, "lucene"); + createKnnIndex(INDEX_NAME, mapping); + + for (int i = 1; i <= 20; ++i) { + Float[] vector = { (float) i, (float) (i + 1), (float) (i + 2) }; + addKnnDocWithNestedField(INDEX_NAME, Integer.toString(i), nestedFieldPath, vector); + } + + int k = 10; // nearest 10 neighbors + + // Create knn search body, all fields + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field("profile", true) + .startObject("query") + .startObject("nested") + .field("path", "nested_field") + .startObject("query") + .startObject("knn") + .startObject("nested_field.my_vector") + .field("vector", new float[] { 2.0f, 2.0f, 2.0f }) + .field("k", k) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + Response response = searchKNNIndex(INDEX_NAME, builder, k); + String responseBody = EntityUtils.toString(response.getEntity()); + List results = parseProfileMetric(responseBody, QueryTimingType.SCORE.toString(), true); + assertEquals(2, results.size()); + + // Create knn search body, all fields + builder = XContentFactory.jsonBuilder() + .startObject() + .field("profile", true) + .startObject("query") + .startObject("nested") + .field("path", "nested_field") + .startObject("query") + .startObject("knn") + .startObject("nested_field.my_vector") + .field("vector", new float[] { 2.0f, 2.0f, 2.0f }) + .field("k", k) + .field("expand_nested_docs", true) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + response = searchKNNIndex(INDEX_NAME, builder, k); + responseBody = EntityUtils.toString(response.getEntity()); + results = parseProfileMetric(responseBody, KNNQueryTimingType.EXACT_SEARCH.toString(), true); + for (Long result : results) { + assertNotEquals(0L, result.longValue()); + } + deleteKNNIndex(INDEX_NAME); + } + public void testKNNSearchWithProfilerEnabled_FaissNested() throws Exception { int dimension = 3; String nestedFieldPath = "nested_field.my_vector"; @@ -1419,6 +1484,60 @@ public void testKNNSearchWithProfilerEnabled_MultipleResults() throws Exception deleteKNNIndex(INDEX_NAME); } + public void testKNNSearchWithProfilerEnabled_LuceneFilter() throws Exception { + int dim = 3; + String mapping = createKnnIndexMapping(FIELD_NAME, dim, "hnsw", "lucene", "l2", false); + createKnnIndex(INDEX_NAME, mapping); + // Add docs with knn_vector fields + for (int i = 1; i <= 20; i++) { + Float[] vector = { (float) i, (float) (i + 1), (float) (i + 2) }; + addKnnDocWithNumericField(INDEX_NAME, Integer.toString(i), FIELD_NAME, vector, "rating", i); + } + float[] query = new float[dim]; + Arrays.fill(query, 2); + + int k = 1; + // Create knn search, P <= k + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field("profile", true) + .startObject("query") + .startObject("knn") + .startObject(FIELD_NAME) + .field("vector", query) + .field("k", k) + .startObject("filter") + .startObject("bool") + .startArray("must") + .startObject() + .startObject("range") + .startObject("rating") + .field("gte", 8) + .field("lte", 14) + .endObject() + .endObject() + .endObject() + .endArray() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + Response response = searchKNNIndex(INDEX_NAME, builder, k); + String responseBody = EntityUtils.toString(response.getEntity()); + List results = parseProfileMetric(responseBody, KNNQueryTimingType.EXACT_SEARCH.toString(), false); + for (Long result : results) { + assertNotEquals(0L, result.longValue()); + } + results = parseProfileMetric(responseBody, KNNQueryTimingType.ANN_SEARCH.toString(), false); + for (Long result : results) { + assertNotEquals(0L, result.longValue()); + } + deleteKNNIndex(INDEX_NAME); + } + public void testKNNSearchWithProfilerEnabled_FaissFilter() throws Exception { int dim = 3; String mapping = createKnnIndexMapping(FIELD_NAME, dim, "hnsw", "faiss", "l2", false); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index eabdc2a822..f3185e688d 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -6,13 +6,9 @@ package org.opensearch.knn.index.query; import org.apache.lucene.index.Term; -import org.apache.lucene.search.KnnByteVectorQuery; -import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.join.BitSetProducer; -import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; -import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; import org.apache.lucene.search.join.ToChildBlockJoinQuery; import org.junit.Before; import org.mockito.Mock; @@ -29,8 +25,8 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.lucene.*; import org.opensearch.knn.index.query.lucenelib.ExpandNestedDocsQuery; -import org.opensearch.knn.index.query.lucene.LuceneEngineKnnVectorQuery; import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery; import org.opensearch.knn.index.query.rescore.RescoreContext; @@ -139,7 +135,7 @@ public void testLuceneFloatVectorQuery() { ); // efsearch > k - Query expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnFloatVectorQuery(testFieldName, testQueryVector, 100, null)); + Query expectedQuery1 = new LuceneEngineKnnVectorQuery(new ProfileKnnFloatVectorQuery(testFieldName, testQueryVector, 100, null)); assertEquals(expectedQuery1, actualQuery1); // efsearch < k @@ -154,7 +150,7 @@ public void testLuceneFloatVectorQuery() { .vectorDataType(VectorDataType.FLOAT) .build() ); - expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnFloatVectorQuery(testFieldName, testQueryVector, testK, null)); + expectedQuery1 = new LuceneEngineKnnVectorQuery(new ProfileKnnFloatVectorQuery(testFieldName, testQueryVector, testK, null)); assertEquals(expectedQuery1, actualQuery1); actualQuery1 = KNNQueryFactory.create( @@ -167,7 +163,7 @@ public void testLuceneFloatVectorQuery() { .vectorDataType(VectorDataType.FLOAT) .build() ); - expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnFloatVectorQuery(testFieldName, testQueryVector, testK, null)); + expectedQuery1 = new LuceneEngineKnnVectorQuery(new ProfileKnnFloatVectorQuery(testFieldName, testQueryVector, testK, null)); assertEquals(expectedQuery1, actualQuery1); } @@ -185,7 +181,7 @@ public void testLuceneByteVectorQuery() { ); // efsearch > k - Query expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnByteVectorQuery(testFieldName, testByteQueryVector, 100, null)); + Query expectedQuery1 = new LuceneEngineKnnVectorQuery(new ProfileKnnByteVectorQuery(testFieldName, testByteQueryVector, 100, null)); assertEquals(expectedQuery1, actualQuery1); // efsearch < k @@ -200,7 +196,7 @@ public void testLuceneByteVectorQuery() { .vectorDataType(VectorDataType.BYTE) .build() ); - expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null)); + expectedQuery1 = new LuceneEngineKnnVectorQuery(new ProfileKnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null)); assertEquals(expectedQuery1, actualQuery1); actualQuery1 = KNNQueryFactory.create( @@ -213,7 +209,7 @@ public void testLuceneByteVectorQuery() { .vectorDataType(VectorDataType.BYTE) .build() ); - expectedQuery1 = new LuceneEngineKnnVectorQuery(new KnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null)); + expectedQuery1 = new LuceneEngineKnnVectorQuery(new ProfileKnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null)); assertEquals(expectedQuery1, actualQuery1); } @@ -507,8 +503,8 @@ public void testCreate_whenExpandNestedDocsQueryWithNmslib_thenCreateKNNQuery() public void testCreate_whenExpandNestedDocsQueryWithLucene_thenCreateExpandNestedDocsQuery() { testExpandNestedDocsQuery(KNNEngine.LUCENE, ExpandNestedDocsQuery.class, VectorDataType.BYTE, true); testExpandNestedDocsQuery(KNNEngine.LUCENE, ExpandNestedDocsQuery.class, VectorDataType.FLOAT, true); - testExpandNestedDocsQuery(KNNEngine.LUCENE, DiversifyingChildrenByteKnnVectorQuery.class, VectorDataType.BYTE, false); - testExpandNestedDocsQuery(KNNEngine.LUCENE, DiversifyingChildrenFloatKnnVectorQuery.class, VectorDataType.FLOAT, false); + testExpandNestedDocsQuery(KNNEngine.LUCENE, ProfileDiversifyingChildrenByteKnnVectorQuery.class, VectorDataType.BYTE, false); + testExpandNestedDocsQuery(KNNEngine.LUCENE, ProfileDiversifyingChildrenFloatKnnVectorQuery.class, VectorDataType.FLOAT, false); } private void testExpandNestedDocsQuery( diff --git a/src/test/java/org/opensearch/knn/index/query/common/QueryUtilsTests.java b/src/test/java/org/opensearch/knn/index/query/common/QueryUtilsTests.java index 6733924ce7..2bb7539069 100644 --- a/src/test/java/org/opensearch/knn/index/query/common/QueryUtilsTests.java +++ b/src/test/java/org/opensearch/knn/index/query/common/QueryUtilsTests.java @@ -20,6 +20,9 @@ import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; import org.junit.Before; +import org.opensearch.knn.profile.query.KNNQueryTimingType; +import org.opensearch.search.profile.ContextualProfileBreakdown; +import org.opensearch.search.profile.Timer; import java.util.Arrays; import java.util.Collections; @@ -30,8 +33,11 @@ import java.util.concurrent.Executor; import java.util.concurrent.Executors; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.times; public class QueryUtilsTests extends TestCase { private Executor executor; @@ -77,6 +83,48 @@ public void testDoSearch_whenExecuted_thenSucceed() { } + @SneakyThrows + public void testDoSearchWithProfile_whenExecuted_thenSucceed() { + IndexSearcher indexSearcher = mock(IndexSearcher.class); + when(indexSearcher.getTaskExecutor()).thenReturn(taskExecutor); + + LeafReaderContext leafReaderContext1 = mock(LeafReaderContext.class); + LeafReaderContext leafReaderContext2 = mock(LeafReaderContext.class); + List leafReaderContexts = Arrays.asList(leafReaderContext1, leafReaderContext2); + + DocIdSetIterator docIdSetIterator = mock(DocIdSetIterator.class); + when(docIdSetIterator.docID()).thenReturn(0, 1, DocIdSetIterator.NO_MORE_DOCS); + Scorer scorer = mock(Scorer.class); + when(scorer.iterator()).thenReturn(docIdSetIterator); + when(scorer.docID()).thenReturn(0, 1, DocIdSetIterator.NO_MORE_DOCS); + when(scorer.score()).thenReturn(10.f, 11.f, -1f); + + Weight weight = mock(Weight.class); + when(weight.scorer(leafReaderContext1)).thenReturn(null); + when(weight.scorer(leafReaderContext2)).thenReturn(scorer); + + ContextualProfileBreakdown profile = mock(ContextualProfileBreakdown.class); + Timer timer = mock(Timer.class); + when(profile.context(any())).thenReturn(profile); + when(profile.getTimer(KNNQueryTimingType.ANN_SEARCH)).thenReturn(timer); + + // Run + List> results = queryUtils.doSearch(indexSearcher, leafReaderContexts, weight, profile); + + // Verify + verify(profile, times(2)).context(any()); + verify(profile, times(2)).getTimer(KNNQueryTimingType.ANN_SEARCH); + verify(timer, times(2)).start(); + verify(timer, times(2)).stop(); + + assertEquals(2, results.size()); + assertEquals(0, results.get(0).size()); + assertEquals(2, results.get(1).size()); + assertEquals(10.f, results.get(1).get(0)); + assertEquals(11.f, results.get(1).get(1)); + + } + @SneakyThrows public void testGetAllSiblings_whenEmptyDocIds_thenEmptyIterator() { LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); diff --git a/src/test/java/org/opensearch/knn/index/query/lucenelib/NestedKnnVectorQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/lucenelib/NestedKnnVectorQueryFactoryTests.java index 5e6570a74c..9402f17eaa 100644 --- a/src/test/java/org/opensearch/knn/index/query/lucenelib/NestedKnnVectorQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/lucenelib/NestedKnnVectorQueryFactoryTests.java @@ -8,8 +8,8 @@ import junit.framework.TestCase; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; -import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery; -import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery; +import org.opensearch.knn.index.query.lucene.ProfileDiversifyingChildrenByteKnnVectorQuery; +import org.opensearch.knn.index.query.lucene.ProfileDiversifyingChildrenFloatKnnVectorQuery; import static org.mockito.Mockito.mock; @@ -50,13 +50,13 @@ public void testCreate_whenNoExpandNestedDocs_thenDiversifyingQuery() { boolean expandNestedDocs = false; assertEquals( - DiversifyingChildrenByteKnnVectorQuery.class, + ProfileDiversifyingChildrenByteKnnVectorQuery.class, NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(fieldName, byteVectors, k, queryFilter, parentFilter, expandNestedDocs) .getClass() ); assertEquals( - DiversifyingChildrenFloatKnnVectorQuery.class, + ProfileDiversifyingChildrenFloatKnnVectorQuery.class, NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(fieldName, floatVectors, k, queryFilter, parentFilter, expandNestedDocs) .getClass() );