/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.expressions.converter;

import java.math.BigDecimal;
import java.time.Duration;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.time.Period;
import java.time.ZoneOffset;
import java.time.temporal.ChronoField;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.calcite.avatica.util.ByteString;
import org.apache.calcite.avatica.util.TimeUnit;
import org.apache.calcite.avatica.util.TimeUnitRange;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.DateString;
import org.apache.calcite.util.TimeString;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.data.DecimalData;
import org.apache.flink.table.expressions.CallExpression;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.ExpressionVisitor;
import org.apache.flink.table.expressions.FieldReferenceExpression;
import org.apache.flink.table.expressions.LocalReferenceExpression;
import org.apache.flink.table.expressions.TimeIntervalUnit;
import org.apache.flink.table.expressions.TimePointUnit;
import org.apache.flink.table.expressions.TypeLiteralExpression;
import org.apache.flink.table.expressions.ValueLiteralExpression;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.calcite.RexFieldVariable;
import org.apache.flink.table.planner.expressions.RexNodeExpression;
import org.apache.flink.table.planner.expressions.converter.CallExpressionConvertRule;
import org.apache.flink.table.planner.expressions.converter.CustomizedConvertRule;
import org.apache.flink.table.planner.expressions.converter.DirectConvertRule;
import org.apache.flink.table.planner.expressions.converter.FunctionDefinitionConvertRule;
import org.apache.flink.table.planner.expressions.converter.LegacyScalarFunctionConvertRule;
import org.apache.flink.table.planner.expressions.converter.OverConvertRule;
import org.apache.flink.table.planner.utils.ShortcutUtils;
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.TimeType;
import org.apache.flink.table.util.TimestampStringUtils;

public class ExpressionConverter
implements ExpressionVisitor<RexNode> {
    private static final List<CallExpressionConvertRule> FUNCTION_CONVERT_CHAIN = Arrays.asList(new LegacyScalarFunctionConvertRule(), new FunctionDefinitionConvertRule(), new OverConvertRule(), new DirectConvertRule(), new CustomizedConvertRule());
    private final RelBuilder relBuilder;
    private final FlinkTypeFactory typeFactory;
    private final DataTypeFactory dataTypeFactory;

    public ExpressionConverter(RelBuilder relBuilder) {
        this.relBuilder = relBuilder;
        this.typeFactory = (FlinkTypeFactory)relBuilder.getRexBuilder().getTypeFactory();
        this.dataTypeFactory = ShortcutUtils.unwrapContext(relBuilder.getCluster()).getCatalogManager().getDataTypeFactory();
    }

    public RexNode visit(CallExpression call) {
        for (CallExpressionConvertRule rule : FUNCTION_CONVERT_CHAIN) {
            Optional<RexNode> converted = rule.convert(call, this.newFunctionContext());
            if (!converted.isPresent()) continue;
            return converted.get();
        }
        throw new RuntimeException("Unknown call expression: " + call);
    }

    public RexNode visit(ValueLiteralExpression valueLiteral) {
        LogicalType type = LogicalTypeDataTypeConverter.fromDataTypeToLogicalType((DataType)valueLiteral.getOutputDataType());
        RexBuilder rexBuilder = this.relBuilder.getRexBuilder();
        FlinkTypeFactory typeFactory = (FlinkTypeFactory)this.relBuilder.getTypeFactory();
        RelDataType relDataType = typeFactory.createFieldTypeFromLogicalType(type);
        if (valueLiteral.isNull()) {
            return rexBuilder.makeNullLiteral(relDataType);
        }
        Object value = null;
        switch (type.getTypeRoot()) {
            case DECIMAL: 
            case TINYINT: 
            case SMALLINT: 
            case INTEGER: 
            case BIGINT: 
            case FLOAT: 
            case DOUBLE: {
                value = ExpressionConverter.extractValue(valueLiteral, BigDecimal.class);
                break;
            }
            case VARCHAR: 
            case CHAR: {
                value = ExpressionConverter.extractValue(valueLiteral, String.class);
                break;
            }
            case BINARY: 
            case VARBINARY: {
                value = new ByteString(ExpressionConverter.extractValue(valueLiteral, byte[].class));
                break;
            }
            case INTERVAL_YEAR_MONTH: {
                value = BigDecimal.valueOf(ExpressionConverter.extractValue(valueLiteral, Period.class).toTotalMonths());
                break;
            }
            case INTERVAL_DAY_TIME: {
                value = BigDecimal.valueOf(ExpressionConverter.extractValue(valueLiteral, Duration.class).toMillis());
                break;
            }
            case DATE: {
                value = DateString.fromDaysSinceEpoch((int)ExpressionConverter.extractValue(valueLiteral, LocalDate.class).toEpochDay());
                break;
            }
            case TIME_WITHOUT_TIME_ZONE: {
                TimeType timeType = (TimeType)type;
                int precision = timeType.getPrecision();
                relDataType = typeFactory.createSqlType(SqlTypeName.TIME, Math.min(precision, 3));
                value = TimeString.fromMillisOfDay(ExpressionConverter.extractValue(valueLiteral, LocalTime.class).get(ChronoField.MILLI_OF_DAY));
                break;
            }
            case TIMESTAMP_WITHOUT_TIME_ZONE: {
                LocalDateTime datetime = ExpressionConverter.extractValue(valueLiteral, LocalDateTime.class);
                value = TimestampStringUtils.fromLocalDateTime(datetime);
                break;
            }
            case TIMESTAMP_WITH_LOCAL_TIME_ZONE: {
                Instant instant = ExpressionConverter.extractValue(valueLiteral, Instant.class);
                value = TimestampStringUtils.fromLocalDateTime(instant.atOffset(ZoneOffset.UTC).toLocalDateTime());
                break;
            }
            default: {
                value = ExpressionConverter.extractValue(valueLiteral, Object.class);
                if (value instanceof TimePointUnit) {
                    value = ExpressionConverter.timePointUnitToTimeUnit((TimePointUnit)value);
                    break;
                }
                if (!(value instanceof TimeIntervalUnit)) break;
                value = ExpressionConverter.intervalUnitToUnitRange((TimeIntervalUnit)value);
            }
        }
        return rexBuilder.makeLiteral(value, relDataType, true);
    }

    public RexNode visit(FieldReferenceExpression fieldReference) {
        return this.relBuilder.field(fieldReference.getName());
    }

    public RexNode visit(TypeLiteralExpression typeLiteral) {
        throw new UnsupportedOperationException();
    }

    public RexNode visit(Expression other) {
        if (other instanceof RexNodeExpression) {
            return ((RexNodeExpression)other).getRexNode();
        }
        if (other instanceof LocalReferenceExpression) {
            LocalReferenceExpression local = (LocalReferenceExpression)other;
            return new RexFieldVariable(local.getName(), this.typeFactory.createFieldTypeFromLogicalType(LogicalTypeDataTypeConverter.fromDataTypeToLogicalType((DataType)local.getOutputDataType())));
        }
        throw new UnsupportedOperationException(other.getClass().getSimpleName() + ":" + other.toString());
    }

    public static List<RexNode> toRexNodes(CallExpressionConvertRule.ConvertContext context, List<Expression> expr) {
        return expr.stream().map(context::toRexNode).collect(Collectors.toList());
    }

    private CallExpressionConvertRule.ConvertContext newFunctionContext() {
        return new CallExpressionConvertRule.ConvertContext(){

            @Override
            public RexNode toRexNode(Expression expr) {
                return (RexNode)expr.accept((ExpressionVisitor)ExpressionConverter.this);
            }

            @Override
            public RelBuilder getRelBuilder() {
                return ExpressionConverter.this.relBuilder;
            }

            @Override
            public FlinkTypeFactory getTypeFactory() {
                return ExpressionConverter.this.typeFactory;
            }

            @Override
            public DataTypeFactory getDataTypeFactory() {
                return ExpressionConverter.this.dataTypeFactory;
            }
        };
    }

    private static TimeUnit timePointUnitToTimeUnit(TimePointUnit unit) {
        switch (unit) {
            case YEAR: {
                return TimeUnit.YEAR;
            }
            case MONTH: {
                return TimeUnit.MONTH;
            }
            case DAY: {
                return TimeUnit.DAY;
            }
            case HOUR: {
                return TimeUnit.HOUR;
            }
            case MINUTE: {
                return TimeUnit.MINUTE;
            }
            case SECOND: {
                return TimeUnit.SECOND;
            }
            case QUARTER: {
                return TimeUnit.QUARTER;
            }
            case WEEK: {
                return TimeUnit.WEEK;
            }
            case MILLISECOND: {
                return TimeUnit.MILLISECOND;
            }
            case MICROSECOND: {
                return TimeUnit.MICROSECOND;
            }
        }
        throw new UnsupportedOperationException("TimePointUnit is: " + unit);
    }

    private static TimeUnitRange intervalUnitToUnitRange(TimeIntervalUnit intervalUnit) {
        switch (intervalUnit) {
            case YEAR: {
                return TimeUnitRange.YEAR;
            }
            case YEAR_TO_MONTH: {
                return TimeUnitRange.YEAR_TO_MONTH;
            }
            case QUARTER: {
                return TimeUnitRange.QUARTER;
            }
            case MONTH: {
                return TimeUnitRange.MONTH;
            }
            case WEEK: {
                return TimeUnitRange.WEEK;
            }
            case DAY: {
                return TimeUnitRange.DAY;
            }
            case DAY_TO_HOUR: {
                return TimeUnitRange.DAY_TO_HOUR;
            }
            case DAY_TO_MINUTE: {
                return TimeUnitRange.DAY_TO_MINUTE;
            }
            case DAY_TO_SECOND: {
                return TimeUnitRange.DAY_TO_SECOND;
            }
            case HOUR: {
                return TimeUnitRange.HOUR;
            }
            case SECOND: {
                return TimeUnitRange.SECOND;
            }
            case HOUR_TO_MINUTE: {
                return TimeUnitRange.HOUR_TO_MINUTE;
            }
            case HOUR_TO_SECOND: {
                return TimeUnitRange.HOUR_TO_SECOND;
            }
            case MINUTE: {
                return TimeUnitRange.MINUTE;
            }
            case MINUTE_TO_SECOND: {
                return TimeUnitRange.MINUTE_TO_SECOND;
            }
        }
        throw new UnsupportedOperationException("TimeIntervalUnit is: " + intervalUnit);
    }

    public static <T> T extractValue(ValueLiteralExpression literal, Class<T> clazz) {
        Optional possibleObject = literal.getValueAs(Object.class);
        if (!possibleObject.isPresent()) {
            throw new TableException("Invalid literal.");
        }
        Object object = possibleObject.get();
        if (clazz.equals(BigDecimal.class)) {
            Optional possibleDecimal = literal.getValueAs(BigDecimal.class);
            if (possibleDecimal.isPresent()) {
                return possibleDecimal.get();
            }
            if (object instanceof DecimalData) {
                return (T)((DecimalData)object).toBigDecimal();
            }
        }
        return literal.getValueAs(clazz).orElseThrow(() -> new TableException("Unsupported literal class: " + clazz));
    }
}

