#!python

import os
import sys
import argparse
import re
import ssl
import socket
import OpenSSL
from datetime import datetime
from pprint import pprint
from cryptography import x509
from cryptography.hazmat.backends import default_backend

version = '0.0.9'

def get_args():
    epilog = """Examples:  
  # just check remote certificate
  {me} example.com

  # check cert for example.com on new.example.com, do not verify
  {me} new.example.com -n example.com -i

  # dump info from local certificate file
  {me} /etc/letsencrypt/live/example.com/fullchain.pem

  # look for expiring letsencrypt certificates
  for cert in /etc/letsencrypt/live/*/cert.pem; do {me} -q -w 7 $cert; done
    """.format(me=sys.argv[0])
    parser = argparse.ArgumentParser(description='Show local/remote SSL certificate info v{version}'.format(version=version),
    formatter_class=argparse.RawTextHelpFormatter, epilog=epilog)
    parser.add_argument('CERT', help='/path/cert.pem or google.com or google.com:443')
    parser.add_argument('-n', '--name', help='name for SNI (if not same as CERT host)')
    parser.add_argument('-i', '--insecure', default=False, action='store_true', help='Do not verify remote certificate')
    parser.add_argument('-q', '--quiet', default=False, action='store_true', help='Print only warning/problems')
    parser.add_argument('-w', '--warn', default=None, metavar='DAYS', nargs='?', type=int, const=20, help='Warn about expiring certificates (def: 20 days)')
    return parser.parse_args()

def get_certificate(host, name=None, port=443, timeout=10, insecure=False):
    name = name or host
    if insecure:        
        context = ssl._create_unverified_context()
        context.verify_mode = ssl.CERT_NONE
    else:
        context = ssl.create_default_context()

    conn = socket.create_connection((host, port))
    sock = context.wrap_socket(conn, server_hostname=name)
    sock.settimeout(timeout)
    try:
        der_cert = sock.getpeercert(True)
    finally:
        sock.close()
    return ssl.DER_cert_to_PEM_cert(der_cert)

def is_local(cert):
    """ guesses is cert is local file or not """
    if os.path.exists(cert):
        return True
    return False


def get_remote_cert(CERT, name=None, insecure=False):
    # parse CERT address

    if ':' in CERT:
        (host, port) = CERT.split(':')
    else:
        host = CERT
        port = 443
    
    name = name or host

    m = re.search('(https?:\/\/)?(?P<host>[^/:]+):?(?P<port>\d+)?', CERT)
    if m is None:
        print("Can not parse {}".format(CERT))
        return
    certhost = m.group('host')
    certport = m.group('port') or '443'
    certport = int(certport)

    host = certhost
    port = certport

    certificate = get_certificate(host, name=name, port=port, insecure=insecure)

    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, certificate)
    return cert

def get_local_cert(CERT):
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, open(CERT).read())
    return cert


def show_cert(CERT, name=None, insecure=False, quiet=False, warn=False):

    def tlist2str(tlist):
        return ' '.join([ '{}={}'.format(t[0].decode(), t[1].decode()) for t in tlist ])

    def tlist2value(tlist, key):
        for t in tlist:
            if t[0].decode() == key:
                return t[1].decode()

    def get_SAN(cert):
        extensions = (cert.get_extension(i) for i in range(cert.get_extension_count()))
        extension_data = {e.get_short_name().decode(): str(e) for e in extensions}
        try:
            return [ n.split(':')[1] for n in extension_data['subjectAltName'].split(',') ]
        except KeyError:
            return [] # No subjectAltName

    try:
        if is_local(CERT):
            cert = get_local_cert(CERT)
        else:
            cert = get_remote_cert(CERT, name, insecure)
    except ssl.SSLCertVerificationError as e:
        print("{CERT} Certificate verification error (use -i): {exception}".format(CERT=CERT, exception=e),
            file=sys.stderr)
        return 1

    subject = tlist2value(cert.get_subject().get_components(), 'CN')
    names = get_SAN(cert)
    if subject in names:
        names.remove(subject)
    
    names.insert(0, subject)

    nbefore = datetime.strptime(cert.get_notBefore().decode(), '%Y%m%d%H%M%SZ')
    nafter = datetime.strptime(cert.get_notAfter().decode(), '%Y%m%d%H%M%SZ')
    daysold = (datetime.now() - nbefore).days
    daysleft = (nafter - datetime.now()).days
    issuer = tlist2str(cert.get_issuer().get_components())


    if not quiet:
        print("Names:", ' '.join(names))
        print("notBefore: {nbefore} ({days} days old)".format(nbefore=nbefore, days=daysold))
        print("notAfter: {nafter} ({days} days left)".format(nafter=nafter, days = daysleft))
        print("Issuer:", issuer)


    if warn is not None and daysleft<warn:
        print("{CERT} expires in {left} days".format(CERT=CERT, left=daysleft),
            file=sys.stderr)
        return 1

    return 0


def main():
    args = get_args()

    sys.exit(show_cert(CERT=args.CERT, name=args.name, insecure=args.insecure, warn=args.warn, quiet=args.quiet))


if __name__ == '__main__':
    main()
    