package net.librec.recommender.context.rating;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.MatrixEntry;
import net.librec.recommender.SocialRecommender;

@ModelData({"isRating", "sorec", "userFactors", "itemFactors"})
/* loaded from: input_file:net/librec/recommender/context/rating/SoRecRecommender.class */
public class SoRecRecommender extends SocialRecommender {
    private DenseMatrix userSocialFactors;
    private float regRateSocial;
    private float regUserSocial;
    private List<Integer> inDegrees;
    private List<Integer> outDegrees;

    @Override // net.librec.recommender.SocialRecommender, net.librec.recommender.MatrixFactorizationRecommender, net.librec.recommender.AbstractRecommender
    public void setup() throws LibrecException {
        super.setup();
        this.regRateSocial = this.conf.getFloat("rec.rate.social.regularization", Float.valueOf(0.01f)).floatValue();
        this.regUserSocial = this.conf.getFloat("rec.user.social.regularization", Float.valueOf(0.01f)).floatValue();
        this.userSocialFactors = new DenseMatrix(this.numUsers, this.numFactors);
        this.userSocialFactors.init(this.initMean, this.initStd);
        this.inDegrees = new ArrayList();
        this.outDegrees = new ArrayList();
        for (int i = 0; i < this.numUsers; i++) {
            int columnSize = this.socialMatrix.columnSize(i);
            int rowSize = this.socialMatrix.rowSize(i);
            this.inDegrees.add(Integer.valueOf(columnSize));
            this.outDegrees.add(Integer.valueOf(rowSize));
        }
    }

    @Override // net.librec.recommender.AbstractRecommender
    protected void trainModel() throws LibrecException {
        for (int i = 1; i <= this.numIterations; i++) {
            this.loss = 0.0d;
            DenseMatrix denseMatrix = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix denseMatrix2 = new DenseMatrix(this.numItems, this.numFactors);
            DenseMatrix denseMatrix3 = new DenseMatrix(this.numUsers, this.numFactors);
            Iterator<MatrixEntry> it = this.trainMatrix.iterator();
            while (it.hasNext()) {
                MatrixEntry next = it.next();
                int row = next.row();
                int column = next.column();
                double d = next.get();
                double predict = predict(row, column);
                double logistic = Maths.logistic(predict) - Maths.normalize(d, this.minRate, this.maxRate);
                this.loss += logistic * logistic;
                for (int i2 = 0; i2 < this.numFactors; i2++) {
                    double d2 = this.userFactors.get(row, i2);
                    double d3 = this.itemFactors.get(column, i2);
                    denseMatrix.add(row, i2, (Maths.logisticGradientValue(predict) * logistic * d3) + (this.regUser * d2));
                    denseMatrix2.add(column, i2, (Maths.logisticGradientValue(predict) * logistic * d2) + (this.regItem * d3));
                    this.loss += (this.regUser * d2 * d2) + (this.regItem * d3 * d3);
                }
            }
            Iterator<MatrixEntry> it2 = this.socialMatrix.iterator();
            while (it2.hasNext()) {
                MatrixEntry next2 = it2.next();
                int row2 = next2.row();
                int column2 = next2.column();
                double d4 = next2.get();
                if (d4 > 0.0d) {
                    double rowMult = DenseMatrix.rowMult(this.userFactors, row2, this.userSocialFactors, column2);
                    double logistic2 = Maths.logistic(rowMult) - (Math.sqrt(this.inDegrees.get(column2).intValue() / ((this.outDegrees.get(row2).intValue() + r0) + 0.0d)) * d4);
                    this.loss += this.regRateSocial * logistic2 * logistic2;
                    for (int i3 = 0; i3 < this.numFactors; i3++) {
                        double d5 = this.userFactors.get(row2, i3);
                        double d6 = this.userSocialFactors.get(column2, i3);
                        denseMatrix.add(row2, i3, this.regRateSocial * Maths.logisticGradientValue(rowMult) * logistic2 * d6);
                        denseMatrix3.add(column2, i3, (this.regRateSocial * Maths.logisticGradientValue(rowMult) * logistic2 * d5) + (this.regUserSocial * d6));
                        this.loss += this.regUserSocial * d6 * d6;
                    }
                }
            }
            this.userFactors = this.userFactors.add(denseMatrix.scale(-this.learnRate));
            this.itemFactors = this.itemFactors.add(denseMatrix2.scale(-this.learnRate));
            this.userSocialFactors = this.userSocialFactors.add(denseMatrix3.scale(-this.learnRate));
            this.loss *= 0.5d;
            if (isConverged(i) && this.earlyStop) {
                return;
            }
            updateLRate(i);
        }
    }
}
