#!/usr/bin/env python
# encoding: utf-8
from __future__ import unicode_literals, print_function
import sys
import argparse
import os
import subprocess
from collections import namedtuple
import datetime
import smtplib
import socket
from email.mime.text import MIMEText

RETURN_CODE_OK = 0
RETURN_CODE_WRONG_PARAMS = 1
RETURN_CODE_WARNING = 2


def to_bytes(value):
    if isinstance(value, unicode):
        return value.encode(encoding='utf-8')
    elif isinstance(value, str):
        return value
    else:
        return str(value)


def warning(message):
    sys.stderr.write(message + '\n')


class EmailReporter(object):

    def __init__(self):
        self._info = []

    def add_additional_info(self, content):
        self._info.append(content)

    def send(self, to):
        subject = 'Warn! Certificate has expired or is about to expire'
        msg = MIMEText('\n'.join(self._info), _charset="utf-8")

        hostname = socket.gethostname()
        me = '{}@{}'.format(os.getlogin(), hostname)

        msg[b'Subject'] = subject
        msg[b'From'] = me
        msg[b'To'] = to
        msg[b'list-id'] = b'certificates warnings <certificates-warnings.{}>'.format(hostname)

        s = smtplib.SMTP(b'localhost')  # TODO: additional smtp settings
        s.sendmail(me, [to], msg.as_string())
        s.quit()


class EmailReporter(object):

    def __init__(self):
        self._info = []

    def add_additional_info(self, content):
        self._info.append(content)

    def send(self, to):
        subject = 'Warn! Certificate has expired or is about to expire'
        msg = MIMEText('\n'.join(self._info), _charset="utf-8")

        hostname = socket.gethostname()
        me = '{}@{}'.format(os.getlogin(), hostname)

        msg[b'Subject'] = subject
        msg[b'From'] = me
        msg[b'To'] = to
        msg[b'list-id'] = b'certificates warnings <certificates-warnings.{}>'.format(hostname)

        s = smtplib.SMTP(b'localhost')  # TODO: additional smtp settings
        s.sendmail(me, [to], msg.as_string())
        s.quit()


def is_executable_found(bin):

    def _is_exe(file_path):
        return os.path.isfile(file_path) and os.access(file_path, os.X_OK)

    dirname, basename = os.path.split(bin)
    if dirname:
        return _is_exe(bin)
    else:
        for path_item in os.environ.get('PATH', '').split(os.pathsep):
            path_item = path_item.strip('"')
            full_path = os.path.join(path_item, basename)
            if _is_exe(full_path):
                return True
    return False


CallProcessResult = namedtuple('CallProcessResult', ['bin_exists', 'return_code', 'stdout', 'stderr'])


def call_process(bin, args=None, stdin=None):
    """
    :return CallProcessResult
    """
    if not is_executable_found(bin):
        error = ('''not found "{}", check if it's installed and exists somewhere in PATH ({})'''
                 .format(bin, os.environ.get('PATH', '')))
        return CallProcessResult(bin_exists=False, return_code=None, stdout=None, stderr=error)
    args = args or []
    call_args = [bin] + args
    process = subprocess.Popen(call_args, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                               stdin=subprocess.PIPE if stdin is not None else None)
    stdin_bytes = to_bytes(stdin) if stdin is not None else None
    (stdout, stderr) = process.communicate(input=stdin_bytes)
    return CallProcessResult(bin_exists=True, return_code=process.returncode, stdout=stdout, stderr=stderr)


def openssl(args=None, stdin=None):
    """
    :return CallProcessResult
    """
    return call_process(bin='openssl', args=args, stdin=stdin)


def pkcs12(args, stdin=None):
    """
    :return CallProcessResult
    """
    return openssl(args=['pkcs12'] + args if args is not None else [], stdin=stdin)


def x509(args, stdin=None):
    """
    :return CallProcessResult
    """
    return openssl(args=['x509'] + args if args is not None else [], stdin=stdin)


def s_client(args, stdin=b''):
    """
    :return CallProcessResult
    """
    return openssl(args=['s_client'] + args if args is not None else [], stdin=stdin)


def parse_certificate_fields(content):
    fields = {}
    for line in (l for l in content.split('\n') if l.strip() != ''):
        key, _, value = line.partition('=')
        fields[key] = value
    return fields


def parse_date_field(value):
    try:
        # ex: Aug 30 23:59:59 2016 GMT
        return datetime.datetime.strptime(value, '%b %d %H:%M:%S %Y %Z')
    except ValueError:
        return None


def is_safe_expire_date(expire_date, expiration_interval, now=None):
    delta = expire_date - now
    return delta >= datetime.timedelta(days=expiration_interval)


def parse_and_check_certificate_info(content, env, certificate_description, warning_func=warning):
    info_fields = parse_certificate_fields(content)
    if 'notAfter' not in info_fields:
        warning_func('Unable to read expire date for {}'.format(certificate_description))
        return RETURN_CODE_WARNING

    raw_expire_date = info_fields['notAfter']
    expire_date = parse_date_field(raw_expire_date)
    cert_subject = info_fields.get('subject', '<unknown>')

    if not is_safe_expire_date(expire_date, env.expiration_interval, now=env.now):
        warning_func(
            (
                'Certificate has expired or is about to expire!\n\n'
                'Certificate: {description}\n'
                'Certificate subject: {subj}\n'
                'Expire date: {expire}'
            ).format(description=certificate_description, subj=cert_subject, expire=raw_expire_date)
        )
        return RETURN_CODE_WARNING

    if env.info:
        print(
            (
                'Certificate: {description}\n'
                'Certificate subject: {subj}\n'
                'Expire date: {expire}\n'
                'Valid after: {after}'
            ).format(description=certificate_description, subj=cert_subject, expire=raw_expire_date,
                     after=info_fields.get('notBefore', '<unknown>'))
        )
    return RETURN_CODE_OK


class P12Check(object):

    @staticmethod
    def init_argparser(subparsers):
        parser = subparsers.add_parser(b'p12', help='check a certificate in a PKCS #12 container (.p12)')
        parser.add_argument('-c', '--cert', type=str, help='certificate file path', required=True, metavar='file.p12')
        parser.add_argument('-p', '--password', type=str, help='certificate password')
        return parser

    def __call__(self, env, warning_func=warning):
        certificate_abs_path = os.path.abspath(env.cert)
        if not os.path.isfile(certificate_abs_path):
            warning_func('Certificate not found at {}'.format(certificate_abs_path))
            return RETURN_CODE_WRONG_PARAMS

        certificate_export_res = pkcs12([
            '-in', certificate_abs_path,
            '-nodes',  # don't encrypt private keys on exporting
            '-passin', 'pass:{}'.format(env.password if env.password is not None else '')
        ])
        if certificate_export_res.return_code != 0:
            warning_func('Unable to load certificate from {p}\n{r.stderr}'
                    .format(p=certificate_abs_path, r=certificate_export_res))
            return RETURN_CODE_WARNING

        info_res = x509(['-subject', '-dates', '-noout'], stdin=certificate_export_res.stdout)
        if info_res.return_code != 0:
            warning_func('Unable to read certificate info for {p}\n{r.stderr}'.format(p=certificate_abs_path, r=info_res))
            return RETURN_CODE_WARNING

        return parse_and_check_certificate_info(info_res.stdout.decode('utf-8'), env,
                                                certificate_description='file {}'.format(certificate_abs_path),
                                                warning_func=warning_func)


class HttpsCheck(object):

    @staticmethod
    def init_argparser(subparsers):
        parser = subparsers.add_parser(b'https', help='check a HTTPS certificate by receiving it from a SSL connection')
        parser.add_argument('-s', '--host', type=str, required=True, metavar='example.com')
        parser.add_argument('--sni', type=str, default=None, metavar='example.com',
                            help='request SNI (for multiple certificats on one host:port)')
        parser.add_argument('-p', '--port', type=int, default=443, metavar='443')
        return parser

    def __call__(self, env, warning_func=warning):
        host_fqn = '{}:{}'.format(env.host, env.port)
        certificate_description = 'host {}'.format(host_fqn)
        if env.sni is not None:
            certificate_description += ' (SNI {})'.format(env.sni)

        request_args = ['-connect', host_fqn]
        if env.sni is not None:
            request_args.extend(('-servername', env.sni))

        request_res = s_client(request_args)
        if request_res.return_code != 0:
            warning_func('Unable to read certificate from {d}\n{r.stderr}'.format(d=certificate_description,
                                                                                  r=request_res))
            return RETURN_CODE_WARNING

        info_res = x509(['-subject', '-dates', '-noout'], stdin=request_res.stdout)
        if info_res.return_code != 0:
            warning_func('Unable to read certificate info for {d}\n{r.stderr}'.format(d=certificate_description,
                                                                                      r=info_res))
            return RETURN_CODE_WARNING

        return parse_and_check_certificate_info(info_res.stdout.decode('utf-8'), env, certificate_description,
                                                warning_func)


class PemCheck(object):

    @staticmethod
    def init_argparser(subparsers):
        parser = subparsers.add_parser(b'pem', help='check a PEM certificate')
        parser.add_argument('-c', '--cert', type=str, help='certificate file path', required=True, metavar='file.crt')
        return parser

    def __call__(self, env, warning_func=warning):
        certificate_abs_path = os.path.abspath(env.cert)
        if not os.path.isfile(certificate_abs_path):
            warning_func('Certificate not found at {}'.format(certificate_abs_path))
            return RETURN_CODE_WRONG_PARAMS

        info_res = x509(['-in', certificate_abs_path, '-subject', '-dates', '-noout'])
        if info_res.return_code != 0:
            warning_func('Unable to read certificate info for {d}\n{r.stderr}'.format(d=certificate_abs_path,
                                                                                      r=info_res))
            return RETURN_CODE_WARNING

        return parse_and_check_certificate_info(info_res.stdout.decode('utf-8'), env,
                                                certificate_description='file {}'.format(certificate_abs_path),
                                                warning_func=warning_func)


_COMMANDS = [P12Check, HttpsCheck, PemCheck]


def _parse_aguments(args):
    main_parser = argparse.ArgumentParser()
    main_parser.add_argument('-i', '--info', action='store_true',
                             help="print certificate info if it's valid")
    main_parser.add_argument('-x', '--expiration-interval', required=True, type=int, default=7, metavar='7',
                             help='expiration interval in days')
    main_parser.add_argument('-e', '--email', type=str, metavar='ops@example.com',
                             help='email for sending report if certificate has expired or is about to expire.')

    subparsers = main_parser.add_subparsers(title='commands', dest='command',
                                            help="use 'cmd -h' to see additional command specific help")

    main_parser.set_defaults(now=datetime.datetime.now())

    command_to_class = {}
    for command in _COMMANDS:
        parser = command.init_argparser(subparsers)
        command_to_class[parser.prog.split(' ')[-1]] = command
    env = main_parser.parse_args(args[1:])
    selected_command = command_to_class[env.command]
    return env, selected_command


def main(args):
    env, selected_command = _parse_aguments(args)

    email_reporter = None
    warning_func = warning
    if env.email:
        email_reporter = EmailReporter()

        def warning_func(message):
            email_reporter.add_additional_info(message)
            warning(message)

    if env.expiration_interval < 0:
        warning_func("Expiration interval should be >= 0")
        return RETURN_CODE_WRONG_PARAMS
    handler = selected_command()
    code = handler(env, warning_func=warning_func)
    if code != RETURN_CODE_OK and email_reporter is not None:
        email_reporter.send(env.email)
    return code

if __name__ == '__main__':
    sys.exit(main(sys.argv))
