/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import java.util.concurrent.Future;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.BinaryFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;

public class AggregateBinaryFEDInstruction
extends BinaryFEDInstruction {
    public AggregateBinaryFEDInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
        super(FEDInstruction.FEDType.AggregateBinary, op, in1, in2, out, opcode, istr);
    }

    public static AggregateBinaryFEDInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (!opcode.equalsIgnoreCase("ba+*")) {
            throw new DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown opcode " + opcode);
        }
        InstructionUtils.checkNumFields(parts, 4);
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand out = new CPOperand(parts[3]);
        int k = Integer.parseInt(parts[4]);
        return new AggregateBinaryFEDInstruction(InstructionUtils.getMatMultOperator(k), in1, in2, out, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        MatrixObject mo1 = ec.getMatrixObject(this.input1);
        MatrixObject mo2 = ec.getMatrixObject(this.input2);
        if (mo1.isFederated(FederationMap.FType.COL) && mo2.isFederated(FederationMap.FType.ROW) && mo1.getFedMapping().isAligned(mo2.getFedMapping(), true)) {
            FederatedRequest fr1 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{mo1.getFedMapping().getID(), mo2.getFedMapping().getID()});
            FederatedRequest fr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
            FederatedRequest fr3 = mo2.getFedMapping().cleanup(this.getTID(), fr1.getID(), fr2.getID());
            Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(this.getTID(), fr1, fr2, fr3);
            MatrixBlock ret = FederationUtils.aggAdd(tmp);
            ec.setMatrixOutput(this.output.getName(), ret);
        } else if (mo1.isFederated(FederationMap.FType.ROW)) {
            FederatedRequest fr1 = mo1.getFedMapping().broadcast(mo2);
            FederatedRequest fr2 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{mo1.getFedMapping().getID(), fr1.getID()});
            if (mo2.getNumColumns() == 1L) {
                FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
                FederatedRequest fr4 = mo1.getFedMapping().cleanup(this.getTID(), fr1.getID(), fr2.getID());
                Future<FederatedResponse>[] tmp = mo1.getFedMapping().execute(this.getTID(), fr1, fr2, fr3, fr4);
                MatrixBlock ret = FederationUtils.rbind(tmp);
                ec.setMatrixOutput(this.output.getName(), ret);
            } else {
                FederatedRequest fr3 = mo1.getFedMapping().cleanup(this.getTID(), fr1.getID());
                mo1.getFedMapping().execute(this.getTID(), true, fr1, fr2, fr3);
                MatrixObject out = ec.getMatrixObject(this.output);
                out.getDataCharacteristics().set(mo1.getNumRows(), mo2.getNumColumns(), (int)mo1.getBlocksize());
                out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr2.getID(), mo2.getNumColumns()));
                out.getFedMapping().setType(FederationMap.FType.ROW);
            }
        } else if (mo2.isFederated(FederationMap.FType.ROW)) {
            FederatedRequest[] fr1 = mo2.getFedMapping().broadcastSliced(mo1, true);
            FederatedRequest fr2 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{fr1[0].getID(), mo2.getFedMapping().getID()});
            FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
            FederatedRequest fr4 = mo2.getFedMapping().cleanup(this.getTID(), fr1[0].getID(), fr2.getID());
            Future<FederatedResponse>[] tmp = mo2.getFedMapping().execute(this.getTID(), fr1, new FederatedRequest[]{fr2, fr3, fr4});
            MatrixBlock ret = FederationUtils.aggAdd(tmp);
            ec.setMatrixOutput(this.output.getName(), ret);
        } else {
            throw new DMLRuntimeException("Federated AggregateBinary not supported with the following federated objects: " + mo1.isFederated() + ":" + mo1.getFedMapping() + " " + mo2.isFederated() + ":" + mo2.getFedMapping());
        }
    }
}

