package net.librec.recommender.cf.ranking;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.structure.SparseMatrix;
import net.librec.recommender.MatrixFactorizationRecommender;

@ModelData({"isRanking", "climf", "userFactors", "itemFactors"})
/* loaded from: input_file:net/librec/recommender/cf/ranking/CLIMFRecommender.class */
public class CLIMFRecommender extends MatrixFactorizationRecommender {
    private List<Set<Integer>> userItemsSet;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.MatrixFactorizationRecommender, net.librec.recommender.AbstractRecommender
    public void setup() throws LibrecException {
        super.setup();
    }

    @Override // net.librec.recommender.AbstractRecommender
    protected void trainModel() throws LibrecException {
        this.userItemsSet = getUserItemsSet(this.trainMatrix);
        for (int i = 1; i <= this.numIterations; i++) {
            this.loss = 0.0d;
            for (int i2 = 0; i2 < this.numUsers; i2++) {
                Set<Integer> set = this.userItemsSet.get(i2);
                double[] dArr = new double[this.numFactors];
                for (int i3 = 0; i3 < this.numFactors; i3++) {
                    double d = (-this.regUser) * this.userFactors.get(i2, i3);
                    Iterator<Integer> it = set.iterator();
                    while (it.hasNext()) {
                        int intValue = it.next().intValue();
                        double predict = predict(i2, intValue);
                        double d2 = this.itemFactors.get(intValue, i3);
                        d += Maths.logistic(-predict) * d2;
                        Iterator<Integer> it2 = set.iterator();
                        while (it2.hasNext()) {
                            int intValue2 = it2.next().intValue();
                            if (intValue2 != intValue) {
                                double predict2 = predict(i2, intValue2);
                                double d3 = this.itemFactors.get(intValue2, i3);
                                double d4 = predict2 - predict;
                                d += (Maths.logisticGradientValue(d4) / (1.0d - Maths.logistic(d4))) * (d2 - d3);
                            }
                        }
                    }
                    dArr[i3] = d;
                }
                HashMap hashMap = new HashMap();
                Iterator<Integer> it3 = set.iterator();
                while (it3.hasNext()) {
                    int intValue3 = it3.next().intValue();
                    double predict3 = predict(i2, intValue3);
                    ArrayList arrayList = new ArrayList();
                    for (int i4 = 0; i4 < this.numFactors; i4++) {
                        double d5 = this.userFactors.get(i2, i4);
                        double logistic = ((1.0d * Maths.logistic(-predict3)) * d5) - (this.regItem * this.itemFactors.get(intValue3, i4));
                        Iterator<Integer> it4 = set.iterator();
                        while (it4.hasNext()) {
                            int intValue4 = it4.next().intValue();
                            if (intValue4 != intValue3) {
                                double predict4 = predict(i2, intValue4) - predict3;
                                logistic += Maths.logisticGradientValue(-predict4) * ((1.0d / (1.0d - Maths.logistic(predict4))) - (1.0d / (1.0d - Maths.logistic(-predict4)))) * d5;
                            }
                        }
                        arrayList.add(Double.valueOf(logistic));
                    }
                    hashMap.put(Integer.valueOf(intValue3), arrayList);
                }
                for (int i5 = 0; i5 < this.numFactors; i5++) {
                    this.userFactors.add(i2, i5, this.learnRate * dArr[i5]);
                }
                Iterator<Integer> it5 = set.iterator();
                while (it5.hasNext()) {
                    int intValue5 = it5.next().intValue();
                    List list = (List) hashMap.get(Integer.valueOf(intValue5));
                    for (int i6 = 0; i6 < this.numFactors; i6++) {
                        this.itemFactors.add(intValue5, i6, this.learnRate * ((Double) list.get(i6)).doubleValue());
                    }
                }
                for (int i7 = 0; i7 < this.numItems; i7++) {
                    if (set.contains(Integer.valueOf(i7))) {
                        double predict5 = predict(i2, i7);
                        this.loss += Math.log(Maths.logistic(predict5));
                        Iterator<Integer> it6 = set.iterator();
                        while (it6.hasNext()) {
                            this.loss += Math.log(1.0d - Maths.logistic(predict(i2, it6.next().intValue()) - predict5));
                        }
                    }
                    for (int i8 = 0; i8 < this.numFactors; i8++) {
                        double d6 = this.userFactors.get(i2, i8);
                        double d7 = this.itemFactors.get(i7, i8);
                        this.loss += (-0.5d) * ((this.regUser * d6 * d6) + (this.regItem * d7 * d7));
                    }
                }
            }
            if (isConverged(i) && this.earlyStop) {
                return;
            }
            updateLRate(i);
        }
    }

    private List<Set<Integer>> getUserItemsSet(SparseMatrix sparseMatrix) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.numUsers; i++) {
            arrayList.add(new HashSet(sparseMatrix.getColumns(i)));
        }
        return arrayList;
    }
}
