package net.librec.recommender;

import com.google.common.collect.BiMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import net.librec.common.LibrecException;
import net.librec.conf.Configuration;
import net.librec.data.DataModel;
import net.librec.eval.Measure;
import net.librec.eval.RecommenderEvaluator;
import net.librec.math.structure.SparseTensor;
import net.librec.math.structure.TensorEntry;
import net.librec.recommender.item.RecommendedItem;
import net.librec.recommender.item.RecommendedItemList;
import net.librec.recommender.item.RecommendedList;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

/* loaded from: input_file:net/librec/recommender/TensorRecommender.class */
public abstract class TensorRecommender implements Recommender {
    protected boolean isRanking;
    protected int topN;
    protected Configuration conf;
    protected RecommenderContext context;
    protected SparseTensor trainTensor;
    protected SparseTensor testTensor;
    protected SparseTensor validTensor;
    protected RecommendedList recommendedList;
    public BiMap<String, Integer> userMappingData;
    public BiMap<String, Integer> itemMappingData;
    protected boolean earlyStop;
    protected int numDimensions;
    protected int[] dimensions;
    protected int numFactors;
    protected double loss;
    protected boolean isBoldDriver;
    protected float decay;
    protected float learnRate;
    protected float maxLearnRate;
    protected int numIterations;
    protected int userDimension;
    protected int itemDimension;
    protected float reg;
    protected int numUsers;
    protected int numItems;
    protected double globalMean;
    protected final Log LOG = LogFactory.getLog(getClass());
    protected double lastLoss = 0.0d;
    protected boolean verbose = true;
    protected double maxRate = Double.MIN_NORMAL;
    protected double minRate = Double.MAX_VALUE;

    protected void setup() throws LibrecException {
        this.conf = this.context.getConf();
        this.isRanking = this.conf.getBoolean("rec.recommender.isranking");
        if (this.isRanking) {
            this.topN = this.conf.getInt("rec.recommender.ranking.topn", 5).intValue();
        }
        this.earlyStop = this.conf.getBoolean("rec.recommender.earlyStop");
        this.verbose = this.conf.getBoolean("rec.recommender.verbose", true);
        this.learnRate = this.conf.getFloat("rec.iterator.learnrate", Float.valueOf(0.01f)).floatValue();
        this.maxLearnRate = this.conf.getFloat("rec.iterator.learnrate.maximum", Float.valueOf(1000.0f)).floatValue();
        this.numFactors = this.conf.getInt("rec.factor.number", 10).intValue();
        this.reg = this.conf.getFloat("rec.tensor.regularization", Float.valueOf(0.01f)).floatValue();
        this.trainTensor = (SparseTensor) getDataModel().getTrainDataSet();
        this.testTensor = (SparseTensor) getDataModel().getTestDataSet();
        this.validTensor = (SparseTensor) getDataModel().getValidDataSet();
        int i = 0;
        double d = 0.0d;
        Iterator<TensorEntry> it = this.trainTensor.iterator();
        while (it.hasNext()) {
            double d2 = it.next().get();
            this.maxRate = this.maxRate > d2 ? this.maxRate : d2;
            this.minRate = this.minRate < d2 ? this.minRate : d2;
            i++;
            d += d2;
        }
        this.globalMean = d / i;
        this.numDimensions = this.trainTensor.numDimensions();
        this.dimensions = this.trainTensor.dimensions();
        this.userMappingData = getDataModel().getUserMappingData();
        this.itemMappingData = getDataModel().getItemMappingData();
        this.userDimension = this.trainTensor.getUserDimension();
        this.itemDimension = this.trainTensor.getItemDimension();
    }

    @Override // net.librec.recommender.Recommender
    public void recommend(RecommenderContext recommenderContext) throws LibrecException {
        this.context = recommenderContext;
        setup();
        this.LOG.info("Job Setup completed.");
        trainModel();
        this.LOG.info("Job Train completed.");
        this.recommendedList = recommend();
        this.LOG.info("Job End.");
        cleanup();
    }

    protected abstract void trainModel() throws LibrecException;

    protected RecommendedList recommend() throws LibrecException {
        if (!this.isRanking || this.topN <= 0) {
            this.recommendedList = recommendRating();
        } else {
            this.recommendedList = recommendRank();
        }
        return this.recommendedList;
    }

    protected RecommendedList recommendRank() throws LibrecException {
        this.recommendedList = new RecommendedItemList(this.numUsers - 1, this.numUsers);
        return this.recommendedList;
    }

    protected RecommendedList recommendRating() throws LibrecException {
        this.recommendedList = new RecommendedItemList(this.numUsers - 1, this.numUsers);
        Iterator<TensorEntry> it = this.testTensor.iterator();
        while (it.hasNext()) {
            TensorEntry next = it.next();
            int[] keys = next.keys();
            int key = next.key(this.userDimension);
            int key2 = next.key(this.itemDimension);
            double predict = predict(keys, true);
            if (Double.isNaN(predict)) {
                predict = this.globalMean;
            }
            this.recommendedList.addUserItemIdx(key, key2, predict);
        }
        return this.recommendedList;
    }

    protected abstract double predict(int[] iArr) throws LibrecException;

    protected double predict(int[] iArr, boolean z) throws LibrecException {
        double predict = predict(iArr);
        if (z) {
            if (predict > this.maxRate) {
                predict = this.maxRate;
            } else if (predict < this.minRate) {
                predict = this.minRate;
            }
        }
        return predict;
    }

    protected boolean isConverged(int i) throws LibrecException {
        float f = (float) (this.lastLoss - this.loss);
        if (this.verbose) {
            this.LOG.info(getClass().getSimpleName().toString() + " iter " + i + ": loss = " + this.loss + ", delta_loss = " + f);
        }
        if (Double.isNaN(this.loss) || Double.isInfinite(this.loss)) {
            throw new LibrecException("Loss = NaN or Infinity: current settings does not fit the recommender! Change the settings and try again!");
        }
        boolean z = Math.abs(this.loss) < 1.0E-5d;
        this.lastLoss = this.loss;
        return z;
    }

    protected void updateLRate(int i) {
        if (this.learnRate < 0.0d) {
            return;
        }
        if (this.isBoldDriver && i > 1) {
            this.learnRate = Math.abs(this.lastLoss) > Math.abs(this.loss) ? this.learnRate * 1.05f : this.learnRate * 0.5f;
        } else if (this.decay > 0.0f && this.decay < 1.0f) {
            this.learnRate *= this.decay;
        }
        if (this.maxLearnRate <= 0.0f || this.learnRate <= this.maxLearnRate) {
            return;
        }
        this.learnRate = this.maxLearnRate;
    }

    protected void cleanup() throws LibrecException {
    }

    @Override // net.librec.recommender.Recommender
    public double evaluate(RecommenderEvaluator recommenderEvaluator) throws LibrecException {
        return 0.0d;
    }

    @Override // net.librec.recommender.Recommender
    public Map<Measure.MeasureValue, Double> evaluateMap() throws LibrecException {
        return null;
    }

    @Override // net.librec.recommender.Recommender
    public DataModel getDataModel() {
        return this.context.getDataModel();
    }

    @Override // net.librec.recommender.Recommender
    public void loadModel(String str) {
    }

    @Override // net.librec.recommender.Recommender
    public void saveModel(String str) {
    }

    @Override // net.librec.recommender.Recommender
    public List<RecommendedItem> getRecommendedList() {
        return null;
    }

    @Override // net.librec.recommender.Recommender
    public void setContext(RecommenderContext recommenderContext) {
        this.context = recommenderContext;
    }
}
