/*
 * Decompiled with CFR 0.152.
 */
package phase;

import blbutil.FloatArray;
import blbutil.FloatList;
import ints.IntArray;
import ints.IntList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import phase.BasicPhaseStates;
import phase.EstPhase;
import phase.HmmUpdater;
import phase.MarkerCluster;
import phase.PbwtPhaseIbs;
import phase.PhaseData;
import phase.SamplePhase;
import vcf.Markers;

public class PhaseBaum1 {
    private final PhaseData phaseData;
    private final boolean burnin;
    private final EstPhase estPhase;
    private final Markers markers;
    private final int nMarkers;
    private final List<int[]> refAlleles;
    private final byte[][][] mismatches;
    private final float pMismatch;
    private final float[] emProbs;
    private final int maxStates;
    private final BasicPhaseStates states;
    private final FloatList lrList;
    private int nStates;
    private final float[][] fwd;
    private final float[][] bwd;
    private final float[] sum;
    private final List<float[]> missProbs1;
    private final List<float[]> missProbs2;
    private final List<float[]> bwdHet1;
    private final List<float[]> bwdHet2;
    private boolean swapHaps = false;
    private int missIndex = -1;
    private int swapCnt = 0;
    private static final AtomicLong nSwaps = new AtomicLong(0L);
    private static final AtomicLong nUnphHets = new AtomicLong(0L);

    public static double getAndResetSwapRate() {
        double d = (double)nSwaps.get() / (double)nUnphHets.get();
        nSwaps.set(0L);
        nUnphHets.set(0L);
        return d;
    }

    public PhaseBaum1(PbwtPhaseIbs pbwtPhaseIbs) {
        this.phaseData = pbwtPhaseIbs.phaseData();
        this.burnin = this.phaseData.it() < this.phaseData.fpd().par().burnin();
        this.estPhase = this.phaseData.estPhase();
        this.markers = this.phaseData.fpd().stage1TargGT().markers();
        this.nMarkers = this.markers.size();
        this.maxStates = this.phaseData.fpd().par().phase_states();
        this.states = new BasicPhaseStates(pbwtPhaseIbs, this.maxStates);
        this.lrList = new FloatList(200);
        this.refAlleles = new ArrayList<int[]>();
        this.mismatches = new byte[3][this.nMarkers][this.maxStates];
        this.pMismatch = this.phaseData.pMismatch();
        this.emProbs = new float[]{1.0f - this.pMismatch, this.pMismatch};
        this.fwd = new float[3][this.maxStates];
        this.bwd = new float[3][this.maxStates];
        this.sum = new float[3];
        this.missProbs1 = new ArrayList<float[]>();
        this.missProbs2 = new ArrayList<float[]>();
        this.bwdHet1 = new ArrayList<float[]>();
        this.bwdHet2 = new ArrayList<float[]>();
    }

    public int nTargSamples() {
        return this.phaseData.fpd().targGT().nSamples();
    }

    public void phase(int n) {
        SamplePhase samplePhase = this.estPhase.get(n);
        this.swapHaps = false;
        this.swapCnt = 0;
        int n2 = samplePhase.unphased().size();
        int n3 = samplePhase.missing().size();
        if (n3 > 0 || n2 > 0) {
            this.lrList.clear();
            this.ensureCapacity(n2, n3);
            MarkerCluster markerCluster = new MarkerCluster(this.phaseData, n);
            this.missIndex = markerCluster.nMissingGTClusters();
            this.nStates = this.states.ibsStates(n, markerCluster, this.refAlleles, this.mismatches);
            this.bwdAlg(markerCluster);
            this.fwdAlg(samplePhase, markerCluster);
            this.updatePhase(n, samplePhase);
            this.estPhase.set(n, samplePhase);
        }
        nUnphHets.addAndGet(n2);
        nSwaps.addAndGet(this.swapCnt);
    }

    private void ensureCapacity(int n, int n2) {
        int n3;
        if (this.refAlleles.size() < n2) {
            for (n3 = this.refAlleles.size(); n3 < n2; ++n3) {
                this.refAlleles.add(new int[this.maxStates]);
                this.missProbs1.add(new float[this.maxStates]);
                this.missProbs2.add(new float[this.maxStates]);
            }
        }
        if (this.bwdHet1.size() < n) {
            for (n3 = this.bwdHet1.size(); n3 < n; ++n3) {
                this.bwdHet1.add(new float[this.maxStates]);
                this.bwdHet2.add(new float[this.maxStates]);
            }
        }
    }

    private void fwdAlg(SamplePhase samplePhase, MarkerCluster markerCluster) {
        IntArray intArray = markerCluster.unphClusters();
        Arrays.fill(this.fwd[0], 0, this.nStates, 1.0f / (float)this.nStates);
        this.sum[0] = 1.0f;
        int n = 0;
        int n2 = intArray.size();
        for (int i = 0; i < n2; ++i) {
            int n3 = intArray.get(i);
            this.fwdAlg(samplePhase, markerCluster, n, n3);
            this.phaseHet(i);
            n = n3;
        }
        if (n < markerCluster.nClusters()) {
            this.fwdAlg(samplePhase, markerCluster, n, markerCluster.nClusters());
        }
    }

    private void fwdAlg(SamplePhase samplePhase, MarkerCluster markerCluster, int n, int n2) {
        if (this.swapHaps) {
            this.swapHaps(samplePhase, markerCluster, n, n2);
        }
        System.arraycopy(this.fwd[0], 0, this.fwd[1], 0, this.nStates);
        System.arraycopy(this.fwd[0], 0, this.fwd[2], 0, this.nStates);
        this.sum[1] = this.sum[2] = this.sum[0];
        FloatArray floatArray = markerCluster.pRecomb();
        for (int i = n; i < n2; ++i) {
            float f = floatArray.get(i);
            this.emProbs[1] = (float)(markerCluster.clusterEnd(i) - markerCluster.clusterStart(i)) * this.pMismatch;
            this.emProbs[0] = 1.0f - this.emProbs[1];
            this.sum[0] = HmmUpdater.fwdUpdate(this.fwd[0], this.sum[0], f, this.emProbs, this.mismatches[0][i], this.nStates);
            this.sum[1] = HmmUpdater.fwdUpdate(this.fwd[1], this.sum[1], f, this.emProbs, this.mismatches[1][i], this.nStates);
            this.sum[2] = HmmUpdater.fwdUpdate(this.fwd[2], this.sum[2], f, this.emProbs, this.mismatches[2][i], this.nStates);
            if (!markerCluster.clustHasMissingGT(i)) continue;
            this.imputeAlleles(samplePhase, markerCluster, i);
        }
    }

    private void swapHaps(SamplePhase samplePhase, MarkerCluster markerCluster, int n, int n2) {
        for (int i = n; i < n2; ++i) {
            byte[] byArray = this.mismatches[1][i];
            this.mismatches[1][i] = this.mismatches[2][i];
            this.mismatches[2][i] = byArray;
        }
        samplePhase.swapHaps(markerCluster.clusterStart(n), markerCluster.clusterEnd(n2 - 1));
    }

    private void imputeAlleles(SamplePhase samplePhase, MarkerCluster markerCluster, int n) {
        if (markerCluster.clustHasMissingGT(n)) {
            int n2;
            float[] fArray;
            if (this.swapHaps) {
                fArray = this.missProbs1.get(this.missIndex);
                this.missProbs1.set(this.missIndex, this.missProbs2.get(this.missIndex));
                this.missProbs2.set(this.missIndex, fArray);
            }
            fArray = this.missProbs1.get(this.missIndex);
            float[] fArray2 = this.missProbs2.get(this.missIndex);
            int[] nArray = this.refAlleles.get(this.missIndex);
            for (n2 = 0; n2 < this.nStates; ++n2) {
                int n3 = n2;
                fArray[n3] = fArray[n3] * this.fwd[1][n2];
                int n4 = n2;
                fArray2[n4] = fArray2[n4] * this.fwd[2][n2];
            }
            assert (markerCluster.clusterEnd(n) - markerCluster.clusterStart(n) == 1);
            n2 = markerCluster.clusterStart(n);
            this.imputeAlleles(samplePhase, n2, fArray, fArray2, nArray);
            ++this.missIndex;
        }
    }

    private void imputeAlleles(SamplePhase samplePhase, int n, float[] fArray, float[] fArray2, int[] nArray) {
        int n2;
        int n3 = this.markers.marker(n).nAlleles();
        float[] fArray3 = new float[n3];
        float[] fArray4 = new float[n3];
        for (n2 = 0; n2 < this.nStates; ++n2) {
            int n4 = nArray[n2];
            fArray3[n4] = fArray3[n4] + fArray[n2];
            int n5 = nArray[n2];
            fArray4[n5] = fArray4[n5] + fArray2[n2];
        }
        n2 = 0;
        int n6 = 0;
        for (int i = 1; i < n3; ++i) {
            if (fArray3[i] > fArray3[n2]) {
                n2 = i;
            }
            if (!(fArray4[i] > fArray4[n6])) continue;
            n6 = i;
        }
        samplePhase.setAllele1(n, n2);
        samplePhase.setAllele2(n, n6);
    }

    private void bwdAlg(MarkerCluster markerCluster) {
        long l = System.nanoTime();
        IntArray intArray = markerCluster.unphClusters();
        Arrays.fill(this.bwd[0], 0, this.nStates, 1.0f / (float)this.nStates);
        int n = markerCluster.nClusters() - 1;
        if (markerCluster.clustHasMissingGT(n)) {
            --this.missIndex;
            System.arraycopy(this.bwd[0], 0, this.missProbs1.get(this.missIndex), 0, this.nStates);
            System.arraycopy(this.bwd[0], 0, this.missProbs2.get(this.missIndex), 0, this.nStates);
        }
        for (int i = intArray.size() - 1; i >= 0; --i) {
            int n2 = intArray.get(i) - 1;
            assert (n2 >= 0);
            this.bwdAlg(markerCluster, n2, n);
            System.arraycopy(this.bwd[1], 0, this.bwdHet1.get(i), 0, this.nStates);
            System.arraycopy(this.bwd[2], 0, this.bwdHet2.get(i), 0, this.nStates);
            n = n2;
        }
        this.bwdAlg(markerCluster, 0, n);
    }

    private void bwdAlg(MarkerCluster markerCluster, int n, int n2) {
        FloatArray floatArray = markerCluster.pRecomb();
        System.arraycopy(this.bwd[0], 0, this.bwd[1], 0, this.nStates);
        System.arraycopy(this.bwd[0], 0, this.bwd[2], 0, this.nStates);
        for (int i = n2 - 1; i >= n; --i) {
            int n3 = i + 1;
            float f = floatArray.get(n3);
            this.emProbs[1] = (float)(markerCluster.clusterEnd(i) - markerCluster.clusterStart(i)) * this.pMismatch;
            this.emProbs[0] = 1.0f - this.emProbs[1];
            HmmUpdater.bwdUpdate(this.bwd[0], f, this.emProbs, this.mismatches[0][n3], this.nStates);
            HmmUpdater.bwdUpdate(this.bwd[1], f, this.emProbs, this.mismatches[1][n3], this.nStates);
            HmmUpdater.bwdUpdate(this.bwd[2], f, this.emProbs, this.mismatches[2][n3], this.nStates);
            if (!markerCluster.clustHasMissingGT(i)) continue;
            --this.missIndex;
            System.arraycopy(this.bwd[1], 0, this.missProbs1.get(this.missIndex), 0, this.nStates);
            System.arraycopy(this.bwd[2], 0, this.missProbs2.get(this.missIndex), 0, this.nStates);
        }
    }

    private void phaseHet(int n) {
        int n2;
        float[] fArray = this.bwdHet1.get(n);
        float[] fArray2 = this.bwdHet2.get(n);
        float f = 0.0f;
        float f2 = 0.0f;
        float f3 = 0.0f;
        float f4 = 0.0f;
        for (n2 = 0; n2 < this.nStates; ++n2) {
            f += this.fwd[1][n2] * fArray[n2];
            f2 += this.fwd[1][n2] * fArray2[n2];
            f3 += this.fwd[2][n2] * fArray[n2];
            f4 += this.fwd[2][n2] * fArray2[n2];
        }
        n2 = this.swapHaps ? 1 : 0;
        float f5 = f * f4;
        float f6 = f2 * f3;
        boolean bl = this.swapHaps = f5 < f6;
        if (this.swapHaps != n2) {
            ++this.swapCnt;
        }
        this.lrList.add(this.swapHaps ? f6 / f5 : f5 / f6);
    }

    private void updatePhase(int n, SamplePhase samplePhase) {
        IntArray intArray = samplePhase.unphased();
        if (intArray.size() > 0 && !this.burnin) {
            float f = this.phaseData.leaveUnphasedProp(n);
            IntList intList = new IntList();
            float f2 = PhaseBaum1.threshold(this.lrList, f);
            int n2 = intArray.size();
            for (int i = 0; i < n2; ++i) {
                if (!(this.lrList.get(i) < f2)) continue;
                intList.add(intArray.get(i));
            }
            samplePhase.setUnphased(IntArray.create(intList, this.nMarkers));
        }
    }

    private static float threshold(FloatList floatList, float f) {
        float[] fArray = floatList.toArray();
        Arrays.sort(fArray);
        int n = (int)Math.floor(f * (float)fArray.length + 0.5f);
        return fArray[n < fArray.length ? n : fArray.length - 1];
    }
}

