/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.types.inference;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.connector.Projection;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.FunctionKind;
import org.apache.flink.table.functions.TableSemantics;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.ArgumentCount;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.table.types.inference.InputTypeStrategies;
import org.apache.flink.table.types.inference.InputTypeStrategy;
import org.apache.flink.table.types.inference.Signature;
import org.apache.flink.table.types.inference.StaticArgument;
import org.apache.flink.table.types.inference.StaticArgumentTrait;
import org.apache.flink.table.types.inference.TypeInference;
import org.apache.flink.table.types.inference.TypeStrategy;

@Internal
public class SystemTypeInference {
    private static final List<StaticArgument> PROCESS_TABLE_FUNCTION_SYSTEM_ARGS = List.of(StaticArgument.scalar("uid", DataTypes.STRING(), true));
    private static final Predicate<String> UID_FORMAT = Pattern.compile("^[a-zA-Z_][a-zA-Z-_0-9]*$").asPredicate();

    public static TypeInference of(FunctionKind functionKind, TypeInference origin) {
        TypeInference.Builder builder = TypeInference.newBuilder();
        List<StaticArgument> systemArgs = SystemTypeInference.deriveSystemArgs(functionKind, origin.getStaticArguments().orElse(null));
        if (systemArgs != null) {
            builder.staticArguments(systemArgs);
        }
        builder.inputTypeStrategy(SystemTypeInference.deriveSystemInputStrategy(functionKind, systemArgs, origin.getInputTypeStrategy()));
        builder.stateTypeStrategies(origin.getStateTypeStrategies());
        builder.outputTypeStrategy(SystemTypeInference.deriveSystemOutputStrategy(functionKind, systemArgs, origin.getOutputTypeStrategy()));
        return builder.build();
    }

    private static void checkScalarArgsOnly(List<StaticArgument> defaultArgs) {
        defaultArgs.forEach(arg -> {
            if (!arg.is(StaticArgumentTrait.SCALAR)) {
                throw new ValidationException(String.format("Only scalar arguments are supported at this location. But argument '%s' declared the following traits: %s", arg.getName(), arg.getTraits()));
            }
        });
    }

    @Nullable
    private static List<StaticArgument> deriveSystemArgs(FunctionKind functionKind, @Nullable List<StaticArgument> declaredArgs) {
        if (functionKind != FunctionKind.PROCESS_TABLE) {
            if (declaredArgs != null) {
                SystemTypeInference.checkScalarArgsOnly(declaredArgs);
            }
            return declaredArgs;
        }
        if (declaredArgs == null) {
            throw new ValidationException("Function requires a static signature that is not overloaded and doesn't contain varargs.");
        }
        SystemTypeInference.checkReservedArgs(declaredArgs);
        ArrayList<StaticArgument> newStaticArgs = new ArrayList<StaticArgument>(declaredArgs);
        newStaticArgs.addAll(PROCESS_TABLE_FUNCTION_SYSTEM_ARGS);
        return newStaticArgs;
    }

    private static void checkReservedArgs(List<StaticArgument> staticArgs) {
        Set declaredArgs = staticArgs.stream().map(StaticArgument::getName).collect(Collectors.toSet());
        Set reservedArgs = PROCESS_TABLE_FUNCTION_SYSTEM_ARGS.stream().map(StaticArgument::getName).collect(Collectors.toSet());
        if (reservedArgs.stream().anyMatch(declaredArgs::contains)) {
            throw new ValidationException("Function signature must not declare system arguments. Reserved argument names are: " + reservedArgs);
        }
    }

    private static InputTypeStrategy deriveSystemInputStrategy(FunctionKind functionKind, @Nullable List<StaticArgument> staticArgs, InputTypeStrategy inputStrategy) {
        if (functionKind != FunctionKind.PROCESS_TABLE) {
            return inputStrategy;
        }
        return new SystemInputStrategy(staticArgs, inputStrategy);
    }

    private static TypeStrategy deriveSystemOutputStrategy(FunctionKind functionKind, @Nullable List<StaticArgument> staticArgs, TypeStrategy outputStrategy) {
        if (functionKind != FunctionKind.TABLE && functionKind != FunctionKind.PROCESS_TABLE) {
            return outputStrategy;
        }
        return new SystemOutputStrategy(staticArgs, outputStrategy);
    }

    private static class SystemInputStrategy
    implements InputTypeStrategy {
        private final List<StaticArgument> staticArgs;
        private final InputTypeStrategy origin;

        private SystemInputStrategy(List<StaticArgument> staticArgs, InputTypeStrategy origin) {
            this.staticArgs = staticArgs;
            this.origin = origin;
        }

        @Override
        public ArgumentCount getArgumentCount() {
            return InputTypeStrategies.WILDCARD.getArgumentCount();
        }

        @Override
        public Optional<List<DataType>> inferInputTypes(CallContext callContext, boolean throwOnFailure) {
            List<DataType> args = callContext.getArgumentDataTypes();
            List inferredDataTypes = this.origin.inferInputTypes(callContext, throwOnFailure).orElse(null);
            if (inferredDataTypes == null || !inferredDataTypes.equals(args)) {
                throw new ValidationException("Process table functions must declare a static signature that is not overloaded and doesn't contain varargs.");
            }
            SystemInputStrategy.checkUidColumn(callContext);
            SystemInputStrategy.checkMultipleTableArgs(callContext);
            SystemInputStrategy.checkTableArgTraits(this.staticArgs, callContext);
            return Optional.of(inferredDataTypes);
        }

        @Override
        public List<Signature> getExpectedSignatures(FunctionDefinition definition) {
            return this.origin.getExpectedSignatures(definition);
        }

        private static void checkUidColumn(CallContext callContext) {
            String uid;
            List<DataType> args = callContext.getArgumentDataTypes();
            int uidPos = args.size() - 1;
            if (!callContext.isArgumentNull(uidPos) && !UID_FORMAT.test(uid = callContext.getArgumentValue(uidPos, String.class).orElse(""))) {
                throw new ValidationException("Invalid unique identifier for process table function. The 'uid' argument must be a string literal that follows the pattern [a-zA-Z_][a-zA-Z-_0-9]*. But found: " + uid);
            }
        }

        private static void checkMultipleTableArgs(CallContext callContext) {
            List<DataType> args = callContext.getArgumentDataTypes();
            List tableSemantics = IntStream.range(0, args.size()).mapToObj(pos -> callContext.getTableSemantics(pos).orElse(null)).collect(Collectors.toList());
            if (tableSemantics.stream().filter(Objects::nonNull).count() > 1L) {
                throw new ValidationException("Currently, only signatures with at most one table argument are supported.");
            }
        }

        private static void checkTableArgTraits(List<StaticArgument> staticArgs, CallContext callContext) {
            IntStream.range(0, staticArgs.size()).forEach(pos -> {
                StaticArgument staticArg = (StaticArgument)staticArgs.get(pos);
                if (!staticArg.is(StaticArgumentTrait.TABLE)) {
                    return;
                }
                TableSemantics semantics = callContext.getTableSemantics(pos).orElse(null);
                if (semantics == null) {
                    throw new ValidationException(String.format("Table expected for argument '%s'.", staticArg.getName()));
                }
                SystemInputStrategy.checkRowSemantics(staticArg, semantics);
                SystemInputStrategy.checkSetSemantics(staticArg, semantics);
            });
        }

        private static void checkRowSemantics(StaticArgument staticArg, TableSemantics semantics) {
            if (!staticArg.is(StaticArgumentTrait.TABLE_AS_ROW)) {
                return;
            }
            if (semantics.partitionByColumns().length > 0 || semantics.orderByColumns().length > 0) {
                throw new ValidationException("PARTITION BY or ORDER BY are not supported for table arguments with row semantics.");
            }
        }

        private static void checkSetSemantics(StaticArgument staticArg, TableSemantics semantics) {
            if (!staticArg.is(StaticArgumentTrait.TABLE_AS_SET)) {
                return;
            }
            if (semantics.partitionByColumns().length == 0 && !staticArg.is(StaticArgumentTrait.OPTIONAL_PARTITION_BY)) {
                throw new ValidationException(String.format("Table argument '%s' requires a PARTITION BY clause for parallel processing.", staticArg.getName()));
            }
        }
    }

    private static class SystemOutputStrategy
    implements TypeStrategy {
        private final List<StaticArgument> staticArgs;
        private final TypeStrategy origin;

        private SystemOutputStrategy(List<StaticArgument> staticArgs, TypeStrategy origin) {
            this.staticArgs = staticArgs;
            this.origin = origin;
        }

        @Override
        public Optional<DataType> inferType(CallContext callContext) {
            return this.origin.inferType(callContext).map(functionDataType -> {
                ArrayList<DataTypes.Field> fields = new ArrayList<DataTypes.Field>();
                fields.addAll(this.derivePassThroughFields(callContext));
                fields.addAll(this.deriveFunctionOutputFields((DataType)functionDataType));
                List<DataTypes.Field> uniqueFields = this.makeFieldNamesUnique(fields);
                return (DataType)DataTypes.ROW(uniqueFields).notNull();
            });
        }

        private List<DataTypes.Field> makeFieldNamesUnique(List<DataTypes.Field> fields) {
            HashMap fieldCount = new HashMap();
            return fields.stream().map(item -> {
                int nextCount = fieldCount.compute(item.getName(), (fieldName, count) -> count == null ? -1 : count + 1);
                String newFieldName = nextCount < 0 ? item.getName() : item.getName() + nextCount;
                return DataTypes.FIELD(newFieldName, item.getDataType());
            }).collect(Collectors.toList());
        }

        private List<DataTypes.Field> derivePassThroughFields(CallContext callContext) {
            if (this.staticArgs == null) {
                return List.of();
            }
            List<DataType> argDataTypes = callContext.getArgumentDataTypes();
            return IntStream.range(0, this.staticArgs.size()).mapToObj(pos -> {
                StaticArgument arg = this.staticArgs.get(pos);
                if (arg.is(StaticArgumentTrait.PASS_COLUMNS_THROUGH)) {
                    return DataType.getFields((DataType)argDataTypes.get(pos)).stream();
                }
                if (!arg.is(StaticArgumentTrait.TABLE_AS_SET)) {
                    return Stream.empty();
                }
                TableSemantics semantics = callContext.getTableSemantics(pos).orElseThrow(IllegalStateException::new);
                DataType projectedRow = Projection.of(semantics.partitionByColumns()).project((DataType)argDataTypes.get(pos));
                return DataType.getFields(projectedRow).stream();
            }).flatMap(s -> s).collect(Collectors.toList());
        }

        private List<DataTypes.Field> deriveFunctionOutputFields(DataType functionDataType) {
            List<DataType> fieldTypes = DataType.getFieldDataTypes(functionDataType);
            List<String> fieldNames = DataType.getFieldNames(functionDataType);
            if (fieldTypes.isEmpty()) {
                return List.of(DataTypes.FIELD("EXPR$0", functionDataType));
            }
            return IntStream.range(0, fieldTypes.size()).mapToObj(pos -> DataTypes.FIELD((String)fieldNames.get(pos), (DataType)fieldTypes.get(pos))).collect(Collectors.toList());
        }
    }
}

