package net.librec.recommender.context.ranking;

import com.google.common.cache.LoadingCache;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutionException;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Maths;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.VectorEntry;
import net.librec.recommender.SocialRecommender;

@ModelData({"isRanking", "sbpr", "userFactors", "itemFactors", "itemBiases"})
/* loaded from: input_file:net/librec/recommender/context/ranking/SBPRRecommender.class */
public class SBPRRecommender extends SocialRecommender {
    private DenseVector itemBiases;
    protected float regBias;
    protected LoadingCache<Integer, List<Integer>> userItemsCache;
    protected static String cacheSpec;
    private List<List<Integer>> userSocialItemsSetList;

    @Override // net.librec.recommender.SocialRecommender, net.librec.recommender.MatrixFactorizationRecommender, net.librec.recommender.AbstractRecommender
    public void setup() throws LibrecException {
        super.setup();
        this.regBias = this.conf.getFloat("rec.bias.regularization", Float.valueOf(0.01f)).floatValue();
        cacheSpec = this.conf.get("guava.cache.spec", "maximumSize=5000,expireAfterAccess=50m");
        this.itemBiases = new DenseVector(this.numItems);
        this.itemBiases.init();
        this.userItemsCache = this.trainMatrix.rowColumnsCache(cacheSpec);
        this.userSocialItemsSetList = new ArrayList(this.numUsers);
        for (int i = 0; i < this.numUsers; i++) {
            this.userSocialItemsSetList.add(new ArrayList());
        }
        for (int i2 = 0; i2 < this.numUsers; i2++) {
            List list = null;
            try {
                list = (List) this.userItemsCache.get(Integer.valueOf(i2));
            } catch (ExecutionException e) {
                e.printStackTrace();
            }
            if (list.size() != 0) {
                List<Integer> columns = this.socialMatrix.getColumns(i2);
                ArrayList arrayList = new ArrayList();
                Iterator<Integer> it = columns.iterator();
                while (it.hasNext()) {
                    List list2 = null;
                    try {
                        list2 = (List) this.userItemsCache.get(Integer.valueOf(it.next().intValue()));
                    } catch (ExecutionException e2) {
                        e2.printStackTrace();
                    }
                    Iterator it2 = list2.iterator();
                    while (it2.hasNext()) {
                        int intValue = ((Integer) it2.next()).intValue();
                        if (!list.contains(Integer.valueOf(intValue)) && !arrayList.contains(Integer.valueOf(intValue))) {
                            arrayList.add(Integer.valueOf(intValue));
                        }
                    }
                }
                this.userSocialItemsSetList.set(i2, arrayList);
            }
        }
    }

    @Override // net.librec.recommender.AbstractRecommender
    protected void trainModel() throws LibrecException {
        int uniform;
        int uniform2;
        for (int i = 1; i <= this.numIterations; i++) {
            this.loss = 0.0d;
            int i2 = this.numUsers * 100;
            for (int i3 = 0; i3 < i2; i3++) {
                List list = null;
                do {
                    uniform = Randoms.uniform(this.trainMatrix.numRows());
                    try {
                        list = (List) this.userItemsCache.get(Integer.valueOf(uniform));
                    } catch (ExecutionException e) {
                        e.printStackTrace();
                    }
                } while (list.size() == 0);
                int intValue = ((Integer) Randoms.random(list)).intValue();
                double predict = predict(uniform, intValue);
                List<Integer> list2 = this.userSocialItemsSetList.get(uniform);
                while (true) {
                    uniform2 = Randoms.uniform(this.numItems);
                    if (!list.contains(Integer.valueOf(uniform2)) && !list2.contains(Integer.valueOf(uniform2))) {
                        break;
                    }
                }
                double predict2 = predict(uniform, uniform2);
                if (list2.size() > 0) {
                    int intValue2 = ((Integer) Randoms.random(list2)).intValue();
                    double predict3 = predict(uniform, intValue2);
                    double d = 0.0d;
                    Iterator<VectorEntry> it = this.socialMatrix.row(uniform).iterator();
                    while (it.hasNext()) {
                        int index = it.next().index();
                        if (index < this.trainMatrix.numRows() && this.trainMatrix.get(index, intValue2) > 0.0d) {
                            d += 1.0d;
                        }
                    }
                    double d2 = (predict - predict3) / (1.0d + d);
                    double d3 = predict3 - predict2;
                    this.loss += (-Math.log(Maths.logistic(d2))) - Math.log(Maths.logistic(d3));
                    double logistic = Maths.logistic(-d2);
                    double logistic2 = Maths.logistic(-d3);
                    double d4 = this.itemBiases.get(intValue);
                    this.itemBiases.add(intValue, this.learnRate * ((logistic / (1.0d + d)) - (this.regBias * d4)));
                    this.loss += this.regBias * d4 * d4;
                    double d5 = this.itemBiases.get(intValue2);
                    this.itemBiases.add(intValue2, this.learnRate * ((((-logistic) / (1.0d + d)) + logistic2) - (this.regBias * d5)));
                    this.loss += this.regBias * d5 * d5;
                    double d6 = this.itemBiases.get(uniform2);
                    this.itemBiases.add(uniform2, this.learnRate * ((-logistic2) - (this.regBias * d6)));
                    this.loss += this.regBias * d6 * d6;
                    for (int i4 = 0; i4 < this.numFactors; i4++) {
                        double d7 = this.userFactors.get(uniform, i4);
                        double d8 = this.itemFactors.get(intValue, i4);
                        double d9 = this.itemFactors.get(intValue2, i4);
                        double d10 = this.itemFactors.get(uniform2, i4);
                        this.userFactors.add(uniform, i4, this.learnRate * ((((logistic * (d8 - d9)) / (1.0d + d)) + (logistic2 * (d9 - d10))) - (this.regUser * d7)));
                        this.itemFactors.add(intValue, i4, this.learnRate * (((logistic * d7) / (1.0d + d)) - (this.regItem * d8)));
                        this.itemFactors.add(intValue2, i4, this.learnRate * (((logistic * ((-d7) / (1.0d + d))) + (logistic2 * d7)) - (this.regItem * d9)));
                        this.itemFactors.add(uniform2, i4, this.learnRate * ((logistic2 * (-d7)) - (this.regItem * d10)));
                        this.loss += (this.regUser * d7 * d7) + (this.regItem * d8 * d8) + (this.regItem * d10 * d10) + (this.regItem * d9 * d9);
                    }
                } else {
                    double d11 = predict - predict2;
                    this.loss += d11;
                    double logistic3 = Maths.logistic(-d11);
                    double d12 = this.itemBiases.get(intValue);
                    this.itemBiases.add(intValue, this.learnRate * (logistic3 - (this.regBias * d12)));
                    this.loss += this.regBias * d12 * d12;
                    double d13 = this.itemBiases.get(uniform2);
                    this.itemBiases.add(uniform2, this.learnRate * ((-logistic3) - (this.regBias * d13)));
                    this.loss += this.regBias * d13 * d13;
                    for (int i5 = 0; i5 < this.numFactors; i5++) {
                        double d14 = this.userFactors.get(uniform, i5);
                        double d15 = this.itemFactors.get(intValue, i5);
                        double d16 = this.itemFactors.get(uniform2, i5);
                        this.userFactors.add(uniform, i5, this.learnRate * ((logistic3 * (d15 - d16)) - (this.regUser * d14)));
                        this.itemFactors.add(intValue, i5, this.learnRate * ((logistic3 * d14) - (this.regItem * d15)));
                        this.itemFactors.add(uniform2, i5, this.learnRate * ((logistic3 * (-d14)) - (this.regItem * d16)));
                        this.loss += (this.regUser * d14 * d14) + (this.regItem * d15 * d15) + (this.regItem * d16 * d16);
                    }
                }
            }
            if (isConverged(i) && this.earlyStop) {
                return;
            }
            updateLRate(i);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.MatrixFactorizationRecommender, net.librec.recommender.AbstractRecommender
    public double predict(int i, int i2) throws LibrecException {
        return this.itemBiases.get(i2) + DenseMatrix.rowMult(this.userFactors, i, this.itemFactors, i2);
    }
}
