# /*
#  * @Author: John Lyu
#  * @Date: 2020-10-10 19:44:59
#  * @Last Modified by:   John Lyu
#  * @Last Modified time: 2020-10-10 19:44:59
#  */
"""Create dynamic network layers of
Application Layer
Transport Layer
Network Layer
Data Link Layer
Physical Layer (No Need to implement)
"""

import ipaddress


class LayerParser(object):
    """
    Base layer class
    """

    def __init__(self, data):
        """
        docstring
        """
        self.data = data
        self.header_struct = self.get_header_struct()
        self.sliced_data = {}
        self.analysis()

    def get_header_struct(self):
        # struck lenght is bits
        # this will keep init order only with python version > 3.5
        return {}

    def pretty_print(self):
        for part, data in self.sliced_data.items():
            print(f"{part}: {hex(int(data, 2))}")

    def analysis(self):
        current_pos = 0
        for part, length in self.header_struct.items():
            end_pos = current_pos + length
            self.sliced_data[part] = self.data[current_pos:end_pos]
            current_pos = end_pos

    def to_upper_layer(self):
        pass


class EthernetParser(LayerParser):
    """
    Ethernet_II Layer
    refer to https://en.wikipedia.org/wiki/Ethernet_frame
    in this layer, MAC address is used as id
    """

    def get_header_struct(self):
        return {
            "DMAC": 6 * 8,
            "SMAC": 6 * 8,
            "Type": 2 * 8,
            # there may be extended header here, but not implemented today
        }

    def next_protocal(self):
        type_id = int(self.sliced_data["Type"], 2)
        if type_id == 0x0800:
            return "ipv4"
        elif type_id == 0x0806:
            return "arp"
        elif type_id == 0x86DD:
            return "ipv6"

    def to_upper_layer(self):
        return self.data[14 * 8:]


class LoopbackParser(LayerParser):
    """
    Virtual Ethernet Layer, implemented for localhost
    the first 4 bytes are
    0000   1e 00 00 00 
    """
    HINT = "00011110" + "0" * (3 * 8)

    def to_upper_layer(self):
        return self.data[4 * 8:]

    def next_protocal(self):
        return "ipv6"


class IPv6Parser(LayerParser):
    """
    NetworkParser, deal with IPv6
    refer to https://en.wikipedia.org/wiki/IPv6_packet
    """

    def get_header_struct(self):
        return {
            "protocol_version": 1 * 4,
            "traffic_class": 2 * 4,
            "flow_label": 5 * 4,
            "payload_length": 4 * 4,
            "next_protocol": 2 * 4,  # 0x06 is TCP and 0x11 is UDP
            "hop_limit": 2 * 4,
            "source_address": 32 * 4,
            "dst_address": 32 * 4,
            # there may be extended header here, but not implemented today
        }

    def to_upper_layer(self):
        return self.data[40 * 8:]

    @property
    def s_ip(self):
        return ipaddress.IPv6Address(
            int(self.sliced_data["source_address"], 2))

    @property
    def d_ip(self):
        return ipaddress.IPv6Address(
            int(self.sliced_data["dst_address"], 2))


class IPv4Parser(LayerParser):
    """
    NetworkParser, deal with IPv4
    refer to https://en.wikipedia.org/wiki/IPv4#Header
    """

    def get_header_struct(self):
        return {
            "protocol_version": 1 * 4,
            "IHL": 1 * 4,
            "DSCP": 6,
            "ECN": 2,
            "Total_Length": 2 * 8,
            "Identification": 2 * 8,
            "Flags": 3,
            "Fragment Offset": 2 * 8 - 3,
            "Time To Live": 8,
            "next_protocol": 8,
            "Header_Checksum": 2 * 8,
            "source_address": 4 * 8,
            "dst_address": 4 * 8,
            # there may be extended header here, but not implemented today
        }

    def analysis(self):
        super().analysis()
        self.header_length = int(self.sliced_data["IHL"], 2) * 32

    def to_upper_layer(self):
        return self.data[self.header_length:]

    @property
    def s_ip(self):
        return ipaddress.IPv4Address(
            int(self.sliced_data["source_address"], 2))

    @property
    def d_ip(self):
        return ipaddress.IPv4Address(
            int(self.sliced_data["dst_address"], 2))


class TCPParser(LayerParser):
    """
    TCP layer data Parser
    refer to https://en.wikipedia.org/wiki/Transmission_Control_Protocol
    """
    # todo get full tcp request instead of one packet
    def get_header_struct(self):
        return {
            "source_port": 8 * 2,
            "dst_port": 8 * 2,
            "sequence_number": 8 * 4,
            "ack_number": 8 * 4,
            "data_offset": 4,
            "reserved": 3,
            "ns": 1,
            "cwr": 1,
            "ece": 1,
            "urg": 1,
            "ack": 1,
            "psh": 1,
            "rst": 1,
            "syn": 1,
            "fin": 1,
            "window_size": 8 * 2,
            "checksum": 8 * 2,
            "urgent_pointer": 8 * 2
        }

    def analysis(self):
        super().analysis()
        self.header_length = int(self.sliced_data["data_offset"], 2) * 32

    def to_upper_layer(self):
        return self.data[header_length:]


class MyShark(object):
    """
    Oh My WireShark!
    """

    def __init__(self, packets):
        self.packets = packets

    def get_stream(self, packet):
        stream = [packet]
        s_ip = packet.s_ip
        d_ip = packet.d_ip
        s_port = packet.s_port
        d_port = packet.d_port
        # todo detect FIN
        for p in self.packets:
            if all([
                    s_ip == p.s_ip,
                    d_ip == p.d_ip,
                    s_port == p.s_port,
                    d_port == p.d_port,
            ]):
                stream.append(p)
        return stream

    def summary(self):
        """pretty print summarize info"""
        # could use tree here for better performance
        tag_dict = {}
        for p in shark.packets:
            src = str(p.s_ip) + ':' + str(p.s_port)
            dst = str(p.d_ip) + ':' + str(p.d_port)
            tag = f"source: {src:<23} dst: {dst:<23} protocol: {p.protocol}"
            if tag in tag_dict.keys():
                tag_dict[tag] += 1
            else:
                tag_dict[tag] = 1
        for t, c in tag_dict.items():
            print(t, "    Count: ", c)


class MyPacket(object):
    """
    Store all layer info in this class
    """

    def __init__(self, pyshark_packet):
        """
        init packet with bin string
        """
        self.packet = pyshark_packet
        rb = self.packet.get_raw_packet()
        raw_hex = rb.hex()
        raw_bin = bin(int(raw_hex, base=16))[2:].zfill(len(rb) * 8)

        # I am not sure how to identity ethernet layer protocol
        if raw_bin.startswith(LoopbackParser.HINT):
            self.ethernet_layer = LoopbackParser(raw_bin)
        else:
            self.ethernet_layer = EthernetParser(raw_bin)

        # switch ip layer version
        if self.ethernet_layer.next_protocal() == "ipv6":
            data = self.ethernet_layer.to_upper_layer()
            self.ip_layer = IPv6Parser(data)
        elif self.ethernet_layer.next_protocal() == "ipv4":
            data = self.ethernet_layer.to_upper_layer()
            self.ip_layer = IPv4Parser(data)
        else:
            raise ValueError("protrocal {} is not implemented".format(
                self.ethernet_layer.next_protocal))

        tcp_data = self.ip_layer.to_upper_layer()
        self.tcp_layer = TCPParser(tcp_data)

    @property
    def s_ip(self):
        return self.ip_layer.s_ip

    @property
    def d_ip(self):
        return self.ip_layer.d_ip

    @property
    def s_port(self):
        return int(self.tcp_layer.sliced_data["source_port"], 2)

    @property
    def d_port(self):
        return int(self.tcp_layer.sliced_data["dst_port"], 2)

    @property
    def protocol(self):
        code = self.ip_layer.sliced_data["next_protocol"]
        code = int(code, 2)
        if code == 0x06:
            return "TCP"
        elif code == 0x11:
            return "UDP"
        else:
            return "Unknown"


if __name__ == "__main__":
    import pyshark
    # from bitstring import BitArray
    from pathlib import Path
    import os

    fp = str(Path.home()) + '/Downloads/lo0_1.pcapng.gz'
    if os.path.exists(fp):
        capture = pyshark.FileCapture(fp,
                                  use_json=True,
                                  include_raw=True)
    else:
        # use this code to capture packet from lo0
        capture = pyshark.LiveCapture(interface='lo0', output_file=fp)
        capture.set_debug()
        capture.sniff(timeout=10)
    
    packets = []
    for p in capture:
        try:
            packets.append(MyPacket(p))
        except ValueError as identifier:
            pass
    shark = MyShark(packets)
    shark.summary()
