/*
 * Decompiled with CFR 0.152.
 */
package com.pushtechnology.diffusion.comms.websocket;

import com.pushtechnology.diffusion.comms.connection.ConnectionInfo;
import com.pushtechnology.diffusion.comms.websocket.MaxMessageSizeException;
import com.pushtechnology.diffusion.io.bytebuffer.serialisation.DeserialisationException;
import com.pushtechnology.diffusion.message.Message;
import com.pushtechnology.diffusion.utils.ConfigurationUtils;
import com.pushtechnology.diffusion.utils.unsafe.UnsafeDirectByteBuffer;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.IntFunction;
import net.jcip.annotations.Immutable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AbstractWebSocketFrameCodec<R, M extends Message, T extends Exception> {
    private static final Logger LOG = LoggerFactory.getLogger(AbstractWebSocketFrameCodec.class);
    public static final String REQUIRE_WEBSOCKET_MASK_PROPERTY = "diffusion.connection.websockets.requireMask";
    private static final boolean REQUIRE_WEBSOCKET_MASK = ConfigurationUtils.getBooleanSystemProperty("diffusion.connection.websockets.requireMask");
    public static final String DISABLE_RANDOM_MASK_PROPERTY = "diffusion.connection.websockets.zeroClientMask";
    public static final String FRAGMENT_WEIGHT_PROPERTY = "diffusion.ws.fragmentweight";
    private static final int FRAGMENT_WEIGHT = ConfigurationUtils.getIntegerSystemProperty("diffusion.ws.fragmentweight", 50);
    private final Delegate<R, T> delegate;
    private final int maximumMessageSize;
    private final boolean isClient;
    private WriteFrameDelegate<M> writeFrameDelegate;

    public AbstractWebSocketFrameCodec(Delegate<R, T> delegate, int maximumMessageSize, boolean isClient) {
        this.delegate = delegate;
        this.maximumMessageSize = maximumMessageSize;
        this.isClient = isClient;
        this.writeFrameDelegate = !isClient ? this::writeFrameServer : (ConfigurationUtils.getBooleanSystemProperty(DISABLE_RANDOM_MASK_PROPERTY) ? this::writeFrameZeroMask : this::writeFrameRandomMask);
    }

    protected abstract R discarded();

    protected abstract R insufficientData();

    protected abstract void writeMessage(ByteBuffer var1, M var2, ConnectionInfo var3);

    protected R callDelegate(ByteBuffer payload, int overhead) throws T {
        return this.delegate.apply(payload, overhead);
    }

    public final R readFrame(ByteBuffer buffer, IntFunction<R> closeFrameHandler) throws DeserialisationException, T {
        int start = buffer.position();
        try {
            WebSocketFrameData frame = this.parseFrame(buffer);
            if (frame == null) {
                buffer.position(start);
                buffer.compact();
                return this.insufficientData();
            }
            frame.checkInitialOpcode();
            if (!frame.isFin()) {
                return this.handleFragment(frame, buffer, closeFrameHandler);
            }
            frame.maskPayload(buffer);
            int oldLimit = buffer.limit();
            int nextFrame = frame.nextFrame();
            buffer.limit(nextFrame);
            R controlResult = this.handleControlFrame(frame.opcode(), buffer, closeFrameHandler);
            if (controlResult != null) {
                buffer.limit(oldLimit).position(nextFrame);
                return controlResult;
            }
            return this.callDelegate(buffer, nextFrame, oldLimit, buffer.position() - start);
        }
        catch (BufferUnderflowException e) {
            buffer.position(start);
            buffer.compact();
            return this.insufficientData();
        }
    }

    private R callDelegate(ByteBuffer payload, int nextFrame, int outerLimit, int overhead) throws T {
        R result = this.callDelegate(payload, overhead);
        if (result == this.insufficientData()) {
            if (nextFrame < outerLimit) {
                payload.limit(outerLimit).position(nextFrame);
                return this.discarded();
            }
            payload.clear();
            return this.insufficientData();
        }
        payload.limit(outerLimit).position(nextFrame);
        return result;
    }

    private R handleControlFrame(byte opcode, ByteBuffer payloadData, IntFunction<R> closeFrameHandler) {
        switch (opcode) {
            case 8: {
                return closeFrameHandler.apply(AbstractWebSocketFrameCodec.readResponseCode(payloadData));
            }
            case 9: {
                LOG.trace("Ignoring PING frame");
                return this.discarded();
            }
            case 10: {
                LOG.trace("Ignoring PONG frame");
                return this.discarded();
            }
        }
        return null;
    }

    private R handleFragment(WebSocketFrameData first, ByteBuffer buffer, IntFunction<R> closeFrameHandler) throws DeserialisationException, T {
        ArrayList<WebSocketFrameData> frames = new ArrayList<WebSocketFrameData>();
        int combinedLength = first.length();
        frames.add(first);
        buffer.position(first.nextFrame());
        int fragments = 1;
        while (buffer.hasRemaining()) {
            WebSocketFrameData frame = this.parseFrame(buffer);
            if (frame == null) {
                throw new BufferUnderflowException();
            }
            frames.add(frame);
            buffer.position(frame.nextFrame());
            byte opcode = frame.opcode();
            if (opcode == 0) {
                ++fragments;
                if ((combinedLength += frame.length()) > this.maximumMessageSize) {
                    throw MaxMessageSizeException.logWithoutStackTrace("Received fragmented WebSocket data of " + combinedLength + " bytes exceeding the configured maximum message size of " + this.maximumMessageSize + " bytes");
                }
                if (fragments > 3 && FRAGMENT_WEIGHT * fragments > this.maximumMessageSize) {
                    throw new DeserialisationException("Processed fragments exceed allowed complexity: " + combinedLength + " bytes split into " + frames.size() + " fragments");
                }
                if (!frame.isFin()) continue;
                return this.handleCompleteFragmentList(frames, buffer, closeFrameHandler);
            }
            if (AbstractWebSocketFrameCodec.isControl(opcode)) continue;
            throw new DeserialisationException("A frame of opcode " + Integer.toHexString(opcode) + " has interleaved a fragmented frame message.");
        }
        throw new BufferUnderflowException();
    }

    private R handleCompleteFragmentList(List<WebSocketFrameData> frames, ByteBuffer buffer, IntFunction<R> closeFrameHandler) throws DeserialisationException, T {
        ByteBuffer combinedPayload = buffer.duplicate();
        combinedPayload.position(0);
        ByteBuffer framePayload = buffer.duplicate();
        for (WebSocketFrameData frame : frames) {
            framePayload.limit(frame.nextFrame()).position(frame.payloadStart());
            frame.maskPayload(framePayload);
            R controlResult = this.handleControlFrame(frame.opcode(), framePayload, closeFrameHandler);
            if (controlResult == null) {
                combinedPayload.put(framePayload);
                continue;
            }
            if (controlResult == this.discarded()) continue;
            return controlResult;
        }
        combinedPayload.flip();
        int nextFrame = framePayload.limit();
        return this.callDelegate(combinedPayload, nextFrame, buffer.limit(), nextFrame - combinedPayload.limit());
    }

    private static boolean isControl(int opCode) {
        switch (opCode) {
            case 8: 
            case 9: 
            case 10: {
                return true;
            }
        }
        return false;
    }

    private WebSocketFrameData parseFrame(ByteBuffer buffer) throws DeserialisationException {
        byte frameByte = buffer.get();
        byte lengthAndMask = buffer.get();
        int length = this.readMessageLength(buffer, (byte)(0x7F & lengthAndMask));
        byte opcode = (byte)(frameByte & 0xF);
        int mask = this.readMask(buffer, lengthAndMask);
        if (buffer.remaining() < length) {
            return null;
        }
        return new WebSocketFrameData(frameByte, length, opcode, mask, buffer.position());
    }

    private static int readResponseCode(ByteBuffer payloadData) {
        if (payloadData.remaining() >= 2) {
            return payloadData.getShort() & 0xFFFF;
        }
        return 1005;
    }

    private int readMessageLength(ByteBuffer buffer, byte lengthLength) throws DeserialisationException {
        long result;
        if (lengthLength == 126) {
            result = buffer.getShort() & 0xFFFF;
        } else if (lengthLength == 127) {
            result = buffer.getLong();
            if (result < 0L) {
                throw new DeserialisationException("Invalid length");
            }
        } else {
            result = lengthLength;
        }
        if (result >= (long)this.maximumMessageSize) {
            throw MaxMessageSizeException.logWithoutStackTrace("Received WebSocket data frame of " + result + " bytes exceeding the configured maximum message size of " + this.maximumMessageSize + " bytes");
        }
        return (int)result;
    }

    private int readMask(ByteBuffer buffer, byte lengthAndMask) throws DeserialisationException {
        if (this.isClient) {
            if (AbstractWebSocketFrameCodec.hasMask(lengthAndMask)) {
                throw new DeserialisationException("Received masked message from server");
            }
        } else {
            if (AbstractWebSocketFrameCodec.hasMask(lengthAndMask)) {
                return buffer.getInt();
            }
            if (REQUIRE_WEBSOCKET_MASK) {
                throw new DeserialisationException("Received unmasked message from client");
            }
        }
        return 0;
    }

    private static boolean hasMask(byte lengthAndMaskFlag) {
        return (lengthAndMaskFlag & 0xFFFFFF80) != 0;
    }

    public int wsFrameLength(int payloadSize) {
        int maskSize;
        int n = maskSize = this.isClient ? 4 : 0;
        if (payloadSize < 126) {
            return payloadSize + 2 + maskSize;
        }
        if (payloadSize < 65536) {
            return payloadSize + 4 + maskSize;
        }
        return payloadSize + 10 + maskSize;
    }

    public void writeFrame(ConnectionInfo connectionInfo, M message, int messageSize, ByteBuffer outputBuffer) {
        outputBuffer.put((byte)-126);
        this.writeFrameDelegate.write(connectionInfo, message, messageSize, outputBuffer);
    }

    private void writeFrameServer(ConnectionInfo connectionInfo, M message, int messageSize, ByteBuffer outputBuffer) {
        AbstractWebSocketFrameCodec.writeWSFrameLength(outputBuffer, messageSize, (byte)0);
        this.writeMessage(outputBuffer, message, connectionInfo);
    }

    private void writeFrameRandomMask(ConnectionInfo connectionInfo, M message, int messageSize, ByteBuffer outputBuffer) {
        int mask = ThreadLocalRandom.current().nextInt();
        AbstractWebSocketFrameCodec.writeFrameLengthAndMask(outputBuffer, messageSize, mask);
        this.writeMessageWithMask(outputBuffer, message, connectionInfo, mask);
    }

    private void writeFrameZeroMask(ConnectionInfo connectionInfo, M message, int messageSize, ByteBuffer outputBuffer) {
        AbstractWebSocketFrameCodec.writeFrameLengthAndMask(outputBuffer, messageSize, 0);
        this.writeMessage(outputBuffer, message, connectionInfo);
    }

    private static void writeFrameLengthAndMask(ByteBuffer outputBuffer, int messageSize, int mask) {
        AbstractWebSocketFrameCodec.writeWSFrameLength(outputBuffer, messageSize, (byte)-128);
        outputBuffer.putInt(mask);
    }

    private static void writeWSFrameLength(ByteBuffer buffer, int length, byte maskFlag) {
        if (length < 126) {
            buffer.put((byte)(maskFlag | (byte)length));
        } else if (length < 65536) {
            buffer.put((byte)(maskFlag | 0x7E));
            buffer.putShort((short)length);
        } else {
            buffer.put((byte)(maskFlag | 0x7F));
            buffer.putLong(length);
        }
    }

    void writeMessageWithMask(ByteBuffer buffer, M message, ConnectionInfo connectionInfo, int intMask) {
        int startOffset = buffer.position();
        this.writeMessage(buffer, message, connectionInfo);
        int endOffset = buffer.position();
        AbstractWebSocketFrameCodec.maskRange(buffer, intMask, startOffset, endOffset);
    }

    private static void maskRange(ByteBuffer buffer, int intMask, int startOffset, int endOffset) {
        long address = buffer.hasArray() ? (long)buffer.arrayOffset() : UnsafeDirectByteBuffer.getAddress(buffer);
        long mask = (long)intMask << 32 | (long)intMask & 0xFFFFFFFFL;
        int startAlignmentOffset = (int)((address + (long)startOffset) % 8L);
        int endAlignmentOffset = (int)((address + (long)endOffset) % 8L);
        int numLeadingUnalignedBytes = (8 - startAlignmentOffset) % 8;
        int alignedBytesStartOffset = Math.min(startOffset + numLeadingUnalignedBytes, endOffset);
        AbstractWebSocketFrameCodec.maskByteByByte(buffer, mask, startOffset, alignedBytesStartOffset);
        int alignedBytesEndOffset = Math.max(endOffset - endAlignmentOffset, alignedBytesStartOffset);
        long alignedMask = Long.rotateLeft(mask, 8 * numLeadingUnalignedBytes);
        AbstractWebSocketFrameCodec.maskAlignedRegion(buffer, alignedMask, alignedBytesStartOffset, alignedBytesEndOffset);
        AbstractWebSocketFrameCodec.maskByteByByte(buffer, alignedMask, alignedBytesEndOffset, endOffset);
    }

    private static void maskByteByByte(ByteBuffer buffer, long mask, int startOffset, int endOffset) {
        assert (endOffset - startOffset < 8);
        int i = startOffset;
        int maskOffset = 0;
        while (i < endOffset) {
            buffer.put(i, (byte)(buffer.get(i) ^ (byte)(0xFFL & mask >> 56 - maskOffset)));
            ++i;
            maskOffset += 8;
        }
    }

    private static void maskAlignedRegion(ByteBuffer buffer, long mask, int startOffset, int endOffset) {
        if (endOffset > startOffset) {
            buffer.order(ByteOrder.nativeOrder());
            long maskLong = ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN ? Long.reverseBytes(mask) : mask;
            for (int i = startOffset; i < endOffset; i += 8) {
                buffer.putLong(i, buffer.getLong(i) ^ maskLong);
            }
            buffer.order(ByteOrder.BIG_ENDIAN);
        }
    }

    @FunctionalInterface
    public static interface Delegate<R, T extends Exception> {
        public R apply(ByteBuffer var1, int var2) throws T;
    }

    @FunctionalInterface
    private static interface WriteFrameDelegate<M> {
        public void write(ConnectionInfo var1, M var2, int var3, ByteBuffer var4);
    }

    @Immutable
    private static final class WebSocketFrameData {
        private final byte frameByte;
        private final int length;
        private final byte opcode;
        private final int mask;
        private final int payloadStart;

        WebSocketFrameData(byte frameByte, int length, byte opcode, int mask, int payloadStart) {
            this.frameByte = frameByte;
            this.length = length;
            this.opcode = opcode;
            this.mask = mask;
            this.payloadStart = payloadStart;
        }

        int length() {
            return this.length;
        }

        byte opcode() {
            return this.opcode;
        }

        int payloadStart() {
            return this.payloadStart;
        }

        int nextFrame() {
            return this.payloadStart + this.length;
        }

        boolean isFin() {
            return (this.frameByte & 0xFFFFFF80) != 0;
        }

        void checkInitialOpcode() throws DeserialisationException {
            switch (this.opcode) {
                case 1: 
                case 2: 
                case 8: 
                case 9: 
                case 10: {
                    break;
                }
                default: {
                    throw new DeserialisationException("A websocket frame [" + Integer.toHexString(this.frameByte) + "||" + this.frameByte + "] with unknown opcode " + Integer.toHexString(this.opcode) + " was received, likely data corruption");
                }
            }
        }

        void maskPayload(ByteBuffer payloadData) {
            if (this.mask != 0) {
                AbstractWebSocketFrameCodec.maskRange(payloadData, this.mask, this.payloadStart, this.payloadStart + this.length);
            }
        }
    }
}

