package net.librec.recommender.context.rating;

import com.google.common.cache.LoadingCache;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import net.librec.annotation.ModelData;
import net.librec.common.LibrecException;
import net.librec.math.algorithm.Randoms;
import net.librec.math.structure.DenseMatrix;
import net.librec.math.structure.DenseVector;
import net.librec.math.structure.MatrixEntry;
import net.librec.math.structure.SparseMatrix;
import net.librec.recommender.cf.rating.BiasedMFRecommender;

@ModelData({"isRating", "timesvd", "userFactors", "itemFactors", "userBiases", "itemBiases", "trainMatrix", "timeMatrix"})
/* loaded from: input_file:net/librec/recommender/context/rating/TimeSVDRecommender.class */
public class TimeSVDRecommender extends BiasedMFRecommender {
    private static int numDays;
    private DenseVector userMeanDate;
    private float beta;
    private int numBins;
    private DenseMatrix Y;
    private DenseMatrix Bit;
    private Table<Integer, Integer, Double> But;
    private DenseVector Alpha;
    private DenseMatrix Auk;
    private Map<Integer, Table<Integer, Integer, Double>> Pukt;
    private DenseVector Cu;
    private DenseMatrix Cut;
    private static long minTimestamp;
    private static long maxTimestamp;
    protected static String cacheSpec;
    private LoadingCache<Integer, List<Integer>> userItemsCache;
    private static SparseMatrix timeMatrix;
    protected DenseMatrix Q;
    protected DenseMatrix P;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.cf.rating.BiasedMFRecommender, net.librec.recommender.MatrixFactorizationRecommender, net.librec.recommender.AbstractRecommender
    public void setup() throws LibrecException {
        super.setup();
        this.beta = this.conf.getFloat("rec.learnrate.decay", Float.valueOf(0.015f)).floatValue();
        this.numBins = this.conf.getInt("rec.numBins", 6).intValue();
        timeMatrix = (SparseMatrix) getDataModel().getDatetimeDataSet();
        getMaxAndMinTimeStamp();
        numDays = days(maxTimestamp, minTimestamp) + 1;
        this.userBiases = new DenseVector(this.numUsers);
        this.userBiases.init();
        this.itemBiases = new DenseVector(this.numItems);
        this.itemBiases.init();
        this.Alpha = new DenseVector(this.numUsers);
        this.Alpha.init();
        this.Bit = new DenseMatrix(this.numItems, this.numBins);
        this.Bit.init();
        this.Y = new DenseMatrix(this.numItems, this.numFactors);
        this.Y.init();
        this.Auk = new DenseMatrix(this.numUsers, this.numFactors);
        this.Auk.init();
        this.But = HashBasedTable.create();
        this.Pukt = new HashMap();
        this.Cu = new DenseVector(this.numUsers);
        this.Cu.init();
        this.Cut = new DenseMatrix(this.numUsers, numDays);
        this.Cut.init();
        cacheSpec = this.conf.get("guava.cache.spec", "maximumSize=200,expireAfterAccess=2m");
        this.userItemsCache = this.trainMatrix.rowColumnsCache(cacheSpec);
        this.P = new DenseMatrix(this.numUsers, this.numFactors);
        this.Q = new DenseMatrix(this.numItems, this.numFactors);
        this.P.init();
        this.Q.init();
        double d = 0.0d;
        int i = 0;
        Iterator<MatrixEntry> it = this.trainMatrix.iterator();
        while (it.hasNext()) {
            MatrixEntry next = it.next();
            int row = next.row();
            int column = next.column();
            if (next.get() > 0.0d) {
                d += days((long) timeMatrix.get(row, column), minTimestamp);
                i++;
            }
        }
        double d2 = d / i;
        this.userMeanDate = new DenseVector(this.numUsers);
        List list = null;
        for (int i2 = 0; i2 < this.numUsers; i2++) {
            double d3 = 0.0d;
            try {
                list = (List) this.userItemsCache.get(Integer.valueOf(i2));
            } catch (ExecutionException e) {
                e.printStackTrace();
            }
            while (list.iterator().hasNext()) {
                d3 += days((long) timeMatrix.get(i2, ((Integer) r0.next()).intValue()), minTimestamp);
            }
            this.userMeanDate.set(i2, list.size() > 0 ? (d3 + 0.0d) / list.size() : d2);
        }
    }

    @Override // net.librec.recommender.cf.rating.BiasedMFRecommender, net.librec.recommender.AbstractRecommender
    protected void trainModel() throws LibrecException {
        for (int i = 1; i <= this.numIterations; i++) {
            this.loss = 0.0d;
            Iterator<MatrixEntry> it = this.trainMatrix.iterator();
            while (it.hasNext()) {
                MatrixEntry next = it.next();
                int row = next.row();
                int column = next.column();
                double d = next.get();
                int days = days((long) timeMatrix.get(row, column), minTimestamp);
                int bin = bin(days);
                double dev = dev(row, days);
                double d2 = this.itemBiases.get(column);
                double d3 = this.Bit.get(column, bin);
                double d4 = this.userBiases.get(row);
                double d5 = this.Cu.get(row);
                double d6 = this.Cut.get(row, days);
                if (!this.But.contains(Integer.valueOf(row), Integer.valueOf(days))) {
                    this.But.put(Integer.valueOf(row), Integer.valueOf(days), Double.valueOf(Randoms.random()));
                }
                double doubleValue = ((Double) this.But.get(Integer.valueOf(row), Integer.valueOf(days))).doubleValue();
                double d7 = this.Alpha.get(row);
                double d8 = this.globalMean + ((d2 + d3) * (d5 + d6)) + d4 + (d7 * dev) + doubleValue;
                List list = null;
                try {
                    list = (List) this.userItemsCache.get(Integer.valueOf(row));
                } catch (ExecutionException e) {
                    e.printStackTrace();
                }
                double d9 = 0.0d;
                Iterator it2 = list.iterator();
                while (it2.hasNext()) {
                    d9 += DenseMatrix.rowMult(this.Y, ((Integer) it2.next()).intValue(), this.Q, column);
                }
                double pow = list.size() > 0 ? Math.pow(list.size(), -0.5d) : 0.0d;
                double d10 = d8 + (d9 * pow);
                if (!this.Pukt.containsKey(Integer.valueOf(row))) {
                    this.Pukt.put(Integer.valueOf(row), HashBasedTable.create());
                }
                Table<Integer, Integer, Double> table = this.Pukt.get(Integer.valueOf(row));
                for (int i2 = 0; i2 < this.numFactors; i2++) {
                    double d11 = this.Q.get(column, i2);
                    if (!table.contains(Integer.valueOf(i2), Integer.valueOf(days))) {
                        table.put(Integer.valueOf(i2), Integer.valueOf(days), Double.valueOf(Randoms.random()));
                    }
                    d10 += (this.P.get(row, i2) + (this.Auk.get(row, i2) * dev) + ((Double) table.get(Integer.valueOf(i2), Integer.valueOf(days))).doubleValue()) * d11;
                }
                double d12 = d10 - d;
                this.loss += d12 * d12;
                this.itemBiases.add(column, (-this.learnRate) * ((d12 * (d5 + d6)) + (this.regBias * d2)));
                this.loss += this.regBias * d2 * d2;
                this.Bit.add(column, bin, (-this.learnRate) * ((d12 * (d5 + d6)) + (this.regBias * d3)));
                this.loss += this.regBias * d3 * d3;
                this.Cu.add(row, (-this.learnRate) * ((d12 * (d2 + d3)) + (this.regBias * d5)));
                this.loss += this.regBias * d5 * d5;
                this.Cut.add(row, days, (-this.learnRate) * ((d12 * (d2 + d3)) + (this.regBias * d6)));
                this.loss += this.regBias * d6 * d6;
                this.userBiases.add(row, (-this.learnRate) * (d12 + (this.regBias * d4)));
                this.loss += this.regBias * d4 * d4;
                this.Alpha.add(row, (-this.learnRate) * ((d12 * dev) + (this.regBias * d7)));
                this.loss += this.regBias * d7 * d7;
                this.But.put(Integer.valueOf(row), Integer.valueOf(days), Double.valueOf(doubleValue - (this.learnRate * (d12 + (this.regBias * doubleValue)))));
                this.loss += this.regBias * doubleValue * doubleValue;
                for (int i3 = 0; i3 < this.numFactors; i3++) {
                    double d13 = this.Q.get(column, i3);
                    double d14 = this.P.get(row, i3);
                    double d15 = this.Auk.get(row, i3);
                    double doubleValue2 = ((Double) table.get(Integer.valueOf(i3), Integer.valueOf(days))).doubleValue();
                    double d16 = d14 + (d15 * dev) + doubleValue2;
                    double d17 = 0.0d;
                    Iterator it3 = list.iterator();
                    while (it3.hasNext()) {
                        d17 += this.Y.get(((Integer) it3.next()).intValue(), i3);
                    }
                    this.Q.add(column, i3, (-this.learnRate) * ((d12 * (d16 + (pow * d17))) + (this.regItem * d13)));
                    this.loss += this.regItem * d13 * d13;
                    this.P.add(row, i3, (-this.learnRate) * ((d12 * d13) + (this.regUser * d14)));
                    this.loss += this.regUser * d14 * d14;
                    this.Auk.add(row, i3, (-this.learnRate) * ((d12 * d13 * dev) + (this.regUser * d15)));
                    this.loss += this.regUser * d15 * d15;
                    table.put(Integer.valueOf(i3), Integer.valueOf(days), Double.valueOf(doubleValue2 - (this.learnRate * ((d12 * d13) + (this.regUser * doubleValue2)))));
                    this.loss += this.regUser * doubleValue2 * doubleValue2;
                    Iterator it4 = list.iterator();
                    while (it4.hasNext()) {
                        int intValue = ((Integer) it4.next()).intValue();
                        double d18 = this.Y.get(intValue, i3);
                        this.Y.add(intValue, i3, (-this.learnRate) * ((d12 * pow * d13) + (this.regItem * d18)));
                        this.loss += this.regItem * d18 * d18;
                    }
                }
            }
            this.loss *= 0.5d;
            if (isConverged(i)) {
                return;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // net.librec.recommender.cf.rating.BiasedMFRecommender, net.librec.recommender.MatrixFactorizationRecommender, net.librec.recommender.AbstractRecommender
    public double predict(int i, int i2) throws LibrecException {
        Table<Integer, Integer, Double> table;
        int days = days((long) timeMatrix.get(i, i2), minTimestamp);
        int bin = bin(days);
        double dev = dev(i, days);
        double doubleValue = this.globalMean + ((this.itemBiases.get(i2) + this.Bit.get(i2, bin)) * (this.Cu.get(i) + this.Cut.get(i, days))) + this.userBiases.get(i) + (this.Alpha.get(i) * dev) + (this.But.contains(Integer.valueOf(i), Integer.valueOf(days)) ? ((Double) this.But.get(Integer.valueOf(i), Integer.valueOf(days))).doubleValue() : 0.0d);
        List list = null;
        try {
            list = (List) this.userItemsCache.get(Integer.valueOf(i));
        } catch (ExecutionException e) {
            e.printStackTrace();
        }
        double d = 0.0d;
        Iterator it = list.iterator();
        while (it.hasNext()) {
            d += DenseMatrix.rowMult(this.Y, ((Integer) it.next()).intValue(), this.Q, i2);
        }
        double pow = doubleValue + (d * (list.size() > 0 ? Math.pow(list.size(), -0.5d) : 0.0d));
        for (int i3 = 0; i3 < this.numFactors; i3++) {
            double d2 = this.Q.get(i2, i3);
            double d3 = this.P.get(i, i3) + (this.Auk.get(i, i3) * dev);
            if (this.Pukt.containsKey(Integer.valueOf(i)) && (table = this.Pukt.get(Integer.valueOf(i))) != null) {
                d3 += table.contains(Integer.valueOf(i3), Integer.valueOf(days)) ? ((Double) table.get(Integer.valueOf(i3), Integer.valueOf(days))).doubleValue() : 0.0d;
            }
            pow += d3 * d2;
        }
        return pow;
    }

    private double dev(int i, int i2) {
        double d = i2 - this.userMeanDate.get(i);
        return Math.signum(d) * Math.pow(Math.abs(d), this.beta);
    }

    private int bin(int i) {
        return (int) ((i / (numDays + 0.0d)) * this.numBins);
    }

    private static int days(long j) {
        return (int) TimeUnit.MILLISECONDS.toDays(j);
    }

    private static int days(long j, long j2) {
        return days(Math.abs(j - j2));
    }

    private void getMaxAndMinTimeStamp() {
        minTimestamp = Long.MAX_VALUE;
        maxTimestamp = Long.MIN_VALUE;
        Iterator<MatrixEntry> it = timeMatrix.iterator();
        while (it.hasNext()) {
            long j = (long) it.next().get();
            if (j < minTimestamp) {
                minTimestamp = j;
            }
            if (j > maxTimestamp) {
                maxTimestamp = j;
            }
        }
    }
}
