#!/usr/local/opt/python/bin/python3.7

from sneks.config import load_stack_info, SneksParser

def parse_args():
    parser = SneksParser()
    parser.add_argument("--domain-name",
                        help="Domain name that you own.  If not registered in Route53, it will be, but you'll need to manually update the DNS settings at your registrar.",
                        dest="domain",
                        required=True)
    parser.add_argument("--rest-api",
                        help="ID of an APIGateway API that you would like this to be associated with.  Default will be the one created by your repo's SAM stack.",
                        dest="rest_api",
                        default=None)
    parser.add_argument("--api-stage",
                        help="Name of the API stage that you would like this to be associated with.  Default is 'Prod'",
                        dest="api_stage",
                        default="Prod")
    parser.add_argument("--api-path",
                        help="Base path mapping for the API stage.  Default is '/'.  If this is not changed, you'll only be able to have that one API stage using this domain name.",
                        dest="api_path",
                        default="/")
    parser.add_argument("--sans",
                        help="Comma-separated lis of Subject Alternative Names to attach to the certificate.",
                        dest="sans",
                        default="")
    return parser.parse_args()

def main():
    args = parse_args()

    domain_name = args.domain
    rest_api = args.rest_api
    api_stage = args.api_stage
    api_path = args.api_path

    if not rest_api:
        stack_info = load_stack_info()
        rest_api = stack_info["deploy_stack_outputs"]["RestApi"]

    route53 = boto3.client("route53")
    acm = boto3.client("acm")
    apigw = boto3.client("apigateway")
    sans = []
    idempotency_token = "{},{},{},{}".format(domain_name,rest_api,api_stage,api_path)
    for san in args.sans.split(","):
        for prefix in ["https://","http://"]:
            if san.startswith(prefix):
                san = san[len(prefix):]
        sans.append(san)
    sans.sort()
    for prefix in ["https://","http://"]:
        if domain_name.startswith(prefix):
            domain_name = domain_name[len(prefix):]
    # TODO: add pagination to the following call
    certs = acm.list_certificates(CertificateStatuses=["PENDING_VALIDATION","ISSUED","INACTIVE"])["CertificateSummaryList"]
    certs = [x for x in certs if x["DomainName"] == domain_name]
    acm_cert = None
    if certs:
        for cert in certs:
            cert_info = acm.describe_certificate(CertificateArn=cert["CertificateArn"])
            cert_sans = cert_info.get("SubjectAlternativeNames",[])
            cert_sans.sort()
            if sans == cert_sans:
                acm_cert = cert
                break
    if not acm_cert:
        cert_arn = acm.request_certificate(DomainName=domain_name,
                                           ValidationMethod="DNS",
                                           SubjectAlternativeNames=sans,
                                           IdempotencyToken=idempotency_token,
                                           Options={"CertificateTransparencyLoggingPreference":"ENABLED"}
        )["CertificateArn"]
        acm_cert = acm.describe_certificate(CertificateArn=cert_arn)
    # if acm_cert["Status"] == "PENDING_VALIDATION":
        

if __name__ == "__main__":
    main()
