#!/usr/bin/python3
from samson.utilities.cli import start_repl, HASHES, EC_CURVES, ED_CURVES, PKI, ENCODING_MAPPING
from samson.utilities.bytes import Bytes
import argparse
import sys


parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest='command')
parser.add_argument('--eval', action='store_true')

hash_parser = subparsers.add_parser('hash')
hash_parser.description = f"Available hash types: {', '.join([k for k,v in HASHES.items()])}"

hash_parser.add_argument('type')
hash_parser.add_argument('text', nargs="?")
hash_parser.add_argument('--args', nargs="?")

pki_parser = subparsers.add_parser('pki')
pki_parser.formatter_class = argparse.RawDescriptionHelpFormatter
pki_parser.description = f"""Available PKI types: {', '.join([k for k,v in PKI.items()])}

Available ECDSA curves: {', '.join([k for k,v in EC_CURVES.items()])}

Available EdDSA curves: {', '.join([k for k,v in ED_CURVES.items()])}

Available encodings: {', '.join([k for k,v in ENCODING_MAPPING.items()])}
"""


pki_parser.add_argument('action')
pki_parser.add_argument('type')
pki_parser.add_argument('--args', nargs="?")
pki_parser.add_argument('filename', nargs="?")
pki_parser.add_argument('--pub', action='store_true')
pki_parser.add_argument('--encoding', nargs="?")
pki_parser.add_argument('--encoding-args', nargs="?")

arguments = parser.parse_args()

def try_parse(val):
    try:
        val = int(val)
    except ValueError as _:
        pass
    
    return val


PARSE_METHOD = try_parse


def export_pki(pki_obj, arguments):
    if arguments.encoding:
        encoding = ENCODING_MAPPING[arguments.encoding.upper()]

    if arguments.encoding_args:
        # Process 'hash' literals
        literals = [arg for arg in arguments.encoding_args.split('#')]
        preprocessed_literals = {}

        if len(literals) > 1:
            for i in range(1, len(literals), 2):
                preprocessed_literals[f'#{i}'] = literals[i]
                literals[i] = f'#{i}'
        

        processed_args = [arg.split('=') for arg in ''.join(literals).split(',')]
        for i, (_, v)in enumerate(processed_args):
            processed_args[i][1] = preprocessed_literals.get(v) or v

        encoding_args = {k: PARSE_METHOD(v) for k,v in processed_args}
    else:
        encoding_args = {}


    if arguments.pub:
        if arguments.encoding:
            print(pki_obj.export_public_key(encoding=encoding, **encoding_args).decode())
        else:
            print(pki_obj.export_public_key().decode())
    else:
        if arguments.encoding:
            print(pki_obj.export_private_key(encoding=encoding, **encoding_args).decode())
        else:
            print(pki_obj.export_private_key().decode())


if __name__ == '__main__':
    if len(sys.argv) == 1:
        start_repl()
    else:
        if arguments.eval:
            PARSE_METHOD = eval

        if arguments.command == 'hash':
            hash_cls = HASHES[arguments.type.lower()]

            text = arguments.text

            if not text:
                text = sys.stdin.buffer.read()
            
            if arguments.eval:
                text = PARSE_METHOD(text)

            if type(text) is str:
                text = text.encode('utf-8')

            if arguments.args:
                dict_args = {k: int(v) for k,v in [arg.split('=') for arg in arguments.args.split(',')]}
            else:
                dict_args = {}

            hash_obj = hash_cls(**dict_args)
            print(hash_obj.hash(text).hex().decode())

        elif arguments.command == 'pki':
            pki_type = arguments.type.lower()
            pki_cls = PKI[pki_type]

            if arguments.args:
                dict_args = {k: try_parse(v) for k,v in [arg.split('=') for arg in arguments.args.split(',')]}
            else:
                dict_args = {}

            if "curve" in dict_args:
                if pki_type == 'ecdsa':
                    dict_args["G"] = EC_CURVES[dict_args["curve"]].G
                    del dict_args["curve"]
                elif pki_type == 'eddsa':
                    dict_args["curve"] = ED_CURVES[dict_args["curve"]]



            if arguments.action == 'generate':
                pki_obj = pki_cls(**dict_args)
                export_pki(pki_obj, arguments)

            elif arguments.action == 'parse':
                if arguments.filename:
                    with open(arguments.filename, 'rb') as f:
                        key_to_parse = f.read()

                else:
                    key_to_parse = sys.stdin.read().encode('utf-8')

                pki_obj = pki_cls.import_key(key_to_parse)

                if arguments.encoding:
                    export_pki(pki_obj, arguments)
                else:
                    print(pki_obj)
