/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.sgd.crf;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.util.Pair;
import java.time.OffsetDateTime;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.logging.Logger;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.WeightedExamples;
import org.tribuo.classification.Label;
import org.tribuo.classification.sgd.Util;
import org.tribuo.classification.sgd.crf.CRFModel;
import org.tribuo.classification.sgd.crf.CRFParameters;
import org.tribuo.math.Parameters;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.sequence.SequenceDataset;
import org.tribuo.sequence.SequenceExample;
import org.tribuo.sequence.SequenceTrainer;

public class CRFTrainer
implements SequenceTrainer<Label>,
WeightedExamples {
    private static final Logger logger = Logger.getLogger(CRFTrainer.class.getName());
    @Config(mandatory=true, description="The gradient optimiser to use.")
    private StochasticGradientOptimiser optimiser;
    @Config(description="The number of gradient descent epochs.")
    private int epochs = 5;
    @Config(description="Log values after this many updates.")
    private int loggingInterval = -1;
    @Config(description="Minibatch size in SGD.")
    private int minibatchSize = 1;
    @Config(mandatory=true, description="Seed for the RNG used to shuffle elements.")
    private long seed;
    @Config(description="Shuffle the data before each epoch. Only turn off for debugging.")
    private boolean shuffle = true;
    private SplittableRandom rng;
    private int trainInvocationCounter;

    public CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, int minibatchSize, long seed) {
        this.optimiser = optimiser;
        this.epochs = epochs;
        this.loggingInterval = loggingInterval;
        this.minibatchSize = minibatchSize;
        this.seed = seed;
        this.postConfig();
    }

    public CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, int loggingInterval, long seed) {
        this(optimiser, epochs, loggingInterval, 1, seed);
    }

    public CRFTrainer(StochasticGradientOptimiser optimiser, int epochs, long seed) {
        this(optimiser, epochs, 100, 1, seed);
    }

    private CRFTrainer() {
    }

    public synchronized void postConfig() {
        this.rng = new SplittableRandom(this.seed);
    }

    public void setShuffle(boolean shuffle) {
        this.shuffle = shuffle;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public CRFModel train(SequenceDataset<Label> sequenceExamples, Map<String, Provenance> runProvenance) {
        TrainerProvenance trainerProvenance;
        StochasticGradientOptimiser localOptimiser;
        SplittableRandom localRNG;
        if (sequenceExamples.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        CRFTrainer cRFTrainer = this;
        synchronized (cRFTrainer) {
            localRNG = this.rng.split();
            localOptimiser = this.optimiser.copy();
            trainerProvenance = this.getProvenance();
            ++this.trainInvocationCounter;
        }
        ImmutableOutputInfo labelIDMap = sequenceExamples.getOutputIDInfo();
        ImmutableFeatureMap featureIDMap = sequenceExamples.getFeatureIDMap();
        SGDVector[][] sgdFeatures = new SGDVector[sequenceExamples.size()][];
        int[][] sgdLabels = new int[sequenceExamples.size()][];
        double[] weights = new double[sequenceExamples.size()];
        int n = 0;
        for (SequenceExample example : sequenceExamples) {
            weights[n] = example.getWeight();
            Pair<int[], SGDVector[]> pair = CRFModel.convertToVector((SequenceExample<Label>)example, featureIDMap, (ImmutableOutputInfo<Label>)labelIDMap);
            sgdFeatures[n] = (SGDVector[])pair.getB();
            sgdLabels[n] = (int[])pair.getA();
            ++n;
        }
        logger.info(String.format("Training SGD CRF with %d examples", n));
        CRFParameters crfParameters = new CRFParameters(featureIDMap.size(), labelIDMap.size());
        localOptimiser.initialise((Parameters)crfParameters);
        double loss = 0.0;
        int iteration = 0;
        for (int i = 0; i < this.epochs; ++i) {
            if (this.shuffle) {
                Util.shuffleInPlace(sgdFeatures, sgdLabels, weights, localRNG);
            }
            if (this.minibatchSize == 1) {
                for (int j = 0; j < sgdFeatures.length; ++j) {
                    Pair<Double, Tensor[]> output = crfParameters.valueAndGradient(sgdFeatures[j], sgdLabels[j]);
                    loss += (Double)output.getA() * weights[j];
                    Tensor[] updates = localOptimiser.step((Tensor[])output.getB(), weights[j]);
                    crfParameters.update(updates);
                    if (++iteration % this.loggingInterval != 0 || this.loggingInterval == -1) continue;
                    logger.info("At iteration " + iteration + ", average loss = " + loss / (double)this.loggingInterval);
                    loss = 0.0;
                }
                continue;
            }
            Tensor[][] gradients = new Tensor[this.minibatchSize][];
            for (int j = 0; j < sgdFeatures.length; j += this.minibatchSize) {
                Tensor[] updates;
                double tempWeight = 0.0;
                int curSize = 0;
                for (int k = j; k < j + this.minibatchSize && k < sgdFeatures.length; ++k) {
                    Pair<Double, Tensor[]> output = crfParameters.valueAndGradient(sgdFeatures[j], sgdLabels[j]);
                    loss += (Double)output.getA() * weights[k];
                    tempWeight += weights[k];
                    gradients[k - j] = (Tensor[])output.getB();
                    ++curSize;
                }
                for (Tensor update : updates = crfParameters.merge(gradients, curSize)) {
                    update.scaleInPlace((double)this.minibatchSize);
                }
                updates = localOptimiser.step(updates, tempWeight /= (double)this.minibatchSize);
                crfParameters.update(updates);
                if (this.loggingInterval == -1 || ++iteration % this.loggingInterval != 0) continue;
                logger.info("At iteration " + iteration + ", average loss = " + loss / (double)this.loggingInterval);
                loss = 0.0;
            }
        }
        localOptimiser.finalise();
        ModelProvenance provenance = new ModelProvenance(CRFModel.class.getName(), OffsetDateTime.now(), (DatasetProvenance)sequenceExamples.getProvenance(), trainerProvenance, runProvenance);
        CRFModel model = new CRFModel("crf-sgd-model", provenance, featureIDMap, (ImmutableOutputInfo<Label>)labelIDMap, crfParameters);
        localOptimiser.reset();
        return model;
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public String toString() {
        return "CRFTrainer(optimiser=" + this.optimiser.toString() + ",epochs=" + this.epochs + ",minibatchSize=" + this.minibatchSize + ",seed=" + this.seed + ")";
    }

    public TrainerProvenance getProvenance() {
        return new TrainerProvenanceImpl((SequenceTrainer)this);
    }
}

