#!python

from cmath import phase
import os
import sys
import argparse
import re
import ssl
import socket
import OpenSSL
import glob
from datetime import datetime
from pprint import pprint
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from collections import namedtuple



version = '0.0.14'

phrase = namedtuple('Phrase', 'say wait expect')

class CertException(Exception):
    pass

class InvalidCertificate(CertException):
    pass

class ServerError(CertException):
    pass

def get_args():
    epilog = """Examples:  
  # just check remote certificate
  {me} example.com

  # check cert for example.com on new.example.com, do not verify
  {me} new.example.com -n example.com -i

  # dump info from local certificate file(s)
  {me} *.pem

  # look for expiring letsencrypt certificates (:le is alias for /etc/letsencrypt/live/*/cert.pem)
  {me} :le -q -w
    """.format(me=sys.argv[0])
    parser = argparse.ArgumentParser(description='Show local/remote SSL certificate info v{version}'.format(version=version),
    formatter_class=argparse.RawTextHelpFormatter, epilog=epilog)
    parser.add_argument('CERT', nargs='+', help='/path/cert.pem or glob pattern or :le google.com or google.com:443')
    parser.add_argument('-n', '--name', help='name for SNI (if not same as CERT host)')
    parser.add_argument('-i', '--insecure', default=False, action='store_true', help='Do not verify remote certificate')
    parser.add_argument('-q', '--quiet', default=False, action='store_true', help='Print only warning/problems')
    parser.add_argument('-w', '--warn', default=None, metavar='DAYS', nargs='?', type=int, const=20, help='Warn about expiring certificates (def: 20 days)')
    parser.add_argument('-t', '--starttls', default='auto', metavar='METHOD', help='starttls method: auto (default, and OK almost always), no, imap, smtp, pop3')
    return parser.parse_args()


def conversation(s, script):
    for ph in script:
        if ph.say is not None:
            s.sendall(ph.say.encode())
        reply = s.recv(2048).decode('utf8')
        if ph.wait is not None and ph.wait not in reply:
            raise ServerError('Not found {!r} in server reply {!r} to {!r}'.format(ph.wait, reply, ph.say))
        if ph.expect is not None and ph.expect not in reply:
            raise ServerError('Not found {!r} in server reply {!r} to {!r}'.format(ph.expect, reply, ph.say))


def starttls_imap(s):
    script = (
        phrase(None, '\n', None),
        phrase('a1 CAPABILITY\n', '\n', 'STARTTLS'),
        phrase('a2 STARTTLS\n','\n', None)
    )
    conversation(s, script)

def starttls_smtp(s):
    script = (
        phrase(None, '\n', None),
        phrase('EHLO www-security.com\n', '\n', 'STARTTLS'),
        phrase('STARTTLS\n','\n', None)
    )
    conversation(s, script)

def starttls_pop3(s):
    script = (
        phrase(None, '\n', None),
        phrase('STLS\n', '\n', None),
    )
    conversation(s, script)


def start_tls(s, method, port):

    port2method_map = {
        25: 'smtp',
        110: 'pop3',
        143: 'imap'
    }

    method_map ={
        'imap': starttls_imap,
        'smtp': starttls_smtp,
        'pop3': starttls_pop3
    }

    if method == 'no':
        return

    if method == 'auto':
        try:
            method = port2method_map[port]
        except KeyError:
            # no special handling needed
            return

    return method_map[method](s)



def get_certificate(host, name=None, port=443, timeout=10, insecure=False, starttls='auto'):
    name = name or host
    if insecure:        
        context = ssl._create_unverified_context()
        context.verify_mode = ssl.CERT_NONE
    else:
        context = ssl.create_default_context()

    conn = socket.create_connection((host, port))
    
    start_tls(conn, starttls, port)

    sock = context.wrap_socket(conn, server_hostname=name)
    sock.settimeout(timeout)
    try:
        der_cert = sock.getpeercert(True)
    finally:
        sock.close()
    return ssl.DER_cert_to_PEM_cert(der_cert)

def is_local(cert):
    """ guesses is cert is local file or not """
    if os.path.exists(cert):
        return True
    return False


def get_remote_cert(location, name=None, insecure=False, starttls='auto'):
    # parse CERT address

    if ':' in location:
        (host, port) = location.split(':')
    else:
        host = location
        port = 443
    
    name = name or host

    m = re.search('(?P<host>[^/:]+):?(?P<port>\d+)?', location)
    if m is None:
        print("Can not parse {}".format(location))
        return
    certhost = m.group('host')
    certport = m.group('port') or '443'
    certport = int(certport)

    host = certhost
    port = certport

    certificate = get_certificate(host, name=name, port=port, insecure=insecure, starttls=starttls)

    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, certificate)
    return cert

def get_local_cert(CERT):
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, open(CERT).read())
    return cert


def show_cert(CERT, name=None, insecure=False, quiet=False, warn=False, starttls='auto'):

    def tlist2str(tlist):
        return ' '.join([ '{}={}'.format(t[0].decode(), t[1].decode()) for t in tlist ])

    def tlist2value(tlist, key):
        for t in tlist:
            if t[0].decode() == key:
                return t[1].decode()

    def get_SAN(cert):

        def safestr(x):
            try:
                return str(x)
            except:
                return ''

        extensions = (cert.get_extension(i) for i in range(cert.get_extension_count()))

        extension_data = {e.get_short_name().decode(): safestr(e) for e in extensions}

        try:
            return [ n.split(':')[1] for n in extension_data['subjectAltName'].split(',') ]
        except KeyError:
            return [] # No subjectAltName
        except IndexError:
            raise InvalidCertificate('Unusual certificate, cannot parse SubjectAltName')

    try:
        if is_local(CERT):
            cert = get_local_cert(CERT)
        else:
            cert = get_remote_cert(location=CERT, name=name, insecure=insecure, starttls=starttls)
    except ssl.SSLCertVerificationError as e:
        print("{CERT} Certificate verification error (use -i): {exception}".format(CERT=CERT, exception=e),
            file=sys.stderr)
        return 1

    subject = tlist2value(cert.get_subject().get_components(), 'CN')
    names = get_SAN(cert)

    if subject in names:
        names.remove(subject)
    
    if subject:
        # add only if Subject exists (yes, not always)
        names.insert(0, subject)

    nbefore = datetime.strptime(cert.get_notBefore().decode(), '%Y%m%d%H%M%SZ')
    nafter = datetime.strptime(cert.get_notAfter().decode(), '%Y%m%d%H%M%SZ')
    daysold = (datetime.now() - nbefore).days
    daysleft = (nafter - datetime.now()).days
    issuer = tlist2str(cert.get_issuer().get_components())


    if not quiet:
        print("Names:", ' '.join(names))
        print("notBefore: {nbefore} ({days} days old)".format(nbefore=nbefore, days=daysold))
        print("notAfter: {nafter} ({days} days left)".format(nafter=nafter, days = daysleft))
        print("Issuer:", issuer)


    if warn is not None and daysleft<warn:
        print("{CERT} expires in {left} days".format(CERT=CERT, left=daysleft),
            file=sys.stderr)
        return 1

    return 0


def main():
    args = get_args()

    if ':le' in args.CERT:
        args.CERT = glob.glob('/etc/letsencrypt/live/*/cert.pem')
    
    maxrc = 0
    for cert in args.CERT:
        try:
            rc = show_cert(CERT=cert, name=args.name, insecure=args.insecure, warn=args.warn, quiet=args.quiet, starttls=args.starttls)
            maxrc = max(maxrc, rc)
        except CertException as e:
            print("{}: {}".format(cert, e))
    sys.exit(maxrc)


if __name__ == '__main__':
    main()
    