/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite3.internal.sql.engine.prepare;

import com.google.common.collect.ImmutableRangeSet;
import com.google.common.collect.RangeSet;
import java.math.BigDecimal;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUnknownAs;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.util.Sarg;
import org.apache.ignite3.internal.sql.engine.sql.fun.IgniteSqlOperatorTable;
import org.apache.ignite3.internal.sql.engine.util.RexUtils;
import org.apache.ignite3.internal.sql.engine.util.TypeUtils;

public class OutOfRangeLiteralComparisonReductionShuttle
extends RexShuttle {
    private final RexBuilder builder;
    private int comparisonOpTracker;

    public OutOfRangeLiteralComparisonReductionShuttle(RexBuilder builder) {
        this.builder = builder;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public RexNode visitCall(RexCall call) {
        RexNode ref;
        if (RexUtils.isBinaryComparison((RexNode)call)) {
            ++this.comparisonOpTracker;
            try {
                BigDecimal lower;
                RelDataType probingType;
                RexNode firstOp = (RexNode)((RexNode)call.getOperands().get(0)).accept((RexVisitor)this);
                RexNode secondOp = (RexNode)((RexNode)call.getOperands().get(1)).accept((RexVisitor)this);
                RexNode expression = null;
                RexNode possibleLiteral = null;
                SqlKind op = call.op.getKind();
                if (RexUtils.isLosslessCast(firstOp)) {
                    expression = (RexNode)((RexCall)firstOp).getOperands().get(0);
                    possibleLiteral = secondOp;
                } else if (RexUtils.isLosslessCast(secondOp)) {
                    expression = (RexNode)((RexCall)secondOp).getOperands().get(0);
                    possibleLiteral = firstOp;
                    op = op.reverse();
                }
                RelDataType relDataType = probingType = expression != null ? expression.getType() : null;
                if (probingType != null && probingType.getFamily() == SqlTypeFamily.NUMERIC) {
                    assert (expression != null && possibleLiteral != null);
                    if (possibleLiteral instanceof RexLiteral) {
                        BigDecimal current;
                        lower = TypeUtils.lowerBoundFor(probingType);
                        BigDecimal upper = TypeUtils.upperBoundFor(probingType);
                        try {
                            current = (BigDecimal)((RexLiteral)possibleLiteral).getValueAs(BigDecimal.class);
                        }
                        catch (NumberFormatException ignored) {
                            current = null;
                        }
                        if (lower != null && upper != null && current != null) {
                            switch (op) {
                                case EQUALS: 
                                case IS_NOT_DISTINCT_FROM: {
                                    if (OutOfRangeLiteralComparisonReductionShuttle.inRange(current, lower, upper) && probingType.getScale() >= current.stripTrailingZeros().scale()) {
                                        RexCall rexCall = call.clone(call.getType(), List.of(expression, this.builder.makeLiteral((Object)current, probingType)));
                                        return rexCall;
                                    }
                                    RexLiteral rexLiteral = this.builder.makeLiteral(false);
                                    return rexLiteral;
                                }
                                case GREATER_THAN: 
                                case GREATER_THAN_OR_EQUAL: {
                                    if (OutOfRangeLiteralComparisonReductionShuttle.inRange(current, lower, upper)) {
                                        if (probingType.getScale() < current.stripTrailingZeros().scale()) {
                                            if (current.signum() < 0) {
                                                RexNode rexNode = this.builder.makeCall((SqlOperator)IgniteSqlOperatorTable.GREATER_THAN_OR_EQUAL, new RexNode[]{expression, this.builder.makeLiteral((Object)current, probingType)});
                                                return rexNode;
                                            }
                                            if (current.signum() > 0) {
                                                RexNode rexNode = this.builder.makeCall((SqlOperator)IgniteSqlOperatorTable.GREATER_THAN, new RexNode[]{expression, this.builder.makeLiteral((Object)current, probingType)});
                                                return rexNode;
                                            }
                                        }
                                        RexCall rexCall = call.clone(call.getType(), List.of(expression, this.builder.makeLiteral((Object)current, probingType)));
                                        return rexCall;
                                    }
                                    if (current.compareTo(upper) > 0) {
                                        RexLiteral rexLiteral = this.builder.makeLiteral(false);
                                        return rexLiteral;
                                    }
                                    RexNode rexNode = this.builder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NOT_NULL, new RexNode[]{expression});
                                    return rexNode;
                                }
                                case LESS_THAN: 
                                case LESS_THAN_OR_EQUAL: {
                                    if (OutOfRangeLiteralComparisonReductionShuttle.inRange(current, lower, upper)) {
                                        if (probingType.getScale() < current.stripTrailingZeros().scale()) {
                                            if (current.signum() < 0) {
                                                RexNode rexNode = this.builder.makeCall((SqlOperator)IgniteSqlOperatorTable.LESS_THAN, new RexNode[]{expression, this.builder.makeLiteral((Object)current, probingType)});
                                                return rexNode;
                                            }
                                            if (current.signum() > 0) {
                                                RexNode rexNode = this.builder.makeCall((SqlOperator)IgniteSqlOperatorTable.LESS_THAN_OR_EQUAL, new RexNode[]{expression, this.builder.makeLiteral((Object)current, probingType)});
                                                return rexNode;
                                            }
                                        }
                                        RexCall rexCall = call.clone(call.getType(), List.of(expression, this.builder.makeLiteral((Object)current, probingType)));
                                        return rexCall;
                                    }
                                    if (current.compareTo(lower) < 0) {
                                        RexLiteral rexLiteral = this.builder.makeLiteral(false);
                                        return rexLiteral;
                                    }
                                    RexNode rexNode = this.builder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NOT_NULL, new RexNode[]{expression});
                                    return rexNode;
                                }
                            }
                        }
                    }
                }
                lower = call.getOperands().get(0) != firstOp || call.getOperands().get(1) != secondOp ? call.clone(call.getType(), List.of(firstOp, secondOp)) : call;
                return lower;
            }
            finally {
                --this.comparisonOpTracker;
            }
        }
        if (call.isA(SqlKind.SEARCH) && RexUtils.isLosslessCast(ref = (RexNode)call.getOperands().get(0)) && ref.getType().getFamily() == SqlTypeFamily.NUMERIC) {
            ref = (RexNode)((RexCall)ref).getOperands().get(0);
            RelDataType probingType = ref.getType();
            Sarg values = (Sarg)((RexLiteral)call.getOperands().get(1)).getValueAs(Sarg.class);
            assert (values != null);
            if (values.isPoints()) {
                BigDecimal lower = TypeUtils.lowerBoundFor(probingType);
                BigDecimal upper = TypeUtils.upperBoundFor(probingType);
                assert (lower != null);
                assert (upper != null);
                ImmutableRangeSet normalized = ImmutableRangeSet.copyOf((Iterable)values.rangeSet.asRanges().stream().filter(r -> {
                    BigDecimal current = (BigDecimal)r.lowerEndpoint();
                    return OutOfRangeLiteralComparisonReductionShuttle.inRange(current, lower, upper) && probingType.getScale() >= current.stripTrailingZeros().scale();
                }).collect(Collectors.toList()));
                RexLiteral normalizedSargLiteral = this.builder.makeSearchArgumentLiteral(Sarg.of((RexUnknownAs)values.nullAs, (RangeSet)normalized), probingType);
                return this.builder.makeCall((SqlOperator)SqlStdOperatorTable.SEARCH, new RexNode[]{ref, normalizedSargLiteral});
            }
        }
        if (this.comparisonOpTracker > 0 && RexUtils.isLosslessCast((RexNode)call) && call.getOperands().get(0) instanceof RexLiteral) {
            return this.builder.makeLiteral((Object)((RexLiteral)call.getOperands().get(0)).getValue(), call.getType());
        }
        return RexUtil.flatten((RexBuilder)this.builder, (RexNode)super.visitCall(call));
    }

    private static boolean inRange(BigDecimal val, BigDecimal lower, BigDecimal upper) {
        return lower.compareTo(val) <= 0 && upper.compareTo(val) >= 0;
    }
}

