#!/usr/bin/env python3

import scrapy
from scrapy.crawler import CrawlerProcess
import paramiko
import click, time, sys
from argparse import ArgumentParser
from cryptography import fernet
# import cryptography
# from cryptography import InvalidToken
import logging, os, getpass, pickle

path = os.path.expanduser('~/.ssh/iitb_login_creds')


class IITB_Spider(scrapy.Spider):
    name = 'iitb_crawler'

    def __init__(self, command):
        self.command = command

        logger=logging.getLogger('scrapy')
        logger.propogate=False
        super().__init__()

    def start_requests(self):
        url = 'https://internet.iitb.ac.in'
        yield scrapy.Request(url, callback=self.parse_page)

    def parse_page(self, response):
        if self.command=='logout':
            return self.logout(response)
        elif self.command=='login':
            return self.login(response)
        elif self.command=='status':
            response.meta['from'] = 'status'
            return self.confirm(response)
        else: print('No valid argument Passed!!!')

    def logout(self, response):
        if self.is_logged_in(response):
            # print('In Logout Handler, sending logout request!!!')
            request = scrapy.FormRequest.from_response(response, formname='auth',
                                                   formdata={'ip':self.get_ip(response),
                                                             'etype':'pg',
                                                             'button':'Logout',
                                                             'etype':'pg'},
                                                   callback=self.confirm)
            request.meta['from']='logout'
            yield request
        else: print('Seems to be logged out already!!')

    def login(self, response):
        if self.is_logged_in(response):
            print(f'Already Logged In as:    {self.get_username(response)}')
            print(f'Your IP:                 {self.get_ip(response)}')
        else:
            # print('Logging In..')
            creds = read_creds()
            username = creds['username']
            password = creds['password']
            request =  scrapy.FormRequest.from_response(response, formname='auth',
                                                   formdata={'uname':username,
                                                             'passwd':password,
                                                             'button':'Login'},
                                                   callback=self.confirm)
            request.meta['from']='login'
            request.meta['creds']=creds
            yield request

    def confirm(self, response):
        # print('reached confirm')
        url = 'https://internet.iitb.ac.in'
        request = response.follow(url=url, callback=self.confirm_response, dont_filter=True)
        try: request.meta['creds']=response.meta['creds']
        except: pass
        request.meta['from']=response.meta['from']
        yield request

    def confirm_response(self, response):
        # print('reached confirm response')
        if response.meta['from']=='login' and 'logout' in response.url:
            print(f'Logged In as {self.get_username(response)}')
            print(f'Your IP:     {self.get_ip(response)}')
            if not os.path.exists(path):
                key = get_key(all=False)
                if key is not None:
                    crypto = get_crypto(key)
                    write_creds(crypto, response.meta['creds'])

        elif response.meta['from']=='logout' and not 'logout' in response.url:
            print('Logged Out')
        elif response.meta['from'] == 'status':
            if 'logout' in response.url:
                print(f'Logged in as {self.get_username(response)}')
                print(f'Your IP:     {self.get_ip(response)}')
            else: print('Not Logged In to internet.iitb.ac.in')
        else:
            print(response.url)
            print(f'{response.meta["from"]} Failed !!!!!')

    def is_logged_in(self, response):
        if 'logout' in response.url: return True
        else: return False


    def get_username(self,response):
        return response.xpath('//div[@class="scrolling"]//tr[1]/td[1]/center/text()')[0].extract().strip()
    def get_ip(self,response):
        return response.xpath('//input[@name="ip" and @checked="checked"]/@value')[0].extract().strip()



def get_key(all=False)->paramiko.agent.AgentKey:
    # Agent that connects to SSH Agent (including forwarded agents)
    agent=paramiko.Agent()
    keys = agent.get_keys()
    if all: return keys # If all keys requested(by read_creds()) return all keys list
    # If no ssh key in agent is found, return None
    if len(keys)==0: return None
    # Read the names of the keys
    key_names = os.popen('ssh-add -l').read().strip().split()
    key_names = [key_names[i-1] for i,_ in enumerate(key_names) if (i+1)%4==0]
    
    # If single key found, return it
    if len(keys)==1: 
        print(f'Encrypting with key:   {key_names[0]}')
        return keys[0]
    # For multiple keys, prompt selection of keys (when all==False)
    else:
        print('Keys available:')
        for i,key in enumerate(keys): print(f"[{i}]   {key_names[i]}")
        key_index = click.prompt('Select Key Index', type=click.Choice([str(i) for i in range(len(keys))]))
        return keys[int(key_index)]
    

def get_crypto(key:paramiko.agent.AgentKey)-> fernet.Fernet:
    """
    Method that returns Fernet class,
    Initialized with signature of a random text below, signed by your ssh key
    """
    if isinstance(key, paramiko.agent.AgentKey):
        password = key.sign_ssh_data('Random Text Whose SSH Key Signature Is Used As Key')
        encode = fernet.base64.b64encode # base64 encoder
        # get base64 encoded 32 length part of signature
        password = encode(encode(password)[:32])
        # Init a Fernet with above password
        return fernet.Fernet(password)
    else: return None

def encrypt(crypto:fernet.Fernet, string:str)-> bytes:
    """Args
    -----
    crypto:     Fernet object initialized with ssh key based signature
                check `get_crypto()` method for initialization
    string:     String to encrypto with the Fernet object passed
                the string is converted to bytes and then encrypted
    """
    return crypto.encrypt(str.encode(string))

def decrypt(crypto:fernet.Fernet, encrypted:bytes)-> str:
    """
    Args
    -----
    crypto:     Fernet object initialized with ssh key based signature
                check `get_crypto()` method for initialization
    encrypted:  Encrypted bytes string to be decoded with Fernet
    """
    return crypto.decrypt(encrypted).decode()

def read_creds()->dict:

    # get all keys as we don't know which key was used to encrypt
    # Try all keys to decrypt the credentials
    keys = get_key(all=True)
    # No key found or no saved creds found , prompt username and password from user
    if len(keys)==0 or not os.path.exists(path):
        username = input('Enter username: ')
        password = getpass.getpass('Enter password')
        return {'username':username,
                'password':password}

    else:
        for key in keys:
            crypto = get_crypto(key)
            with open(path, 'rb') as f:
                creds=pickle.load(f)
                assert(len(creds)==2),"credentials read from ~/.ssh/iitb_login_creds are faulty. \
                Please reset credentials with 'iitb reset' and try again"
            # Only right key will decrypt, else will raise InvalidToken
            try:
                return {'username': decrypt(crypto, creds['username']),
                        'password': decrypt(crypto, creds['password'])}
            except fernet.InvalidToken: pass
        # If all keys are exhausted with InvalidToken and no creds are returned
        print("Couldn't decrypt the saved credentials with current keys. \
            Either reset credentials with 'iitb reset' or enter manually...")
        username = input('Enter username: ')
        password = getpass.getpass('Enter password')
        return {'username':username,
                'password':password}


def write_creds(crypto:fernet.Fernet, creds:dict)->None:

    if not os.path.exists(path):
        print('This Script uses your ssh key to encrypt credentials')
        if not click.confirm('Save credentials: '): return
        print('Credentials written to file: ~/.ssh/iitb_login_creds and can only be decrypted with your ssh key')
    creds['username'] = encrypt(crypto, creds['username'])
    creds['password'] = encrypt(crypto, creds['password'])
    with open(path, 'wb') as f:
        pickle.dump(creds, f)



if __name__ == '__main__':
    parser = ArgumentParser(
        description='Script for headless login and logout on all IITB servers')
    parser.add_argument('command', nargs='?', default='status',
            choices=['login', 'logout', 'status', 'reset'],
            help="\nyou can run: \n'iitb login'\t to login,\n'iitb logout'\t to logout,\
                'iitb status'\t to check login status\n'iitb reset'\t to reset the credentials")

    args = parser.parse_args()
    if args.command=='reset': # delete the creds and exit the script
        if os.path.exists(path): 
            os.remove(path)
            print('Credentials Deleted....')
        else: print('No credentials to delete')
        sys.exit(1)
    # key = get_key()
    # crypto = get_crypto(key)

    crawler = CrawlerProcess(settings={'LOG_ENABLED': True, 'LOG_LEVEL':'ERROR'})
    crawler.crawl(IITB_Spider, command=args.command)
    crawler.start()

