package net.librec.recommender.cf.ranking;

import com.google.common.cache.LoadingCache;
import com.google.common.collect.Table;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutionException;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.SparseVector;
import net.librec.math.structure.VectorEntry;
import net.librec.recommender.MatrixFactorizationRecommender;

@ModelData({"isRanking", "fismrmse", "P", "Q", "itemBiases", "userBiases"})
/* loaded from: input_file:net/librec/recommender/cf/ranking/FISMrmseRecommender.class */
public class FISMrmseRecommender extends MatrixFactorizationRecommender {
    private int rho;
    private float alpha;
    private int trainMatrixSize;
    private double regBias;
    private DenseVector itemBiases;
    private DenseVector userBiases;
    private DenseMatrix P;
    private DenseMatrix Q;
    protected LoadingCache<Integer, List<Integer>> userItemsCache;
    protected static String cacheSpec;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.MatrixFactorizationRecommender, net.librec.recommender.AbstractRecommender
    public void setup() throws LibrecException {
        super.setup();
        this.P = new DenseMatrix(this.numItems, this.numFactors);
        this.Q = new DenseMatrix(this.numItems, this.numFactors);
        this.P.init(0.01d);
        this.Q.init(0.01d);
        this.itemBiases = new DenseVector(this.numItems);
        this.userBiases = new DenseVector(this.numUsers);
        this.itemBiases.init(0.01d);
        this.userBiases.init(0.01d);
        this.trainMatrixSize = this.trainMatrix.size();
        this.rho = this.conf.getInt("rec.fismrmse.rho").intValue();
        this.alpha = this.conf.getFloat("rec.fismrmse.alpha").floatValue();
        this.regBias = this.conf.getDouble("rec.bias.regularization", Double.valueOf(0.01d)).doubleValue();
        cacheSpec = this.conf.get("guava.cache.spec", "maximumSize=200,expireAfterAccess=2m");
        this.userItemsCache = this.trainMatrix.rowColumnsCache(cacheSpec);
    }

    @Override // net.librec.recommender.AbstractRecommender
    protected void trainModel() throws LibrecException {
        int i = this.rho * this.trainMatrixSize;
        int i2 = this.numUsers * this.numItems;
        for (int i3 = 1; i3 <= this.numIterations; i3++) {
            this.loss = 0.0d;
            DenseMatrix denseMatrix = new DenseMatrix(this.numItems, this.numFactors);
            DenseMatrix denseMatrix2 = new DenseMatrix(this.numItems, this.numFactors);
            Table<Integer, Integer, Double> dataTable = this.trainMatrix.getDataTable();
            List<Integer> list = null;
            try {
                list = Randoms.randInts(i, 0, i2 - this.trainMatrixSize);
            } catch (Exception e) {
                e.printStackTrace();
            }
            int i4 = 0;
            int i5 = 0;
            boolean z = false;
            for (int i6 = 0; i6 < this.numUsers; i6++) {
                int i7 = 0;
                while (true) {
                    if (i7 >= this.numItems) {
                        break;
                    }
                    if (this.trainMatrix.get(i6, i7) == 0.0d) {
                        int i8 = i5;
                        i5++;
                        if (i8 == list.get(i4).intValue()) {
                            dataTable.put(Integer.valueOf(i6), Integer.valueOf(i7), Double.valueOf(0.0d));
                            i4++;
                            if (i4 >= list.size()) {
                                z = true;
                                break;
                            }
                        } else {
                            continue;
                        }
                    }
                    i7++;
                }
                if (z) {
                    break;
                }
            }
            for (Table.Cell cell : dataTable.cellSet()) {
                int intValue = ((Integer) cell.getRowKey()).intValue();
                int intValue2 = ((Integer) cell.getColumnKey()).intValue();
                double doubleValue = ((Double) cell.getValue()).doubleValue();
                SparseVector row = this.trainMatrix.row(intValue);
                double d = this.userBiases.get(intValue);
                double d2 = this.itemBiases.get(intValue2);
                double d3 = 0.0d;
                int i9 = 0;
                Iterator<VectorEntry> it = row.iterator();
                while (it.hasNext()) {
                    int index = it.next().index();
                    if (index != intValue2) {
                        d3 += DenseMatrix.rowMult(this.P, index, this.Q, intValue2);
                        i9++;
                    }
                }
                double pow = i9 > 0 ? Math.pow(i9, -this.alpha) : 0.0d;
                double d4 = ((d + d2) + (pow * d3)) - doubleValue;
                this.loss += d4 * d4;
                this.userBiases.add(intValue, (-this.learnRate) * (d4 + (this.regBias * d)));
                this.itemBiases.add(intValue2, (-this.learnRate) * (d4 + (this.regBias * d2)));
                this.loss += (this.regBias * d * d) + (this.regBias * d2 * d2);
                for (int i10 = 0; i10 < this.numFactors; i10++) {
                    double d5 = this.Q.get(intValue2, i10);
                    double d6 = 0.0d;
                    Iterator<VectorEntry> it2 = row.iterator();
                    while (it2.hasNext()) {
                        int index2 = it2.next().index();
                        if (index2 != intValue2) {
                            d6 += this.P.get(index2, i10);
                        }
                    }
                    denseMatrix2.add(intValue2, i10, (-this.learnRate) * ((d4 * pow * d6) + (this.regItem * d5)));
                    this.loss += this.regItem * d5 * d5;
                }
                Iterator<VectorEntry> it3 = row.iterator();
                while (it3.hasNext()) {
                    int index3 = it3.next().index();
                    if (index3 != intValue2) {
                        for (int i11 = 0; i11 < this.numFactors; i11++) {
                            double d7 = this.P.get(index3, i11);
                            denseMatrix.add(index3, i11, (-this.learnRate) * ((d4 * pow * this.Q.get(intValue2, i11)) + (this.regItem * d7)));
                            this.loss += this.regItem * d7 * d7;
                        }
                    }
                }
            }
            this.P = this.P.add(denseMatrix);
            this.Q = this.Q.add(denseMatrix2);
            this.loss *= 0.5d;
            if (isConverged(i3) && this.earlyStop) {
                return;
            }
            updateLRate(i3);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.MatrixFactorizationRecommender, net.librec.recommender.AbstractRecommender
    public double predict(int i, int i2) throws LibrecException {
        double d = this.userBiases.get(i) + this.itemBiases.get(i2);
        double d2 = 0.0d;
        int i3 = 0;
        List list = null;
        try {
            list = (List) this.userItemsCache.get(Integer.valueOf(i));
        } catch (ExecutionException e) {
            e.printStackTrace();
        }
        Iterator it = list.iterator();
        while (it.hasNext()) {
            int intValue = ((Integer) it.next()).intValue();
            if (intValue != i2) {
                d2 += DenseMatrix.rowMult(this.P, intValue, this.Q, i2);
                i3++;
            }
        }
        return d + ((i3 > 0 ? Math.pow(i3, -this.alpha) : 0.0d) * d2);
    }
}
