/*
 * Decompiled with CFR 0.152.
 */
package io.trino.operator.scalar;

import io.airlift.joni.Matcher;
import io.airlift.joni.Regex;
import io.airlift.joni.Region;
import io.airlift.joni.exception.ValueException;
import io.airlift.slice.DynamicSliceOutput;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.SliceUtf8;
import io.airlift.slice.Slices;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.Description;
import io.trino.spi.function.LiteralParameters;
import io.trino.spi.function.ScalarFunction;
import io.trino.spi.function.SqlNullable;
import io.trino.spi.function.SqlType;
import io.trino.spi.type.VarcharType;
import io.trino.type.Constraint;
import io.trino.type.JoniRegexp;
import java.nio.charset.StandardCharsets;

public final class JoniRegexpFunctions {
    private JoniRegexpFunctions() {
    }

    @Description(value="Returns whether the pattern is contained within the string")
    @ScalarFunction
    @LiteralParameters(value={"x"})
    @SqlType(value="boolean")
    public static boolean regexpLike(@SqlType(value="varchar(x)") Slice source, @SqlType(value="JoniRegExp") JoniRegexp pattern) {
        Matcher matcher;
        int offset;
        if (source.hasByteArray()) {
            offset = source.byteArrayOffset();
            matcher = pattern.regex().matcher(source.byteArray(), offset, offset + source.length());
        } else {
            offset = 0;
            matcher = pattern.matcher(source.getBytes());
        }
        return matcher.search(offset, offset + source.length(), 0) != -1;
    }

    private static int getNextStart(Slice source, Matcher matcher) {
        if (matcher.getEnd() == matcher.getBegin()) {
            if (matcher.getBegin() < source.length()) {
                return matcher.getEnd() + SliceUtf8.lengthOfCodePointFromStartByte((byte)source.getByte(matcher.getBegin()));
            }
            return matcher.getEnd() + 1;
        }
        return matcher.getEnd();
    }

    @Description(value="Removes substrings matching a regular expression")
    @ScalarFunction
    @LiteralParameters(value={"x"})
    @SqlType(value="varchar(x)")
    public static Slice regexpReplace(@SqlType(value="varchar(x)") Slice source, @SqlType(value="JoniRegExp") JoniRegexp pattern) {
        return JoniRegexpFunctions.regexpReplace(source, pattern, Slices.EMPTY_SLICE);
    }

    @Description(value="Replaces substrings matching a regular expression by given string")
    @ScalarFunction
    @LiteralParameters(value={"x", "y", "z"})
    @Constraint(variable="z", expression="min(2147483647, x + max(x * y / 2, y) * (x + 1))")
    @SqlType(value="varchar(z)")
    public static Slice regexpReplace(@SqlType(value="varchar(x)") Slice source, @SqlType(value="JoniRegExp") JoniRegexp pattern, @SqlType(value="varchar(y)") Slice replacement) {
        int offset;
        Matcher matcher = pattern.matcher(source.getBytes());
        DynamicSliceOutput sliceOutput = new DynamicSliceOutput(source.length() + replacement.length() * 5);
        int lastEnd = 0;
        int nextStart = 0;
        while ((offset = matcher.search(nextStart, source.length(), 0)) != -1) {
            nextStart = JoniRegexpFunctions.getNextStart(source, matcher);
            Slice sliceBetweenReplacements = source.slice(lastEnd, matcher.getBegin() - lastEnd);
            lastEnd = matcher.getEnd();
            sliceOutput.appendBytes(sliceBetweenReplacements);
            JoniRegexpFunctions.appendReplacement((SliceOutput)sliceOutput, source, pattern.regex(), matcher.getEagerRegion(), replacement);
        }
        sliceOutput.appendBytes(source.slice(lastEnd, source.length() - lastEnd));
        return sliceOutput.slice();
    }

    private static void appendReplacement(SliceOutput result, Slice source, Regex pattern, Region region, Slice replacement) {
        int idx = 0;
        while (idx < replacement.length()) {
            byte nextByte = replacement.getByte(idx);
            if (nextByte == 36) {
                int backref;
                if (++idx == replacement.length()) {
                    throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.INVALID_FUNCTION_ARGUMENT, "Illegal replacement sequence: " + replacement.toStringUtf8());
                }
                nextByte = replacement.getByte(idx);
                if (nextByte == 123) {
                    int startCursor = ++idx;
                    while (idx < replacement.length() && (nextByte = replacement.getByte(idx)) != 125) {
                        ++idx;
                    }
                    byte[] groupName = replacement.getBytes(startCursor, idx - startCursor);
                    try {
                        backref = pattern.nameToBackrefNumber(groupName, 0, groupName.length, region);
                    }
                    catch (ValueException e) {
                        throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.INVALID_FUNCTION_ARGUMENT, "Illegal replacement sequence: unknown group { " + new String(groupName, StandardCharsets.UTF_8) + " }");
                    }
                    ++idx;
                } else {
                    int newBackref;
                    int nextDigit;
                    backref = nextByte - 48;
                    if (backref < 0 || backref > 9) {
                        throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.INVALID_FUNCTION_ARGUMENT, "Illegal replacement sequence: " + replacement.toStringUtf8());
                    }
                    if (region.numRegs <= backref) {
                        throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.INVALID_FUNCTION_ARGUMENT, "Illegal replacement sequence: unknown group " + backref);
                    }
                    ++idx;
                    while (idx < replacement.length() && (nextDigit = replacement.getByte(idx) - 48) >= 0 && nextDigit <= 9 && region.numRegs > (newBackref = backref * 10 + nextDigit)) {
                        backref = newBackref;
                        ++idx;
                    }
                }
                int beg = region.beg[backref];
                int end = region.end[backref];
                if (beg == -1 || end == -1) continue;
                result.appendBytes(source.slice(beg, end - beg));
                continue;
            }
            if (nextByte == 92) {
                if (++idx == replacement.length()) {
                    throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.INVALID_FUNCTION_ARGUMENT, "Illegal replacement sequence: " + replacement.toStringUtf8());
                }
                nextByte = replacement.getByte(idx);
            }
            result.appendByte((int)nextByte);
            ++idx;
        }
    }

    @Description(value="String(s) extracted using the given pattern")
    @ScalarFunction
    @LiteralParameters(value={"x"})
    @SqlType(value="array(varchar(x))")
    public static Block regexpExtractAll(@SqlType(value="varchar(x)") Slice source, @SqlType(value="JoniRegExp") JoniRegexp pattern) {
        return JoniRegexpFunctions.regexpExtractAll(source, pattern, 0L);
    }

    @Description(value="Group(s) extracted using the given pattern")
    @ScalarFunction
    @LiteralParameters(value={"x"})
    @SqlType(value="array(varchar(x))")
    public static Block regexpExtractAll(@SqlType(value="varchar(x)") Slice source, @SqlType(value="JoniRegExp") JoniRegexp pattern, @SqlType(value="bigint") long groupIndex) {
        int offset;
        Matcher matcher = pattern.matcher(source.getBytes());
        JoniRegexpFunctions.validateGroup(groupIndex, matcher.getEagerRegion());
        BlockBuilder blockBuilder = VarcharType.VARCHAR.createBlockBuilder(null, 32);
        int group = Math.toIntExact(groupIndex);
        int nextStart = 0;
        while ((offset = matcher.search(nextStart, source.length(), 0)) != -1) {
            nextStart = JoniRegexpFunctions.getNextStart(source, matcher);
            Region region = matcher.getEagerRegion();
            int beg = region.beg[group];
            int end = region.end[group];
            if (beg == -1 || end == -1) {
                blockBuilder.appendNull();
                continue;
            }
            Slice slice = source.slice(beg, end - beg);
            VarcharType.VARCHAR.writeSlice(blockBuilder, slice);
        }
        return blockBuilder.build();
    }

    @SqlNullable
    @Description(value="String extracted using the given pattern")
    @ScalarFunction
    @LiteralParameters(value={"x"})
    @SqlType(value="varchar(x)")
    public static Slice regexpExtract(@SqlType(value="varchar(x)") Slice source, @SqlType(value="JoniRegExp") JoniRegexp pattern) {
        return JoniRegexpFunctions.regexpExtract(source, pattern, 0L);
    }

    @SqlNullable
    @Description(value="Returns regex group of extracted string with a pattern")
    @ScalarFunction
    @LiteralParameters(value={"x"})
    @SqlType(value="varchar(x)")
    public static Slice regexpExtract(@SqlType(value="varchar(x)") Slice source, @SqlType(value="JoniRegExp") JoniRegexp pattern, @SqlType(value="bigint") long groupIndex) {
        Matcher matcher = pattern.matcher(source.getBytes());
        JoniRegexpFunctions.validateGroup(groupIndex, matcher.getEagerRegion());
        int group = Math.toIntExact(groupIndex);
        int offset = matcher.search(0, source.length(), 0);
        if (offset == -1) {
            return null;
        }
        Region region = matcher.getEagerRegion();
        int beg = region.beg[group];
        int end = region.end[group];
        if (beg == -1) {
            return null;
        }
        return source.slice(beg, end - beg);
    }

    @ScalarFunction
    @LiteralParameters(value={"x"})
    @Description(value="Returns array of strings split by pattern")
    @SqlType(value="array(varchar(x))")
    public static Block regexpSplit(@SqlType(value="varchar(x)") Slice source, @SqlType(value="JoniRegExp") JoniRegexp pattern) {
        int offset;
        Matcher matcher = pattern.matcher(source.getBytes());
        BlockBuilder blockBuilder = VarcharType.VARCHAR.createBlockBuilder(null, 32);
        int lastEnd = 0;
        int nextStart = 0;
        while ((offset = matcher.search(nextStart, source.length(), 0)) != -1) {
            nextStart = JoniRegexpFunctions.getNextStart(source, matcher);
            Slice slice = source.slice(lastEnd, matcher.getBegin() - lastEnd);
            lastEnd = matcher.getEnd();
            VarcharType.VARCHAR.writeSlice(blockBuilder, slice);
        }
        VarcharType.VARCHAR.writeSlice(blockBuilder, source.slice(lastEnd, source.length() - lastEnd));
        return blockBuilder.build();
    }

    private static void validateGroup(long group, Region region) {
        if (group < 0L) {
            throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.INVALID_FUNCTION_ARGUMENT, "Group cannot be negative");
        }
        if (group > (long)(region.numRegs - 1)) {
            throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.INVALID_FUNCTION_ARGUMENT, String.format("Pattern has %d groups. Cannot access group %d", region.numRegs - 1, group));
        }
    }

    @ScalarFunction
    @Description(value="Returns the index of the matched substring")
    @LiteralParameters(value={"x"})
    @SqlType(value="integer")
    public static long regexpPosition(@SqlType(value="varchar(x)") Slice source, @SqlType(value="JoniRegExp") JoniRegexp pattern) {
        return JoniRegexpFunctions.regexpPosition(source, pattern, 1L);
    }

    @ScalarFunction
    @Description(value="Returns the index of the matched substring starting from the specified position")
    @LiteralParameters(value={"x"})
    @SqlType(value="integer")
    public static long regexpPosition(@SqlType(value="varchar(x)") Slice source, @SqlType(value="JoniRegExp") JoniRegexp pattern, @SqlType(value="integer") long start) {
        return JoniRegexpFunctions.regexpPosition(source, pattern, start, 1L);
    }

    @ScalarFunction
    @Description(value="Returns the index of the n-th matched substring starting from the specified position")
    @LiteralParameters(value={"x"})
    @SqlType(value="integer")
    public static long regexpPosition(@SqlType(value="varchar(x)") Slice source, @SqlType(value="JoniRegExp") JoniRegexp pattern, @SqlType(value="integer") long start, @SqlType(value="integer") long occurrence) {
        if (start < 1L) {
            throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.INVALID_FUNCTION_ARGUMENT, "start position cannot be smaller than 1");
        }
        if (occurrence < 1L) {
            throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.INVALID_FUNCTION_ARGUMENT, "occurrence cannot be smaller than 1");
        }
        if (start > (long)SliceUtf8.countCodePoints((Slice)source)) {
            return -1L;
        }
        Matcher matcher = pattern.matcher(source.getBytes());
        long count = 0L;
        int nextStart = SliceUtf8.offsetOfCodePoint((Slice)source, (int)((int)start - 1));
        int offset;
        while ((offset = matcher.search(nextStart, source.length(), 0)) >= 0) {
            if (++count == occurrence) {
                return SliceUtf8.countCodePoints((Slice)source, (int)0, (int)matcher.getBegin()) + 1;
            }
            nextStart = JoniRegexpFunctions.getNextStart(source, matcher);
        }
        return -1L;
    }

    @ScalarFunction
    @Description(value="Returns the number of times that a pattern occurs in a string")
    @LiteralParameters(value={"x"})
    @SqlType(value="bigint")
    public static long regexpCount(@SqlType(value="varchar(x)") Slice source, @SqlType(value="JoniRegExp") JoniRegexp pattern) {
        int offset;
        Matcher matcher = pattern.matcher(source.getBytes());
        int count = 0;
        int nextStart = 0;
        while ((offset = matcher.search(nextStart, source.length(), 0)) >= 0) {
            nextStart = JoniRegexpFunctions.getNextStart(source, matcher);
            ++count;
        }
        return count;
    }
}

