package net.librec.recommender.cf;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import java.util.Iterator;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Gamma;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.recommender.ProbabilisticGraphicalRecommender;

/* loaded from: input_file:net/librec/recommender/cf/BUCMRecommender.class */
public class BUCMRecommender extends ProbabilisticGraphicalRecommender {
    private int[][][] topicItemRatingNum;
    private DenseMatrix userTopicNum;
    private DenseVector userNum;
    private DenseMatrix topicItemNum;
    private DenseVector topicNum;
    private double[][][] topicItemRatingSumProbs;
    private double[][][] topicItemRatingProbs;
    private DenseMatrix userTopicProbs;
    private DenseMatrix userTopicSumProbs;
    private DenseMatrix topicItemProbs;
    private DenseMatrix topicItemSumProbs;
    private DenseVector alpha;
    private DenseVector beta;
    private DenseVector gamma;
    protected Table<Integer, Integer, Integer> topics;
    protected int numTopics;
    protected int numRatingLevels;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.ProbabilisticGraphicalRecommender, net.librec.recommender.AbstractRecommender
    public void setup() throws LibrecException {
        super.setup();
        this.numTopics = this.conf.getInt("rec.pgm.topic.number", 10).intValue();
        this.numRatingLevels = this.trainMatrix.getValueSet().size();
        this.userTopicSumProbs = new DenseMatrix(this.numUsers, this.numTopics);
        this.topicItemSumProbs = new DenseMatrix(this.numTopics, this.numItems);
        this.topicItemRatingSumProbs = new double[this.numTopics][this.numItems][this.numRatingLevels];
        this.userTopicNum = new DenseMatrix(this.numUsers, this.numTopics);
        this.userNum = new DenseVector(this.numUsers);
        this.topicItemNum = new DenseMatrix(this.numTopics, this.numItems);
        this.topicNum = new DenseVector(this.numTopics);
        this.topicItemRatingNum = new int[this.numTopics][this.numItems][this.numRatingLevels];
        double doubleValue = this.conf.getDouble("rec.bucm.alpha", Double.valueOf(1.0d / this.numTopics)).doubleValue();
        this.alpha = new DenseVector(this.numTopics);
        this.alpha.setAll(doubleValue);
        double doubleValue2 = this.conf.getDouble("re.bucm.beta", Double.valueOf(1.0d / this.numItems)).doubleValue();
        this.beta = new DenseVector(this.numItems);
        this.beta.setAll(doubleValue2);
        double doubleValue3 = this.conf.getDouble("rec.bucm.gamma", Double.valueOf(1.0d / this.numTopics)).doubleValue();
        this.gamma = new DenseVector(this.numRatingLevels);
        this.gamma.setAll(doubleValue3);
        this.topics = HashBasedTable.create();
        Iterator<MatrixEntry> it = this.trainMatrix.iterator();
        while (it.hasNext()) {
            MatrixEntry next = it.next();
            int row = next.row();
            int column = next.column();
            int indexOf = ratingScale.indexOf(Double.valueOf(next.get()));
            int uniform = (int) (Randoms.uniform() * this.numTopics);
            this.topics.put(Integer.valueOf(row), Integer.valueOf(column), Integer.valueOf(uniform));
            this.userTopicNum.add(row, uniform, 1.0d);
            this.userNum.add(row, 1.0d);
            this.topicItemNum.add(uniform, column, 1.0d);
            this.topicNum.add(uniform, 1.0d);
            int[] iArr = this.topicItemRatingNum[uniform][column];
            iArr[indexOf] = iArr[indexOf] + 1;
        }
    }

    @Override // net.librec.recommender.ProbabilisticGraphicalRecommender
    protected void eStep() {
        double sum = this.alpha.sum();
        double sum2 = this.beta.sum();
        double sum3 = this.gamma.sum();
        Iterator<MatrixEntry> it = this.trainMatrix.iterator();
        while (it.hasNext()) {
            MatrixEntry next = it.next();
            int row = next.row();
            int column = next.column();
            int indexOf = ratingScale.indexOf(Double.valueOf(next.get()));
            int intValue = ((Integer) this.topics.get(Integer.valueOf(row), Integer.valueOf(column))).intValue();
            this.userTopicNum.add(row, intValue, -1.0d);
            this.userNum.add(row, -1.0d);
            this.topicItemNum.add(intValue, column, -1.0d);
            this.topicNum.add(intValue, -1.0d);
            int[] iArr = this.topicItemRatingNum[intValue][column];
            iArr[indexOf] = iArr[indexOf] - 1;
            double[] dArr = new double[this.numTopics];
            for (int i = 0; i < this.numTopics; i++) {
                dArr[i] = ((this.userTopicNum.get(row, i) + this.alpha.get(i)) / (this.userNum.get(row) + sum)) * ((this.topicItemNum.get(i, column) + this.beta.get(column)) / (this.topicNum.get(i) + sum2)) * ((this.topicItemRatingNum[i][column][indexOf] + this.gamma.get(indexOf)) / (this.topicItemNum.get(i, column) + sum3));
            }
            for (int i2 = 1; i2 < this.numTopics; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + dArr[i2 - 1];
            }
            double uniform = Randoms.uniform() * dArr[this.numTopics - 1];
            int i4 = 0;
            while (i4 < this.numTopics && uniform >= dArr[i4]) {
                i4++;
            }
            this.topics.put(Integer.valueOf(row), Integer.valueOf(column), Integer.valueOf(i4));
            this.userTopicNum.add(row, i4, 1.0d);
            this.userNum.add(row, 1.0d);
            this.topicItemNum.add(i4, column, 1.0d);
            this.topicNum.add(i4, 1.0d);
            int[] iArr2 = this.topicItemRatingNum[i4][column];
            iArr2[indexOf] = iArr2[indexOf] + 1;
        }
    }

    @Override // net.librec.recommender.ProbabilisticGraphicalRecommender
    protected void mStep() {
        double sum = this.alpha.sum();
        double sum2 = this.beta.sum();
        double sum3 = this.gamma.sum();
        for (int i = 0; i < this.numTopics; i++) {
            double d = this.alpha.get(i);
            double d2 = 0.0d;
            double d3 = 0.0d;
            for (int i2 = 0; i2 < this.numUsers; i2++) {
                d2 += Gamma.digamma(this.userTopicNum.get(i2, i) + d) - Gamma.digamma(d);
                d3 += Gamma.digamma(this.userNum.get(i2) + sum) - Gamma.digamma(sum);
            }
            if (d2 != 0.0d) {
                this.alpha.set(i, d * (d2 / d3));
            }
        }
        for (int i3 = 0; i3 < this.numItems; i3++) {
            double d4 = this.beta.get(i3);
            double d5 = 0.0d;
            double d6 = 0.0d;
            for (int i4 = 0; i4 < this.numTopics; i4++) {
                d5 += Gamma.digamma(this.topicItemNum.get(i4, i3) + d4) - Gamma.digamma(d4);
                d6 += Gamma.digamma(this.topicNum.get(i4) + sum2) - Gamma.digamma(sum2);
            }
            if (d5 != 0.0d) {
                this.beta.set(i3, d4 * (d5 / d6));
            }
        }
        for (int i5 = 0; i5 < this.numRatingLevels; i5++) {
            double d7 = this.gamma.get(i5);
            double d8 = 0.0d;
            double d9 = 0.0d;
            for (int i6 = 0; i6 < this.numItems; i6++) {
                for (int i7 = 0; i7 < this.numTopics; i7++) {
                    d8 += Gamma.digamma(this.topicItemRatingNum[i7][i6][i5] + d7) - Gamma.digamma(d7);
                    d9 += Gamma.digamma(this.topicItemNum.get(i7, i6) + sum3) - Gamma.digamma(sum3);
                }
            }
            if (d8 != 0.0d) {
                this.gamma.set(i5, d7 * (d8 / d9));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.AbstractRecommender
    public boolean isConverged(int i) {
        double d = 0.0d;
        estimateParams();
        int i2 = 0;
        Iterator<MatrixEntry> it = this.trainMatrix.iterator();
        while (it.hasNext()) {
            MatrixEntry next = it.next();
            int row = next.row();
            int column = next.column();
            int indexOf = ratingScale.indexOf(Double.valueOf(next.get()));
            double d2 = 0.0d;
            for (int i3 = 0; i3 < this.numTopics; i3++) {
                d2 += this.userTopicProbs.get(row, i3) * this.topicItemProbs.get(i3, column) * this.topicItemRatingProbs[i3][column][indexOf];
            }
            d += -Math.log(d2);
            i2++;
        }
        double d3 = d / i2;
        double d4 = d3 - this.lastLoss;
        if (this.numStats > 1 && d4 > 0.0d) {
            return true;
        }
        this.lastLoss = d3;
        return false;
    }

    @Override // net.librec.recommender.ProbabilisticGraphicalRecommender
    protected void readoutParams() {
        double sum = this.alpha.sum();
        double sum2 = this.beta.sum();
        double sum3 = this.gamma.sum();
        for (int i = 0; i < this.numUsers; i++) {
            for (int i2 = 0; i2 < this.numTopics; i2++) {
                this.userTopicSumProbs.add(i, i2, (this.userTopicNum.get(i, i2) + this.alpha.get(i2)) / (this.userNum.get(i) + sum));
            }
        }
        for (int i3 = 0; i3 < this.numTopics; i3++) {
            for (int i4 = 0; i4 < this.numItems; i4++) {
                this.topicItemSumProbs.add(i3, i4, (this.topicItemNum.get(i3, i4) + this.beta.get(i4)) / (this.topicNum.get(i3) + sum2));
            }
        }
        for (int i5 = 0; i5 < this.numTopics; i5++) {
            for (int i6 = 0; i6 < this.numItems; i6++) {
                for (int i7 = 0; i7 < this.numRatingLevels; i7++) {
                    double d = (this.topicItemRatingNum[i5][i6][i7] + this.gamma.get(i7)) / (this.topicItemNum.get(i5, i6) + sum3);
                    double[] dArr = this.topicItemRatingSumProbs[i5][i6];
                    int i8 = i7;
                    dArr[i8] = dArr[i8] + d;
                }
            }
        }
        this.numStats++;
    }

    @Override // net.librec.recommender.ProbabilisticGraphicalRecommender
    protected void estimateParams() {
        this.userTopicProbs = this.userTopicSumProbs.scale(1.0d / this.numStats);
        this.topicItemProbs = this.topicItemSumProbs.scale(1.0d / this.numStats);
        this.topicItemRatingProbs = new double[this.numTopics][this.numItems][this.numRatingLevels];
        for (int i = 0; i < this.numTopics; i++) {
            for (int i2 = 0; i2 < this.numItems; i2++) {
                for (int i3 = 0; i3 < this.numRatingLevels; i3++) {
                    this.topicItemRatingProbs[i][i2][i3] = this.topicItemRatingSumProbs[i][i2][i3] / this.numStats;
                }
            }
        }
    }

    protected double perplexity(int i, int i2, double d) throws Exception {
        int i3 = ((int) (d / this.minRate)) - 1;
        double d2 = 0.0d;
        for (int i4 = 0; i4 < this.numTopics; i4++) {
            d2 += this.userTopicProbs.get(i, i4) * this.topicItemProbs.get(i4, i2) * this.topicItemRatingProbs[i4][i2][i3];
        }
        return -Math.log(d2);
    }

    @Override // net.librec.recommender.AbstractRecommender
    protected double predict(int i, int i2) throws LibrecException {
        return this.isRanking ? predictRanking(i, i2) : predictRating(i, i2);
    }

    protected double predictRating(int i, int i2) {
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i3 = 0; i3 < this.numRatingLevels; i3++) {
            double doubleValue = ratingScale.get(i3).doubleValue();
            double d3 = 0.0d;
            for (int i4 = 0; i4 < this.numTopics; i4++) {
                d3 += this.userTopicProbs.get(i, i4) * this.topicItemProbs.get(i4, i2) * this.topicItemRatingProbs[i4][i2][i3];
            }
            d += d3 * doubleValue;
            d2 += d3;
        }
        return d / d2;
    }

    protected double predictRanking(int i, int i2) {
        double d = 0.0d;
        for (int i3 = 0; i3 < this.numTopics; i3++) {
            double d2 = 0.0d;
            for (int i4 = 0; i4 < this.numRatingLevels; i4++) {
                if (ratingScale.get(i4).doubleValue() > this.globalMean) {
                    d2 += this.topicItemRatingProbs[i3][i2][i4];
                }
            }
            d += this.userTopicProbs.get(i, i3) * this.topicItemProbs.get(i3, i2) * d2;
        }
        return d;
    }
}
