/*
 * Decompiled with CFR 0.152.
 */
package org.encog.ml.hmm.train.kmeans;

import java.util.Collection;
import java.util.List;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.MLSequenceSet;
import org.encog.ml.data.basic.BasicMLDataSet;
import org.encog.ml.hmm.HiddenMarkovModel;
import org.encog.ml.hmm.alog.ViterbiCalculator;
import org.encog.ml.hmm.distributions.StateDistribution;
import org.encog.ml.hmm.train.kmeans.Clusters;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.Strategy;
import org.encog.neural.networks.training.propagation.TrainingContinuation;

public class TrainKMeans
implements MLTrain {
    private final Clusters clusters;
    private final int states;
    private final MLSequenceSet sequnces;
    private boolean done;
    private final HiddenMarkovModel modelHMM;
    private int iteration;
    private HiddenMarkovModel method;
    private final MLSequenceSet training;

    public TrainKMeans(HiddenMarkovModel method, MLSequenceSet sequences) {
        this.method = method;
        this.modelHMM = method;
        this.sequnces = sequences;
        this.states = method.getStateCount();
        this.training = sequences;
        this.clusters = new Clusters(this.states, sequences);
        this.done = false;
    }

    @Override
    public void addStrategy(Strategy strategy) {
    }

    @Override
    public boolean canContinue() {
        return false;
    }

    @Override
    public void finishTraining() {
    }

    @Override
    public double getError() {
        return this.done ? 0.0 : 100.0;
    }

    @Override
    public TrainingImplementationType getImplementationType() {
        return TrainingImplementationType.Iterative;
    }

    @Override
    public int getIteration() {
        return this.iteration;
    }

    @Override
    public MLMethod getMethod() {
        return this.method;
    }

    @Override
    public List<Strategy> getStrategies() {
        return null;
    }

    @Override
    public MLDataSet getTraining() {
        return this.training;
    }

    @Override
    public boolean isTrainingDone() {
        return this.done;
    }

    @Override
    public void iteration() {
        HiddenMarkovModel hmm = this.modelHMM.cloneStructure();
        this.learnPi(hmm);
        this.learnTransition(hmm);
        this.learnOpdf(hmm);
        this.done = this.optimizeCluster(hmm);
        this.method = hmm;
    }

    @Override
    public void iteration(int count) {
    }

    private void learnOpdf(HiddenMarkovModel hmm) {
        for (int i = 0; i < hmm.getStateCount(); ++i) {
            Collection<MLDataPair> clusterObservations = this.clusters.cluster(i);
            if (clusterObservations.size() < 1) {
                StateDistribution o = this.modelHMM.createNewDistribution();
                hmm.setStateDistribution(i, o);
                continue;
            }
            BasicMLDataSet temp = new BasicMLDataSet();
            for (MLDataPair pair : clusterObservations) {
                temp.add(pair);
            }
            hmm.getStateDistribution(i).fit(temp);
        }
    }

    private void learnPi(HiddenMarkovModel hmm) {
        double[] pi = new double[this.states];
        for (int i = 0; i < this.states; ++i) {
            pi[i] = 0.0;
        }
        for (MLDataSet sequence : this.sequnces.getSequences()) {
            int n = this.clusters.cluster(sequence.get(0));
            pi[n] = pi[n] + 1.0;
        }
        for (int i = 0; i < this.states; ++i) {
            hmm.setPi(i, pi[i] / (double)this.sequnces.size());
        }
    }

    private void learnTransition(HiddenMarkovModel hmm) {
        for (int i = 0; i < hmm.getStateCount(); ++i) {
            for (int j = 0; j < hmm.getStateCount(); ++j) {
                hmm.setTransitionProbability(i, j, 0.0);
            }
        }
        for (MLDataSet obsSeq : this.sequnces.getSequences()) {
            if (obsSeq.size() < 2) continue;
            int second_state = this.clusters.cluster(obsSeq.get(0));
            for (int i = 1; i < obsSeq.size(); ++i) {
                int first_state = second_state;
                second_state = this.clusters.cluster(obsSeq.get(i));
                hmm.setTransitionProbability(first_state, second_state, hmm.getTransitionProbability(first_state, second_state) + 1.0);
            }
        }
        for (int i = 0; i < hmm.getStateCount(); ++i) {
            int j;
            double sum = 0.0;
            for (j = 0; j < hmm.getStateCount(); ++j) {
                sum += hmm.getTransitionProbability(i, j);
            }
            if (sum == 0.0) {
                for (j = 0; j < hmm.getStateCount(); ++j) {
                    hmm.setTransitionProbability(i, j, 1.0 / (double)hmm.getStateCount());
                }
                continue;
            }
            for (j = 0; j < hmm.getStateCount(); ++j) {
                hmm.setTransitionProbability(i, j, hmm.getTransitionProbability(i, j) / sum);
            }
        }
    }

    private boolean optimizeCluster(HiddenMarkovModel hmm) {
        boolean modif = false;
        for (MLDataSet obsSeq : this.sequnces.getSequences()) {
            ViterbiCalculator vc = new ViterbiCalculator(obsSeq, hmm);
            int[] states = vc.stateSequence();
            for (int i = 0; i < states.length; ++i) {
                MLDataPair o = obsSeq.get(i);
                if (this.clusters.cluster(o) == states[i]) continue;
                modif = true;
                this.clusters.remove(o, this.clusters.cluster(o));
                this.clusters.put(o, states[i]);
            }
        }
        return !modif;
    }

    @Override
    public TrainingContinuation pause() {
        return null;
    }

    @Override
    public void resume(TrainingContinuation state) {
    }

    @Override
    public void setError(double error) {
    }

    @Override
    public void setIteration(int iteration) {
        this.iteration = iteration;
    }
}

