/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.topics.TopicAssignment;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.LabelSequence;
import cc.mallet.util.Randoms;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.Callable;

public class WorkerCallable
implements Callable<Integer> {
    ArrayList<TopicAssignment> data;
    int startDoc;
    int numDocs;
    protected int numTopics;
    protected int topicMask;
    protected int topicBits;
    protected int numTypes;
    protected double[] alpha;
    protected double alphaSum;
    protected double beta;
    protected double betaSum;
    public static final double DEFAULT_BETA = 0.01;
    protected double smoothingOnlyMass = 0.0;
    protected double[] cachedCoefficients;
    protected int[][] typeTopicCounts;
    protected int[] tokensPerTopic;
    protected int[] docLengthCounts;
    protected int[][] topicDocCounts;
    boolean shouldSaveState = false;
    boolean shouldBuildLocalCounts = true;
    protected Randoms random;

    public WorkerCallable() {
    }

    public WorkerCallable(int numTopics, double[] alpha, double alphaSum, double beta, Randoms random, ArrayList<TopicAssignment> data, int[][] typeTopicCounts, int[] tokensPerTopic, int startDoc, int numDocs) {
        this.data = data;
        this.numTopics = numTopics;
        this.numTypes = typeTopicCounts.length;
        if (Integer.bitCount(numTopics) == 1) {
            this.topicMask = numTopics - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        } else {
            this.topicMask = Integer.highestOneBit(numTopics) * 2 - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        }
        this.typeTopicCounts = typeTopicCounts;
        this.tokensPerTopic = tokensPerTopic;
        this.alphaSum = alphaSum;
        this.alpha = alpha;
        this.beta = beta;
        this.betaSum = beta * (double)this.numTypes;
        this.random = random;
        this.startDoc = startDoc;
        this.numDocs = numDocs;
        this.cachedCoefficients = new double[numTopics];
    }

    public void makeOnlyThread() {
        this.shouldBuildLocalCounts = false;
    }

    public int[] getTokensPerTopic() {
        return this.tokensPerTopic;
    }

    public int[][] getTypeTopicCounts() {
        return this.typeTopicCounts;
    }

    public int[] getDocLengthCounts() {
        return this.docLengthCounts;
    }

    public int[][] getTopicDocCounts() {
        return this.topicDocCounts;
    }

    public void initializeAlphaStatistics(int size) {
        this.docLengthCounts = new int[size];
        this.topicDocCounts = new int[this.numTopics][size];
    }

    public void collectAlphaStatistics() {
        this.shouldSaveState = true;
    }

    public void resetBeta(double beta, double betaSum) {
        this.beta = beta;
        this.betaSum = betaSum;
    }

    public void buildLocalTypeTopicCounts() {
        Arrays.fill(this.tokensPerTopic, 0);
        for (int type = 0; type < this.typeTopicCounts.length; ++type) {
            int[] topicCounts = this.typeTopicCounts[type];
            Arrays.fill(topicCounts, 0);
        }
        for (int doc = this.startDoc; doc < this.data.size() && doc < this.startDoc + this.numDocs; ++doc) {
            TopicAssignment document = this.data.get(doc);
            FeatureSequence tokens = (FeatureSequence)document.instance.getData();
            LabelSequence topicSequence = document.topicSequence;
            int[] topics = topicSequence.getFeatures();
            for (int position = 0; position < tokens.size(); ++position) {
                int topic = topics[position];
                if (topic == -1) continue;
                int n = topic;
                this.tokensPerTopic[n] = this.tokensPerTopic[n] + 1;
                int type = tokens.getIndexAtPosition(position);
                int[] currentTypeTopicCounts = this.typeTopicCounts[type];
                int index = 0;
                int currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                while (currentTypeTopicCounts[index] > 0 && currentTopic != topic) {
                    if (++index == currentTypeTopicCounts.length) {
                        System.out.println("overflow on type " + type + " for topic " + topic);
                        StringBuilder out = new StringBuilder();
                        for (int value : currentTypeTopicCounts) {
                            out.append(value + " ");
                        }
                        System.out.println(out);
                    }
                    currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                }
                int currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                if (currentValue == 0) {
                    currentTypeTopicCounts[index] = (1 << this.topicBits) + topic;
                    continue;
                }
                currentTypeTopicCounts[index] = (currentValue + 1 << this.topicBits) + topic;
                while (index > 0 && currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
                    int temp = currentTypeTopicCounts[index];
                    currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
                    currentTypeTopicCounts[index - 1] = temp;
                    --index;
                }
            }
        }
    }

    @Override
    public Integer call() throws Exception {
        this.smoothingOnlyMass = 0.0;
        for (int topic = 0; topic < this.numTopics; ++topic) {
            this.smoothingOnlyMass += this.alpha[topic] * this.beta / ((double)this.tokensPerTopic[topic] + this.betaSum);
            this.cachedCoefficients[topic] = this.alpha[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum);
        }
        int changed = 0;
        for (int doc = this.startDoc; doc < this.data.size() && doc < this.startDoc + this.numDocs; ++doc) {
            FeatureSequence tokenSequence = (FeatureSequence)this.data.get((int)doc).instance.getData();
            LabelSequence topicSequence = this.data.get((int)doc).topicSequence;
            changed += this.sampleTopicsForOneDoc(tokenSequence, topicSequence, true);
        }
        if (this.shouldBuildLocalCounts) {
            this.buildLocalTypeTopicCounts();
        }
        return changed;
    }

    protected int sampleTopicsForOneDoc(FeatureSequence tokenSequence, FeatureSequence topicSequence, boolean readjustTopicsAndStats) {
        int topic;
        int[] oneDocTopics = topicSequence.getFeatures();
        int docLength = tokenSequence.getLength();
        int[] localTopicCounts = new int[this.numTopics];
        int[] localTopicIndex = new int[this.numTopics];
        for (int position = 0; position < docLength; ++position) {
            if (oneDocTopics[position] == -1) continue;
            int n = oneDocTopics[position];
            localTopicCounts[n] = localTopicCounts[n] + 1;
        }
        int denseIndex = 0;
        for (int topic2 = 0; topic2 < this.numTopics; ++topic2) {
            if (localTopicCounts[topic2] == 0) continue;
            localTopicIndex[denseIndex] = topic2;
            ++denseIndex;
        }
        int nonZeroTopics = denseIndex;
        double topicBetaMass = 0.0;
        for (denseIndex = 0; denseIndex < nonZeroTopics; ++denseIndex) {
            int topic3 = localTopicIndex[denseIndex];
            int n = localTopicCounts[topic3];
            topicBetaMass += this.beta * (double)n / ((double)this.tokensPerTopic[topic3] + this.betaSum);
            this.cachedCoefficients[topic3] = (this.alpha[topic3] + (double)n) / ((double)this.tokensPerTopic[topic3] + this.betaSum);
        }
        double topicTermMass = 0.0;
        double[] topicTermScores = new double[this.numTopics];
        int changed = 0;
        for (int position = 0; position < docLength; ++position) {
            int temp;
            double sample;
            int currentValue;
            int type = tokenSequence.getIndexAtPosition(position);
            int oldTopic = oneDocTopics[position];
            int[] currentTypeTopicCounts = this.typeTopicCounts[type];
            if (oldTopic != -1) {
                this.smoothingOnlyMass -= this.alpha[oldTopic] * this.beta / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
                topicBetaMass -= this.beta * (double)localTopicCounts[oldTopic] / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
                int n = oldTopic;
                localTopicCounts[n] = localTopicCounts[n] - 1;
                if (localTopicCounts[oldTopic] == 0) {
                    denseIndex = 0;
                    while (localTopicIndex[denseIndex] != oldTopic) {
                        ++denseIndex;
                    }
                    while (denseIndex < nonZeroTopics) {
                        if (denseIndex < localTopicIndex.length - 1) {
                            localTopicIndex[denseIndex] = localTopicIndex[denseIndex + 1];
                        }
                        ++denseIndex;
                    }
                    --nonZeroTopics;
                }
                int n2 = oldTopic;
                this.tokensPerTopic[n2] = this.tokensPerTopic[n2] - 1;
                assert (this.tokensPerTopic[oldTopic] >= 0) : "old Topic " + oldTopic + " below 0";
                this.smoothingOnlyMass += this.alpha[oldTopic] * this.beta / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
                topicBetaMass += this.beta * (double)localTopicCounts[oldTopic] / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
                this.cachedCoefficients[oldTopic] = (this.alpha[oldTopic] + (double)localTopicCounts[oldTopic]) / ((double)this.tokensPerTopic[oldTopic] + this.betaSum);
            }
            int index = 0;
            boolean alreadyDecremented = oldTopic == -1;
            topicTermMass = 0.0;
            while (index < currentTypeTopicCounts.length && currentTypeTopicCounts[index] > 0) {
                int currentTopic = currentTypeTopicCounts[index] & this.topicMask;
                currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                if (!alreadyDecremented && currentTopic == oldTopic) {
                    currentTypeTopicCounts[index] = --currentValue == 0 ? 0 : (currentValue << this.topicBits) + oldTopic;
                    for (int subIndex = index; subIndex < currentTypeTopicCounts.length - 1 && currentTypeTopicCounts[subIndex] < currentTypeTopicCounts[subIndex + 1]; ++subIndex) {
                        int temp2 = currentTypeTopicCounts[subIndex];
                        currentTypeTopicCounts[subIndex] = currentTypeTopicCounts[subIndex + 1];
                        currentTypeTopicCounts[subIndex + 1] = temp2;
                    }
                    alreadyDecremented = true;
                    continue;
                }
                double score = this.cachedCoefficients[currentTopic] * (double)currentValue;
                topicTermMass += score;
                topicTermScores[index] = score;
                ++index;
            }
            double origSample = sample = this.random.nextUniform() * (this.smoothingOnlyMass + topicBetaMass + topicTermMass);
            int newTopic = -1;
            if (sample < topicTermMass) {
                int i = -1;
                while (sample > 0.0) {
                    sample -= topicTermScores[++i];
                }
                newTopic = currentTypeTopicCounts[i] & this.topicMask;
                currentValue = currentTypeTopicCounts[i] >> this.topicBits;
                currentTypeTopicCounts[i] = (currentValue + 1 << this.topicBits) + newTopic;
                while (i > 0 && currentTypeTopicCounts[i] > currentTypeTopicCounts[i - 1]) {
                    temp = currentTypeTopicCounts[i];
                    currentTypeTopicCounts[i] = currentTypeTopicCounts[i - 1];
                    currentTypeTopicCounts[i - 1] = temp;
                    --i;
                }
            } else {
                if ((sample -= topicTermMass) < topicBetaMass) {
                    sample /= this.beta;
                    for (denseIndex = 0; denseIndex < nonZeroTopics; ++denseIndex) {
                        int topic4 = localTopicIndex[denseIndex];
                        if (!((sample -= (double)localTopicCounts[topic4] / ((double)this.tokensPerTopic[topic4] + this.betaSum)) <= 0.0)) continue;
                        newTopic = topic4;
                        break;
                    }
                } else {
                    sample -= topicBetaMass;
                    sample /= this.beta;
                    newTopic = 0;
                    sample -= this.alpha[newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                    while (sample > 0.0) {
                        sample -= this.alpha[++newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
                    }
                }
                index = 0;
                while (currentTypeTopicCounts[index] > 0 && (currentTypeTopicCounts[index] & this.topicMask) != newTopic) {
                    if (++index != currentTypeTopicCounts.length) continue;
                    System.err.println("type: " + type + " new topic: " + newTopic);
                    for (int k = 0; k < currentTypeTopicCounts.length; ++k) {
                        System.err.print((currentTypeTopicCounts[k] & this.topicMask) + ":" + (currentTypeTopicCounts[k] >> this.topicBits) + " ");
                    }
                    System.err.println();
                }
                if (currentTypeTopicCounts[index] == 0) {
                    currentTypeTopicCounts[index] = (1 << this.topicBits) + newTopic;
                } else {
                    currentValue = currentTypeTopicCounts[index] >> this.topicBits;
                    currentTypeTopicCounts[index] = (currentValue + 1 << this.topicBits) + newTopic;
                    while (index > 0 && currentTypeTopicCounts[index] > currentTypeTopicCounts[index - 1]) {
                        temp = currentTypeTopicCounts[index];
                        currentTypeTopicCounts[index] = currentTypeTopicCounts[index - 1];
                        currentTypeTopicCounts[index - 1] = temp;
                        --index;
                    }
                }
            }
            if (newTopic == -1) {
                System.err.println("WorkerCallable sampling error: " + origSample + " " + sample + " " + this.smoothingOnlyMass + " " + topicBetaMass + " " + topicTermMass);
                newTopic = this.numTopics - 1;
            }
            oneDocTopics[position] = newTopic;
            this.smoothingOnlyMass -= this.alpha[newTopic] * this.beta / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            topicBetaMass -= this.beta * (double)localTopicCounts[newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            int n = newTopic;
            localTopicCounts[n] = localTopicCounts[n] + 1;
            if (localTopicCounts[newTopic] == 1) {
                for (denseIndex = nonZeroTopics; denseIndex > 0 && localTopicIndex[denseIndex - 1] > newTopic; --denseIndex) {
                    localTopicIndex[denseIndex] = localTopicIndex[denseIndex - 1];
                }
                localTopicIndex[denseIndex] = newTopic;
                ++nonZeroTopics;
            }
            int n3 = newTopic;
            this.tokensPerTopic[n3] = this.tokensPerTopic[n3] + 1;
            this.cachedCoefficients[newTopic] = (this.alpha[newTopic] + (double)localTopicCounts[newTopic]) / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            this.smoothingOnlyMass += this.alpha[newTopic] * this.beta / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            topicBetaMass += this.beta * (double)localTopicCounts[newTopic] / ((double)this.tokensPerTopic[newTopic] + this.betaSum);
            if (newTopic == oldTopic) continue;
            ++changed;
        }
        if (this.shouldSaveState) {
            int n = docLength;
            this.docLengthCounts[n] = this.docLengthCounts[n] + 1;
            for (denseIndex = 0; denseIndex < nonZeroTopics; ++denseIndex) {
                topic = localTopicIndex[denseIndex];
                int[] nArray = this.topicDocCounts[topic];
                int n4 = localTopicCounts[topic];
                nArray[n4] = nArray[n4] + 1;
            }
        }
        for (denseIndex = 0; denseIndex < nonZeroTopics; ++denseIndex) {
            topic = localTopicIndex[denseIndex];
            this.cachedCoefficients[topic] = this.alpha[topic] / ((double)this.tokensPerTopic[topic] + this.betaSum);
        }
        return changed;
    }
}

