package net.librec.recommender.context.rating;

import java.util.Iterator;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.SparseVector;
import net.librec.math.structure.VectorEntry;
import net.librec.recommender.SocialRecommender;

@ModelData({"isRating", "rste", "userFactors", "itemFactors", "userSocialRatio", "socialMatrix"})
/* loaded from: input_file:net/librec/recommender/context/rating/RSTERecommender.class */
public class RSTERecommender extends SocialRecommender {
    private float userSocialRatio;

    @Override // net.librec.recommender.SocialRecommender, net.librec.recommender.MatrixFactorizationRecommender, net.librec.recommender.AbstractRecommender
    public void setup() throws LibrecException {
        super.setup();
        this.userSocialRatio = this.conf.getFloat("rec.user.social.ratio", Float.valueOf(0.8f)).floatValue();
    }

    @Override // net.librec.recommender.AbstractRecommender
    protected void trainModel() throws LibrecException {
        for (int i = 1; i <= this.numIterations; i++) {
            this.loss = 0.0d;
            DenseMatrix denseMatrix = new DenseMatrix(this.numUsers, this.numFactors);
            DenseMatrix denseMatrix2 = new DenseMatrix(this.numItems, this.numFactors);
            for (int i2 = 0; i2 < this.numUsers; i2++) {
                SparseVector row = this.socialMatrix.row(i2);
                int[] index = row.getIndex();
                double d = 0.0d;
                for (int i3 : index) {
                    d += row.get(i3);
                }
                double[] dArr = new double[this.numFactors];
                for (int i4 = 0; i4 < this.numFactors; i4++) {
                    for (int i5 : index) {
                        int i6 = i4;
                        dArr[i6] = dArr[i6] + (row.get(i5) * this.userFactors.get(i5, i4));
                    }
                }
                Iterator<VectorEntry> it = this.trainMatrix.row(i2).iterator();
                while (it.hasNext()) {
                    VectorEntry next = it.next();
                    int index2 = next.index();
                    double normalize = Maths.normalize(next.get(), this.minRate, this.maxRate);
                    double rowMult = DenseMatrix.rowMult(this.userFactors, i2, this.itemFactors, index2);
                    double d2 = 0.0d;
                    for (int i7 : index) {
                        d2 += row.get(i7) * DenseMatrix.rowMult(this.userFactors, i7, this.itemFactors, index2);
                    }
                    double d3 = (this.userSocialRatio * rowMult) + ((1.0f - this.userSocialRatio) * (d > 0.0d ? d2 / d : 0.0d));
                    double logistic = Maths.logistic(d3) - normalize;
                    this.loss += logistic * logistic;
                    double logisticGradientValue = Maths.logisticGradientValue(d3) * logistic;
                    for (int i8 = 0; i8 < this.numFactors; i8++) {
                        double d4 = this.userFactors.get(i2, i8);
                        double d5 = this.itemFactors.get(index2, i8);
                        double d6 = (this.userSocialRatio * logisticGradientValue * d5) + (this.regUser * d4);
                        double d7 = (logisticGradientValue * ((this.userSocialRatio * d4) + ((1.0f - this.userSocialRatio) * (d > 0.0d ? dArr[i8] / d : 0.0d)))) + (this.regItem * d5);
                        denseMatrix.add(i2, i8, d6);
                        denseMatrix2.add(index2, i8, d7);
                        this.loss += (this.regUser * d4 * d4) + (this.regItem * d5 * d5);
                    }
                }
            }
            for (int i9 = 0; i9 < this.numUsers; i9++) {
                for (int i10 : this.socialMatrix.column(i9).getIndex()) {
                    if (i10 < this.numUsers) {
                        SparseVector row2 = this.trainMatrix.row(i10);
                        SparseVector row3 = this.socialMatrix.row(i10);
                        int[] index3 = row3.getIndex();
                        for (int i11 : row2.getIndex()) {
                            double rowMult2 = DenseMatrix.rowMult(this.userFactors, i10, this.itemFactors, i11);
                            double d8 = 0.0d;
                            double d9 = 0.0d;
                            for (int i12 : index3) {
                                double d10 = row3.get(i12);
                                d8 += d10 * DenseMatrix.rowMult(this.userFactors, i12, this.itemFactors, i11);
                                d9 += d10;
                            }
                            double d11 = (this.userSocialRatio * rowMult2) + ((1.0f - this.userSocialRatio) * (d9 > 0.0d ? d8 / d9 : 0.0d));
                            double logisticGradientValue2 = Maths.logisticGradientValue(d11) * (Maths.logistic(d11) - Maths.normalize(row2.get(i11), this.minRate, this.maxRate)) * row2.get(i11);
                            for (int i13 = 0; i13 < this.numFactors; i13++) {
                                denseMatrix.add(i9, i13, (1.0f - this.userSocialRatio) * logisticGradientValue2 * this.itemFactors.get(i11, i13));
                            }
                        }
                    }
                }
            }
            this.userFactors = this.userFactors.add(denseMatrix.scale(-this.learnRate));
            this.itemFactors = this.itemFactors.add(denseMatrix2.scale(-this.learnRate));
            this.loss *= 0.5d;
            if (isConverged(i) && this.earlyStop) {
                return;
            }
            updateLRate(i);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.MatrixFactorizationRecommender, net.librec.recommender.AbstractRecommender
    public double predict(int i, int i2) {
        double rowMult = DenseMatrix.rowMult(this.userFactors, i, this.itemFactors, i2);
        double d = 0.0d;
        double d2 = 0.0d;
        SparseVector row = this.socialMatrix.row(i);
        for (int i3 : row.getIndex()) {
            double d3 = row.get(i3);
            d += d3 * DenseMatrix.rowMult(this.userFactors, i3, this.itemFactors, i2);
            d2 += d3;
        }
        return (this.userSocialRatio * rowMult) + ((1.0f - this.userSocialRatio) * (d2 > 0.0d ? d / d2 : 0.0d));
    }
}
