/*
 * Decompiled with CFR 0.152.
 */
package tigase.io;

import java.io.IOException;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;
import java.nio.charset.StandardCharsets;
import java.util.function.BiConsumer;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import tigase.io.IOInterface;
import tigase.stats.StatisticsList;
import tigase.util.Algorithms;

public class ProxyIO
implements IOInterface {
    private static final Logger log = Logger.getLogger(ProxyIO.class.getName());
    private static final byte[] SIGNATURE_1 = "PROXY".getBytes(StandardCharsets.UTF_8);
    private static final byte[] SIGNATURE_2 = new byte[]{13, 10, 13, 10, 0, 13, 10, 81, 85, 73, 84, 10};
    private static final int PROXY_1_FIELD_PROTOCOL = 1;
    private static final int PROXY_1_FIELD_SRC_IP = 2;
    private static final int PROXY_1_FIELD_DST_IP = 3;
    private static final int PROXY_1_FIELD_SRC_PORT = 4;
    private static final int PROXY_1_FIELD_DST_PORT = 5;
    private final IOInterface io;
    private State state = State.NEW;
    private byte[] partialData = null;
    private BiConsumer<String, String> addressConsumer = null;
    private Proxy2Header proxy2Header;

    public ProxyIO(IOInterface io, BiConsumer<String, String> addressConsumer) {
        this.io = io;
        this.partialData = null;
        this.addressConsumer = addressConsumer;
    }

    @Override
    public int bytesRead() {
        return this.io.bytesRead();
    }

    @Override
    public boolean checkCapabilities(String caps) {
        return caps.contains("PROXY") || this.io.checkCapabilities(caps);
    }

    @Override
    public int getInputPacketSize() throws IOException {
        return this.io.getInputPacketSize();
    }

    @Override
    public SocketChannel getSocketChannel() {
        return this.io.getSocketChannel();
    }

    @Override
    public void getStatistics(StatisticsList list, boolean reset) {
        this.io.getStatistics(list, reset);
    }

    @Override
    public long getBytesSent(boolean reset) {
        return this.io.getBytesSent(reset);
    }

    @Override
    public long getTotalBytesSent() {
        return this.io.getTotalBytesSent();
    }

    @Override
    public long getBytesReceived(boolean reset) {
        return this.io.getBytesReceived(reset);
    }

    @Override
    public long getTotalBytesReceived() {
        return this.io.getTotalBytesReceived();
    }

    @Override
    public long getBuffOverflow(boolean reset) {
        return this.io.getBuffOverflow(reset);
    }

    @Override
    public long getTotalBuffOverflow() {
        return this.io.getTotalBuffOverflow();
    }

    @Override
    public boolean isConnected() {
        return this.io.isConnected();
    }

    @Override
    public boolean isRemoteAddress(String addr) {
        return this.io.isRemoteAddress(addr);
    }

    @Override
    public ByteBuffer read(ByteBuffer buff) throws IOException {
        if (this.state != State.DONE) {
            ByteBuffer buf;
            ByteBuffer tmpBuffer = this.io.read(buff);
            if (this.partialData == null) {
                buf = tmpBuffer;
            } else {
                buf = ByteBuffer.allocate(this.partialData.length + tmpBuffer.remaining());
                buf.put(this.partialData);
                buf.put(tmpBuffer);
                buf.flip();
                tmpBuffer.clear();
                this.partialData = null;
            }
            buf.mark();
            if (this.state == State.NEW) {
                boolean proxy1Checked = false;
                boolean proxy2Checked = false;
                if (buf.remaining() >= SIGNATURE_1.length) {
                    proxy1Checked = true;
                    if (IntStream.range(0, SIGNATURE_1.length).allMatch(i -> buf.get(i) == SIGNATURE_1[i])) {
                        this.state = State.PROXY_1;
                    }
                }
                if (buf.remaining() >= SIGNATURE_2.length) {
                    proxy2Checked = true;
                    if (IntStream.range(0, SIGNATURE_2.length).allMatch(i -> buf.get(i) == SIGNATURE_2[i])) {
                        this.state = State.PROXY_2;
                    }
                }
                if (proxy1Checked && proxy2Checked) {
                    if (this.state == State.NEW) {
                        this.state = State.DONE;
                        return buf;
                    }
                } else {
                    this.partialData = new byte[buf.remaining()];
                    buf.get(this.partialData);
                    return tmpBuffer;
                }
            }
            return switch (this.state.ordinal()) {
                case 1 -> {
                    String[] fields = new String[6];
                    int fieldIndex = 0;
                    StringBuilder builder = new StringBuilder();
                    boolean awaitLineFeed = false;
                    while (buf.hasRemaining()) {
                        byte b = buf.get();
                        if (!awaitLineFeed) {
                            if (b == 32 || b == 13) {
                                fields[fieldIndex] = builder.toString();
                                ++fieldIndex;
                                builder.setLength(0);
                                if (b != 13) continue;
                                awaitLineFeed = true;
                                continue;
                            }
                            if (b < 32) {
                                this.state = State.DONE;
                                buf.reset();
                                yield buf;
                            }
                            builder.append((char)b);
                            continue;
                        }
                        if (b == 10) {
                            if (!"PROXY".equals(fields[0])) {
                                log.log(Level.FINE, () -> "Not a PROXY protocol!");
                                this.state = State.DONE;
                                buf.reset();
                                yield buf;
                            }
                            if ("UNKNOWN".equals(fields[1])) {
                                this.state = State.DONE;
                                yield buf;
                            }
                            if (this.addressConsumer != null) {
                                this.addressConsumer.accept(fields[3], fields[2]);
                            }
                            yield buf;
                        }
                        buf.reset();
                        this.partialData = new byte[buf.remaining()];
                        buf.get(this.partialData);
                        yield tmpBuffer;
                    }
                    buf.reset();
                    this.partialData = new byte[buf.remaining()];
                    buf.get(this.partialData);
                    yield tmpBuffer;
                }
                case 2 -> {
                    if (this.proxy2Header == null) {
                        buf.position(buf.position() + SIGNATURE_2.length);
                        if (buf.remaining() < 4) {
                            buf.reset();
                            this.partialData = new byte[buf.remaining()];
                            buf.get(this.partialData);
                            yield tmpBuffer;
                        }
                        int verAndCmd = 0xFF & buf.get();
                        if ((verAndCmd & 0xF0) != 32) {
                            log.log(Level.FINE, () -> "Bad Proxy v2 version");
                            this.state = State.DONE;
                            buf.reset();
                            yield buf;
                        }
                        boolean isLocal = (verAndCmd & 0xF0) == 0;
                        int transportAndFamily = 0xFF & buf.get();
                        Family v1 = switch (transportAndFamily >> 4) {
                            case 0 -> Family.UNSPECIFIED;
                            case 1 -> Family.INET;
                            case 2 -> Family.INET6;
                            case 3 -> Family.UNIX;
                            default -> throw new IOException("Bad Proxy Family value");
                        };
                        Family family = v1;
                        Transport v2 = switch (transportAndFamily & 0xF) {
                            case 0 -> Transport.UNSPECIFIED;
                            case 1 -> Transport.STREAM;
                            case 2 -> Transport.DATAGRAM;
                            default -> throw new IOException("Bad Proxy Transport value");
                        };
                        Transport transport = v2;
                        if (!(isLocal || family != Family.UNSPECIFIED && transport == Transport.STREAM)) {
                            throw new IOException("Unsupported Proxy mode");
                        }
                        char length = buf.getChar();
                        if (length > '\u0400') {
                            throw new IOException("Unsupported Proxy header length - too long");
                        }
                        this.proxy2Header = new Proxy2Header(transport, family, isLocal, length);
                        buf.mark();
                    }
                    if (buf.remaining() < this.proxy2Header.length) {
                        buf.reset();
                        this.partialData = new byte[buf.remaining()];
                        buf.get(this.partialData);
                        yield tmpBuffer;
                    }
                    int nonProxyRemaining = buf.remaining() - this.proxy2Header.length;
                    if (this.proxy2Header.isLocal) {
                        buf.position(buf.position() + this.proxy2Header.length);
                    } else {
                        int dataLength = this.proxy2Header.family.getAddressLength();
                        if (dataLength < 0) {
                            throw new IOException("Unsupported socket address");
                        }
                        byte[] data = new byte[dataLength];
                        buf.get(data);
                        InetAddress srcAddr = this.proxy2Header.family.getByAddress(data);
                        buf.get(data);
                        InetAddress dstAddr = this.proxy2Header.family.getByAddress(data);
                        char srcPort = buf.getChar();
                        char dstPort = buf.getChar();
                        InetSocketAddress local = new InetSocketAddress(dstAddr, (int)dstPort);
                        InetSocketAddress remote = new InetSocketAddress(srcAddr, (int)srcPort);
                        if (this.addressConsumer != null) {
                            this.addressConsumer.accept(local.getHostString(), remote.getHostString());
                        }
                    }
                    while (buf.remaining() > nonProxyRemaining) {
                        int type = 0xFF & buf.get();
                        char length = buf.getChar();
                        byte[] value = new byte[length];
                        buf.get(value);
                        log.log(Level.FINEST, () -> String.format("Proxy v2 T=%x L=%d V=%s for %s", type, length, Algorithms.bytesToHex((byte[])value), this));
                    }
                    this.state = State.DONE;
                    yield buf;
                }
                default -> buf;
            };
        }
        return this.io.read(buff);
    }

    @Override
    public void stop() throws IOException {
        this.io.stop();
    }

    @Override
    public boolean waitingToSend() {
        return this.io.waitingToSend();
    }

    @Override
    public int waitingToSendSize() {
        return this.io.waitingToSendSize();
    }

    @Override
    public int write(ByteBuffer buff) throws IOException {
        return this.io.write(buff);
    }

    @Override
    public void setLogId(String logId) {
        this.io.setLogId(logId);
    }

    private static enum State {
        NEW,
        PROXY_1,
        PROXY_2,
        DONE;

    }

    record Proxy2Header(Transport transport, Family family, boolean isLocal, int length) {
    }

    static enum Family {
        UNSPECIFIED(-1),
        INET(4),
        INET6(16),
        UNIX(108);

        private int addrLen;

        private Family(int addrLen) {
            this.addrLen = addrLen;
        }

        public int getAddressLength() {
            return this.addrLen;
        }

        public InetAddress getByAddress(byte[] addr) throws IOException {
            return switch (this.ordinal()) {
                case 1 -> InetAddress.getByAddress(addr);
                case 2 -> Inet6Address.getByAddress(addr);
                default -> throw new IOException("Unsupported socket address");
            };
        }
    }

    static enum Transport {
        UNSPECIFIED,
        STREAM,
        DATAGRAM;

    }
}

