package net.librec.recommender.cf.ranking;

import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.SparseVector;
import net.librec.math.structure.SymmMatrix;
import net.librec.math.structure.VectorEntry;
import net.librec.recommender.AbstractRecommender;
import net.librec.util.Lists;

@ModelData({"isRanking", "slim", "coefficientMatrix", "trainMatrix", "similarityMatrix", "knn"})
/* loaded from: input_file:net/librec/recommender/cf/ranking/SLIMRecommender.class */
public class SLIMRecommender extends AbstractRecommender {
    protected int numIterations;
    private DenseMatrix coefficientMatrix;
    private Set<Integer>[] itemNNs;
    private float regL1Norm;
    private float regL2Norm;
    protected static int knn;
    private SymmMatrix similarityMatrix;
    private Set<Integer> allItems;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.AbstractRecommender
    public void setup() throws LibrecException {
        super.setup();
        knn = this.conf.getInt("rec.neighbors.knn.number", 50).intValue();
        this.numIterations = this.conf.getInt("rec.iterator.maximum").intValue();
        this.regL1Norm = this.conf.getFloat("rec.slim.regularization.l1", Float.valueOf(1.0f)).floatValue();
        this.regL2Norm = this.conf.getFloat("rec.slim.regularization.l2", Float.valueOf(1.0f)).floatValue();
        this.coefficientMatrix = new DenseMatrix(this.numItems, this.numItems);
        this.coefficientMatrix.init();
        this.similarityMatrix = this.context.getSimilarity().getSimilarityMatrix();
        for (int i = 0; i < this.numItems; i++) {
            this.coefficientMatrix.set(i, i, 0.0d);
        }
        createItemNNs();
    }

    @Override // net.librec.recommender.AbstractRecommender
    protected void trainModel() throws LibrecException {
        for (int i = 1; i <= this.numIterations; i++) {
            this.loss = 0.0d;
            for (int i2 = 0; i2 < this.numItems; i2++) {
                Set<Integer> set = knn > 0 ? this.itemNNs[i2] : this.allItems;
                double[] dArr = new double[this.numUsers];
                Iterator<VectorEntry> rowIterator = this.trainMatrix.rowIterator(i2);
                while (rowIterator.hasNext()) {
                    VectorEntry next = rowIterator.next();
                    dArr[next.index()] = next.get();
                }
                for (Integer num : set) {
                    if (num.intValue() != i2) {
                        double d = 0.0d;
                        double d2 = 0.0d;
                        double d3 = 0.0d;
                        Iterator<VectorEntry> rowIterator2 = this.trainMatrix.rowIterator(num.intValue());
                        if (rowIterator2.hasNext()) {
                            int i3 = 0;
                            while (rowIterator2.hasNext()) {
                                VectorEntry next2 = rowIterator2.next();
                                int index = next2.index();
                                double d4 = next2.get();
                                double predict = dArr[index] - predict(index, i2, num.intValue());
                                d += d4 * predict;
                                d2 += d4 * d4;
                                d3 += predict * predict;
                                i3++;
                            }
                            double d5 = d / i3;
                            double d6 = d2 / i3;
                            double d7 = this.coefficientMatrix.get(num.intValue(), i2);
                            this.loss += (d3 / i3) + (0.5d * this.regL2Norm * d7 * d7) + (this.regL1Norm * d7);
                            this.coefficientMatrix.set(num.intValue(), i2, this.regL1Norm < Math.abs(d5) ? d5 > 0.0d ? (d5 - this.regL1Norm) / (this.regL2Norm + d6) : (d5 + this.regL1Norm) / (this.regL2Norm + d6) : 0.0d);
                        }
                    }
                }
            }
            if (isConverged(i) && this.earlyStop) {
                return;
            }
        }
    }

    protected double predict(int i, int i2, int i3) {
        double d = 0.0d;
        Iterator<VectorEntry> colIterator = this.trainMatrix.colIterator(i);
        while (colIterator.hasNext()) {
            VectorEntry next = colIterator.next();
            int index = next.index();
            double d2 = next.get();
            if (this.itemNNs[i2].contains(Integer.valueOf(index)) && index != i3) {
                d += d2 * this.coefficientMatrix.get(index, i2);
            }
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.AbstractRecommender
    public boolean isConverged(int i) {
        double d = this.lastLoss - this.loss;
        this.lastLoss = this.loss;
        if (verbose) {
            this.LOG.info(getClass().getSimpleName().toString() + " iter " + i + ": loss = " + this.loss + ", delta_loss = " + d);
        }
        return i > 1 && d < 1.0E-5d;
    }

    @Override // net.librec.recommender.AbstractRecommender
    protected double predict(int i, int i2) throws LibrecException {
        if (null == this.itemNNs || this.itemNNs.length <= 0) {
            createItemNNs();
        }
        return predict(i, i2, -1);
    }

    public void createItemNNs() {
        this.itemNNs = new HashSet[this.numItems];
        if (knn <= 0) {
            this.allItems = new HashSet(this.trainMatrix.columns());
            return;
        }
        for (int i = 0; i < this.numItems; i++) {
            SparseVector row = this.similarityMatrix.row(i);
            if (knn < row.size()) {
                ArrayList arrayList = new ArrayList(row.size() + 1);
                Iterator<VectorEntry> it = row.iterator();
                while (it.hasNext()) {
                    VectorEntry next = it.next();
                    arrayList.add(new AbstractMap.SimpleImmutableEntry(Integer.valueOf(next.index()), Double.valueOf(next.get())));
                }
                List sortListTopK = Lists.sortListTopK(arrayList, true, knn);
                this.itemNNs[i] = new HashSet((int) (sortListTopK.size() / 0.5d));
                Iterator it2 = sortListTopK.iterator();
                while (it2.hasNext()) {
                    this.itemNNs[i].add(((Map.Entry) it2.next()).getKey());
                }
            } else {
                this.itemNNs[i] = row.getIndexSet();
            }
        }
    }
}
