/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.scoref;

import edu.stanford.nlp.scoref.CompressedFeatureVector;
import edu.stanford.nlp.scoref.Compressor;
import edu.stanford.nlp.scoref.Example;
import edu.stanford.nlp.scoref.MetaFeatureExtractor;
import edu.stanford.nlp.scoref.SimpleLinearClassifier;
import edu.stanford.nlp.scoref.StatisticalCorefTrainer;
import edu.stanford.nlp.scoref.StatisticalCorefUtils;
import edu.stanford.nlp.stats.Counter;
import java.io.File;
import java.io.PrintWriter;
import java.util.Map;

public class PairwiseModel {
    public final String name;
    private final int trainingExamples;
    private final int epochs;
    protected final SimpleLinearClassifier classifier;
    private final double singletonRatio;
    private final String str;
    protected final MetaFeatureExtractor meta;

    public static Builder newBuilder(String name, MetaFeatureExtractor meta) {
        return new Builder(name, meta);
    }

    public PairwiseModel(Builder builder) {
        this.name = builder.name;
        this.meta = builder.meta;
        this.trainingExamples = builder.trainingExamples;
        this.epochs = builder.epochs;
        this.singletonRatio = builder.singletonRatio;
        this.classifier = new SimpleLinearClassifier(builder.loss, builder.learningRateSchedule, builder.regularizationStrength, builder.modelFile == null ? null : (builder.modelFile.endsWith(".ser") || builder.modelFile.endsWith(".gz") ? builder.modelFile : StatisticalCorefTrainer.pairwiseModelsPath + builder.modelFile + "/model.ser"));
        this.str = StatisticalCorefUtils.fieldValues(builder);
    }

    public String getDefaultOutputPath() {
        return StatisticalCorefTrainer.pairwiseModelsPath + this.name + "/";
    }

    public SimpleLinearClassifier getClassifier() {
        return this.classifier;
    }

    public void writeModel() throws Exception {
        this.writeModel(this.getDefaultOutputPath());
    }

    public void writeModel(String outputPath) throws Exception {
        File outDir = new File(outputPath);
        if (!outDir.exists()) {
            outDir.mkdir();
        }
        try (PrintWriter writer = new PrintWriter(outputPath + "config", "UTF-8");){
            writer.print(this.str);
        }
        writer = new PrintWriter(outputPath + "/weights", "UTF-8");
        var4_4 = null;
        try {
            this.classifier.printWeightVector(writer);
        }
        catch (Throwable throwable) {
            var4_4 = throwable;
            throw throwable;
        }
        finally {
            if (writer != null) {
                if (var4_4 != null) {
                    try {
                        writer.close();
                    }
                    catch (Throwable throwable) {
                        var4_4.addSuppressed(throwable);
                    }
                } else {
                    writer.close();
                }
            }
        }
        this.classifier.writeWeights(outputPath + "/model.ser");
    }

    public void learn(Example example, Map<Integer, CompressedFeatureVector> mentionFeatures, Compressor<String> compressor) {
        Counter<String> features = this.meta.getFeatures(example, mentionFeatures, compressor);
        this.classifier.learn(features, example.label == 1.0 ? 1.0 : -1.0, 1.0);
    }

    public void learn(Example example, Map<Integer, CompressedFeatureVector> mentionFeatures, Compressor<String> compressor, double weight) {
        Counter<String> features = this.meta.getFeatures(example, mentionFeatures, compressor);
        this.classifier.learn(features, example.label == 1.0 ? 1.0 : -1.0, weight);
    }

    public void learn(Example correct, Example incorrect, Map<Integer, CompressedFeatureVector> mentionFeatures, Compressor<String> compressor, double weight) {
        Counter<String> cFeatures = null;
        Counter<String> iFeatures = null;
        if (correct != null) {
            cFeatures = this.meta.getFeatures(correct, mentionFeatures, compressor);
        }
        if (incorrect != null) {
            iFeatures = this.meta.getFeatures(incorrect, mentionFeatures, compressor);
        }
        if (correct == null || incorrect == null) {
            if (this.singletonRatio != 0.0) {
                if (correct != null) {
                    this.classifier.learn(cFeatures, 1.0, weight * this.singletonRatio);
                }
                if (incorrect != null) {
                    this.classifier.learn(iFeatures, -1.0, weight * this.singletonRatio);
                }
            }
        } else {
            this.classifier.learn(cFeatures, 1.0, weight);
            this.classifier.learn(iFeatures, -1.0, weight);
        }
    }

    public double predict(Example example, Map<Integer, CompressedFeatureVector> mentionFeatures, Compressor<String> compressor) {
        Counter<String> features = this.meta.getFeatures(example, mentionFeatures, compressor);
        return this.classifier.label(features);
    }

    public int getNumTrainingExamples() {
        return this.trainingExamples;
    }

    public int getNumEpochs() {
        return this.epochs;
    }

    public static class Builder {
        private final String name;
        private final MetaFeatureExtractor meta;
        private final String source = StatisticalCorefTrainer.extractedFeaturesFile;
        private int trainingExamples = 100000000;
        private int epochs = 8;
        private SimpleLinearClassifier.Loss loss = SimpleLinearClassifier.log();
        private SimpleLinearClassifier.LearningRateSchedule learningRateSchedule = SimpleLinearClassifier.adaGrad(0.05, 30.0);
        private double regularizationStrength = 1.0E-7;
        private double singletonRatio = 0.3;
        private String modelFile = null;

        public Builder(String name, MetaFeatureExtractor meta) {
            this.name = name;
            this.meta = meta;
        }

        public Builder trainingExamples(int trainingExamples) {
            this.trainingExamples = trainingExamples;
            return this;
        }

        public Builder epochs(int epochs) {
            this.epochs = epochs;
            return this;
        }

        public Builder singletonRatio(double singletonRatio) {
            this.singletonRatio = singletonRatio;
            return this;
        }

        public Builder loss(SimpleLinearClassifier.Loss loss) {
            this.loss = loss;
            return this;
        }

        public Builder regularizationStrength(double regularizationStrength) {
            this.regularizationStrength = regularizationStrength;
            return this;
        }

        public Builder learningRateSchedule(SimpleLinearClassifier.LearningRateSchedule learningRateSchedule) {
            this.learningRateSchedule = learningRateSchedule;
            return this;
        }

        public Builder modelPath(String modelFile) {
            this.modelFile = modelFile;
            return this;
        }

        public PairwiseModel build() {
            return new PairwiseModel(this);
        }
    }
}

