package org.ujmp.core.doublematrix.calculation.general.missingvalues;

import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;
import org.ujmp.core.doublematrix.DenseDoubleMatrix2D;
import org.ujmp.core.doublematrix.calculation.AbstractDoubleCalculation;
import org.ujmp.core.doublematrix.calculation.general.missingvalues.Impute;
import org.ujmp.core.util.MathUtil;
import org.ujmp.core.util.UJMPSettings;

/* loaded from: classes2.dex */
public class ImputeEM extends AbstractDoubleCalculation {
    private static final long serialVersionUID = -1272010036598212696L;
    private Matrix bestGuess;
    private final double decay;
    private double delta;
    private Matrix imputed;
    private File tempFile;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: classes2.dex */
    public class PredictColumn implements Callable<Long> {
        long column;

        public PredictColumn(long j) {
            this.column = 0L;
            this.column = j;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Long call() throws Exception {
            Matrix replaceInColumn = ImputeEM.replaceInColumn(ImputeEM.this.getSource(), ImputeEM.this.bestGuess, this.column);
            synchronized (ImputeEM.this.imputed) {
                int i = 0;
                while (true) {
                    long j = i;
                    if (j < replaceInColumn.getRowCount()) {
                        ImputeEM.this.imputed.setAsDouble(replaceInColumn.getAsDouble(j, 0), j, this.column);
                        i++;
                    }
                }
            }
            return Long.valueOf(this.column);
        }
    }

    public ImputeEM(Matrix matrix) throws IOException {
        this(matrix, null);
    }

    public ImputeEM(Matrix matrix, Matrix matrix2) throws IOException {
        this(matrix, matrix2, 1.0E-6d, File.createTempFile("ujmp-impute-em-" + System.currentTimeMillis(), ".csv"));
    }

    public ImputeEM(Matrix matrix, Matrix matrix2, double d, File file) {
        super(matrix);
        this.bestGuess = null;
        this.imputed = null;
        this.delta = 1.0E-6d;
        this.decay = 0.66d;
        this.bestGuess = matrix2;
        this.delta = d;
        this.tempFile = file;
    }

    private boolean containsMissingValues(long j) {
        int i = 0;
        while (true) {
            long j2 = i;
            if (j2 >= getSource().getRowCount()) {
                return false;
            }
            if (MathUtil.isNaNOrInfinite(getSource().getAsDouble(j2, j))) {
                return true;
            }
            i++;
        }
    }

    private void createMatrix() {
        try {
            ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(UJMPSettings.getInstance().getNumberOfThreads());
            Matrix source = getSource();
            double valueCount = source.getValueCount();
            long euklideanValue = (long) source.countMissing(Calculation.Ret.NEW, Integer.MAX_VALUE).getEuklideanValue();
            double d = euklideanValue;
            PrintStream printStream = System.out;
            printStream.println("missing values: " + euklideanValue + " (" + (((int) Math.round((d * 1000.0d) / valueCount)) / 10.0d) + "%)");
            System.out.println("============================================");
            if (this.bestGuess == null) {
                this.bestGuess = getSource().impute(Calculation.Ret.NEW, Impute.ImputationMethod.RowMean, new Object[0]);
            }
            int i = 0;
            while (true) {
                PrintStream printStream2 = System.out;
                StringBuilder sb = new StringBuilder();
                sb.append("Iteration ");
                int i2 = i + 1;
                sb.append(i);
                printStream2.println(sb.toString());
                ArrayList arrayList = new ArrayList();
                this.imputed = Matrix.Factory.zeros(source.getSize());
                long currentTimeMillis = System.currentTimeMillis();
                for (long j = 0; j < source.getColumnCount(); j++) {
                    if (containsMissingValues(j)) {
                        arrayList.add(newFixedThreadPool.submit(new PredictColumn(j)));
                    }
                }
                Iterator it = arrayList.iterator();
                while (it.hasNext()) {
                    Long l = (Long) ((Future) it.next()).get();
                    PrintStream printStream3 = System.out;
                    printStream3.println((((l.longValue() * 1000) / source.getColumnCount()) / 10.0d) + "% completed (" + ((long) (((source.getColumnCount() - l.longValue()) / ((l.longValue() + 1) / (System.currentTimeMillis() - currentTimeMillis))) / 1000.0d)) + " seconds remaining)");
                    currentTimeMillis = currentTimeMillis;
                }
                Matrix plus = this.bestGuess.times(0.66d).plus(this.imputed.times(0.33999999999999997d));
                int i3 = 0;
                while (true) {
                    long j2 = i3;
                    if (j2 >= getSource().getRowCount()) {
                        break;
                    }
                    int i4 = 0;
                    while (true) {
                        long j3 = i4;
                        if (j3 < getSource().getColumnCount()) {
                            Matrix matrix = source;
                            int i5 = i2;
                            double asDouble = getSource().getAsDouble(j2, j3);
                            if (!MathUtil.isNaNOrInfinite(asDouble)) {
                                plus.setAsDouble(asDouble, j2, j3);
                            }
                            i4++;
                            source = matrix;
                            i2 = i5;
                        }
                    }
                    i3++;
                }
                Matrix matrix2 = source;
                int i6 = i2;
                double euklideanDistanceTo = plus.euklideanDistanceTo(this.bestGuess, true) / d;
                System.out.println("delta: " + euklideanDistanceTo);
                System.out.println("============================================");
                this.bestGuess = plus;
                this.bestGuess.exportTo().file(this.tempFile).asDenseCSV();
                if (this.delta >= euklideanDistanceTo) {
                    break;
                }
                source = matrix2;
                i = i6;
            }
            newFixedThreadPool.shutdown();
            this.imputed = this.bestGuess;
            if (this.imputed.containsMissingValues()) {
                throw new RuntimeException("Matrix has still missing values after imputation");
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Matrix replaceInColumn(Matrix matrix, Matrix matrix2, long j) {
        Matrix deleteColumns = matrix2.deleteColumns(Calculation.Ret.NEW, j);
        Matrix selectColumns = matrix.selectColumns(Calculation.Ret.NEW, j);
        ArrayList arrayList = new ArrayList();
        long rowCount = selectColumns.getRowCount();
        while (true) {
            rowCount--;
            if (rowCount < 0) {
                break;
            }
            if (MathUtil.isNaNOrInfinite(selectColumns.getAsDouble(rowCount, 0))) {
                arrayList.add(Long.valueOf(rowCount));
            }
        }
        if (arrayList.isEmpty()) {
            return selectColumns;
        }
        Matrix deleteRows = deleteColumns.deleteRows(Calculation.Ret.NEW, arrayList);
        Matrix mtimes = Matrix.Factory.horCat(deleteColumns, (DenseDoubleMatrix2D) DenseDoubleMatrix2D.Factory.ones(deleteColumns.getRowCount(), 1L)).mtimes(Matrix.Factory.horCat(deleteRows, (DenseDoubleMatrix2D) DenseDoubleMatrix2D.Factory.ones(deleteRows.getRowCount(), 1L)).pinv().mtimes(selectColumns.deleteRows(Calculation.Ret.NEW, arrayList)));
        int i = 0;
        while (true) {
            long j2 = i;
            if (j2 >= selectColumns.getRowCount()) {
                return mtimes;
            }
            double asDouble = selectColumns.getAsDouble(j2, 0);
            if (!Double.isNaN(asDouble)) {
                mtimes.setAsDouble(asDouble, j2, 0);
            }
            i++;
        }
    }

    @Override // org.ujmp.core.doublematrix.calculation.DoubleCalculation
    public double getDouble(long... jArr) {
        if (this.imputed == null) {
            createMatrix();
        }
        double asDouble = getSource().getAsDouble(jArr);
        return MathUtil.isNaNOrInfinite(asDouble) ? this.imputed.getAsDouble(jArr) : asDouble;
    }
}
