package net.librec.recommender.cf.ranking;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SparseMatrix;
import net.librec.recommender.MatrixFactorizationRecommender;
import net.librec.util.Lists;

@ModelData({"isRanking", "ranksgd", "userFactors", "itemFactors", "trainMatrix"})
/* loaded from: input_file:net/librec/recommender/cf/ranking/RankSGDRecommender.class */
public class RankSGDRecommender extends MatrixFactorizationRecommender {
    protected List<Map.Entry<Integer, Double>> itemProbs;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.MatrixFactorizationRecommender, net.librec.recommender.AbstractRecommender
    public void setup() throws LibrecException {
        super.setup();
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.numItems; i++) {
            double columnSize = (this.trainMatrix.columnSize(i) + 0.0d) / this.numRates;
            if (columnSize > 0.0d) {
                hashMap.put(Integer.valueOf(i), Double.valueOf(columnSize));
            }
        }
        this.itemProbs = Lists.sortMap(hashMap);
    }

    @Override // net.librec.recommender.AbstractRecommender
    protected void trainModel() throws LibrecException {
        List<Set<Integer>> userItemsSet = getUserItemsSet(this.trainMatrix);
        for (int i = 1; i <= this.numIterations; i++) {
            this.loss = 0.0d;
            Iterator<MatrixEntry> it = this.trainMatrix.iterator();
            while (it.hasNext()) {
                MatrixEntry next = it.next();
                int row = next.row();
                int column = next.column();
                double d = next.get();
                int i2 = -1;
                do {
                    double d2 = 0.0d;
                    double random = Randoms.random();
                    Iterator<Map.Entry<Integer, Double>> it2 = this.itemProbs.iterator();
                    while (true) {
                        if (!it2.hasNext()) {
                            break;
                        }
                        Map.Entry<Integer, Double> next2 = it2.next();
                        int intValue = next2.getKey().intValue();
                        d2 += next2.getValue().doubleValue();
                        if (d2 >= random) {
                            i2 = intValue;
                            break;
                        }
                    }
                } while (userItemsSet.get(row).contains(Integer.valueOf(i2)));
                double predict = (predict(row, column) - predict(row, i2)) - (d - 0.0d);
                this.loss += predict * predict;
                double d3 = this.learnRate * predict;
                for (int i3 = 0; i3 < this.numFactors; i3++) {
                    double d4 = this.userFactors.get(row, i3);
                    this.userFactors.add(row, i3, (-d3) * (this.itemFactors.get(column, i3) - this.itemFactors.get(i2, i3)));
                    this.itemFactors.add(column, i3, (-d3) * d4);
                    this.itemFactors.add(i2, i3, d3 * d4);
                }
            }
            this.loss *= 0.5d;
            if (isConverged(i) && this.earlyStop) {
                return;
            }
            updateLRate(i);
        }
    }

    private List<Set<Integer>> getUserItemsSet(SparseMatrix sparseMatrix) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.numUsers; i++) {
            arrayList.add(new HashSet(sparseMatrix.getColumns(i)));
        }
        return arrayList;
    }
}
