/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.search.startree;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.index.codec.composite.CompositeIndexFieldInfo;
import org.opensearch.index.codec.composite.CompositeIndexReader;
import org.opensearch.index.compositeindex.datacube.Dimension;
import org.opensearch.index.compositeindex.datacube.MetricStat;
import org.opensearch.index.compositeindex.datacube.startree.index.StarTreeValues;
import org.opensearch.index.compositeindex.datacube.startree.utils.StarTreeUtils;
import org.opensearch.index.compositeindex.datacube.startree.utils.iterator.SortedNumericStarTreeValuesIterator;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.search.aggregations.Aggregator;
import org.opensearch.search.aggregations.StarTreeBucketCollector;
import org.opensearch.search.aggregations.StarTreePreComputeCollector;
import org.opensearch.search.aggregations.support.ValuesSource;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.startree.StarTreeQueryContext;
import org.opensearch.search.startree.StarTreeTraversalUtil;
import org.opensearch.search.startree.filter.DimensionFilter;
import org.opensearch.search.startree.filter.StarTreeFilter;

public class StarTreeQueryHelper {
    public static boolean isStarTreeSupported(SearchContext context) {
        return context.aggregations() != null && context.mapperService().isCompositeIndexPresent() && context.parsedPostFilter() == null;
    }

    public static CompositeIndexFieldInfo getSupportedStarTree(QueryShardContext context) {
        StarTreeQueryContext starTreeQueryContext = context.getStarTreeQueryContext();
        return starTreeQueryContext != null ? starTreeQueryContext.getStarTree() : null;
    }

    public static StarTreeValues getStarTreeValues(LeafReaderContext context, CompositeIndexFieldInfo starTree) throws IOException {
        SegmentReader reader = Lucene.segmentReader(context.reader());
        DocValuesProducer docValuesProducer = reader.getDocValuesReader();
        if (!(docValuesProducer instanceof CompositeIndexReader)) {
            return null;
        }
        CompositeIndexReader starTreeDocValuesReader = (CompositeIndexReader)docValuesProducer;
        return (StarTreeValues)starTreeDocValuesReader.getCompositeIndexValues(starTree);
    }

    public static void precomputeLeafUsingStarTree(SearchContext context, ValuesSource.Numeric valuesSource, LeafReaderContext ctx, CompositeIndexFieldInfo starTree, String metric, Consumer<Long> valueConsumer, Runnable finalConsumer) throws IOException {
        StarTreeValues starTreeValues = StarTreeQueryHelper.getStarTreeValues(ctx, starTree);
        assert (starTreeValues != null);
        String fieldName = ((ValuesSource.Numeric.FieldData)valuesSource).getIndexFieldName();
        String metricName = StarTreeUtils.fullyQualifiedFieldNameForStarTreeMetricsDocValues(starTree.getField(), fieldName, metric);
        SortedNumericStarTreeValuesIterator valuesIterator = (SortedNumericStarTreeValuesIterator)starTreeValues.getMetricValuesIterator(metricName);
        FixedBitSet filteredValues = StarTreeQueryHelper.getStarTreeFilteredValues(context, ctx, starTreeValues);
        int numBits = filteredValues.length();
        if (numBits > 0) {
            int bit = filteredValues.nextSetBit(0);
            while (bit != Integer.MAX_VALUE) {
                if (valuesIterator.advanceExact(bit)) {
                    int count = valuesIterator.entryValueCount();
                    for (int i = 0; i < count; ++i) {
                        long value = valuesIterator.nextValue();
                        valueConsumer.accept(value);
                    }
                }
                bit = bit + 1 < numBits ? filteredValues.nextSetBit(bit + 1) : Integer.MAX_VALUE;
            }
        }
        finalConsumer.run();
    }

    public static FixedBitSet getStarTreeFilteredValues(SearchContext context, LeafReaderContext ctx, StarTreeValues starTreeValues) throws IOException {
        FixedBitSet result = context.getQueryShardContext().getStarTreeQueryContext().maybeGetCachedNodeIdsForSegment(ctx.ord);
        if (result == null) {
            result = StarTreeTraversalUtil.getStarTreeResult(starTreeValues, context.getQueryShardContext().getStarTreeQueryContext().getBaseQueryStarTreeFilter(), context);
        }
        context.getQueryShardContext().getStarTreeQueryContext().maybeSetCachedNodeIdsForSegment(ctx.ord, result);
        return result;
    }

    public static Dimension getMatchingDimensionOrThrow(String dimensionName, List<Dimension> orderedDimensions) {
        Dimension matchingDimension = StarTreeQueryHelper.getMatchingDimensionOrNull(dimensionName, orderedDimensions);
        if (matchingDimension == null) {
            throw new IllegalStateException("No matching dimension found for [" + dimensionName + "]");
        }
        return matchingDimension;
    }

    public static Dimension getMatchingDimensionOrNull(String dimensionName, List<Dimension> orderedDimensions) {
        List<Dimension> matchingDimensions = orderedDimensions.stream().filter(x -> x.getField().equals(dimensionName)).toList();
        if (matchingDimensions.size() != 1) {
            return null;
        }
        return matchingDimensions.get(0);
    }

    public static StarTreeBucketCollector getStarTreeBucketMetricCollector(final CompositeIndexFieldInfo starTree, final String metric, final ValuesSource.Numeric valuesSource, StarTreeBucketCollector parentCollector, final Consumer<Long> growArrays, final BiConsumer<Long, Long> updateBucket) throws IOException {
        assert (parentCollector != null);
        return new StarTreeBucketCollector(parentCollector){
            String metricName;
            SortedNumericStarTreeValuesIterator metricValuesIterator;
            {
                super(parent);
                this.metricName = StarTreeUtils.fullyQualifiedFieldNameForStarTreeMetricsDocValues(starTree.getField(), ((ValuesSource.Numeric.FieldData)valuesSource).getIndexFieldName(), metric);
                this.metricValuesIterator = (SortedNumericStarTreeValuesIterator)this.starTreeValues.getMetricValuesIterator(this.metricName);
            }

            @Override
            public void collectStarTreeEntry(int starTreeEntryBit, long bucket) throws IOException {
                growArrays.accept(bucket);
                if (!this.metricValuesIterator.advanceExact(starTreeEntryBit)) {
                    return;
                }
                long metricValue = this.metricValuesIterator.nextValue();
                updateBucket.accept(bucket, metricValue);
            }
        };
    }

    public static SortedNumericStarTreeValuesIterator getDocCountsIterator(StarTreeValues starTreeValues, CompositeIndexFieldInfo starTree) {
        String metricName = StarTreeUtils.fullyQualifiedFieldNameForStarTreeMetricsDocValues(starTree.getField(), "_doc_count", MetricStat.DOC_COUNT.getTypeName());
        return (SortedNumericStarTreeValuesIterator)starTreeValues.getMetricValuesIterator(metricName);
    }

    public static void preComputeBucketsWithStarTree(StarTreeBucketCollector starTreeBucketCollector) throws IOException {
        FixedBitSet matchingDocsBitSet = starTreeBucketCollector.getMatchingDocsBitSet();
        int numBits = matchingDocsBitSet.length();
        if (numBits > 0) {
            int bit = matchingDocsBitSet.nextSetBit(0);
            while (bit != Integer.MAX_VALUE) {
                starTreeBucketCollector.collectStarTreeEntry(bit, 0L);
                bit = bit + 1 < numBits ? matchingDocsBitSet.nextSetBit(bit + 1) : Integer.MAX_VALUE;
            }
        }
    }

    public static StarTreeFilter mergeDimensionFilterIfNotExists(StarTreeFilter baseStarTreeFilter, String dimensionToMerge, List<DimensionFilter> dimensionFiltersToMerge) {
        HashMap<String, List<DimensionFilter>> dimensionFilterMap = new HashMap<String, List<DimensionFilter>>(baseStarTreeFilter.getDimensions().size());
        for (String baseDimension : baseStarTreeFilter.getDimensions()) {
            dimensionFilterMap.put(baseDimension, baseStarTreeFilter.getFiltersForDimension(baseDimension));
        }
        if (!dimensionFilterMap.containsKey(dimensionToMerge)) {
            dimensionFilterMap.put(dimensionToMerge, dimensionFiltersToMerge);
        }
        return new StarTreeFilter(dimensionFilterMap);
    }

    public static FixedBitSet getStarTreeResult(StarTreeValues starTreeValues, SearchContext context, List<DimensionFilter> dimensionFiltersToMerge) throws IOException {
        StarTreeFilter starTreeFilter = context.getQueryShardContext().getStarTreeQueryContext().getBaseQueryStarTreeFilter();
        for (DimensionFilter dimensionFilter : dimensionFiltersToMerge) {
            starTreeFilter = StarTreeQueryHelper.mergeDimensionFilterIfNotExists(starTreeFilter, dimensionFilter.getMatchingDimension(), List.of(dimensionFilter));
        }
        return StarTreeTraversalUtil.getStarTreeResult(starTreeValues, starTreeFilter, context);
    }

    public static List<DimensionFilter> collectDimensionFilters(DimensionFilter initialDimensionFilter, Aggregator[] subAggregators) {
        return StarTreeQueryHelper.collectDimensionFilters(List.of(initialDimensionFilter), subAggregators);
    }

    public static List<DimensionFilter> collectDimensionFilters(List<DimensionFilter> initialDimensionFilters, Aggregator[] subAggregators) {
        ArrayList<DimensionFilter> dimensionFiltersToMerge = new ArrayList<DimensionFilter>(initialDimensionFilters);
        for (Aggregator subAgg : subAggregators) {
            if (!(subAgg instanceof StarTreePreComputeCollector)) continue;
            StarTreePreComputeCollector collector = (StarTreePreComputeCollector)((Object)subAgg);
            List<DimensionFilter> childFilters = collector.getDimensionFilters();
            dimensionFiltersToMerge.addAll(childFilters != null ? childFilters : Collections.emptyList());
        }
        return dimensionFiltersToMerge;
    }
}

