#!python


import datetime
import time
import icmpSender
import rsa
import os
import threading
import redis

PACKET_TYPE = {
    b'\xa0': "PING",
    b"\xb0": "PONG",

    b"\xa1": "SEND_KEY",
    b"\xb1": "SEND_KEY_ACK",

    b"\xa2": "GET_KEY",
    b"\xb2": "USER_KEY",

    b"\xa5": "GET_MESSAGE",
    b"\xb5": "MESSAGE",

    b"\xa6": "SEND_MESSAGE",
    b"\xb6": "MESSAGE_ACK",
}


def calculate_size(packet):
    packet_len = (len(packet))
    packet_size_list = []
    for i in range(5):
        if packet_len >= 255:
            packet_len -= 255
            packet_size_list.append(b"\xff")
        else:
            packet_size_list.append(bytes([packet_len]))
            break
    packet_size_list = (
        [b"\x00"] * (5 - len(packet_size_list)))+packet_size_list
    packet_size_list = packet_size_list[:5]
    packet_size = b"".join(packet_size_list)
    return packet_size


def add_packet_size(packet):

    packet_str = b"".join(packet)
    packet_size = calculate_size(packet_str)
    packet[2] = packet_size


class Client:
    received_data = ()
    contacts_key = {}

    def __init__(self, ip, username) -> None:
        self.ip = ip
        self.username = username
        self.key = self._get_key()
        tr = threading.Thread(target=icmpSender.receive_data,
                              args=(self.ip, self.receiver))
        tr.start()

    def receiver(self, addr, data, received_time):
        self.received_data += ((data, received_time),)
        # print(data)

    def _ack(self):
        pass

    def _make_key(self) -> tuple:
        home = os.path.expanduser("~")
        key = rsa.newkeys(2048)
        with open(f"{home}/.private_key.pem", "wb") as f:
            f.write(key[1].save_pkcs1("PEM"))
        with open(f"{home}/.public_key.pem", "wb") as f:
            f.write(key[0].save_pkcs1("PEM"))
        return (key[1], key[0])

    def _load_key(self):
        try:
            home = os.path.expanduser("~")
            with open(f"{home}/.private_key.pem", "rb") as f:
                private_key = rsa.PrivateKey.load_pkcs1(f.read())
            with open(f"{home}/.public_key.pem", "rb") as f:
                public_key = rsa.PublicKey.load_pkcs1(f.read())
            return (private_key, public_key)
        except:
            return False

    def _get_key(self, key='', size=2048) -> tuple:
        if key:
            pass
        elif key := self._load_key():
            pass
        else:
            key = self._make_key()
        return key

    def ping(self) -> bool:
        '''
        check iceChat server is up or no
        '''
        def check_ping_received():

            start_time = datetime.datetime.now()
            if self.received_data:
                for i in self.received_data:
                    time.sleep(.3)
                    if (datetime.timedelta.total_seconds(start_time - i[1]) > 0
                            and i[0][-5:-1] == b"PONG"
                            and i[0][29] == 0xb0
                            and i[0][30] == 0x01
                            and i[0][-1] == 0x00
                            ):
                        return True
            else:
                return False
        try:
            # flag of ping type
            ping_packet = [b'\xa0']
            # start packet flag
            ping_packet.append(b"\x01")
            # make empty size
            ping_packet.append((b"\x00")*5)
            # add body of packet
            ping_packet.append(b"PING")
            # add end of packet flag
            ping_packet.append(b"\x00")
            # fix packet size
            add_packet_size(ping_packet)
            # list to string custom data
            ping_packet = b"".join(ping_packet)
            icmpSender.send_data(self.ip, ping_packet)
            time.sleep(.3)
            if check_ping_received():
                return True
            return False
        except:
            return False

    def send_key(self):
        def check_send_key_received():

            start_time = datetime.datetime.now()
            if self.received_data:
                for i in self.received_data:
                    if (datetime.timedelta.total_seconds(start_time - i[1]) > 0
                        and i[0][29] == 0xb1
                        and i[0][30] == 0x01
                        and i[0][-1] == 0x00
                        and rsa.decrypt(i[0][36:-1], self.key[0]).decode() == self.username

                            ):

                        return True
            else:
                return False
        try:
            # flag of key type
            key_packet = [b'\xa1']
            # start packet flag
            key_packet.append(b"\x01")
            # make empty size
            key_packet.append((b"\x00")*5)
            # add body of packet
            key_packet.append(
                str(self.key[1].save_pkcs1("PEM"))[34:-33].encode())
            # add username
            key_packet.append(b"\x00\x00"+self.username.encode())
            # add end of packet flag
            key_packet.append(b"\x00")
            # fix packet size
            add_packet_size(key_packet)
            # list to string custom data
            key_packet = b"".join(key_packet)
            icmpSender.send_data(self.ip, key_packet)
            # print(self._get_key()[1].save_pkcs1("PEM"))
            time.sleep(.3)
            if check_send_key_received():
                return True
            return False
        except:
            pass

    def get_key(self, username):
        def check_get_key_received():

            start_time = datetime.datetime.now()
            if self.received_data:
                for i in self.received_data:
                    if (datetime.timedelta.total_seconds(start_time - i[1]) > 0
                            and i[0][29] == 0xb2
                            and i[0][30] == 0x01
                            and i[0][-1] == 0x00
                            ):
                        data = i[0][36:-1]
                        data = data.replace(b"\\n", b"\n")
                        data = data.split(b"\x00\x00")
                        if data[1] == username.encode():
                            pub_key = rsa.PublicKey.load_pkcs1((
                                f"""-----BEGIN RSA PUBLIC KEY-----\n{data[0].decode()}\n-----END RSA PUBLIC KEY-----\n""").encode())
                            if pub_key:
                                self.contacts_key[username] = pub_key
                                return True
                            return False
            else:
                return False
        # flag of get key type
        get_key_packet = [b'\xa2']
        # start packet flag
        get_key_packet.append(b"\x01")
        # make empty size
        get_key_packet.append((b"\x00")*5)
        # add username
        get_key_packet.append(username.encode())
        # add end of packet flag
        get_key_packet.append(b"\x00")
        # fix packet size
        add_packet_size(get_key_packet)
        # list to string custom data
        get_key_packet = b"".join(get_key_packet)
        icmpSender.send_data(self.ip, get_key_packet)
        time.sleep(.3)
        if check_get_key_received():
            return True
        return False
    def get_message(self, username):
        def check_get_message_received():
            messages = []
            time.sleep(.5)
            if self.received_data:
                for i in self.received_data:
                    if (datetime.timedelta.total_seconds(i[1]- send_time) > 0
                            and i[0][29] == 0xb5
                            and i[0][30] == 0x01
                            and i[0][-1] == 0x00
                            ):
                        text = i[0][36:-1]
                        try : 
                            messages.append(rsa.decrypt(text, self._get_key()[0]).decode())
                        except:
                            pass
                return messages
            else:
                return False
        # flag of get message type
        get_message_packet = [b'\xa5']
        # start packet flag
        get_message_packet.append(b"\x01") 
        # make empty size
        get_message_packet.append(b"\x00" * 5)
        # add username 
        get_message_packet.append(self.username.encode())
        #add slicer
        get_message_packet.append(b"\x00\x00")
        # add target username
        get_message_packet.append(username.encode())
        # add end of packet flag
        get_message_packet.append(b"\x00")
        # fix packet size
        add_packet_size(get_message_packet)
        # list to string custom data
        send_time = datetime.datetime.now()
        get_message_packet = b"".join(get_message_packet)
        icmpSender.send_data(self.ip, get_message_packet)
        time.sleep(.3)
        if check_get_message_received():
            message = check_get_message_received()
            return message
        return False


    def send_message(self, username, message):
        def check_send_message_received():
            start_time = datetime.datetime.now()
            time.sleep(.5)
            if self.received_data:
                for i in self.received_data:
                    if (datetime.timedelta.total_seconds(i[1] - send_time) > 0
                            and i[0][29] == 0xb6
                            and i[0][30] == 0x01
                            and i[0][-1] == 0x00
                            ):
                        data = i[0][36:-1]
                        if f'{username}'.encode() in data and not(b"error" in data) :
                            message_id = data.split(b"\x00\x00")
                            return int(message_id[-1].decode())
            else:
                return False
        message = rsa.encrypt(message.encode(), self.contacts_key[username])
        # flag of send message type
        send_message_packet = [b'\xa6']
        # start packet flag
        send_message_packet.append(b"\x01") 
        # make empty size
        send_message_packet.append(b"\x00" * 5)
        # add username 
        send_message_packet.append(self.username.encode())
        #add slicer
        send_message_packet.append(b"\x00\x00")
        # add target username
        send_message_packet.append(username.encode())
        # add slicer
        send_message_packet.append(b"\x00\x00")
        # add message
        send_message_packet.append(message)
        # add end of packet flag
        send_message_packet.append(b"\x00")
        # fix packet size
        add_packet_size(send_message_packet)
        # set time of send
        send_time = datetime.datetime.now()
        # list to string custom data
        send_message_packet = b"".join(send_message_packet)
        icmpSender.send_data(self.ip, send_message_packet)
        # print(check_send_message_received())
        if check_send_message_received():
            return check_send_message_received()
        return False


class Server:
    def __init__(self) -> None:
        self.db = redis.Redis(host='localhost', port=6379, db=0)

    def _verify_packet(self, packet) -> bool:
        try:
            packet_code = packet[1]
            packet_len_data = packet[2:7]
            packet_end_flag = packet[-1:]
            if packet_len_data == calculate_size(packet) and packet_end_flag == b"\x00":
                return True
            else:
                return False
        except:
            return False

    def _get_type_packet(self, packet) -> int:
        try:
            packet_type_code = packet[0]
            return PACKET_TYPE[(bytes([packet_type_code]))]

        except:
            pass

    def _key_verify(self, packet):
        pass

    def run_server(self):
        receiver = icmpSender.receive_data_server("*")
        for packet in receiver:
            data = packet[1][29:]
            if packet[1][-1:] == b"\x00" and self._verify_packet(data):
                type_packet = self._get_type_packet(data)
                match type_packet:
                    case "PING":
                        pong_packet = []
                        # flag of ping type
                        pong_packet.append(b"\xb0")
                        # start packet flag
                        pong_packet.append(b"\x01")
                        # make empty size
                        pong_packet.append((b"\x00")*5)
                        # add body of packet
                        pong_packet.append(b"PONG")
                        # add end of packet flag
                        pong_packet.append(b"\x00")
                        # fix packet size
                        add_packet_size(pong_packet)
                        # list to string pong_packet
                        pong_packet = b"".join(pong_packet)
                        icmpSender.send_custom_answer(packet[0], pong_packet)
                    case "SEND_KEY":
                        try:
                            body = ((data[7:-1]))
                            data = body.split(b"\x00\x00")
                            if data[1]:
                                username = data[1].decode()
                                data[0] = data[0].replace(b"\\n", b"\n")
                                self.db.setex(
                                    f"user:publicKey:{username}",3600, data[0])
                            pub_key = (
                                f"""-----BEGIN RSA PUBLIC KEY-----\n{data[0].decode()}\n-----END RSA PUBLIC KEY-----\n""").encode()
                            pub_key = rsa.PublicKey.load_pkcs1(pub_key)
                            body = (rsa.encrypt(username.encode(), pub_key))
                            # flag of send_key_ack type
                            send_key_ack_packet = [b"\xb1"]
                            # start packet flag
                            send_key_ack_packet.append(b"\x01")
                            # make empty size
                            send_key_ack_packet.append((b"\x00")*5)
                            # add body of packet
                            send_key_ack_packet.append(body)
                            # add end of packet flag
                            send_key_ack_packet.append(b"\x00")
                            # fix packet size
                            add_packet_size(send_key_ack_packet)
                            # list to string send_key_ack_packet
                            send_key_ack_packet = b"".join(send_key_ack_packet)
                            icmpSender.send_custom_answer(
                                packet[0], send_key_ack_packet)
                        except:
                            icmpSender.send_custom_answer(packet[0], b"error")

                    case "GET_KEY":
                        try:
                            username = data[7:-1].decode()
                            pub_key = self.db.get(f"user:publicKey:{username}")
                            body = pub_key + b"\x00\x00" + username.encode()
                            # flag of send_key_ack type
                            send_key_ack_packet = [b"\xb2"]
                            # start packet flag
                            send_key_ack_packet.append(b"\x01")
                            # make empty size
                            send_key_ack_packet.append((b"\x00")*5)
                            # add body of packet
                            send_key_ack_packet.append(body)
                            # add end of packet flag
                            send_key_ack_packet.append(b"\x00")
                            # fix packet size
                            add_packet_size(send_key_ack_packet)
                            # list to string send_key_ack_packet
                            send_key_ack_packet = b"".join(send_key_ack_packet)
                            icmpSender.send_custom_answer(
                                packet[0], send_key_ack_packet)
                        except:
                            icmpSender.send_custom_answer(packet[0], b"error")

                    case "GET_MESSAGE":
                        try :
                            usernames = data[7:-1]
                            usernames = usernames.split(b"\x00\x00")
                            title_of_messages = self.db.keys(f"chat:{usernames[0].decode()}:{usernames[1].decode()}:*")
                            list_of_messages = []
                            for i in title_of_messages:
                                list_of_messages.append(self.db.get(i))
                                self.db.delete(i)
                            #flag of message 
                            message_packet = [b"\xb5"]
                            #start packet flag
                            message_packet.append(b"\x01")
                            #make empty size
                            message_packet.append((b"\x00")*5)
                            # add empty body
                            message_packet.append(b"")
                            #add end of packet flag
                            message_packet.append(b"\x00")
                            for i in list_of_messages:
                                message_packet[3] = i
                                add_packet_size(message_packet)
                                message_packet_str = b"".join(message_packet)
                                icmpSender.send_custom_answer(packet[0], message_packet_str)

                        except:
                            icmpSender.send_custom_answer(packet[0], b"error")

                    case "SEND_MESSAGE":
                        try:
                            data = data[7:-1]
                            data = data.split(b"\x00\x00")
                            user = data[1]
                            target_user = data[0]
                            message = data[-1]
                            if self.db.get(f"counter:chat:{user.decode()}:{target_user.decode()}"):
                                counter = int(self.db.get(f"counter:chat:{user.decode()}:{target_user.decode()}").decode())
                                self.db.incr(f"counter:chat:{user.decode()}:{target_user.decode()}")
                            else:
                                self.db.setex(f"counter:chat:{user.decode()}:{target_user.decode()}",3600, 1)
                                counter = 0
                            self.db.setex(f"chat:{user.decode()}:{target_user.decode()}:{counter+1}",3600, message)
                            # flag of message_ack type
                            message_ack_packet = [b"\xb6"]
                            # start packet flag
                            message_ack_packet.append(b"\x01")
                            # make empty size
                            message_ack_packet.append((b"\x00")*5)
                            # add empty body
                            message_ack_packet.append(f"{user.decode()}\x00\x00{target_user.decode()}\x00\x00{counter+1}".encode())
                            # add end of packet flag
                            message_ack_packet.append(b"\x00")
                            # fix packet size
                            add_packet_size(message_ack_packet)
                            # list to string message_ack_packet
                            message_ack_packet = b"".join(message_ack_packet)
                            icmpSender.send_custom_answer(packet[0], message_ack_packet)
                        except:
                            icmpSender.send_custom_answer(packet[0], b"error")


if __name__ == "__main__":
    c = Server()
    c.run_server()
