#!/usr/bin/env python2

from __future__ import print_function
import sys
import argparse
import vault_certificate_deploy
from vault_certificate_deploy import base
import hvac
import os
import stat
import urllib3
import re
import shutil
import requests
import OpenSSL
import time
import pwd
import grp

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
errors_count = 0

# Default configuration file
default_config_file = '/etc/vault-certificate-deploy/config.cnf'
default_secret_mount_point = 'cert'
default_role_name = 'default'

# Required config sections and options
config_map = {
    "vault": ["address", "verify_tls"],
    "storage": ["path"],
}

#
# Config Validation
#
def validate_configuration(parsed_config):
    """Validate configuration for required options and format"""

    config_result_test = True
    # Test sections
    for section in ['vault', 'approle', 'storage']:
        if not parsed_config.parser.has_section(section):
            base.perr('No section %s in configuration file' % (section))
            config_result_test = False

    # Test options
    for section in config_map.keys():
        for option in config_map[section]:
            if not parsed_config.parser.has_option(section, option):
                base.perr('No options %s in section %s' % (option, section))
                config_result_test = False

    if not config_result_test:
        base.eexit(1, "Configuration errors")


#
# Prepare variables
#
def prepare_variables():
    """ Prepare variables from config file and arguments """
    global storage_core_path
    global storage_paths
    global role_id
    global secret_id
    global verify
    global vault_mount_point
    global cert_list_dict
    global deploy_user
    global deploy_group
    global deploy_user_id
    global deploy_group_id
    global role_name

    cert_list = []

    # Role ID
    if args.role_id:
        base.pdeb("Role id set from argument", args.debug)
        role_id = args.role_id
    else:
        try:
            base.pdeb("Role id set from config file", args.debug)
            role_id = config.parser.get('approle', 'role_id')
        except:
            base.perr("Unable to determine role-id")
            base.eexit(1, "You have to provide role-id in configuration file or as argument")

    # Secret ID
    if args.secret_id:
        base.pdeb("Secret id set from argument", args.debug)
        secret_id = args.secret_id
    else:
        try:
            base.pdeb("Secret id set from config file", args.debug)
            secret_id = config.parser.get('approle', 'secret_id')
        except:
            base.perr("Unable to determine secret-id")
            base.eexit(1, "You have to provide secret-id in configuration file or as argument")

    # Role Name
    if args.role_name:
        base.pdeb("Role name set from argument", args.debug)
        role_name = args.role_name
    else:
        role_name = default_role_name

    if config.parser.get('vault', 'verify_tls') == "no":
        base.pdeb("TLS Verify disabled", args.debug)
        verify = False
    else:
        base.pdeb("TLS Verify enabled", args.debug)
        verify = True

    # Deploy User and Group
    try:
        base.pdeb("Deploy User set from config file", args.debug)
        deploy_user = config.parser.get('vault', 'deploy_user')
    except:
        base.pdeb("Deploy User not find in config file, settings root", args.debug)
        deploy_user = 'root'

    try:
        base.pdeb("Deploy Group set from config file", args.debug)
        deploy_group = config.parser.get('vault', 'deploy_group')
    except:
        base.pdeb("Deploy Group not find in config file, settings root", args.debug)
        deploy_group = 'root'

    try:
        deploy_user_id = pwd.getpwnam(deploy_user).pw_uid
    except KeyError:
        base.pwrn("Unable to find user %s in the system, fallback to root" % (deploy_user))
        deploy_user_id = 0
    try:
        deploy_group_id = grp.getgrnam(deploy_group).gr_gid
    except KeyError:
        base.pwrn("Unable to find group %s in the system, fallback to root" % (deploy_group))
        deploy_group_id = 0

    # Vault mount point
    if args.mount_point:
        vault_mount_point = args.mount_point
    elif "mount_point" in config.parser.options('vault'):
        vault_mount_point = config.parser.get('vault', 'mount_point')
    else:
        vault_mount_point = default_secret_mount_point
    base.pdeb("Vault secret mount point: " + vault_mount_point, args.debug)

    # Cert list and Cert Name
    # Prepare cert_list, merge config with command line if needed
    # We create simple datastructure here in dict per CA
    cert_list_lines = []
    # Use Cert file
    if args.cert_list:
        with open(args.cert_list, 'r') as fh:
            tmp = fh.read().split('\n')
            fh.close()
        # remove unwanted lines
        cert_list_lines = [x for x in tmp if x and not x.startswith('#')]

    # User certname argument
    if args.cert_name:
        cert_list_lines.append(args.cert_name)

    if not args.cert_name and not args.cert_list:
        base.eexit(1, "Unable to determine certificate to deploy. Use -n or --cert-list")
    base.pdeb("Merged cert_list: %s" % (str(cert_list_lines)), args.debug)

    # Prepare simple dict per CA
    cert_list_dict = {}
    for i in cert_list_lines:
        ia = i.split(' ', 4)
        if len(ia) == 1:
            if vault_mount_point in cert_list_dict:
                cert_list_dict[vault_mount_point][ia[0]] = {'role': role_name, 'extra': False }
            else:
                cert_list_dict[vault_mount_point] = {ia[0]: {'role': role_name, 'extra': False }}
        elif len(ia) == 2:
            if ia[1] in cert_list_dict:
                cert_list_dict[ia[1]][ia[0]] = {'role': role_name, 'extra': False }
            else:
                cert_list_dict[ia[1]] = {ia[0]: {'role': role_name, 'extra': False }}
        elif len(ia) == 3:
            if ia[1] in cert_list_dict:
                cert_list_dict[ia[1]][ia[0]] = {'role': ia[2], 'extra': False }
            else:
                cert_list_dict[ia[1]] = {ia[0]: {'role': ia[2], 'extra': False }}
        else:
            if ia[1] in cert_list_dict:
                cert_list_dict[ia[1]][ia[0]] = {'role': ia[2], 'extra': ia[3] }
            else:
                cert_list_dict[ia[1]] = {ia[0]: {'role': ia[2], 'extra': ia[3] }}

    base.pdeb("Cert Dict: %s" % (str(cert_list_dict)), args.debug)

    # Storage Path
    storage_core_path = config.parser.get('storage', 'path')
    storage_paths = [ config.parser.get('storage', 'path') + "/" + x for x in cert_list_dict.keys() ]



#
# Clean certificates
#
def clean_certificates(storage, certificates):
    ''' Clean unwanted certificates '''

    # Exclude custom & global path by default
    # Global path is done by different script
    exclude = '.*/(custom|global)/.*'
    # Array of directories to be deleted
    dirs_to_delete = []

    # Find all files in storage
    for root, dirnames, filenames in os.walk(storage):
        # Filenames
        if filenames:
            # Every file test
            for filename in filenames:
                delete_it = True
                # Not excluded path
                if not re.match(exclude, root):
                    # Check file against deployed certs
                    for ca_name in certificates.keys():
                        certs_a = certificates[ca_name].keys()

                        for cert in certs_a:
                            test_str = ".*/%s/(certs|private)/%s$" % (ca_name, re.escape(cert))
                            if re.match(test_str, root):
                                # We found it, don't delete it
                                delete_it = False

                        if delete_it:
                            # Delete root dir of the file
                            if root not in dirs_to_delete:
                                dirs_to_delete.append(root)

    if len(dirs_to_delete) > 0:
        base.pout('There are some old cert dirs to be deleted')
        for delete_dir in dirs_to_delete:
            base.pout("Removing directory " + delete_dir)
            shutil.rmtree(delete_dir)

#
# Validate Certificates
#
def certificate_validate(cert_t):
    ''' Validate and Check certificates touple from Vault '''

    # Check that SSL parts are valid certificates
    try:
        x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_t[2]['data']['crt'])
    except OpenSSL.crypto.Error as e:
        base.perr("Certificate %s:%s not valid format: %s" % (str(cert_t[0]), cert_t[1], str(e)))
        return False
        
    # Check Private key
    try:
        private_key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, cert_t[2]['data']['key'])
        private_key.check()
    except TypeError as e:
        base.perr("Private key in bad format for %s: %s" % (str(cert_t[0]), str(e)))
        return False
    except OpenSSL.crypto.Error as e:
        base.perr("Private key inconsistent for %s: %s" % (str(cert_t[0]), str(e)))
        return False

    # Check expiration
    seconds_expire = time.mktime(time.strptime(x509.get_notAfter().decode(), '%Y%m%d%H%M%SZ')) - time.time()
    if seconds_expire < 345600:
        base.pwrn("Certificate %s is about to expire (%s seconds)" % (str(cert_t[0]), str(seconds_expire)))

    return True

#
# Check Certificate
#
def certificate_check(cert_file, cert_priv_file):
    ''' Check validity and existence of the certificate '''
    
    if not os.path.isfile(cert_file):
        return False

    if not os.path.isfile(cert_priv_file):
        return False

    # Check that SSL is valid certificates
    f = open(cert_file, 'r')
    try:
        x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, f.read())
    except OpenSSL.crypto.Error as e:
        base.perr("Certificate %s not valid format: %s" % (str(cert_name), str(e)))
        return False
    f.close()

    # Check certificate expiration
    seconds_expire = time.mktime(time.strptime(x509.get_notAfter().decode(), '%Y%m%d%H%M%SZ')) - time.time()
    seconds_min_ttl = int(args.cert_min_ttl)

    if seconds_expire < seconds_min_ttl:
        base.pwrn("Certificate %s is about to expire (%s < %s seconds)" % (cert_name, seconds_expire, str(seconds_min_ttl)))
        return False    
    else:
        base.pdeb("Certificate %s is ok. TTL to expire is %s seconds" % (cert_name, seconds_expire), args.debug)
        return True


#
# MAIN PROGRAM LOOP
#
if __name__ == '__main__':
    # Parsing arguments
    parser = argparse.ArgumentParser(
        description='Certificate deploy script for HashiCorp Vault',
        epilog='Created by Robert Vojcik <robert@vojcik.net>')

    parser.add_argument(
        '-c',
        dest='config_file',
        default=default_config_file,
        help='configuration file (default: %s)' % (default_config_file))
    parser.add_argument(
        '--role-id',
        dest='role_id',
        default=False,
        help='Role id, as argument instead in config file')
    parser.add_argument(
        '--secret-id',
        dest='secret_id',
        default=False,
        help='Secret id, as argument instead in config file')
    parser.add_argument(
        '--vault-pki',
        dest='mount_point',
        default=False,
        help='Vault PKI secrets mount point')
    parser.add_argument(
        '--cert-ttl',
        dest='cert_ttl',
        default='7776000',
        help='Certificate TTL when issuing new certificate (seconds)')
    parser.add_argument(
        '--cert-role',
        dest='role_name',
        default='default',
        help='Certificate role in Vault')
    parser.add_argument(
        '--cert-min-ttl',
        dest='cert_min_ttl',
        default='1209600',
        help='Certificate minimum TTL when checking certificate (seconds)')
    parser.add_argument(
        '--cert-extra-options',
        dest='cert_extra_options',
        default=False,
        help='Certificate extra options as KEY1=VALUE delimited by semicolon. More info on: https://www.vaultproject.io/api/secret/pki/index.html#generate-certificate')
    parser.add_argument(
        '-n',
        dest='cert_name',
        default=False,
        help='Certificate name to deploy')
    parser.add_argument(
        '--cert-list',
        dest='cert_list',
        default=False,
        help='File containing list of certificates to deploy')
    parser.add_argument(
        '--ignore-ssl-check',
        dest='ignore_ssl_check',
        default=False,
        action='store_true',
        help='Skip certificate check')
    parser.add_argument(
        '--version',
        dest='version_print',
        default=False,
        action='store_true',
        help='Show version of the script')
    parser.add_argument(
        '-d',
        dest='debug',
        default=False,
        action='store_true',
        help='debug mode')

    args = parser.parse_args()

    if args.version_print:
        print(vault_certificate_deploy.__version__)
        sys.exit(0)

    # Config file parsing
    base.pdeb("Loading configuration from %s" % (args.config_file), args.debug)
    config = base.ConfigParse(args.config_file)
    validate_configuration(config)

    # Prepare variables from configuration and arguments
    prepare_variables()

    # For every storage path
    for storage_path in storage_paths:
        path_cert    = storage_path + '/certs'
        path_private = storage_path + '/private'
        # Prepare Basic Directories
        for path in [(path_cert, 0o0755), (path_private, 0o0750)]:
            if not os.path.isdir(path[0]):
                try:
                    os.makedirs(path[0], path[1])
                    # Set correct owners
                    base.pdeb("CHOWN %s -> %d:%d" %(path[0], deploy_user_id, deploy_group_id), args.debug)
                    os.chown(path[0], deploy_user_id, deploy_group_id)
                    os.chmod(path[0], path[1])
                except:
                    base.perr("Unable to create directory %s" % (str(path)))
                    base.eexit(1, "Error occured")
            else:
                # Set correct owners
                base.pdeb("CHOWN %s -> %d:%d" %(path[0], deploy_user_id, deploy_group_id), args.debug)
                os.chown(path[0], deploy_user_id, deploy_group_id)
                os.chmod(path[0], path[1])

    # Empty certificates are error state
    if len(cert_list_dict) < 1:
        base.eexit(1, "There are no certificates to deploy.")

    # Vault Auth, with approle
    vault = hvac.Client(
        url=config.parser.get('vault', 'address'),
        verify=verify)
    try:
        auth_token = vault.auth_approle(role_id, secret_id)

    except requests.ConnectTimeout as e:
        base.perr("Connection Timeout: %s" % (str(e)))
        base.eexit(1, "Connection Timeout")

    except requests.ConnectionError as e:
        base.perr("Connection Error: %s" % (str(e)))
        base.eexit(1, "Connection Error")
    
    except hvac.exceptions.InvalidRequest as e:
        base.eexit(1, "VAULT: %s" % (str(e)))

    # Debug output        
    base.pdeb("Vault auth connection: " + str(auth_token), args.debug)

    # Go through certificates and issue or renew
    certificates = []
    for ca_name in cert_list_dict.keys():
        for cert_name in cert_list_dict[ca_name].keys():

            cert_options = cert_list_dict[ca_name][cert_name]['extra']
            cert_file = "%s/%s/certs/%s/%s.crt" % (storage_core_path, ca_name, cert_name, cert_name)
            cert_priv_file = "%s/%s/private/%s/%s.key" % (storage_core_path, ca_name, cert_name, cert_name)
            cert_role_name = cert_list_dict[ca_name][cert_name]['role']

            if certificate_check(cert_file, cert_priv_file):
                # Certificate is OK and Valid
                base.pdeb("Certificate %s is ok and valid" % (cert_name), args.debug)
            else:
                # Certificate does not exist or it's not OK

                # Handle extra options
                extra_options = {}
                if cert_options:
                    process_opts = cert_options
                elif args.cert_extra_options:
                    process_opts = args.cert_extra_options
                else:
                    process_opts = False
                
                if process_opts:
                    for opt in process_opts.split(';'):
                        key, value = opt.split('=')
                        extra_options[key] = value

                # Ziskanie certifikatu
                try:
                    basic_params = { 'ttl': args.cert_ttl }
                    # Merge basic parameters with extra options
                    params = basic_params.copy()
                    params.update(extra_options)
                    # Generate certificate
                    response = vault.secrets.pki.generate_certificate(
                        mount_point=ca_name,
                        name=cert_role_name,
                        common_name=cert_name,
                        extra_params=params
                    )
                except hvac.exceptions.InvalidRequest as e:
                    base.perr("Unable to issue %s certificate against vault server. Get %s" % (cert_name, str(e)))
                    
                if (response.status_code == 200):
                    secret = { 
                        'data': {
                            'key': response.json()['data']['private_key'],
                            'ica': response.json()['data']['issuing_ca'],
                            'crt': response.json()['data']['certificate']
                        }
                    }
                    certificates.append((cert_name, ca_name, secret))
                else:
                    base.perr("Unable to issue %s certificate against vault server. Get %s" % (cert_name, response.text))
                    errors_count += 1

    # Deploy certificates
    for certificate_t in certificates:
        cert_name = certificate_t[0]
        ca_name = certificate_t[1]
        certificate = certificate_t[2]
        cert_dir_cert = "%s/%s/certs/%s" % (storage_core_path, ca_name, cert_name)
        cert_dir_private = "%s/%s/private/%s" % (storage_core_path, ca_name, cert_name)

        # Check Certificate. Avoid deploying wrong certificates
        if args.ignore_ssl_check is False:
            test_result = certificate_validate(certificate_t)

            if test_result is not True:
                base.pwrn("Certificate %s not pass the checks, skipping" % (str(cert_name)))
                errors_count += 1
                # Skip all and move to next cert
                continue

        for path in [(cert_dir_cert, 0o0755), (cert_dir_private, 0o0750)]:
            if not os.path.isdir(path[0]):
                try:
                    os.makedirs(path[0], path[1])
                    base.pdeb("CHOWN %s -> %d:%d" %(path[0], deploy_user_id, deploy_group_id), args.debug)
                    os.chmod(path[0], path[1])
                    os.chown(path[0], deploy_user_id, deploy_group_id)
                except:
                    base.perr("Unable to create directory %s" % (str(path)))
                    base.eexit(1, "Error occured")
            else:
                # Set correct owners
                base.pdeb("CHOWN %s -> %d:%d" %(path[0], deploy_user_id, deploy_group_id), args.debug)
                os.chown(path[0], deploy_user_id, deploy_group_id)
                os.chmod(path[0], path[1])

        # Create certificate dir and files
        for key in certificate['data'].keys():
            if key == "key":
                file_path = cert_dir_private + "/" + cert_name + "." + key
            else:
                file_path = cert_dir_cert + "/" + cert_name + "." + key

            base.pdeb("Writing " + file_path, args.debug)
            with open(file_path, 'w') as fh:
                fh.write(certificate['data'][key])
                fh.close()

            # Change file permissions
            if key == "key":
                os.chmod(file_path, 0o640)
            else:
                os.chmod(file_path, 0o644)

            os.chown(file_path, deploy_user_id, deploy_group_id)

    # Clean unwanted certificates
    clean_certificates(storage=storage_core_path, certificates=cert_list_dict)

if errors_count > 0:
    base.eexit(1, "There was %d errors during process" % (errors_count))
