package net.librec.recommender.cf.ranking;

import java.util.Iterator;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.recommender.MatrixFactorizationRecommender;

/* loaded from: input_file:net/librec/recommender/cf/ranking/ListwiseMFRecommender.class */
public class ListwiseMFRecommender extends MatrixFactorizationRecommender {
    public DenseVector userExp;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.MatrixFactorizationRecommender, net.librec.recommender.AbstractRecommender
    public void setup() throws LibrecException {
        super.setup();
        this.userExp = new DenseVector(this.numUsers);
        Iterator<MatrixEntry> it = this.trainMatrix.iterator();
        while (it.hasNext()) {
            MatrixEntry next = it.next();
            this.userExp.add(next.row(), Math.exp(next.get()));
        }
    }

    @Override // net.librec.recommender.AbstractRecommender
    protected void trainModel() throws LibrecException {
        for (int i = 1; i <= this.numIterations; i++) {
            this.loss = 0.0d;
            Iterator<MatrixEntry> it = this.trainMatrix.iterator();
            while (it.hasNext()) {
                MatrixEntry next = it.next();
                int row = next.row();
                int column = next.column();
                double d = next.get();
                double rowMult = DenseMatrix.rowMult(this.userFactors, row, this.itemFactors, column);
                double d2 = 0.0d;
                Iterator<Integer> it2 = this.trainMatrix.getColumns(row).iterator();
                while (it2.hasNext()) {
                    d2 += Math.exp(DenseMatrix.rowMult(this.userFactors, row, this.itemFactors, it2.next().intValue()));
                }
                this.loss -= (Math.exp(d) / this.userExp.get(row)) * Math.log(Math.exp(rowMult) / d2);
                for (int i2 = 0; i2 < this.numFactors; i2++) {
                    double d3 = this.userFactors.get(row, i2);
                    double d4 = this.itemFactors.get(column, i2);
                    double exp = ((((Math.exp(d) / this.userExp.get(row)) - (Math.exp(rowMult) / d2)) * Maths.logisticGradientValue(rowMult)) * d4) - (this.regUser * d3);
                    double exp2 = ((((Math.exp(d) / this.userExp.get(row)) - (Math.exp(rowMult) / d2)) * Maths.logisticGradientValue(rowMult)) * d3) - (this.regItem * d4);
                    this.userFactors.add(row, i2, this.learnRate * exp);
                    this.itemFactors.add(column, i2, this.learnRate * exp2);
                    this.loss += (0.5d * this.regUser * d3 * d3) + (0.5d * this.regItem * d4 * d4);
                }
            }
            if (isConverged(i) && this.earlyStop) {
                return;
            }
            updateLRate(i);
        }
    }
}
