/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import org.apache.sysds.runtime.DMLScriptException;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.colgroup.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.ColGroup;
import org.apache.sysds.runtime.compress.colgroup.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.QDictionary;
import org.apache.sysds.runtime.compress.utils.ABitmap;
import org.apache.sysds.runtime.compress.utils.Bitmap;
import org.apache.sysds.runtime.compress.utils.BitmapLossy;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.KahanFunction;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.KahanPlusSq;
import org.apache.sysds.runtime.functionobjects.Mean;
import org.apache.sysds.runtime.functionobjects.ReduceAll;
import org.apache.sysds.runtime.functionobjects.ReduceCol;
import org.apache.sysds.runtime.functionobjects.ReduceRow;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.Pair;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;

public abstract class ColGroupValue
extends ColGroup {
    private static final long serialVersionUID = 3786247536054353658L;
    private static ThreadLocal<Pair<int[], double[]>> memPool = new ThreadLocal<Pair<int[], double[]>>(){

        @Override
        protected Pair<int[], double[]> initialValue() {
            return new Pair<int[], double[]>();
        }
    };
    protected ADictionary _dict;

    protected ColGroupValue() {
    }

    protected ColGroupValue(int[] colIndices, int numRows, ABitmap ubm, CompressionSettings cs) {
        super(colIndices, numRows);
        this._lossy = false;
        this._zeros = ubm.containsZero();
        if (cs.sortValuesByLength && numRows > 65536) {
            ubm.sortValuesByFrequency();
        }
        switch (ubm.getType()) {
            case Full: {
                this._dict = new Dictionary(((Bitmap)ubm).getValues());
                break;
            }
            case Lossy: {
                this._dict = new QDictionary((BitmapLossy)ubm);
                this._lossy = true;
            }
        }
    }

    protected ColGroupValue(int[] colIndices, int numRows, ADictionary dict) {
        super(colIndices, numRows);
        this._dict = dict;
    }

    public int getNumValues() {
        return this._dict.getNumberOfValues(this._colIndexes.length);
    }

    @Override
    public double[] getValues() {
        return this._dict.getValues();
    }

    public byte[] getByteValues() {
        return ((QDictionary)this._dict).getValuesByte();
    }

    @Override
    public MatrixBlock getValuesAsBlock() {
        double[] values = this.getValues();
        int vlen = values.length;
        int rlen = this._zeros ? vlen + 1 : vlen;
        MatrixBlock ret = new MatrixBlock(rlen, 1, false);
        for (int i = 0; i < vlen; ++i) {
            ret.quickSetValue(i, 0, values[i]);
        }
        return ret;
    }

    public final int[] getCounts() {
        int[] tmp = this._zeros ? ColGroupValue.allocIVector(this.getNumValues() + 1, true) : ColGroupValue.allocIVector(this.getNumValues(), true);
        return this.getCounts(tmp);
    }

    public final int[] getCounts(int rl, int ru) {
        int[] tmp = this._zeros ? ColGroupValue.allocIVector(this.getNumValues() + 1, true) : ColGroupValue.allocIVector(this.getNumValues(), true);
        return this.getCounts(rl, ru, tmp);
    }

    @Override
    public boolean getIfCountsType() {
        return true;
    }

    protected int containsAllZeroValue() {
        return this._dict.hasZeroTuple(this._colIndexes.length);
    }

    protected final double sumValues(int valIx, double[] b, double[] dictVals) {
        int numCols = this.getNumCols();
        int valOff = valIx * numCols;
        double val = 0.0;
        for (int i = 0; i < numCols; ++i) {
            val += dictVals[valOff + i] * b[this._colIndexes[i]];
        }
        return val;
    }

    protected final double[] preaggValues(int numVals, double[] b, double[] dictVals) {
        return this.preaggValues(numVals, b, false, dictVals);
    }

    protected final double[] preaggValues(int numVals, double[] b, boolean allocNew, double[] dictVals) {
        double[] ret;
        double[] dArray = ret = allocNew ? new double[numVals + 1] : ColGroupValue.allocDVector(numVals + 1, false);
        if (this._colIndexes.length == 1) {
            for (int k = 0; k < numVals; ++k) {
                ret[k] = dictVals[k] * b[this._colIndexes[0]];
            }
        } else {
            for (int k = 0; k < numVals; ++k) {
                ret[k] = this.sumValues(k, b, dictVals);
            }
        }
        return ret;
    }

    protected void computeMxx(double[] c, Builtin builtin) {
        if (this._zeros) {
            c[0] = builtin.execute(c[0], 0.0);
        }
        c[0] = this._dict.aggregate(c[0], builtin);
    }

    protected void computeColMxx(double[] c, Builtin builtin) {
        if (this._zeros) {
            for (int x = 0; x < this._colIndexes.length; ++x) {
                c[this._colIndexes[x]] = builtin.execute(c[this._colIndexes[x]], 0.0);
            }
        }
        this._dict.aggregateCols(c, builtin, this._colIndexes);
    }

    protected ADictionary applyScalarOp(ScalarOperator op) {
        return this._dict.clone().apply(op);
    }

    protected ADictionary applyScalarOp(ScalarOperator op, double newVal, int numCols) {
        return this._dict.applyScalarOp(op, newVal, numCols);
    }

    @Override
    public void unaryAggregateOperations(AggregateUnaryOperator op, double[] c) {
        this.unaryAggregateOperations(op, c, 0, this._numRows);
    }

    @Override
    public void unaryAggregateOperations(AggregateUnaryOperator op, double[] c, int rl, int ru) {
        if (op.aggOp.increOp.fn instanceof KahanPlus || op.aggOp.increOp.fn instanceof KahanPlusSq || op.aggOp.increOp.fn instanceof Mean) {
            KahanPlus kplus = op.aggOp.increOp.fn instanceof KahanPlus || op.aggOp.increOp.fn instanceof Mean ? KahanPlus.getKahanPlusFnObject() : KahanPlusSq.getKahanPlusSqFnObject();
            boolean mean = op.aggOp.increOp.fn instanceof Mean;
            if (op.indexFn instanceof ReduceAll) {
                this.computeSum(c, kplus);
            } else if (op.indexFn instanceof ReduceCol) {
                this.computeRowSums(c, kplus, rl, ru, mean);
            } else if (op.indexFn instanceof ReduceRow) {
                this.computeColSums(c, kplus);
            }
        } else if (op.aggOp.increOp.fn instanceof Builtin && (((Builtin)op.aggOp.increOp.fn).getBuiltinCode() == Builtin.BuiltinCode.MAX || ((Builtin)op.aggOp.increOp.fn).getBuiltinCode() == Builtin.BuiltinCode.MIN)) {
            Builtin builtin = (Builtin)op.aggOp.increOp.fn;
            if (op.indexFn instanceof ReduceAll) {
                this.computeMxx(c, builtin);
            } else if (op.indexFn instanceof ReduceCol) {
                this.computeRowMxx(c, builtin, rl, ru);
            } else if (op.indexFn instanceof ReduceRow) {
                this.computeColMxx(c, builtin);
            }
        } else {
            throw new DMLScriptException("Unknown UnaryAggregate operator on CompressedMatrixBlock");
        }
    }

    protected void setandExecute(double[] c, KahanObject kbuff, KahanPlus kplus2, double val, int rix) {
        kbuff.set(c[rix], c[rix + 1]);
        kplus2.execute2(kbuff, val);
        c[rix] = kbuff._sum;
        c[rix + 1] = kbuff._correction;
    }

    public static void setupThreadLocalMemory(int len) {
        Pair<int[], double[]> p = new Pair<int[], double[]>();
        p.setKey(new int[len]);
        p.setValue(new double[len]);
        memPool.set(p);
    }

    public static void cleanupThreadLocalMemory() {
        memPool.remove();
    }

    protected static double[] allocDVector(int len, boolean reset) {
        Pair<int[], double[]> p = memPool.get();
        if (p.getValue() == null) {
            return new double[len];
        }
        double[] tmp = p.getValue();
        if (reset) {
            Arrays.fill(tmp, 0, len, 0.0);
        }
        return tmp;
    }

    protected static int[] allocIVector(int len, boolean reset) {
        Pair<int[], double[]> p = memPool.get();
        if (p.getKey() == null) {
            return new int[len];
        }
        int[] tmp = p.getKey();
        if (reset) {
            Arrays.fill(tmp, 0, len, 0);
        }
        return tmp;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append(String.format("\n%15s%5d ", "Columns:", this._colIndexes.length));
        sb.append(Arrays.toString(this._colIndexes));
        sb.append(String.format("\n%15s%5d ", "Values:", this._dict.getValues().length));
        sb.append(Arrays.toString(this._dict.getValues()));
        return sb.toString();
    }

    @Override
    public boolean isLossy() {
        return this._lossy;
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        this._numRows = in.readInt();
        int numCols = in.readInt();
        this._zeros = in.readBoolean();
        this._lossy = in.readBoolean();
        this._colIndexes = new int[numCols];
        for (int i = 0; i < numCols; ++i) {
            this._colIndexes[i] = in.readInt();
        }
        this._dict = ADictionary.read(in, this._lossy);
    }

    @Override
    public void write(DataOutput out) throws IOException {
        int numCols = this.getNumCols();
        out.writeInt(this._numRows);
        out.writeInt(numCols);
        out.writeBoolean(this._zeros);
        out.writeBoolean(this._lossy);
        for (int i = 0; i < this._colIndexes.length; ++i) {
            out.writeInt(this._colIndexes[i]);
        }
        this._dict.write(out);
    }

    @Override
    public long getExactSizeOnDisk() {
        long ret = 0L;
        ret += 4L;
        ret += 4L;
        ++ret;
        ++ret;
        ret += (long)(4 * this._colIndexes.length);
        return ret += this._dict.getExactSizeOnDisk();
    }

    public abstract int[] getCounts(int[] var1);

    public abstract int[] getCounts(int var1, int var2, int[] var3);

    protected abstract void computeSum(double[] var1, KahanFunction var2);

    protected abstract void computeRowSums(double[] var1, KahanFunction var2, int var3, int var4, boolean var5);

    protected abstract void computeColSums(double[] var1, KahanFunction var2);

    protected abstract void computeRowMxx(double[] var1, Builtin var2, int var3, int var4);
}

