/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.memoryoptsearch.faiss;

import java.io.IOException;
import java.util.EnumMap;
import java.util.Map;
import lombok.Generated;
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.opensearch.knn.index.KNNVectorSimilarityFunction;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.plugin.script.KNNScoringUtil;

public final class FlatVectorsScorerProvider {
    private static final FlatVectorsScorer DELEGATE_VECTOR_SCORER = FlatVectorScorerUtil.getLucene99FlatVectorsScorer();
    private static final FlatVectorsScorer HAMMING_VECTOR_SCORER = new HammingFlatVectorsScorer();
    private static final Map<SpaceType, FlatVectorsScorer> ADC_FLAT_SCORERS = FlatVectorsScorerProvider.initializeAdcFlatScorers();

    private static Map<SpaceType, FlatVectorsScorer> initializeAdcFlatScorers() {
        EnumMap<SpaceType, FlatVectorsScorer> scorers = new EnumMap<SpaceType, FlatVectorsScorer>(SpaceType.class);
        scorers.put(SpaceType.L2, new ADCFlatVectorsScorer(KNNVectorSimilarityFunction.EUCLIDEAN, SpaceType.L2));
        scorers.put(SpaceType.COSINESIMIL, new ADCFlatVectorsScorer(KNNVectorSimilarityFunction.COSINE, SpaceType.COSINESIMIL));
        scorers.put(SpaceType.INNER_PRODUCT, new ADCFlatVectorsScorer(KNNVectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, SpaceType.INNER_PRODUCT));
        return scorers;
    }

    public static FlatVectorsScorer getFlatVectorsScorer(KNNVectorSimilarityFunction similarityFunction) {
        return FlatVectorsScorerProvider.getFlatVectorsScorer(similarityFunction, false, null);
    }

    public static FlatVectorsScorer getFlatVectorsScorer(KNNVectorSimilarityFunction similarityFunction, boolean isAdc, SpaceType spaceType) {
        if (isAdc) {
            return ADC_FLAT_SCORERS.get((Object)spaceType);
        }
        if (similarityFunction == KNNVectorSimilarityFunction.HAMMING) {
            return HAMMING_VECTOR_SCORER;
        }
        return DELEGATE_VECTOR_SCORER;
    }

    @Generated
    private FlatVectorsScorerProvider() {
        throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
    }

    public static class ADCFlatVectorsScorer
    implements FlatVectorsScorer {
        private final KNNVectorSimilarityFunction knnSimilarityFunction;
        private final SpaceType spaceType;

        public ADCFlatVectorsScorer(KNNVectorSimilarityFunction knnSimilarityFunction, SpaceType spaceType) {
            this.knnSimilarityFunction = knnSimilarityFunction;
            this.spaceType = spaceType;
        }

        public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction vectorSimilarityFunction, KnnVectorValues knnVectorValues, byte[] target) {
            throw new UnsupportedOperationException("ADC does not support byte vector search");
        }

        public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction vectorSimilarityFunction, KnnVectorValues knnVectorValues, final float[] target) {
            if (!(knnVectorValues instanceof ByteVectorValues)) {
                throw new IllegalArgumentException("Expected " + ByteVectorValues.class.getSimpleName() + " for ADC scorer, got " + knnVectorValues.getClass().getSimpleName());
            }
            final ByteVectorValues byteVectorValues = (ByteVectorValues)knnVectorValues;
            return switch (this.spaceType) {
                case SpaceType.L2 -> new RandomVectorScorer.AbstractRandomVectorScorer(this, knnVectorValues){

                    public float score(int internalVectorId) throws IOException {
                        byte[] quantizedByteVector = byteVectorValues.vectorValue(internalVectorId);
                        return SpaceType.L2.scoreTranslation(KNNScoringUtil.l2SquaredADC(target, quantizedByteVector));
                    }
                };
                case SpaceType.COSINESIMIL -> new RandomVectorScorer.AbstractRandomVectorScorer(this, knnVectorValues){

                    public float score(int internalVectorId) throws IOException {
                        byte[] quantizedByteVector = byteVectorValues.vectorValue(internalVectorId);
                        return SpaceType.COSINESIMIL.scoreTranslation(1.0f - KNNScoringUtil.innerProductADC(target, quantizedByteVector));
                    }
                };
                case SpaceType.INNER_PRODUCT -> new RandomVectorScorer.AbstractRandomVectorScorer(this, knnVectorValues){

                    public float score(int internalVectorId) throws IOException {
                        byte[] quantizedByteVector = byteVectorValues.vectorValue(internalVectorId);
                        return SpaceType.INNER_PRODUCT.scoreTranslation(-1.0f * KNNScoringUtil.innerProductADC(target, quantizedByteVector));
                    }
                };
                default -> throw new IllegalArgumentException("Unsupported space type: " + String.valueOf((Object)this.spaceType));
            };
        }

        public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) throws IOException {
            throw new UnsupportedOperationException("ADC does not support RandomVectorScorerSupplier");
        }
    }

    private static class HammingFlatVectorsScorer
    implements FlatVectorsScorer {
        private HammingFlatVectorsScorer() {
        }

        public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction vectorSimilarityFunction, KnnVectorValues knnVectorValues, final byte[] target) {
            if (knnVectorValues instanceof ByteVectorValues) {
                final ByteVectorValues byteVectorValues = (ByteVectorValues)knnVectorValues;
                return new RandomVectorScorer.AbstractRandomVectorScorer(this, knnVectorValues){

                    public float score(int internalVectorId) throws IOException {
                        byte[] quantizedByteVector = byteVectorValues.vectorValue(internalVectorId);
                        return KNNVectorSimilarityFunction.HAMMING.compare(target, quantizedByteVector);
                    }
                };
            }
            throw new IllegalArgumentException("Expected " + ByteVectorValues.class.getSimpleName() + " for hamming vector scorer, got " + knnVectorValues.getClass().getSimpleName());
        }

        public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction vectorSimilarityFunction, KnnVectorValues knnVectorValues) {
            throw new UnsupportedOperationException();
        }

        public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction vectorSimilarityFunction, KnnVectorValues knnVectorValues, float[] target) {
            throw new UnsupportedOperationException();
        }
    }
}

