#!python

import os
import argparse
import subprocess
import re
import ssl
import socket
import OpenSSL
from datetime import datetime
from pprint import pprint
from cryptography import x509
from cryptography.hazmat.backends import default_backend
import sys

def get_args():
    epilog = """Examples:  
  # just check remote certificate
  {me} example.com

  # check cert for example.com on new.example.com, do not verify
  {me} new.example.com -n example.com -i

  # dump info from local certificate file
  {me} /etc/letsencrypt/live/example.com/fullchain.pem
    """.format(me=sys.argv[0])
    parser = argparse.ArgumentParser(description='Show local/remote SSL certificate info',
    formatter_class=argparse.RawTextHelpFormatter, epilog=epilog)
    parser.add_argument('CERT', help='/path/cert.pem or google.com or google.com:443')
    parser.add_argument('-n', '--name', help='name for SNI (if not same as CERT host)')
    parser.add_argument('-i', '--insecure', default=False, action='store_true', help='Do not verify remote certificate')
    return parser.parse_args()

def get_certificate(host, name=None, port=443, timeout=10, insecure=False):
    name = name or host
    if insecure:        
        context = ssl._create_unverified_context()
        context.verify_mode = ssl.CERT_NONE
    else:
        context = ssl.create_default_context()

    conn = socket.create_connection((host, port))
    sock = context.wrap_socket(conn, server_hostname=name)
    sock.settimeout(timeout)
    try:
        der_cert = sock.getpeercert(True)
    finally:
        sock.close()
    return ssl.DER_cert_to_PEM_cert(der_cert)

def is_local(cert):
    """ guesses is cert is local file or not """
    if cert.startswith('https://'):
        return False
    if os.path.exists(cert):
        return True
    return False


def get_remote_cert(CERT, name=None, insecure=False):
    # parse CERT address

    if ':' in CERT:
        (host, port) = CERT.split(':')
    else:
        host = CERT
        port = 443
    
    name = name or host

    m = re.search('(https?:\/\/)?(?P<host>[^/:]+):?(?P<port>\d+)?', CERT)
    if m is None:
        print(f"Can not parse {CERT}")
        return
    certhost = m.group('host')
    certport = m.group('port') or '443'
    certport = int(certport)

    host = certhost
    port = certport

    certificate = get_certificate(host, name=name, port=port, insecure=insecure)

    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, certificate)
    return cert

def get_local_cert(CERT):
    cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, open(CERT).read())
    return cert


def show_cert(CERT, addr, insecure):

    def tlist2str(tlist):
        return ' '.join([ f'{t[0].decode()}={t[1].decode()}' for t in tlist ])

    def tlist2value(tlist, key):
        for t in tlist:
            if t[0].decode() == key:
                return t[1].decode()

    def get_SAN(cert):
        extensions = (cert.get_extension(i) for i in range(cert.get_extension_count()))
        extension_data = {e.get_short_name().decode(): str(e) for e in extensions}
        try:
            return [ n.split(':')[1] for n in extension_data['subjectAltName'].split(',') ]
        except KeyError:
            return [] # No subjectAltName

    try:
        if is_local(CERT):
            cert = get_local_cert(CERT)
        else:
            cert = get_remote_cert(CERT, addr, insecure)
    except ssl.SSLCertVerificationError as e:
        print("Certificate verification error (use -i):", e)
        return

    subject = tlist2value(cert.get_subject().get_components(), 'CN')
    names = get_SAN(cert)
    if subject in names:
        names.remove(subject)
    
    names.insert(0, subject)

    nbefore = datetime.strptime(cert.get_notBefore().decode(), '%Y%m%d%H%M%SZ')
    nafter = datetime.strptime(cert.get_notAfter().decode(), '%Y%m%d%H%M%SZ')
    daysold = (datetime.now() - nbefore).days
    daysleft = (nafter - datetime.now()).days
    issuer = tlist2str(cert.get_issuer().get_components())

    print("Names:", ' '.join(names))
    print("notBefore: {nbefore} ({days} days old)".format(nbefore=nbefore, days=daysold))
    print("notAfter: {nafter} ({days} days left)".format(nafter=nafter, days = daysleft))
    print("Issuer:", issuer)



def main():
    args = get_args()

    show_cert(args.CERT, args.name, args.insecure)


if __name__ == '__main__':
    main()
    