#!python
"""
WAFLABS CLI Tool
"""

from __future__ import print_function

import os
import sys
import time
import json
import argparse
import requests

from datetime import datetime
from yaspin import yaspin
from waflabs import waflabsapi

def main():
    """
    Main function for WAFLABS CLI Tool
    """
    ### Parse CLI arguments
    parser = argparse.ArgumentParser(description="WAFLABS CLI")
    subparsers = parser.add_subparsers(title="commands", dest="command", help='show command help: waflabs <command> --help')

    # test command
    parser_test = subparsers.add_parser("test", help="test my WAF!")
    parser_test.add_argument("--url",
                             help="target url for running tests",
                             type=str)
    parser_test.add_argument("--block-code",
                            help="expected block code (HTTP status code)",
                            type=int)
    parser_test.add_argument("--waf",
                        help="the WAF being tested",
                        type=str)
    parser_test.add_argument("--scope",
                             help="the scope of the test (fingerprint webcve attack bypass)",
                             default=[],
                             nargs="+")
    parser_test.add_argument("--definition",
                             help="path to JSON file with test definition",
                             type=str)
    parser_test.add_argument("--poll",
                             help="poll for test results",
                             default=False,
                             action="store_true")

    # results command     
    parser_results = subparsers.add_parser('results', help='retreive test results')
    parser_results.add_argument("--history",
                             help="retreive your test history",
                        default=False,
                        action="store_true")
    results_arg = parser_results.add_argument("--history-all",
                                              help="retreive your all (company) test history",
                                              default=False,
                                              action="store_true")
    parser_results.add_argument("--run-id",
                                help="the run_id of a test",
                                type=str)

    # domain command 
    parser_domain = subparsers.add_parser('domain', help='domain ownership operations')
    domain_arg = parser_domain.add_argument("--claim",
                               help="domain to claim ownership of",
                               type=str)
    parser_domain.add_argument("--list",
                               help="returns list of domain ownership records",
                               default=False,
                               action="store_true")
    parser_domain.add_argument("--validate",
                               help="validate ownership of a domain",
                               type=str)
    parser_domain.add_argument("--strategy",
                               help="validation strategy",
                               choices=['http', 'dns'])

    # report command 
    parser_report = subparsers.add_parser('report', help='generates a report')
    report_arg = parser_report.add_argument("--run-id",
                               help="the run_id of a test",
                               type=str)

    args = parser.parse_args()

    ### Get API keys
    try:
        access_id = os.environ["WAFLABS_ACCESS_ID"]
        access_key = os.environ["WAFLABS_ACCESS_KEY"]

    except KeyError as error:
        print("Environment variable not set {}".format(str(error)))
        sys.exit()

    ### Execute
    try:
        waflabs = waflabsapi.WAFLabsApi(access_id=access_id, access_key=access_key)

        if args.command == "test":
            if args.definition is not None:
                if os.path.exists(args.definition):
                    with open(args.definition) as definition:
                        payload = json.loads(definition.read())
                else:
                    print("File not found {}".format(args.definition), file=sys.stderr)
                    sys.exit(1)
            else:
                if args.url is None:
                    raise argparse.ArgumentError(args.url , "You must provide a test URL.")

                if args.block_code is None:
                    raise argparse.ArgumentError(args.block_code , "You must provide a test block code.")

                if args.waf is None:
                    raise argparse.ArgumentError(args.waf , "You must provide the type of WAF being tested.")

                payload = {
                    "url": args.url,
                    "block_code": args.block_code,
                    "waf": args.waf,
                }

                if len(args.scope):
                    payload["scope"] = args.scope

            response = waflabs.testmywaf(payload)

            if "message" in response:
                print(response)
                sys.exit()

            if args.poll:
                results = poll_results(waflabs, response["run_id"])
                print(json.dumps(results))
            else:
                print(json.dumps(response))
        
        if args.command == "results":
            if args.history:
                results = waflabs.history()
                print(json.dumps(results))
            elif args.history_all:
                results = waflabs.company_history()
                print(json.dumps(results))
            else:
                if args.run_id is None:
                    raise argparse.ArgumentError(results_arg, "You must provide a valid run-id.")

                results = waflabs.results(args.run_id)
                print(json.dumps(results))

        if args.command == "domain":
            if args.list:
                results = waflabs.domain_list()
                print(json.dumps(results))
            elif args.validate:
                if args.strategy is None:
                    raise argparse.ArgumentError(results_arg, "You must provide a domain validation strategy.")

                payload = {"domain": args.validate, "type": args.strategy}
                results = waflabs.domain_validate(payload)
                print(json.dumps(results))
            else:
                if args.claim is None:
                    raise argparse.ArgumentError(domain_arg, "You must provide a domain to claim.")

                results = waflabs.domain({"domain": args.claim})
                print(json.dumps(results))
    
        if args.command == "report":
            if args.run_id is None:
                raise argparse.ArgumentError(report_arg, "You must provide a valid run-id.")

            results = waflabs.results(args.run_id)
            report(results)

    except argparse.ArgumentError as error:
        print("{}\n".format(error))
        if args.command == "test":
            parser_test.print_help()
        if args.command == "results":
            parser_results.print_help()
        if args.command == "domain":
            parser_domain.print_help()
        if args.command == "report":
            parser_report.print_help()
        
    except ValueError as error:
        print(error, file=sys.stderr)
        sys.exit(1)

@yaspin(text="Polling...")
def poll_results(waflabs, run_id):
    results = waflabs.results(run_id)

    while results["status"] != "Complete":
        time.sleep(10)
        results = waflabs.results(run_id)

    return results

def report(results, destination=None, report_type="scorecard"):
    if results['status'] != "Complete":
        raise Exception("Testing incomplete, cannot generate report.")

    template = requests.get("https://waflabs.com/report.html")
    colors =  {
        "A": "#DDFFDD",
        "B": "#EEFFFF",
        "C": "#FFFFEE",
        "D": "#FFEEEE",
        "F": "#FFCCCC"
    }

    # Header
    report = template.text.replace("TIMESTAMP", datetime.fromtimestamp(results['timestamp']).strftime("%m/%d/%Y"))
    report = report.replace("URL", results['url'])
    report = report.replace("BLOCK_CODE", str(results['block_code']))
    report = report.replace("WAF_NAME", results['waf'])
    report = report.replace("SCOPE", ",".join(results['scope']))

    report = report.replace("OVERALL_GRADE", results['results']['scores']['overall']['grade'])
    report = report.replace("OVERALL_COLOR", colors[results['results']['scores']['overall']['grade']])
    report = report.replace("OVERALL_SCORE", "{0:.0%}".format(results['results']['scores']['overall']['score']))

    # Scorecard
    if "fingerprint" in results['scope']:
        report = report.replace("FINGERPRINT_GRADE", results['results']['scores']['fingerprint']['grade'])
        report = report.replace("FINGERPRINT_COLOR", colors[results['results']['scores']['fingerprint']['grade']])
        report = report.replace("FINGERPRINT_SCORE", "{0:.0%}".format(results['results']['scores']['fingerprint']['score']))
    else:
        report = report.replace("FINGERPRINT_GRADE", "N/A")
        report = report.replace("FINGERPRINT_COLOR", "#bfbfbf")
        report = report.replace("FINGERPRINT_SCORE", "{0:.0%}".format(0))

    if "webcve" in results['scope']:
        report = report.replace("WEBCVE_GRADE", results['results']['scores']['webcve']['grade'])
        report = report.replace("WEBCVE_COLOR", colors[results['results']['scores']['webcve']['grade']])
        report = report.replace("WEBCVE_SCORE", "{0:.0%}".format(results['results']['scores']['webcve']['score']))
    else:
        report = report.replace("WEBCVE_GRADE", "N/A")
        report = report.replace("WEBCVE_COLOR", "#bfbfbf")
        report = report.replace("WEBCVE_SCORE", "{0:.0%}".format(0))

    if "attack" in results['scope']:
        report = report.replace("ATTACK_GRADE", results['results']['scores']['attack']['grade'])
        report = report.replace("ATTACK_COLOR", colors[results['results']['scores']['attack']['grade']])
        report = report.replace("ATTACK_SCORE", "{0:.0%}".format(results['results']['scores']['attack']['score']))
    else:
        report = report.replace("ATTACK_GRADE", "N/A")
        report = report.replace("ATTACK_COLOR", "#bfbfbf")
        report = report.replace("ATTACK_SCORE", "{0:.0%}".format(0))
    
    if "bypass" in results['scope']:
        report = report.replace("BYPASS_GRADE", results['results']['scores']['bypass']['grade'])
        report = report.replace("BYPASS_COLOR", colors[results['results']['scores']['bypass']['grade']])
        report = report.replace("BYPASS_SCORE", "{0:.0%}".format(results['results']['scores']['bypass']['score']))
    else:
        report = report.replace("BYPASS_GRADE", "N/A")
        report = report.replace("BYPASS_COLOR", "#bfbfbf")
        report = report.replace("BYPASS_SCORE", "{0:.0%}".format(0))

    # Fingerprint
    if "fingerprint" in results['scope']:
        output = ""
        fingerprint = results['results']['details']['fingerprint']
        output = "The WAF was detected as {}.".format(fingerprint['DetectedAs'])
        if len(fingerprint['AlsoDetectedAs']):
            output += " It could also be {}.".format(",".join(fingerprint['AlsoDetectedAs']))
        report = report.replace("FINGERPRINT_RESULTS", output)
    else:
        report = report.replace("FINGERPRINT_RESULTS", "Not in scope.")

    # WebCVE
    if "webcve" in results['scope']:
        output = ""
        webcves = results['results']['details']['webcve']

        output = '<div class="table-responsive"><table class="table table-striped table-hover"><thead><tr><th>Web CVE</th><th>CVSS</th><th>CWE</th><th>Summary</th></tr></thead><tbody>'
        for webcve in webcves:
            response = requests.get("https://cve.circl.lu/api/cve/{}".format(webcve['cve']))
            cveinfo = json.loads(response.text)

            cvss = ""
            cwe = ""
            summary = ""
            if cveinfo is not None:
                if "cvss" in cveinfo:
                    cvss = cveinfo['cvss']

                if "cwe" in cveinfo:
                    cwe = cveinfo['cwe']

                if "summary" in cveinfo:
                    summary = cveinfo['summary']

            output += "<tr>"
            output += '<td style="white-space:nowrap"><a href="https://nvd.nist.gov/vuln/detail/{}" target="_blank">{}</a></td>'.format(webcve['cve'], webcve['cve'])
            output += '<td style="white-space:nowrap">{}</td>'.format(cvss)
            output += '<td style="white-space:nowrap">{}</td>'.format(cwe)
            output += "<td><small>{}</small></td>".format(summary)
            output += "</tr>"
        output += '</tbody></table></div>'

        report = report.replace("WEBCVE_RESULTS", output)
    else:
        report = report.replace("WEBCVE_RESULTS", "Not in scope.")

    # Attacks
    if "attack" in results['scope']:
        output = ""
        # Attack > Payloads
        if "payloads" in results['results']['details']['attacks']:
            payloads = results['results']['details']['attacks']['payloads']
            output += "Attack &gt; Payloads &gt; XSS"
            output += '<div class="table-responsive"><table class="table table-striped table-hover"><thead><tr><th>Status Code</th><th>Payload</th></tr></thead><tbody>'
            for payload in payloads:
                for attack in payload['xss']:
                    output += "<tr>"
                    output += '<td style="white-space:nowrap">{}</td>'.format(attack['status_code'])
                    output += '<td><small>{}</small></td>'.format(attack['payload'])
                    output += "</tr>"
            output += '</tbody></table></div>'

            output += "Attack &gt; Payloads &gt; SQLi"
            output += '<div class="table-responsive"><table class="table table-striped table-hover"><thead><tr><th>Status Code</th><th>Payload</th></tr></thead><tbody>'
            for payload in payloads:
                for attack in payload['sqli']:
                    output += "<tr>"
                    output += '<td style="white-space:nowrap">{}</td>'.format(attack['status_code'])
                    output += '<td><small>{}</small></td>'.format(attack['payload'])
                    output += "</tr>"
            output += '</tbody></table></div>'

            output += "Attack &gt; Payloads &gt; Command Injection"
            output += '<div class="table-responsive"><table class="table table-striped table-hover"><thead><tr><th>Status Code</th><th>Payload</th></tr></thead><tbody>'
            for payload in payloads:
                for attack in payload['commandinject']:
                    output += "<tr>"
                    output += '<td style="white-space:nowrap">{}</td>'.format(attack['status_code'])
                    output += '<td><small>{}</small></td>'.format(attack['payload'])
                    output += "</tr>"
            output += '</tbody></table></div>'

            output += "Attack &gt; Payloads &gt; Code Injection"
            output += '<div class="table-responsive"><table class="table table-striped table-hover"><thead><tr><th>Status Code</th><th>Payload</th></tr></thead><tbody>'
            for payload in payloads:
                for attack in payload['codeinject']:
                    output += "<tr>"
                    output += '<td style="white-space:nowrap">{}</td>'.format(attack['status_code'])
                    output += '<td><small>{}</small></td>'.format(attack['payload'])
                    output += "</tr>"
            output += '</tbody></table></div>'

            output += "Attack &gt; Payloads &gt; Traversal"
            output += '<div class="table-responsive"><table class="table table-striped table-hover"><thead><tr><th>Status Code</th><th>Payload</th></tr></thead><tbody>'
            for payload in payloads:
                for attack in payload['traversal']:
                    output += "<tr>"
                    output += '<td style="white-space:nowrap">{}</td>'.format(attack['status_code'])
                    output += '<td><small>{}</small></td>'.format(attack['payload'])
                    output += "</tr>"
            output += '</tbody></table></div>'

            output += "Attack &gt; Payloads &gt; Response Splitting"
            output += '<div class="table-responsive"><table class="table table-striped table-hover"><thead><tr><th>Status Code</th><th>Payload</th></tr></thead><tbody>'
            for payload in payloads:
                for attack in payload['responsesplit']:
                    output += "<tr>"
                    output += '<td style="white-space:nowrap">{}</td>'.format(attack['status_code'])
                    output += '<td><small>{}</small></td>'.format(attack['payload'])
                    output += "</tr>"
            output += '</tbody></table></div>'

        # Attack > Methods
        if "methods" in results['results']['details']['attacks']:
            methods = results['results']['details']['attacks']['methods']

            output += "Attack &gt; Methods"
            output += '<div class="table-responsive"><table class="table table-striped table-hover"><thead><tr><th>Status Code</th><th>Method</th><th>Attack</th></tr></thead><tbody>'
            for method in methods:
                output += "<tr>"
                output += '<td style="white-space:nowrap">{}</td>'.format(method['status_code'])
                output += '<td>{}</td>'.format(method['method'])
                output += '<td>{}</td>'.format(method['attack'])
                output += "</tr>"
            output += '</tbody></table></div>'

        # Attack > Locations
        if "locations" in results['results']['details']['attacks']:
            locations = results['results']['details']['attacks']['locations']

            output += "Attack &gt; Locations"
            output += '<div class="table-responsive"><table class="table table-striped table-hover"><thead><tr><th>Status Code</th><th>Location</th><th>Attack</th></tr></thead><tbody>'
            for method in locations:
                output += "<tr>"
                output += '<td style="white-space:nowrap">{}</td>'.format(method['status_code'])
                output += '<td>{}</td>'.format(method['location'])
                output += '<td>{}</td>'.format(method['attack'])
                output += "</tr>"
            output += '</tbody></table></div>'

        # Attack > Quick
        if "quick" in results['results']['details']['attacks']:
            quick = results['results']['details']['attacks']['quick']

            output += "Attack &gt; Quick"
            output += '<div class="table-responsive"><table class="table table-striped table-hover"><thead><tr><th>Status Code</th><th>Location</th><th>Attack</th></tr></thead><tbody>'
            for payload in quick:
                output += "<tr>"
                output += '<td style="white-space:nowrap">{}</td>'.format(payload['status_code'])
                output += '<td>{}</td>'.format(payload['location'])
                output += '<td>{}</td>'.format(payload['attack'])
                output += "</tr>"
            output += '</tbody></table></div>'

        report = report.replace("ATTACK_RESULTS", output)
    else:
        report = report.replace("ATTACK_RESULTS", "Not in scope.")

    # Bypass
    if "bypass" in results['scope']:
        output = ""
        # Bypass > XSS
        bypasses = results['results']['details']['bypass']
        output += "Bypass &gt; XSS"
        output += '<div class="table-responsive"><table class="table table-striped table-hover"><thead><tr><th>Status Code</th><th>Payload</th></tr></thead><tbody>'
        for bypass in bypasses['xss']:
            output += "<tr>"
            output += '<td style="white-space:nowrap">{}</td>'.format(bypass['status_code'])
            output += '<td><input type="text" value="{}"></td>'.format(bypass['payload'])
            output += "</tr>"
        output += '</tbody></table></div>'

        # Bypass > SQLi
        bypasses = results['results']['details']['bypass']
        output += "Bypass &gt; SQLi"
        output += '<div class="table-responsive"><table class="table table-striped table-hover"><thead><tr><th>Status Code</th><th>Payload</th></tr></thead><tbody>'
        for bypass in bypasses['sqli']:
            output += "<tr>"
            output += '<td style="white-space:nowrap">{}</td>'.format(bypass['status_code'])
            output += '<td><code>{}</code></td>'.format(bypass['payload'])
            output += "</tr>"
        output += '</tbody></table></div>'

        # Bypass > Content-Length
        bypasses = results['results']['details']['bypass']
        output += "Bypass &gt; Content-Length"
        output += '<div class="table-responsive"><table class="table table-striped table-hover"><thead><tr><th>Status Code</th><th>Payload</th></tr></thead><tbody>'
        for bypass in bypasses['content-length']:
            output += "<tr>"
            output += '<td style="white-space:nowrap">{}</td>'.format(bypass['status_code'])
            output += '<td>{}</td>'.format(bypass['payload'])
            output += "</tr>"
        output += '</tbody></table></div>'

        report = report.replace("BYPASS_RESULTS", output)
    else:
        report = report.replace("BYPASS_RESULTS", "Not in scope.")

    # Write
    print(report)


if __name__ == '__main__':
    main()
