/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.lops.rewrite;

import java.util.ArrayList;
import java.util.List;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.rewrite.LopRewriteRule;
import org.apache.sysds.lops.rewrite.RewriteAddBroadcastLop;
import org.apache.sysds.lops.rewrite.RewriteAddChkpointInLoop;
import org.apache.sysds.lops.rewrite.RewriteAddChkpointLop;
import org.apache.sysds.lops.rewrite.RewriteAddGPUEvictLop;
import org.apache.sysds.lops.rewrite.RewriteAddPrefetchLop;
import org.apache.sysds.lops.rewrite.RewriteFixIDs;
import org.apache.sysds.lops.rewrite.RewriteUpdateGPUPlacements;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;

public class LopRewriter {
    private ArrayList<LopRewriteRule> _lopSBRuleSet = new ArrayList();

    public LopRewriter() {
        this._lopSBRuleSet.add(new RewriteUpdateGPUPlacements());
        this._lopSBRuleSet.add(new RewriteAddPrefetchLop());
        this._lopSBRuleSet.add(new RewriteAddBroadcastLop());
        this._lopSBRuleSet.add(new RewriteAddChkpointLop());
        this._lopSBRuleSet.add(new RewriteAddChkpointInLoop());
        this._lopSBRuleSet.add(new RewriteAddGPUEvictLop());
        this._lopSBRuleSet.add(new RewriteFixIDs());
    }

    public void rewriteProgramLopDAGs(DMLProgram dmlp) {
        for (String namespaceKey : dmlp.getNamespaces().keySet()) {
            for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
                FunctionStatementBlock fsblock = dmlp.getFunctionStatementBlock(namespaceKey, fname);
                this.rewriteLopDAGsFunction(fsblock);
            }
        }
        if (!this._lopSBRuleSet.isEmpty()) {
            ArrayList<StatementBlock> sbs = this.rRewriteLops(dmlp.getStatementBlocks());
            dmlp.setStatementBlocks(sbs);
        }
    }

    public void rewriteLopDAGsFunction(FunctionStatementBlock fsb) {
        if (!this._lopSBRuleSet.isEmpty()) {
            this.rRewriteLop(fsb);
        }
    }

    public ArrayList<Lop> rewriteLopDAG(StatementBlock sb, ArrayList<Lop> lops) {
        sb.setLops(lops);
        return this.rRewriteLop(sb).get(0).getLops();
    }

    public ArrayList<StatementBlock> rRewriteLops(ArrayList<StatementBlock> sbs) {
        List<StatementBlock> tmp = sbs;
        for (LopRewriteRule r : this._lopSBRuleSet) {
            tmp = r.rewriteLOPinStatementBlocks(tmp);
        }
        ArrayList<StatementBlock> tmp2 = new ArrayList<StatementBlock>();
        for (StatementBlock sb : tmp) {
            tmp2.addAll(this.rRewriteLop(sb));
        }
        sbs.clear();
        sbs.addAll(tmp2);
        return sbs;
    }

    public ArrayList<StatementBlock> rRewriteLop(StatementBlock sb) {
        Statement fstmt;
        StatementBlock fsb;
        ArrayList<StatementBlock> ret = new ArrayList<StatementBlock>();
        ret.add(sb);
        if (sb instanceof FunctionStatementBlock) {
            fsb = (FunctionStatementBlock)sb;
            fstmt = (FunctionStatement)fsb.getStatement(0);
            ((FunctionStatement)fstmt).setBody(this.rRewriteLops(((FunctionStatement)fstmt).getBody()));
        } else if (sb instanceof WhileStatementBlock) {
            WhileStatementBlock wsb = (WhileStatementBlock)sb;
            WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
            wstmt.setBody(this.rRewriteLops(wstmt.getBody()));
        } else if (sb instanceof IfStatementBlock) {
            IfStatementBlock isb = (IfStatementBlock)sb;
            IfStatement istmt = (IfStatement)isb.getStatement(0);
            istmt.setIfBody(this.rRewriteLops(istmt.getIfBody()));
            istmt.setElseBody(this.rRewriteLops(istmt.getElseBody()));
        } else if (sb instanceof ForStatementBlock) {
            fsb = (ForStatementBlock)sb;
            fstmt = (ForStatement)fsb.getStatement(0);
            ((ForStatement)fstmt).setBody(this.rRewriteLops(((ForStatement)fstmt).getBody()));
        }
        for (LopRewriteRule r : this._lopSBRuleSet) {
            ArrayList<StatementBlock> tmp = new ArrayList<StatementBlock>();
            for (StatementBlock sbc : ret) {
                tmp.addAll(r.rewriteLOPinStatementBlock(sbc));
            }
            ret.clear();
            ret.addAll(tmp);
        }
        return ret;
    }
}

