package net.librec.recommender.cf.ranking;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.SparseMatrix;
import net.librec.recommender.MatrixFactorizationRecommender;

@ModelData({"isRanking", "bpr", "userFactors", "itemFactors"})
/* loaded from: input_file:net/librec/recommender/cf/ranking/BPRRecommender.class */
public class BPRRecommender extends MatrixFactorizationRecommender {
    private List<Set<Integer>> userItemsSet;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.MatrixFactorizationRecommender, net.librec.recommender.AbstractRecommender
    public void setup() throws LibrecException {
        super.setup();
    }

    @Override // net.librec.recommender.AbstractRecommender
    protected void trainModel() throws LibrecException {
        int uniform;
        Set<Integer> set;
        int uniform2;
        this.userItemsSet = getUserItemsSet(this.trainMatrix);
        for (int i = 1; i <= this.numIterations; i++) {
            this.loss = 0.0d;
            int i2 = this.numUsers * 100;
            for (int i3 = 0; i3 < i2; i3++) {
                while (true) {
                    uniform = Randoms.uniform(this.numUsers);
                    set = this.userItemsSet.get(uniform);
                    if (set.size() != 0 && set.size() != this.numItems) {
                        break;
                    }
                }
                List<Integer> columns = this.trainMatrix.getColumns(uniform);
                int intValue = columns.get(Randoms.uniform(columns.size())).intValue();
                do {
                    uniform2 = Randoms.uniform(this.numItems);
                } while (set.contains(Integer.valueOf(uniform2)));
                double predict = predict(uniform, intValue) - predict(uniform, uniform2);
                this.loss += -Math.log(Maths.logistic(predict));
                double logistic = Maths.logistic(-predict);
                for (int i4 = 0; i4 < this.numFactors; i4++) {
                    double d = this.userFactors.get(uniform, i4);
                    double d2 = this.itemFactors.get(intValue, i4);
                    double d3 = this.itemFactors.get(uniform2, i4);
                    this.userFactors.add(uniform, i4, this.learnRate * ((logistic * (d2 - d3)) - (this.regUser * d)));
                    this.itemFactors.add(intValue, i4, this.learnRate * ((logistic * d) - (this.regItem * d2)));
                    this.itemFactors.add(uniform2, i4, this.learnRate * ((logistic * (-d)) - (this.regItem * d3)));
                    this.loss += (this.regUser * d * d) + (this.regItem * d2 * d2) + (this.regItem * d3 * d3);
                }
            }
            if (isConverged(i) && this.earlyStop) {
                return;
            }
            updateLRate(i);
        }
    }

    private List<Set<Integer>> getUserItemsSet(SparseMatrix sparseMatrix) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.numUsers; i++) {
            arrayList.add(new HashSet(sparseMatrix.getColumns(i)));
        }
        return arrayList;
    }
}
