/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.scoref;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.PrintWriter;
import java.util.Map;
import java.util.SortedMap;
import java.util.TreeMap;

public class SimpleLinearClassifier {
    private final Loss defaultLoss;
    private final LearningRateSchedule learningRateSchedule;
    private final double regularizationStrength;
    private final Counter<String> weights;
    private final Counter<String> accessTimes;
    private int examplesSeen;

    public SimpleLinearClassifier(Loss loss, LearningRateSchedule learningRateSchedule, double regularizationStrength) {
        this(loss, learningRateSchedule, regularizationStrength, null);
    }

    public SimpleLinearClassifier(Loss loss, LearningRateSchedule learningRateSchedule, double regularizationStrength, String modelFile) {
        if (modelFile != null) {
            try {
                this.weights = (Counter)IOUtils.readObjectFromURLOrClasspathOrFileSystem(modelFile);
            }
            catch (Exception e) {
                throw new RuntimeException("Error leading weights from " + modelFile, e);
            }
        } else {
            this.weights = new ClassicCounter<String>();
        }
        this.defaultLoss = loss;
        this.regularizationStrength = regularizationStrength;
        this.learningRateSchedule = learningRateSchedule;
        this.accessTimes = new ClassicCounter<String>();
        this.examplesSeen = 0;
    }

    public void learn(Counter<String> features, double label, double weight) {
        this.learn(features, label, weight, this.defaultLoss);
    }

    public void learn(Counter<String> features, double label, double weight, Loss loss) {
        ++this.examplesSeen;
        double dloss = loss.derivative(label, this.weightFeatureProduct(features));
        for (Map.Entry<String, Double> feature : features.entrySet()) {
            double dfeature = weight * (-dloss * feature.getValue());
            if (dfeature == 0.0) continue;
            String featureName = feature.getKey();
            this.learningRateSchedule.update(featureName, dfeature);
            double lr = this.learningRateSchedule.getLearningRate(featureName);
            double w = this.weights.getCount(featureName);
            double dreg = weight * this.regularizationStrength * ((double)this.examplesSeen - this.accessTimes.getCount(featureName));
            double afterReg = w - Math.signum(w) * dreg * lr;
            this.weights.setCount(featureName, (Math.signum(afterReg) != Math.signum(w) ? 0.0 : afterReg) + dfeature * lr);
            this.accessTimes.setCount(featureName, this.examplesSeen);
        }
    }

    public double label(Counter<String> features) {
        return this.defaultLoss.predict(this.weightFeatureProduct(features));
    }

    public double weightFeatureProduct(Counter<String> features) {
        double product = 0.0;
        for (Map.Entry<String, Double> feature : features.entrySet()) {
            product += feature.getValue() * this.weights.getCount(feature.getKey());
        }
        return product;
    }

    public void setWeight(String featureName, double weight) {
        this.weights.setCount(featureName, weight);
    }

    public SortedMap<String, Double> getWeightVector() {
        TreeMap<String, Double> m = new TreeMap<String, Double>((f1, f2) -> {
            double weightDifference = Math.abs(this.weights.getCount(f2)) - Math.abs(this.weights.getCount(f1));
            return weightDifference == 0.0 ? f1.compareTo((String)f2) : (int)Math.signum(weightDifference);
        });
        this.weights.entrySet().stream().forEach(e -> {
            Double cfr_ignored_0 = (Double)m.put((String)e.getKey(), (Double)e.getValue());
        });
        return m;
    }

    public void printWeightVector() {
        this.printWeightVector(null);
    }

    public void printWeightVector(PrintWriter writer) {
        SortedMap<String, Double> sortedWeights = this.getWeightVector();
        for (Map.Entry<String, Double> e : sortedWeights.entrySet()) {
            if (writer == null) {
                Redwood.log("scoref.train", e.getKey() + " => " + e.getValue());
                continue;
            }
            writer.println(e.getKey() + " => " + e.getValue());
        }
    }

    public void writeWeights(String fname) throws Exception {
        IOUtils.writeObjectToFile(this.weights, fname);
    }

    public static Loss log() {
        return new Loss(){

            @Override
            public double predict(double product) {
                return 1.0 - 1.0 / (1.0 + Math.exp(product));
            }

            @Override
            public double derivative(double label, double product) {
                return -label / (1.0 + Math.exp(label * product));
            }

            public String toString() {
                return "log";
            }
        };
    }

    public static Loss quadraticallySmoothedSVM(final double gamma) {
        return new Loss(){

            @Override
            public double predict(double product) {
                return product;
            }

            @Override
            public double derivative(double label, double product) {
                double mistake = label * product;
                return mistake >= 1.0 ? 0.0 : (mistake >= 1.0 - gamma ? (mistake - 1.0) * label / gamma : -label);
            }

            public String toString() {
                return String.format("quadraticallySmoothed(%s)", gamma);
            }
        };
    }

    public static Loss hinge() {
        return SimpleLinearClassifier.quadraticallySmoothedSVM(0.0);
    }

    public static Loss maxMargin(final double h) {
        return new Loss(){

            @Override
            public double predict(double product) {
                throw new UnsupportedOperationException("Predict not implemented for max margin");
            }

            @Override
            public double derivative(double label, double product) {
                return product < -h ? 0.0 : 1.0;
            }

            public String toString() {
                return String.format("max-margin(%s)", h);
            }
        };
    }

    public static Loss risk() {
        return new Loss(){

            @Override
            public double predict(double product) {
                return 1.0 / (1.0 + Math.exp(product));
            }

            @Override
            public double derivative(double label, double product) {
                return -Math.exp(product) / Math.pow(1.0 + Math.exp(product), 2.0);
            }

            public String toString() {
                return String.format("risk", new Object[0]);
            }
        };
    }

    public static LearningRateSchedule constant(final double eta) {
        return new LearningRateSchedule(){

            @Override
            public double getLearningRate(String feature) {
                return eta;
            }

            @Override
            public void update(String feature, double gradient) {
            }

            public String toString() {
                return String.format("constant(%s)", eta);
            }
        };
    }

    public static LearningRateSchedule invScaling(final double eta, final double p) {
        return new CountBasedLearningRate(){

            @Override
            public double getCounterIncrement(double gradient) {
                return 1.0;
            }

            @Override
            public double getLearningRate(double count) {
                return eta / Math.pow(1.0 + count, p);
            }

            public String toString() {
                return String.format("invScaling(%s, %s)", eta, p);
            }
        };
    }

    public static LearningRateSchedule adaGrad(final double eta, final double tau) {
        return new CountBasedLearningRate(){

            @Override
            public double getCounterIncrement(double gradient) {
                return gradient * gradient;
            }

            @Override
            public double getLearningRate(double count) {
                return eta / (tau + Math.sqrt(count));
            }

            public String toString() {
                return String.format("adaGrad(%s, %s)", eta, tau);
            }
        };
    }

    private static abstract class CountBasedLearningRate
    implements LearningRateSchedule {
        private final Counter<String> counter = new ClassicCounter<String>();

        @Override
        public void update(String feature, double gradient) {
            this.counter.incrementCount(feature, this.getCounterIncrement(gradient));
        }

        @Override
        public double getLearningRate(String feature) {
            return this.getLearningRate(this.counter.getCount(feature));
        }

        public abstract double getCounterIncrement(double var1);

        public abstract double getLearningRate(double var1);
    }

    public static interface LearningRateSchedule {
        public void update(String var1, double var2);

        public double getLearningRate(String var1);
    }

    public static interface Loss {
        public double predict(double var1);

        public double derivative(double var1, double var3);
    }
}

