/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.query.nativelib;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.Bits;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.query.ExactSearcher;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.KNNWeight;
import org.opensearch.knn.index.query.PerLeafResult;
import org.opensearch.knn.index.query.ResultUtil;
import org.opensearch.knn.index.query.TopDocsDISI;
import org.opensearch.knn.index.query.common.QueryUtils;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.profile.KNNProfileUtil;
import org.opensearch.knn.profile.LongMetric;
import org.opensearch.search.profile.AbstractProfileBreakdown;
import org.opensearch.search.profile.ContextualProfileBreakdown;
import org.opensearch.search.profile.query.QueryProfiler;

public class NativeEngineKnnVectorQuery
extends Query {
    @Generated
    private static final Logger log = LogManager.getLogger(NativeEngineKnnVectorQuery.class);
    private final KNNQuery knnQuery;
    private final QueryUtils queryUtils;
    private final boolean expandNestedDocs;

    public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, float boost) throws IOException {
        List<PerLeafResult> perLeafResults;
        KNNWeight knnWeight;
        IndexReader reader = indexSearcher.getIndexReader();
        QueryProfiler profiler = KNNProfileUtil.getProfiler(indexSearcher);
        if (profiler != null) {
            profiler.getQueryBreakdown((Object)this.knnQuery);
            knnWeight = (KNNWeight)this.knnQuery.createWeight(indexSearcher, scoreMode, 1.0f);
            profiler.pollLastElement();
        } else {
            knnWeight = (KNNWeight)this.knnQuery.createWeight(indexSearcher, scoreMode, 1.0f);
        }
        List leafReaderContexts = reader.leaves();
        RescoreContext rescoreContext = this.knnQuery.getRescoreContext();
        int finalK = this.knnQuery.getK();
        if (rescoreContext == null || !rescoreContext.isRescoreEnabled()) {
            perLeafResults = this.doSearch(indexSearcher, leafReaderContexts, knnWeight, finalK);
        } else {
            boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(this.knnQuery.getIndexName());
            int dimension = this.knnQuery.getQueryVector().length;
            int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension);
            perLeafResults = this.doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK);
            if (!isShardLevelRescoringDisabled) {
                ResultUtil.reduceToTopK(perLeafResults, firstPassK);
            }
            StopWatch stopWatch = new StopWatch().start();
            perLeafResults = this.doRescore(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, finalK);
            long rescoreTime = stopWatch.stop().totalTime().millis();
            log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", (Object)rescoreTime, (Object)firstPassK, (Object)leafReaderContexts.size());
        }
        ResultUtil.reduceToTopK(perLeafResults, finalK);
        if (this.expandNestedDocs) {
            StopWatch stopWatch = new StopWatch().start();
            perLeafResults = this.retrieveAll(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, rescoreContext == null);
            long time_in_millis = stopWatch.stop().totalTime().millis();
            if (log.isDebugEnabled()) {
                long totalNestedDocs = perLeafResults.stream().mapToLong(perLeafResult -> perLeafResult.getResult().scoreDocs.length).sum();
                log.debug("Expanding of nested docs took {} ms. totalNestedDocs:{} ", (Object)time_in_millis, (Object)totalNestedDocs);
            }
        }
        TopDocs[] topDocs = new TopDocs[perLeafResults.size()];
        for (int i = 0; i < perLeafResults.size(); ++i) {
            TopDocs leafTopDocs = perLeafResults.get(i).getResult();
            for (ScoreDoc scoreDoc : leafTopDocs.scoreDocs) {
                scoreDoc.doc += ((LeafReaderContext)leafReaderContexts.get((int)i)).docBase;
            }
            topDocs[i] = leafTopDocs;
        }
        TopDocs topK = TopDocs.merge((int)this.getTotalTopDoc(topDocs), (TopDocs[])topDocs);
        if (topK.scoreDocs.length == 0) {
            return new MatchNoDocsQuery().createWeight(indexSearcher, scoreMode, boost);
        }
        return this.queryUtils.createDocAndScoreQuery(reader, topK, knnWeight).createWeight(indexSearcher, scoreMode, boost);
    }

    private int getTotalTopDoc(TopDocs[] topDocs) {
        if (!this.expandNestedDocs) {
            return this.knnQuery.getK();
        }
        int sum = 0;
        for (TopDocs topDoc : topDocs) {
            sum += topDoc.scoreDocs.length;
        }
        return sum;
    }

    private List<PerLeafResult> retrieveAll(IndexSearcher indexSearcher, List<LeafReaderContext> leafReaderContexts, KNNWeight knnWeight, List<PerLeafResult> perLeafResults, boolean useQuantizedVectors) throws IOException {
        ArrayList<Callable<PerLeafResult>> nestedQueryTasks = new ArrayList<Callable<PerLeafResult>>(leafReaderContexts.size());
        int i = 0;
        while (i < perLeafResults.size()) {
            LeafReaderContext leafReaderContext = leafReaderContexts.get(i);
            QueryProfiler profiler = KNNProfileUtil.getProfiler(indexSearcher);
            int finalI = i++;
            nestedQueryTasks.add(() -> {
                PerLeafResult result = this.retrieveLeafResult(leafReaderContext, knnWeight, perLeafResults, useQuantizedVectors, finalI);
                if (profiler != null) {
                    AbstractProfileBreakdown profile = ((ContextualProfileBreakdown)profiler.getProfileBreakdown((Query)this)).context((Object)leafReaderContext);
                    LongMetric metric = (LongMetric)profile.getMetric("num_nested_docs");
                    metric.setValue(Long.valueOf(result.getResult().scoreDocs.length));
                }
                return result;
            });
        }
        return indexSearcher.getTaskExecutor().invokeAll(nestedQueryTasks);
    }

    private PerLeafResult retrieveLeafResult(LeafReaderContext leafReaderContext, KNNWeight knnWeight, List<PerLeafResult> perLeafResults, boolean useQuantizedVectors, int finalI) throws IOException {
        PerLeafResult perLeafResult = perLeafResults.get(finalI);
        if (perLeafResult.getResult().scoreDocs.length == 0) {
            return perLeafResult;
        }
        Set<Integer> docIds = Arrays.stream(perLeafResult.getResult().scoreDocs).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toSet());
        DocIdSetIterator allSiblings = this.queryUtils.getAllSiblings(leafReaderContext, docIds, this.knnQuery.getParentsFilter(), perLeafResult.getFilterBits());
        ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder().matchedDocsIterator(allSiblings).numberOfMatchedDocs(allSiblings.cost()).useQuantizedVectorsForSearch(useQuantizedVectors).k((int)allSiblings.cost()).field(this.knnQuery.getField()).radius(this.knnQuery.getRadius()).floatQueryVector(this.knnQuery.getQueryVector()).byteQueryVector(this.knnQuery.getByteQueryVector()).isMemoryOptimizedSearchEnabled(this.knnQuery.isMemoryOptimizedSearch()).build();
        TopDocs rescoreResult = knnWeight.exactSearch(leafReaderContext, exactSearcherContext);
        return new PerLeafResult(perLeafResult.getFilterBits(), rescoreResult);
    }

    private List<PerLeafResult> doSearch(IndexSearcher indexSearcher, List<LeafReaderContext> leafReaderContexts, KNNWeight knnWeight, int k) throws IOException {
        ArrayList<Callable<PerLeafResult>> tasks = new ArrayList<Callable<PerLeafResult>>(leafReaderContexts.size());
        for (LeafReaderContext leafReaderContext : leafReaderContexts) {
            tasks.add(() -> this.searchLeaf(leafReaderContext, knnWeight, k));
        }
        return indexSearcher.getTaskExecutor().invokeAll(tasks);
    }

    private List<PerLeafResult> doRescore(IndexSearcher indexSearcher, List<LeafReaderContext> leafReaderContexts, KNNWeight knnWeight, List<PerLeafResult> perLeafResults, int k) throws IOException {
        ArrayList<Callable<PerLeafResult>> rescoreTasks = new ArrayList<Callable<PerLeafResult>>(leafReaderContexts.size());
        int i = 0;
        while (i < perLeafResults.size()) {
            LeafReaderContext leafReaderContext = leafReaderContexts.get(i);
            int finalI = i++;
            rescoreTasks.add(() -> {
                PerLeafResult perLeafeResult = (PerLeafResult)perLeafResults.get(finalI);
                if (perLeafeResult.getResult().scoreDocs.length == 0) {
                    return perLeafeResult;
                }
                Set<Integer> docIds = Arrays.stream(perLeafeResult.getResult().scoreDocs).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toSet());
                TopDocsDISI matchedDocs = this.knnQuery.getParentsFilter() != null ? this.queryUtils.getAllSiblings(leafReaderContext, docIds, this.knnQuery.getParentsFilter(), perLeafeResult.getFilterBits()) : new TopDocsDISI(perLeafeResult.getResult());
                ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder().matchedDocsIterator(matchedDocs).numberOfMatchedDocs(matchedDocs.cost()).useQuantizedVectorsForSearch(false).k(k).radius(this.knnQuery.getRadius()).field(this.knnQuery.getField()).floatQueryVector(this.knnQuery.getQueryVector()).byteQueryVector(this.knnQuery.getByteQueryVector()).isMemoryOptimizedSearchEnabled(this.knnQuery.isMemoryOptimizedSearch()).parentsFilter(this.knnQuery.getParentsFilter()).build();
                TopDocs rescoreResult = knnWeight.exactSearch(leafReaderContext, exactSearcherContext);
                return new PerLeafResult(perLeafeResult.getFilterBits(), rescoreResult);
            });
        }
        return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks);
    }

    private PerLeafResult searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, int k) throws IOException {
        PerLeafResult perLeafResult = queryWeight.searchLeaf(ctx, k);
        Bits liveDocs = ctx.reader().getLiveDocs();
        if (liveDocs != null) {
            ArrayList<ScoreDoc> list = new ArrayList<ScoreDoc>();
            for (ScoreDoc scoreDoc : perLeafResult.getResult().scoreDocs) {
                if (!liveDocs.get(scoreDoc.doc)) continue;
                list.add(scoreDoc);
            }
            ScoreDoc[] filteredScoreDoc = list.toArray(new ScoreDoc[0]);
            TotalHits totalHits = new TotalHits((long)filteredScoreDoc.length, TotalHits.Relation.EQUAL_TO);
            return new PerLeafResult(perLeafResult.getFilterBits(), new TopDocs(totalHits, filteredScoreDoc));
        }
        return perLeafResult;
    }

    public String toString(String field) {
        return ((Object)((Object)this)).getClass().getSimpleName() + "[" + field + "]..." + KNNQuery.class.getSimpleName() + "[" + this.knnQuery.toString() + "]";
    }

    public void visit(QueryVisitor visitor) {
        visitor.visitLeaf((Query)this);
    }

    public boolean equals(Object obj) {
        if (!this.sameClassAs(obj)) {
            return false;
        }
        return this.knnQuery == ((NativeEngineKnnVectorQuery)((Object)obj)).knnQuery;
    }

    public int hashCode() {
        return Objects.hash(this.classHash(), this.knnQuery.hashCode());
    }

    @Generated
    public KNNQuery getKnnQuery() {
        return this.knnQuery;
    }

    @Generated
    public QueryUtils getQueryUtils() {
        return this.queryUtils;
    }

    @Generated
    public boolean isExpandNestedDocs() {
        return this.expandNestedDocs;
    }

    @Generated
    public NativeEngineKnnVectorQuery(KNNQuery knnQuery, QueryUtils queryUtils, boolean expandNestedDocs) {
        this.knnQuery = knnQuery;
        this.queryUtils = queryUtils;
        this.expandNestedDocs = expandNestedDocs;
    }
}

