/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.util;

import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.util.Randoms;
import java.text.NumberFormat;
import java.util.Arrays;

public class MVNormal {
    public static double[] cholesky(double[] input2, int numRows) {
        double[] result2 = new double[input2.length];
        double sumRowSquared = 0.0;
        double dotProduct = 0.0;
        int rowOffset = 0;
        int colOffset = 0;
        for (int row = 0; row < numRows; ++row) {
            sumRowSquared = 0.0;
            rowOffset = row * numRows;
            for (int col = 0; col < row; ++col) {
                dotProduct = 0.0;
                colOffset = col * numRows;
                for (int i = 0; i < col; ++i) {
                    dotProduct += result2[rowOffset + i] * result2[colOffset + i];
                }
                result2[rowOffset + col] = (input2[rowOffset + col] - dotProduct) / result2[colOffset + col];
                sumRowSquared += result2[rowOffset + col] * result2[rowOffset + col];
            }
            result2[rowOffset + row] = Math.sqrt(input2[rowOffset + row] - sumRowSquared);
        }
        return result2;
    }

    public static double[] bandCholesky(double[] input2, int numRows) {
        double[] result2 = new double[input2.length];
        double sumRowSquared = 0.0;
        double dotProduct = 0.0;
        int rowOffset = 0;
        int colOffset = 0;
        for (int row = 0; row < numRows; ++row) {
            sumRowSquared = 0.0;
            rowOffset = row * numRows;
            int firstNonZero = row;
            for (int col = 0; col < row; ++col) {
                if (firstNonZero == row) {
                    if (input2[rowOffset + col] == 0.0) continue;
                    firstNonZero = col;
                }
                dotProduct = 0.0;
                colOffset = col * numRows;
                for (int i = firstNonZero; i < col; ++i) {
                    dotProduct += result2[rowOffset + i] * result2[colOffset + i];
                }
                result2[rowOffset + col] = (input2[rowOffset + col] - dotProduct) / result2[colOffset + col];
                sumRowSquared += result2[rowOffset + col] * result2[rowOffset + col];
            }
            result2[rowOffset + row] = Math.sqrt(input2[rowOffset + row] - sumRowSquared);
        }
        return result2;
    }

    public static double[] bandMatrixRoot(int dim, int bandwidth) {
        double[] result2 = new double[dim * dim];
        for (int row = 0; row < dim; ++row) {
            int rowOffset = row * dim;
            for (int col = Math.max(0, row - bandwidth + 1); col <= row; ++col) {
                result2[rowOffset + col] = 1.0;
            }
        }
        return result2;
    }

    public static double[] nextMVNormal(double[] mean, double[] precision, Randoms random) {
        return MVNormal.nextMVNormalWithCholesky(mean, MVNormal.cholesky(precision, mean.length), random);
    }

    public static double[] nextMVNormalWithCholesky(double[] mean, double[] precisionLowerTriangular, Randoms random) {
        int i;
        int n = mean.length;
        double[] result2 = new double[n];
        for (int i2 = 0; i2 < n; ++i2) {
            result2[i2] = random.nextGaussian();
        }
        for (i = n - 1; i >= 0; --i) {
            double innerProduct = 0.0;
            for (int j = i + 1; j < n; ++j) {
                innerProduct += result2[j] * precisionLowerTriangular[n * j + i];
            }
            result2[i] = (result2[i] - innerProduct) / precisionLowerTriangular[n * i + i];
        }
        for (i = 0; i < n; ++i) {
            int n2 = i;
            result2[n2] = result2[n2] + mean[i];
        }
        return result2;
    }

    public static double[] nextZeroSumMVNormalWithCholesky(double[] mean, double[] precisionLowerTriangular, Randoms random) {
        int n = mean.length;
        double[] result2 = MVNormal.nextMVNormalWithCholesky(mean, precisionLowerTriangular, random);
        double sum2 = 0.0;
        for (int i = 0; i < n; ++i) {
            sum2 += result2[i];
        }
        double[] ones = new double[n];
        Arrays.fill(ones, 1.0);
        double[] firstSolution = MVNormal.solveWithForwardSubstitution(ones, precisionLowerTriangular);
        double[] rowSums = MVNormal.solveWithBackSubstitution(firstSolution, precisionLowerTriangular);
        double sumOfRowSums = 0.0;
        for (int i = 0; i < n; ++i) {
            sumOfRowSums += rowSums[i];
        }
        double inverseSumOfRowSums = 1.0 / sumOfRowSums;
        for (int i = 0; i < n; ++i) {
            int n2 = i;
            result2[n2] = result2[n2] - inverseSumOfRowSums * rowSums[i] * sum2;
        }
        return result2;
    }

    public static double[][] nextMVNormal(int n, double[] mean, double[] precision, Randoms random) {
        double[][] result2 = new double[n][];
        for (int i = 0; i < n; ++i) {
            result2[i] = MVNormal.nextMVNormal(mean, precision, random);
        }
        return result2;
    }

    public static FeatureVector nextFeatureVector(Alphabet alphabet, double[] mean, double[] precision, Randoms random) {
        return new FeatureVector(alphabet, MVNormal.nextMVNormal(mean, precision, random));
    }

    public static double[] nextMVNormalPosterior(double[] priorMean, double[] priorPrecisionDiagonal, double[] precision, double[] observedMean, int observations, Randoms random) {
        int dimension = priorMean.length;
        double[] linearCombination = new double[dimension];
        int i = 0;
        while (i < dimension) {
            linearCombination[i] = priorMean[i] * priorPrecisionDiagonal[i];
            double innerProduct = 0.0;
            for (int j = 0; j < dimension; ++j) {
                innerProduct += precision[dimension * i + j] * observedMean[j];
            }
            int n = i++;
            linearCombination[n] = linearCombination[n] + (double)observations * innerProduct;
        }
        double[] posteriorPrecision = new double[precision.length];
        for (int row = 0; row < dimension; ++row) {
            for (int col = 0; col < dimension; ++col) {
                posteriorPrecision[dimension * row + col] = (double)observations * precision[dimension * row + col];
                if (row != col) continue;
                int n = dimension * row + col;
                posteriorPrecision[n] = posteriorPrecision[n] + priorPrecisionDiagonal[row];
            }
        }
        double[] inversePosteriorPrecision = MVNormal.invertSPD(posteriorPrecision, dimension);
        double[] posteriorMean = new double[dimension];
        for (int row = 0; row < dimension; ++row) {
            double innerProduct = 0.0;
            for (int col = 0; col < dimension; ++col) {
                innerProduct += inversePosteriorPrecision[dimension * row + col] * linearCombination[col];
            }
            posteriorMean[row] = innerProduct;
        }
        return MVNormal.nextMVNormal(posteriorMean, posteriorPrecision, random);
    }

    public static double[] solveWithBackSubstitution(double[] b, double[] lowerTriangular) {
        int n = b.length;
        double[] result2 = new double[n];
        for (int i = n - 1; i >= 0; --i) {
            double innerProduct = 0.0;
            for (int j = i + 1; j < n; ++j) {
                innerProduct += result2[j] * lowerTriangular[n * j + i];
            }
            result2[i] = (b[i] - innerProduct) / lowerTriangular[n * i + i];
        }
        return result2;
    }

    public static double[] solveWithForwardSubstitution(double[] b, double[] lowerTriangular) {
        int n = b.length;
        double[] result2 = new double[n];
        for (int i = 0; i < n; ++i) {
            double innerProduct = 0.0;
            for (int j = 0; j < i; ++j) {
                innerProduct += result2[j] * lowerTriangular[n * i + j];
            }
            result2[i] = (b[i] - innerProduct) / lowerTriangular[n * i + i];
        }
        return result2;
    }

    public static double[] invertLowerTriangular(double[] inputMatrix, int dimension) {
        double[] outputMatrix = new double[inputMatrix.length];
        for (int row = 0; row < dimension; ++row) {
            for (int col = 0; col <= row; ++col) {
                double x2 = col == row ? 1.0 : 0.0;
                for (int i = col; i < row; ++i) {
                    x2 -= inputMatrix[dimension * row + i] * outputMatrix[dimension * i + col];
                }
                outputMatrix[dimension * row + col] = x2 / inputMatrix[dimension * row + row];
            }
        }
        return outputMatrix;
    }

    public static double[] lowerTriangularCrossproduct(double[] inputMatrix, int dimension) {
        double[] outputMatrix = new double[inputMatrix.length];
        for (int row = 0; row < dimension; ++row) {
            for (int col = row; col < dimension; ++col) {
                double innerProduct = 0.0;
                for (int i = col; i < dimension; ++i) {
                    innerProduct += inputMatrix[row + dimension * i] * inputMatrix[col + dimension * i];
                }
                outputMatrix[dimension * row + col] = innerProduct;
                outputMatrix[row + dimension * col] = innerProduct;
            }
        }
        return outputMatrix;
    }

    public static double[] lowerTriangularProduct(double[] leftMatrix, double[] rightMatrix, int dimension) {
        double[] outputMatrix = new double[leftMatrix.length];
        for (int row = 0; row < dimension; ++row) {
            for (int col = 0; col <= row; ++col) {
                double innerProduct = 0.0;
                for (int i = col; i <= row; ++i) {
                    innerProduct += leftMatrix[dimension * row + i] * rightMatrix[dimension * i + col];
                }
                outputMatrix[dimension * row + col] = innerProduct;
            }
        }
        return outputMatrix;
    }

    public static double[] invertSPD(double[] inputMatrix, int dimension) {
        return MVNormal.lowerTriangularCrossproduct(MVNormal.invertLowerTriangular(MVNormal.bandCholesky(inputMatrix, dimension), dimension), dimension);
    }

    public static double[] nextWishart(double[] sqrtScaleMatrix, int dimension, int degreesOfFreedom, Randoms random) {
        double[] sample = new double[sqrtScaleMatrix.length];
        for (int row = 0; row < dimension; ++row) {
            for (int col = 0; col < row; ++col) {
                sample[row * dimension + col] = random.nextGaussian(0.0, 1.0);
            }
            sample[row * dimension + row] = Math.sqrt(random.nextChiSq(degreesOfFreedom));
        }
        System.out.println(MVNormal.diagonalToString(sample, dimension));
        System.out.println(MVNormal.diagonalToString(sqrtScaleMatrix, dimension));
        System.out.println(MVNormal.diagonalToString(MVNormal.lowerTriangularProduct(sample, sqrtScaleMatrix, dimension), dimension));
        return MVNormal.lowerTriangularCrossproduct(MVNormal.lowerTriangularProduct(sample, sqrtScaleMatrix, dimension), dimension);
    }

    public static double[] nextWishartPosterior(double[] scatterMatrix, int observations, double[] priorPrecisionDiagonal, int priorDF, int dimension, Randoms random) {
        double[] scatterPlusPrior = new double[scatterMatrix.length];
        System.arraycopy(scatterMatrix, 0, scatterPlusPrior, 0, scatterMatrix.length);
        for (int i = 0; i < dimension; ++i) {
            int n = dimension * i + i;
            scatterPlusPrior[n] = scatterPlusPrior[n] + 1.0 / priorPrecisionDiagonal[i];
        }
        System.out.println(" inverted scatter plus prior");
        System.out.println(MVNormal.diagonalToString(MVNormal.invertSPD(scatterPlusPrior, dimension), dimension));
        System.out.println(" chol inverted scatter plus prior");
        System.out.println(MVNormal.diagonalToString(MVNormal.cholesky(MVNormal.invertSPD(scatterPlusPrior, dimension), dimension), dimension));
        double[] sqrtScaleMatrix = MVNormal.cholesky(MVNormal.invertSPD(scatterPlusPrior, dimension), dimension);
        return MVNormal.nextWishart(sqrtScaleMatrix, dimension, observations + priorDF, random);
    }

    public static String doubleArrayToString(double[] matrix, int dimension) {
        NumberFormat formatter = NumberFormat.getInstance();
        formatter.setMaximumFractionDigits(10);
        StringBuffer output2 = new StringBuffer();
        for (int row = 0; row < dimension; ++row) {
            for (int col = 0; col < dimension; ++col) {
                output2.append(formatter.format(matrix[dimension * row + col]));
                output2.append("\t");
            }
            output2.append("\n");
        }
        return output2.toString();
    }

    public static String diagonalToString(double[] matrix, int dimension) {
        NumberFormat formatter = NumberFormat.getInstance();
        formatter.setMaximumFractionDigits(4);
        StringBuffer output2 = new StringBuffer();
        for (int row = 0; row < dimension; ++row) {
            output2.append(formatter.format(matrix[dimension * row + row]));
            output2.append(" ");
        }
        return output2.toString();
    }

    public static double[] getScatterMatrix(double[][] observationMatrix) {
        int i;
        int observations = observationMatrix.length;
        int dimension = observationMatrix[0].length;
        double[] outputMatrix = new double[dimension * dimension];
        double[] means = new double[dimension];
        for (i = 0; i < observations; ++i) {
            for (int d = 0; d < dimension; ++d) {
                int n = d;
                means[n] = means[n] + observationMatrix[i][d];
            }
        }
        int d = 0;
        while (d < dimension) {
            int n = d++;
            means[n] = means[n] / (double)observations;
        }
        for (i = 0; i < observations; ++i) {
            for (int d1 = 0; d1 < dimension; ++d1) {
                for (int d2 = 0; d2 < dimension; ++d2) {
                    int n = dimension * d1 + d2;
                    outputMatrix[n] = outputMatrix[n] + (observationMatrix[i][d1] - means[d1]) * (observationMatrix[i][d2] - means[d2]);
                }
            }
        }
        return outputMatrix;
    }

    public static void testCholesky() {
        int observations = 1000;
        double[] mean = new double[20];
        double[] precisionMatrix = new double[400];
        for (int i = 0; i < 20; ++i) {
            precisionMatrix[20 * i + i] = 1.0;
        }
        Randoms random = new Randoms();
        double[] scatterMatrix = MVNormal.getScatterMatrix(MVNormal.nextMVNormal(observations, mean, precisionMatrix, random));
        double[] priorPrecision = new double[20];
        Arrays.fill(priorPrecision, 1.0);
        MVNormal.nextWishartPosterior(scatterMatrix, observations, priorPrecision, 21, 20, random);
    }

    public static void main(String[] args) {
        int i;
        double[] sample;
        int iter2;
        double[] spd = new double[]{3.0, 0.0, -1.0, 0.0, 3.0, 0.0, -1.0, 0.0, 3.0};
        Randoms random = new Randoms();
        double[] mean = new double[]{1.0, 1.0, 1.0};
        double[] lower = MVNormal.cholesky(spd, 3);
        for (iter2 = 0; iter2 < 10; ++iter2) {
            sample = MVNormal.nextMVNormalWithCholesky(mean, lower, random);
            for (i = 0; i < sample.length; ++i) {
                System.out.print(sample[i] + "\t");
            }
            System.out.println();
        }
        for (iter2 = 0; iter2 < 10; ++iter2) {
            sample = MVNormal.nextZeroSumMVNormalWithCholesky(mean, lower, random);
            for (i = 0; i < sample.length; ++i) {
                System.out.print(sample[i] + "\t");
            }
            System.out.println();
        }
    }
}

