/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.internal.vectorization;

import java.io.IOException;
import java.lang.foreign.MemorySegment;
import java.util.Optional;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.vectorization.MemorySegmentBulkVectorOps;
import org.apache.lucene.store.FilterIndexInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.MemorySegmentAccessInput;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;

abstract sealed class Lucene99MemorySegmentFloatVectorScorer
extends RandomVectorScorer.AbstractRandomVectorScorer {
    final FloatVectorValues values;
    final int vectorByteSize;
    final MemorySegment seg;
    final float[] query;
    final float[] scratchScores = new float[4];

    public static Optional<Lucene99MemorySegmentFloatVectorScorer> create(VectorSimilarityFunction type, IndexInput input, FloatVectorValues values, float[] query) throws IOException {
        MemorySegmentAccessInput msInput;
        MemorySegment seg;
        if (!((input = FilterIndexInput.unwrapOnlyTest(input)) instanceof MemorySegmentAccessInput) || (seg = (msInput = (MemorySegmentAccessInput)((Object)input)).segmentSliceOrNull(0L, msInput.length())) == null) {
            return Optional.empty();
        }
        Lucene99MemorySegmentFloatVectorScorer.checkInvariants(values.size(), values.getVectorByteLength(), input);
        return switch (type) {
            default -> throw new MatchException(null, null);
            case VectorSimilarityFunction.COSINE -> Optional.of(new CosineScorer(seg, values, query));
            case VectorSimilarityFunction.DOT_PRODUCT -> Optional.of(new DotProductScorer(seg, values, query));
            case VectorSimilarityFunction.EUCLIDEAN -> Optional.of(new EuclideanScorer(seg, values, query));
            case VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT -> Optional.of(new MaxInnerProductScorer(seg, values, query));
        };
    }

    Lucene99MemorySegmentFloatVectorScorer(MemorySegment seg, FloatVectorValues values, float[] query) {
        super(values);
        this.values = values;
        this.seg = seg;
        this.vectorByteSize = values.getVectorByteLength();
        this.query = query;
    }

    static void checkInvariants(int maxOrd, int vectorByteLength, IndexInput input) {
        if (input.length() < (long)vectorByteLength * (long)maxOrd) {
            throw new IllegalArgumentException("input length is less than expected vector data");
        }
    }

    final void checkOrdinal(int ord) {
        if (ord < 0 || ord >= this.maxOrd()) {
            throw new IllegalArgumentException("illegal ordinal: " + ord);
        }
    }

    @Override
    public float bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException {
        int i;
        int limit = numNodes & 0xFFFFFFFC;
        float maxScore = Float.NEGATIVE_INFINITY;
        for (i = 0; i < limit; i += 4) {
            long offset1 = (long)nodes[i] * (long)this.vectorByteSize;
            long offset2 = (long)nodes[i + 1] * (long)this.vectorByteSize;
            long offset3 = (long)nodes[i + 2] * (long)this.vectorByteSize;
            long offset4 = (long)nodes[i + 3] * (long)this.vectorByteSize;
            this.vectorOp(this.seg, this.scratchScores, offset1, offset2, offset3, offset4, this.query.length);
            scores[i + 0] = this.normalizeRawScore(this.scratchScores[0]);
            maxScore = Math.max(maxScore, scores[i + 0]);
            scores[i + 1] = this.normalizeRawScore(this.scratchScores[1]);
            maxScore = Math.max(maxScore, scores[i + 1]);
            scores[i + 2] = this.normalizeRawScore(this.scratchScores[2]);
            maxScore = Math.max(maxScore, scores[i + 2]);
            scores[i + 3] = this.normalizeRawScore(this.scratchScores[3]);
            maxScore = Math.max(maxScore, scores[i + 3]);
        }
        int remaining = numNodes - i;
        if (remaining > 0) {
            long addr1 = (long)nodes[i] * (long)this.vectorByteSize;
            long addr2 = remaining > 1 ? (long)nodes[i + 1] * (long)this.vectorByteSize : addr1;
            long addr3 = remaining > 2 ? (long)nodes[i + 2] * (long)this.vectorByteSize : addr1;
            this.vectorOp(this.seg, this.scratchScores, addr1, addr2, addr3, addr3, this.query.length);
            scores[i] = this.normalizeRawScore(this.scratchScores[0]);
            maxScore = Math.max(maxScore, scores[i]);
            if (remaining > 1) {
                scores[i + 1] = this.normalizeRawScore(this.scratchScores[1]);
                maxScore = Math.max(maxScore, scores[i + 1]);
            }
            if (remaining > 2) {
                scores[i + 2] = this.normalizeRawScore(this.scratchScores[2]);
                maxScore = Math.max(maxScore, scores[i + 2]);
            }
        }
        return maxScore;
    }

    abstract void vectorOp(MemorySegment var1, float[] var2, long var3, long var5, long var7, long var9, int var11);

    abstract float normalizeRawScore(float var1);

    static final class CosineScorer
    extends Lucene99MemorySegmentFloatVectorScorer {
        static final MemorySegmentBulkVectorOps.Cosine COS_OPS = MemorySegmentBulkVectorOps.COS_INSTANCE;

        CosineScorer(MemorySegment seg, FloatVectorValues values, float[] query) {
            super(seg, values, query);
        }

        @Override
        public float score(int node) throws IOException {
            this.checkOrdinal(node);
            return VectorSimilarityFunction.COSINE.compare(this.query, this.values.vectorValue(node));
        }

        @Override
        void vectorOp(MemorySegment seg, float[] scores, long node1Offset, long node2Offset, long node3Offset, long node4Offset, int elementCount) {
            COS_OPS.cosineBulk(seg, scores, this.query, node1Offset, node2Offset, node3Offset, node4Offset, elementCount);
        }

        @Override
        float normalizeRawScore(float rawScore) {
            return VectorUtil.normalizeToUnitInterval(rawScore);
        }
    }

    static final class DotProductScorer
    extends Lucene99MemorySegmentFloatVectorScorer {
        static final MemorySegmentBulkVectorOps.DotProduct DOT_OPS = MemorySegmentBulkVectorOps.DOT_INSTANCE;

        DotProductScorer(MemorySegment input, FloatVectorValues values, float[] query) {
            super(input, values, query);
        }

        @Override
        public float score(int node) throws IOException {
            this.checkOrdinal(node);
            return VectorSimilarityFunction.DOT_PRODUCT.compare(this.query, this.values.vectorValue(node));
        }

        @Override
        void vectorOp(MemorySegment seg, float[] scores, long node1Offset, long node2Offset, long node3Offset, long node4Offset, int elementCount) {
            DOT_OPS.dotProductBulk(seg, scores, this.query, node1Offset, node2Offset, node3Offset, node4Offset, elementCount);
        }

        @Override
        float normalizeRawScore(float rawScore) {
            return VectorUtil.normalizeToUnitInterval(rawScore);
        }
    }

    static final class EuclideanScorer
    extends Lucene99MemorySegmentFloatVectorScorer {
        static final MemorySegmentBulkVectorOps.SqrDistance SQR_OPS = MemorySegmentBulkVectorOps.SQR_INSTANCE;

        EuclideanScorer(MemorySegment seg, FloatVectorValues values, float[] query) {
            super(seg, values, query);
        }

        @Override
        public float score(int node) throws IOException {
            this.checkOrdinal(node);
            return VectorSimilarityFunction.EUCLIDEAN.compare(this.query, this.values.vectorValue(node));
        }

        @Override
        void vectorOp(MemorySegment seg, float[] scores, long node1Offset, long node2Offset, long node3Offset, long node4Offset, int elementCount) {
            SQR_OPS.sqrDistanceBulk(seg, scores, this.query, node1Offset, node2Offset, node3Offset, node4Offset, elementCount);
        }

        @Override
        float normalizeRawScore(float rawScore) {
            return VectorUtil.normalizeDistanceToUnitInterval(rawScore);
        }
    }

    static final class MaxInnerProductScorer
    extends Lucene99MemorySegmentFloatVectorScorer {
        static final MemorySegmentBulkVectorOps.DotProduct DOT_OPS = MemorySegmentBulkVectorOps.DOT_INSTANCE;

        MaxInnerProductScorer(MemorySegment seg, FloatVectorValues values, float[] query) {
            super(seg, values, query);
        }

        @Override
        public float score(int node) throws IOException {
            this.checkOrdinal(node);
            return VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT.compare(this.query, this.values.vectorValue(node));
        }

        @Override
        void vectorOp(MemorySegment seg, float[] scores, long node1Offset, long node2Offset, long node3Offset, long node4Offset, int elementCount) {
            DOT_OPS.dotProductBulk(seg, scores, this.query, node1Offset, node2Offset, node3Offset, node4Offset, elementCount);
        }

        @Override
        float normalizeRawScore(float rawScore) {
            return VectorUtil.scaleMaxInnerProductScore(rawScore);
        }
    }
}

