package weka.classifiers.functions;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.UpdateableBatchProcessor;
import weka.classifiers.UpdateableClassifier;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.core.Aggregateable;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TestInstances;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;
import weka.core.json.JSONInstances;
import weka.core.stemmers.NullStemmer;
import weka.core.stemmers.Stemmer;
import weka.core.stopwords.Null;
import weka.core.stopwords.StopwordsHandler;
import weka.core.tokenizers.Tokenizer;
import weka.core.tokenizers.WordTokenizer;
import weka.gui.knowledgeflow.KnowledgeFlowApp;

/* loaded from: input_file:weka/classifiers/functions/SGDText.class */
public class SGDText extends RandomizableClassifier implements UpdateableClassifier, UpdateableBatchProcessor, WeightedInstancesHandler, Aggregateable<SGDText> {
    private static final long serialVersionUID = 7200171484002029584L;
    protected LinkedHashMap<String, Count> m_dictionary;
    protected boolean m_lowercaseTokens;
    protected double m_t;
    protected double m_bias;
    protected double m_numInstances;
    protected Instances m_data;
    protected transient LinkedHashMap<String, Count> m_inputVector;
    public static final int HINGE = 0;
    public static final int LOGLOSS = 1;
    public static final Tag[] TAGS_SELECTION = {new Tag(0, "Hinge loss (SVM)"), new Tag(1, "Log loss (logistic regression)")};
    protected SGD m_svmProbs;
    protected Instances m_fitLogisticStructure;
    protected int m_periodicP = 0;
    protected double m_minWordP = 3.0d;
    protected double m_minAbsCoefficient = 0.001d;
    protected boolean m_wordFrequencies = false;
    protected boolean m_normalize = false;
    protected double m_norm = 1.0d;
    protected double m_lnorm = 2.0d;
    protected StopwordsHandler m_StopwordsHandler = new Null();
    protected Tokenizer m_tokenizer = new WordTokenizer();
    protected Stemmer m_stemmer = new NullStemmer();
    protected double m_lambda = 1.0E-4d;
    protected double m_learningRate = 0.01d;
    protected int m_epochs = 500;
    protected int m_loss = 0;
    protected boolean m_fitLogistic = false;
    protected int m_numModels = 0;

    /* loaded from: input_file:weka/classifiers/functions/SGDText$Count.class */
    public static class Count implements Serializable {
        private static final long serialVersionUID = 2104201532017340967L;
        public double m_count;
        public double m_weight;

        public Count(double d) {
            this.m_count = d;
        }
    }

    protected double dloss(double d) {
        if (this.m_loss == 0) {
            if (d < 1.0d) {
                return 1.0d;
            }
            return KStarConstants.FLOOR;
        }
        if (d < KStarConstants.FLOOR) {
            return 1.0d / (Math.exp(d) + 1.0d);
        }
        double exp = Math.exp(-d);
        return exp / (exp + 1.0d);
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.STRING_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        capabilities.setMinimumNumberInstances(0);
        return capabilities;
    }

    public void setStemmer(Stemmer stemmer) {
        if (stemmer != null) {
            this.m_stemmer = stemmer;
        } else {
            this.m_stemmer = new NullStemmer();
        }
    }

    public Stemmer getStemmer() {
        return this.m_stemmer;
    }

    public String stemmerTipText() {
        return "The stemming algorithm to use on the words.";
    }

    public void setTokenizer(Tokenizer tokenizer) {
        this.m_tokenizer = tokenizer;
    }

    public Tokenizer getTokenizer() {
        return this.m_tokenizer;
    }

    public String tokenizerTipText() {
        return "The tokenizing algorithm to use on the strings.";
    }

    public String useWordFrequenciesTipText() {
        return "Use word frequencies rather than binary bag of words representation";
    }

    public void setUseWordFrequencies(boolean z) {
        this.m_wordFrequencies = z;
    }

    public boolean getUseWordFrequencies() {
        return this.m_wordFrequencies;
    }

    public String lowercaseTokensTipText() {
        return "Whether to convert all tokens to lowercase";
    }

    public void setLowercaseTokens(boolean z) {
        this.m_lowercaseTokens = z;
    }

    public boolean getLowercaseTokens() {
        return this.m_lowercaseTokens;
    }

    public void setStopwordsHandler(StopwordsHandler stopwordsHandler) {
        if (stopwordsHandler != null) {
            this.m_StopwordsHandler = stopwordsHandler;
        } else {
            this.m_StopwordsHandler = new Null();
        }
    }

    public StopwordsHandler getStopwordsHandler() {
        return this.m_StopwordsHandler;
    }

    public String stopwordsHandlerTipText() {
        return "The stopwords handler to use (Null means no stopwords are used).";
    }

    public String periodicPruningTipText() {
        return "How often (number of instances) to prune the dictionary of low frequency terms. 0 means don't prune. Setting a positive integer n means prune after every n instances";
    }

    public void setPeriodicPruning(int i) {
        this.m_periodicP = i;
    }

    public int getPeriodicPruning() {
        return this.m_periodicP;
    }

    public String minWordFrequencyTipText() {
        return "Ignore any words that don't occur at least min frequency times in the training data. If periodic pruning is turned on, then the dictionary is pruned according to this value";
    }

    public void setMinWordFrequency(double d) {
        this.m_minWordP = d;
    }

    public double getMinWordFrequency() {
        return this.m_minWordP;
    }

    public String minAbsoluteCoefficientValueTipText() {
        return "The minimum absolute magnitude for model coefficients. Terms with weights smaller than this value are ignored. If periodic pruning is turned on then this is also used to determine if a word should be removed from the dictionary.";
    }

    public void setMinAbsoluteCoefficientValue(double d) {
        this.m_minAbsCoefficient = d;
    }

    public double getMinAbsoluteCoefficientValue() {
        return this.m_minAbsCoefficient;
    }

    public String normalizeDocLengthTipText() {
        return "If true then document length is normalized according to the settings for norm and lnorm";
    }

    public void setNormalizeDocLength(boolean z) {
        this.m_normalize = z;
    }

    public boolean getNormalizeDocLength() {
        return this.m_normalize;
    }

    public String normTipText() {
        return "The norm of the instances after normalization.";
    }

    public double getNorm() {
        return this.m_norm;
    }

    public void setNorm(double d) {
        this.m_norm = d;
    }

    public String LNormTipText() {
        return "The LNorm to use for document length normalization.";
    }

    public double getLNorm() {
        return this.m_lnorm;
    }

    public void setLNorm(double d) {
        this.m_lnorm = d;
    }

    public String lambdaTipText() {
        return "The regularization constant. (default = 0.0001)";
    }

    public void setLambda(double d) {
        this.m_lambda = d;
    }

    public double getLambda() {
        return this.m_lambda;
    }

    public void setLearningRate(double d) {
        this.m_learningRate = d;
    }

    public double getLearningRate() {
        return this.m_learningRate;
    }

    public String learningRateTipText() {
        return "The learning rate.";
    }

    public String epochsTipText() {
        return "The number of epochs to perform (batch learning). The total number of iterations is epochs * num instances.";
    }

    public void setEpochs(int i) {
        this.m_epochs = i;
    }

    public int getEpochs() {
        return this.m_epochs;
    }

    public void setLossFunction(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_SELECTION) {
            this.m_loss = selectedTag.getSelectedTag().getID();
        }
    }

    public SelectedTag getLossFunction() {
        return new SelectedTag(this.m_loss, TAGS_SELECTION);
    }

    public String lossFunctionTipText() {
        return "The loss function to use. Hinge loss (SVM), log loss (logistic regression) or squared loss (regression).";
    }

    public void setOutputProbsForSVM(boolean z) {
        this.m_fitLogistic = z;
    }

    public boolean getOutputProbsForSVM() {
        return this.m_fitLogistic;
    }

    public String outputProbsForSVMTipText() {
        return "Fit a logistic regression to the output of SVM for producing probability estimates";
    }

    @Override // weka.classifiers.RandomizableClassifier, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public Enumeration<Option> listOptions() {
        Vector vector = new Vector();
        vector.add(new Option("\tSet the loss function to minimize. 0 = hinge loss (SVM), 1 = log loss (logistic regression)\n\t(default = 0)", "F", 1, "-F"));
        vector.add(new Option("\tOutput probabilities for SVMs (fits a logsitic\n\tmodel to the output of the SVM)", "output-probs", 0, "-outputProbs"));
        vector.add(new Option("\tThe learning rate (default = 0.01).", "L", 1, "-L"));
        vector.add(new Option("\tThe lambda regularization constant (default = 0.0001)", "R", 1, "-R <double>"));
        vector.add(new Option("\tThe number of epochs to perform (batch learning only, default = 500)", "E", 1, "-E <integer>"));
        vector.add(new Option("\tUse word frequencies instead of binary bag of words.", "W", 0, "-W"));
        vector.add(new Option("\tHow often to prune the dictionary of low frequency words (default = 0, i.e. don't prune)", "P", 1, "-P <# instances>"));
        vector.add(new Option("\tMinimum word frequency. Words with less than this frequence are ignored.\n\tIf periodic pruning is turned on then this is also used to determine which\n\twords to remove from the dictionary (default = 3).", "M", 1, "-M <double>"));
        vector.add(new Option("\tMinimum absolute value of coefficients in the model.\n\tIf periodic pruning is turned on then this\n\tis also used to prune words from the dictionary\n\t(default = 0.001", "min-coeff", 1, "-min-coeff <double>"));
        vector.addElement(new Option("\tNormalize document length (use in conjunction with -norm and -lnorm)", "normalize", 0, "-normalize"));
        vector.addElement(new Option("\tSpecify the norm that each instance must have (default 1.0)", "norm", 1, "-norm <num>"));
        vector.addElement(new Option("\tSpecify L-norm to use (default 2.0)", "lnorm", 1, "-lnorm <num>"));
        vector.addElement(new Option("\tConvert all tokens to lowercase before adding to the dictionary.", "lowercase", 0, "-lowercase"));
        vector.addElement(new Option("\tThe stopwords handler to use (default Null).", "-stopwords-handler", 1, "-stopwords-handler"));
        vector.addElement(new Option("\tThe tokenizing algorihtm (classname plus parameters) to use.\n\t(default: " + WordTokenizer.class.getName() + ")", "tokenizer", 1, "-tokenizer <spec>"));
        vector.addElement(new Option("\tThe stemmering algorihtm (classname plus parameters) to use.", "stemmer", 1, "-stemmer <spec>"));
        vector.addAll(Collections.list(super.listOptions()));
        return vector.elements();
    }

    @Override // weka.classifiers.RandomizableClassifier, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        reset();
        String option = Utils.getOption('F', strArr);
        if (option.length() != 0) {
            setLossFunction(new SelectedTag(Integer.parseInt(option), TAGS_SELECTION));
        }
        setOutputProbsForSVM(Utils.getFlag("output-probs", strArr));
        String option2 = Utils.getOption('R', strArr);
        if (option2.length() > 0) {
            setLambda(Double.parseDouble(option2));
        }
        String option3 = Utils.getOption('L', strArr);
        if (option3.length() > 0) {
            setLearningRate(Double.parseDouble(option3));
        }
        String option4 = Utils.getOption("E", strArr);
        if (option4.length() > 0) {
            setEpochs(Integer.parseInt(option4));
        }
        setUseWordFrequencies(Utils.getFlag("W", strArr));
        String option5 = Utils.getOption("P", strArr);
        if (option5.length() > 0) {
            setPeriodicPruning(Integer.parseInt(option5));
        }
        String option6 = Utils.getOption("M", strArr);
        if (option6.length() > 0) {
            setMinWordFrequency(Double.parseDouble(option6));
        }
        String option7 = Utils.getOption("min-coeff", strArr);
        if (option7.length() > 0) {
            setMinAbsoluteCoefficientValue(Double.parseDouble(option7));
        }
        setNormalizeDocLength(Utils.getFlag("normalize", strArr));
        String option8 = Utils.getOption("norm", strArr);
        if (option8.length() > 0) {
            setNorm(Double.parseDouble(option8));
        }
        String option9 = Utils.getOption("lnorm", strArr);
        if (option9.length() > 0) {
            setLNorm(Double.parseDouble(option9));
        }
        setLowercaseTokens(Utils.getFlag("lowercase", strArr));
        String option10 = Utils.getOption("stemmer", strArr);
        if (option10.length() == 0) {
            setStemmer(null);
        } else {
            String[] splitOptions = Utils.splitOptions(option10);
            if (splitOptions.length == 0) {
                throw new Exception("Invalid stemmer specification string");
            }
            String str = splitOptions[0];
            splitOptions[0] = KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF;
            setStemmer((Stemmer) Utils.forName(Class.forName("weka.core.stemmers.Stemmer"), str, splitOptions));
        }
        String option11 = Utils.getOption("stopwords-handler", strArr);
        if (option11.length() == 0) {
            setStopwordsHandler(null);
        } else {
            String[] splitOptions2 = Utils.splitOptions(option11);
            if (splitOptions2.length == 0) {
                throw new Exception("Invalid StopwordsHandler specification string");
            }
            String str2 = splitOptions2[0];
            splitOptions2[0] = KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF;
            setStopwordsHandler((StopwordsHandler) Utils.forName(Class.forName("weka.core.stopwords.StopwordsHandler"), str2, splitOptions2));
        }
        String option12 = Utils.getOption("tokenizer", strArr);
        if (option12.length() == 0) {
            setTokenizer(new WordTokenizer());
        } else {
            String[] splitOptions3 = Utils.splitOptions(option12);
            if (splitOptions3.length == 0) {
                throw new Exception("Invalid tokenizer specification string");
            }
            String str3 = splitOptions3[0];
            splitOptions3[0] = KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF;
            setTokenizer((Tokenizer) Utils.forName(Class.forName("weka.core.tokenizers.Tokenizer"), str3, splitOptions3));
        }
        super.setOptions(strArr);
        Utils.checkForRemainingOptions(strArr);
    }

    @Override // weka.classifiers.RandomizableClassifier, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public String[] getOptions() {
        ArrayList arrayList = new ArrayList();
        arrayList.add("-F");
        arrayList.add(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF + getLossFunction().getSelectedTag().getID());
        if (getOutputProbsForSVM()) {
            arrayList.add("-output-probs");
        }
        arrayList.add("-L");
        arrayList.add(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF + getLearningRate());
        arrayList.add("-R");
        arrayList.add(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF + getLambda());
        arrayList.add("-E");
        arrayList.add(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF + getEpochs());
        if (getUseWordFrequencies()) {
            arrayList.add("-W");
        }
        arrayList.add("-P");
        arrayList.add(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF + getPeriodicPruning());
        arrayList.add("-M");
        arrayList.add(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF + getMinWordFrequency());
        arrayList.add("-min-coeff");
        arrayList.add(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF + getMinAbsoluteCoefficientValue());
        if (getNormalizeDocLength()) {
            arrayList.add("-normalize");
        }
        arrayList.add("-norm");
        arrayList.add(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF + getNorm());
        arrayList.add("-lnorm");
        arrayList.add(KnowledgeFlowApp.KnowledgeFlowGeneralDefaults.LAF + getLNorm());
        if (getLowercaseTokens()) {
            arrayList.add("-lowercase");
        }
        if (getStopwordsHandler() != null) {
            arrayList.add("-stopwords-handler");
            String name = getStopwordsHandler().getClass().getName();
            if (getStopwordsHandler() instanceof OptionHandler) {
                name = name + TestInstances.DEFAULT_SEPARATORS + Utils.joinOptions(((OptionHandler) getStopwordsHandler()).getOptions());
            }
            arrayList.add(name.trim());
        }
        arrayList.add("-tokenizer");
        String name2 = getTokenizer().getClass().getName();
        if (getTokenizer() instanceof OptionHandler) {
            name2 = name2 + TestInstances.DEFAULT_SEPARATORS + Utils.joinOptions(getTokenizer().getOptions());
        }
        arrayList.add(name2.trim());
        if (getStemmer() != null) {
            arrayList.add("-stemmer");
            String name3 = getStemmer().getClass().getName();
            if (getStemmer() instanceof OptionHandler) {
                name3 = name3 + TestInstances.DEFAULT_SEPARATORS + Utils.joinOptions(((OptionHandler) getStemmer()).getOptions());
            }
            arrayList.add(name3.trim());
        }
        Collections.addAll(arrayList, super.getOptions());
        return (String[]) arrayList.toArray(new String[1]);
    }

    public String globalInfo() {
        return "Implements stochastic gradient descent for learning a linear binary class SVM or binary class logistic regression on text data. Operates directly (and only) on String attributes. Other types of input attributes are accepted but ignored during training and classification.";
    }

    public void reset() {
        this.m_t = 1.0d;
        this.m_bias = KStarConstants.FLOOR;
        this.m_dictionary = null;
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        reset();
        getCapabilities().testWithFail(instances);
        this.m_dictionary = new LinkedHashMap<>(10000);
        this.m_numInstances = instances.numInstances();
        this.m_data = new Instances(instances, 0);
        Instances instances2 = new Instances(instances);
        if (this.m_fitLogistic && this.m_loss == 0) {
            initializeSVMProbs(instances2);
        }
        if (instances2.numInstances() > 0) {
            instances2.randomize(new Random(getSeed()));
            train(instances2);
            pruneDictionary(true);
        }
    }

    protected void initializeSVMProbs(Instances instances) throws Exception {
        this.m_svmProbs = new SGD();
        this.m_svmProbs.setLossFunction(new SelectedTag(1, TAGS_SELECTION));
        this.m_svmProbs.setLearningRate(this.m_learningRate);
        this.m_svmProbs.setLambda(this.m_lambda);
        this.m_svmProbs.setEpochs(this.m_epochs);
        ArrayList arrayList = new ArrayList(2);
        arrayList.add(new Attribute("pred"));
        ArrayList arrayList2 = new ArrayList(2);
        arrayList2.add(instances.classAttribute().value(0));
        arrayList2.add(instances.classAttribute().value(1));
        arrayList.add(new Attribute("class", arrayList2));
        this.m_fitLogisticStructure = new Instances(JSONInstances.DATA, (ArrayList<Attribute>) arrayList, 0);
        this.m_fitLogisticStructure.setClassIndex(1);
        this.m_svmProbs.buildClassifier(this.m_fitLogisticStructure);
    }

    protected void train(Instances instances) throws Exception {
        for (int i = 0; i < this.m_epochs; i++) {
            for (int i2 = 0; i2 < instances.numInstances(); i2++) {
                if (i == 0) {
                    updateClassifier(instances.instance(i2), true);
                } else {
                    updateClassifier(instances.instance(i2), false);
                }
            }
        }
    }

    @Override // weka.classifiers.UpdateableClassifier
    public void updateClassifier(Instance instance) throws Exception {
        updateClassifier(instance, true);
    }

    protected void updateClassifier(Instance instance, boolean z) throws Exception {
        if (instance.classIsMissing()) {
            return;
        }
        tokenizeInstance(instance, z);
        if (this.m_loss == 0 && this.m_fitLogistic) {
            DenseInstance denseInstance = new DenseInstance(instance.weight(), new double[]{svmOutput(), instance.classValue()});
            denseInstance.setDataset(this.m_fitLogisticStructure);
            this.m_svmProbs.updateClassifier(denseInstance);
        }
        double dotProd = dotProd(this.m_inputVector);
        double d = instance.classValue() == KStarConstants.FLOOR ? -1.0d : 1.0d;
        double d2 = d * (dotProd + this.m_bias);
        double d3 = this.m_numInstances == KStarConstants.FLOOR ? 1.0d - ((this.m_learningRate * this.m_lambda) / this.m_t) : 1.0d - ((this.m_learningRate * this.m_lambda) / this.m_numInstances);
        Iterator<Map.Entry<String, Count>> it = this.m_dictionary.entrySet().iterator();
        while (it.hasNext()) {
            it.next().getValue().m_weight *= d3;
        }
        if (this.m_loss != 0 || d2 < 1.0d) {
            double dloss = this.m_learningRate * d * dloss(d2);
            for (Map.Entry<String, Count> entry : this.m_inputVector.entrySet()) {
                String key = entry.getKey();
                double d4 = this.m_wordFrequencies ? entry.getValue().m_count : 1.0d;
                Count count = this.m_dictionary.get(key);
                if (count != null) {
                    count.m_weight += dloss * d4;
                }
            }
            this.m_bias += dloss;
        }
        this.m_t += 1.0d;
    }

    protected void tokenizeInstance(Instance instance, boolean z) {
        if (this.m_inputVector == null) {
            this.m_inputVector = new LinkedHashMap<>();
        } else {
            this.m_inputVector.clear();
        }
        for (int i = 0; i < instance.numAttributes(); i++) {
            if (instance.attribute(i).isString() && !instance.isMissing(i)) {
                this.m_tokenizer.tokenize(instance.stringValue(i));
                while (this.m_tokenizer.hasMoreElements()) {
                    String nextElement = this.m_tokenizer.nextElement();
                    if (this.m_lowercaseTokens) {
                        nextElement = nextElement.toLowerCase();
                    }
                    String stem = this.m_stemmer.stem(nextElement);
                    if (!this.m_StopwordsHandler.isStopword(stem)) {
                        Count count = this.m_inputVector.get(stem);
                        if (count == null) {
                            this.m_inputVector.put(stem, new Count(instance.weight()));
                        } else {
                            count.m_count += instance.weight();
                        }
                        if (z) {
                            Count count2 = this.m_dictionary.get(stem);
                            if (count2 == null) {
                                this.m_dictionary.put(stem, new Count(instance.weight()));
                            } else {
                                count2.m_count += instance.weight();
                            }
                        }
                    }
                }
            }
        }
        if (z) {
            pruneDictionary(false);
        }
    }

    protected void pruneDictionary(boolean z) {
        if ((this.m_periodicP <= 0 || this.m_t % this.m_periodicP > KStarConstants.FLOOR) && !z) {
            return;
        }
        Iterator<Map.Entry<String, Count>> it = this.m_dictionary.entrySet().iterator();
        while (it.hasNext()) {
            Map.Entry<String, Count> next = it.next();
            if (next.getValue().m_count < this.m_minWordP || Math.abs(next.getValue().m_weight) < this.m_minAbsCoefficient) {
                it.remove();
            }
        }
    }

    protected double svmOutput() {
        return dotProd(this.m_inputVector) + this.m_bias;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] dArr = new double[2];
        tokenizeInstance(instance, false);
        double dotProd = dotProd(this.m_inputVector) + this.m_bias;
        if (this.m_loss == 0 && this.m_fitLogistic) {
            DenseInstance denseInstance = new DenseInstance(instance.weight(), new double[]{dotProd, Utils.missingValue()});
            denseInstance.setDataset(this.m_fitLogisticStructure);
            return this.m_svmProbs.distributionForInstance(denseInstance);
        }
        if (dotProd <= KStarConstants.FLOOR) {
            if (this.m_loss == 1) {
                dArr[0] = 1.0d / (1.0d + Math.exp(dotProd));
                dArr[1] = 1.0d - dArr[0];
            } else {
                dArr[0] = 1.0d;
            }
        } else if (this.m_loss == 1) {
            dArr[1] = 1.0d / (1.0d + Math.exp(-dotProd));
            dArr[0] = 1.0d - dArr[1];
        } else {
            dArr[1] = 1.0d;
        }
        return dArr;
    }

    protected double dotProd(Map<String, Count> map) {
        double d = 0.0d;
        double d2 = 0.0d;
        if (this.m_normalize) {
            Iterator<Count> it = map.values().iterator();
            while (it.hasNext()) {
                d2 += Math.pow(Math.abs(this.m_wordFrequencies ? it.next().m_count : 1.0d), this.m_lnorm);
            }
            d2 = Math.pow(d2, 1.0d / this.m_lnorm);
        }
        for (Map.Entry<String, Count> entry : map.entrySet()) {
            String key = entry.getKey();
            double d3 = this.m_wordFrequencies ? entry.getValue().m_count : 1.0d;
            if (this.m_normalize) {
                d3 *= this.m_norm / d2;
            }
            Count count = this.m_dictionary.get(key);
            if (count != null && count.m_count >= this.m_minWordP && Math.abs(count.m_weight) >= this.m_minAbsCoefficient) {
                d += d3 * count.m_weight;
            }
        }
        return d;
    }

    public String toString() {
        if (this.m_dictionary == null) {
            return "SGDText: No model built yet.\n";
        }
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("SGDText:\n\n");
        stringBuffer.append("Loss function: ");
        if (this.m_loss == 0) {
            stringBuffer.append("Hinge loss (SVM)\n\n");
        } else {
            stringBuffer.append("Log loss (logistic regression)\n\n");
        }
        int i = 0;
        for (Map.Entry<String, Count> entry : this.m_dictionary.entrySet()) {
            if (entry.getValue().m_count >= this.m_minWordP && Math.abs(entry.getValue().m_weight) >= this.m_minAbsCoefficient) {
                i++;
            }
        }
        stringBuffer.append("Dictionary size: " + i + "\n\n");
        stringBuffer.append(this.m_data.classAttribute().name() + " = \n\n");
        int i2 = 0;
        for (Map.Entry<String, Count> entry2 : this.m_dictionary.entrySet()) {
            if (entry2.getValue().m_count >= this.m_minWordP && Math.abs(entry2.getValue().m_weight) >= this.m_minAbsCoefficient) {
                if (i2 > 0) {
                    stringBuffer.append(" + ");
                } else {
                    stringBuffer.append("   ");
                }
                stringBuffer.append(Utils.doubleToString(entry2.getValue().m_weight, 12, 4) + TestInstances.DEFAULT_SEPARATORS + entry2.getKey() + TestInstances.DEFAULT_SEPARATORS + entry2.getValue().m_count + "\n");
                i2++;
            }
        }
        if (this.m_bias > KStarConstants.FLOOR) {
            stringBuffer.append(" + " + Utils.doubleToString(this.m_bias, 12, 4));
        } else {
            stringBuffer.append(" - " + Utils.doubleToString(-this.m_bias, 12, 4));
        }
        return stringBuffer.toString();
    }

    public LinkedHashMap<String, Count> getDictionary() {
        return this.m_dictionary;
    }

    public int getDictionarySize() {
        int i = 0;
        if (this.m_dictionary != null) {
            for (Map.Entry<String, Count> entry : this.m_dictionary.entrySet()) {
                if (entry.getValue().m_count >= this.m_minWordP && Math.abs(entry.getValue().m_weight) >= this.m_minAbsCoefficient) {
                    i++;
                }
            }
        }
        return i;
    }

    public double bias() {
        return this.m_bias;
    }

    public void setBias(double d) {
        this.m_bias = d;
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 13280 $");
    }

    @Override // weka.core.Aggregateable
    public SGDText aggregate(SGDText sGDText) throws Exception {
        if (this.m_dictionary == null) {
            throw new Exception("No model built yet, can't aggregate");
        }
        for (Map.Entry<String, Count> entry : sGDText.getDictionary().entrySet()) {
            Count count = this.m_dictionary.get(entry.getKey());
            if (count == null) {
                Count count2 = new Count(entry.getValue().m_count);
                count2.m_weight = entry.getValue().m_weight;
                this.m_dictionary.put(entry.getKey(), count2);
            } else {
                count.m_count += entry.getValue().m_count;
                count.m_weight += entry.getValue().m_weight;
            }
        }
        this.m_bias += sGDText.bias();
        this.m_numModels++;
        return this;
    }

    @Override // weka.core.Aggregateable
    public void finalizeAggregation() throws Exception {
        if (this.m_numModels == 0) {
            throw new Exception("Unable to finalize aggregation - haven't seen any models to aggregate");
        }
        pruneDictionary(true);
        for (Map.Entry<String, Count> entry : this.m_dictionary.entrySet()) {
            entry.getValue().m_count /= this.m_numModels + 1;
            entry.getValue().m_weight /= this.m_numModels + 1;
        }
        this.m_bias /= this.m_numModels + 1;
        this.m_numModels = 0;
    }

    @Override // weka.classifiers.UpdateableBatchProcessor
    public void batchFinished() throws Exception {
        pruneDictionary(true);
    }

    public static void main(String[] strArr) {
        runClassifier(new SGDText(), strArr);
    }
}
