#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""URL bruteforcer to locate existing and/or hidden files or directories.."""

from __future__ import print_function

import os
import re
import sys
import itertools
import argparse
import requests
from requests.packages.urllib3.exceptions import InsecureRequestWarning

requests.packages.urllib3.disable_warnings(InsecureRequestWarning)


# -------------------------------------------------------------------------------------------------
# GLOBALS
# -------------------------------------------------------------------------------------------------

VERSION = "0.4.0"

DEFAULT_USERAGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/79.0.3945.123 Safari/537.36"  # noqa: E501
DEFAULT_SLASH = "no"
SUPPORTED_SLASHES = {
    "no": [""],
    "yes": ["/"],
    "both": ["", "/"],
}
DEFAULT_METHOD = "GET"
SUPPORTED_METHODS = ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]
DEFAULT_CODES = [
    "2..",
    "3..",
    "403",
    "407",
    "411",
    "426",
    "429",
    "500",
    "505",
    "511",
]
DEFAULT_TIMEOUT = 5.0
DEFAULT_RETRIES = 3


# -------------------------------------------------------------------------------------------------
# HELPER FUNCTIONS
# -------------------------------------------------------------------------------------------------


def get_status_text(curr, total, method, target, curr_try, retries):
    return "{clr}({curr}/{total}): ({curr_try}/{retries}) [{method}] {target}{rst}".format(
        clr="\033[93m",
        curr=curr,
        total=total,
        curr_try=curr_try,
        retries=retries,
        method=method,
        target=target,
        rst="\033[00m",
    )


def print_status(status):
    """Print temporary status."""
    print(status, end="\r")
    sys.stdout.flush()


def clear_status(status):
    """Deletet temporary status."""
    print(" " * len(status), end="\r")  # clear line
    sys.stdout.flush()


def print_succ(data):
    """Print success."""
    print("{color}{data}{rst}".format(color="\033[92m", data=data, rst="\033[00m"))


def print_err(data):
    """Print success."""
    print("{color}{data}{rst}".format(color="\033[91m", data=data, rst="\033[00m"))


# -------------------------------------------------------------------------------------------------
# FILE FUNCTIONS
# -------------------------------------------------------------------------------------------------


def read_file(filepath):
    """Read words from file line by line and store each line as a list entry."""
    with open(filepath) as f:
        content = f.readlines()
    # Remove whitespace characters like '\n' at the end of each line
    return [x.strip() for x in content]


# -------------------------------------------------------------------------------------------------
# URL FUNCTIONS
# -------------------------------------------------------------------------------------------------


def get_session(auth, headers, proxies):
    """Return session object for persistent connection."""
    s = requests.Session()
    if auth is not None:
        s.auth = auth
    if proxies is not None:
        s.proxies = proxies
    s.headers.update(headers)

    return s


def session_request(s, url, method, cookies, headers, timeout, verify):
    """Connect to a persistent http connection."""
    # s.(get|post|delete|...)
    fn = getattr(s, method.lower())
    try:
        return (
            True,
            fn(
                url,
                data={},
                allow_redirects=False,
                cookies=cookies,
                headers=headers,
                timeout=timeout,
                verify=verify,
            ),
        )
    except requests.exceptions.Timeout as err:
        # Maybe set up for a retry, or continue in a retry loop
        return False, {"type": "timeout", "err": err}
    except requests.exceptions.TooManyRedirects as err:
        # Tell the user their URL was bad and try a different one
        return False, {"type": "toomanyredirects", "err": err}
    except requests.exceptions.RequestException as err:
        # catastrophic error. bail.
        return False, {"type": "exception", "err": err}


def request(url, method, auth, cookies, headers, proxies, timeout, verify):
    """Open an http request."""
    # requests.(get|post|delete|...)
    fn = getattr(requests, method.lower())
    try:
        return (
            True,
            fn(
                url,
                data={},
                allow_redirects=False,
                auth=auth,
                cookies=cookies,
                headers=headers,
                proxies=proxies,
                timeout=timeout,
                verify=verify,
            ),
        )
    except requests.exceptions.Timeout as err:
        # Maybe set up for a retry, or continue in a retry loop
        return False, {"type": "timeout", "err": err}
    except requests.exceptions.TooManyRedirects as err:
        # Tell the user their URL was bad and try a different one
        return False, {"type": "toomanyredirects", "err": err}
    except requests.exceptions.RequestException as err:
        # catastrophic error. bail.
        return False, {"type": "exception", "err": err}


def check_code(code, codes):
    """Check if http status code is a successful code."""
    for reg in codes:
        if re.match(reg, str(code)):
            return True

    return False


# -------------------------------------------------------------------------------------------------
# ARGS
# -------------------------------------------------------------------------------------------------


def _args_check_codes(value):
    """Check argument for valid status codes."""
    strval = str(value)
    code = strval.replace(".", "1")
    try:
        code = int(code)
    except ValueError:
        raise argparse.ArgumentTypeError('Invalid status code "%s"', strval)
    if code < 100 or code >= 600:
        raise argparse.ArgumentTypeError('Invalid status code "%s"', strval)
    return strval


def _args_check_auth(value):
    """Check argument for valid methods."""
    strval = str(value)
    auth = strval.split(":")
    if len(auth) != 2:
        raise argparse.ArgumentTypeError('Invalid auth value "%s"', strval)
    return strval


def _args_check_method(value):
    """Check argument for valid methods."""
    strval = str(value)
    method = strval
    if method not in SUPPORTED_METHODS:
        raise argparse.ArgumentTypeError(
            'Invalid method "%s". Supported: %s' % (strval, ", ".join(SUPPORTED_METHODS))
        )
    return strval


def _args_check_slash(value):
    """Check argument for valid slash value."""
    strval = str(value)
    if strval not in SUPPORTED_SLASHES.keys():
        raise argparse.ArgumentTypeError(
            'Invalid slash value "%s". Supported: %s' % (value, ", ".join(SUPPORTED_SLASHES.keys()))
        )
    return strval


def _args_check_header(value):
    """Check argument for valid header value."""
    strval = str(value)
    if ":" not in strval:
        raise argparse.ArgumentTypeError('Invalid header value "%s".' % (strval))
    return strval


def _args_check_proxy(value):
    """Check argument for valid proxy value."""
    strval = str(value)
    if not re.match("(http(s)?|socks5)://(.+:.+@)?.+:[0-9]+", strval):
        raise argparse.ArgumentTypeError('Invalid proxy value "%s".' % (strval))
    print(strval)
    return strval


def _args_check_cookie(value):
    """Check argument for valid cookie value."""
    strval = str(value)
    if "=" not in strval:
        raise argparse.ArgumentTypeError('Invalid cookie value "%s".' % (strval))
    return strval


def _args_check_file(value):
    """Check argument for existing file."""
    strval = str(value)
    if not os.path.isfile(strval):
        raise argparse.ArgumentTypeError('File "%s" not found.' % value)
    return strval


def get_args():
    """Retrieve command line arguments."""
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawTextHelpFormatter,
        add_help=False,
        usage="""%(prog)s [options] -w <str>/-W <file> BASE_URL
       %(prog)s -v, --help
       %(prog)s -h, --version
""",
        description="""URL bruteforcer to locate existing and/or hidden files or directories.

Similar to dirb or gobuster, but also allows to iterate over multiple HTTP request methods,
multiple useragents and multiple host header values.
""",
        epilog="""examples

  %(prog)s -W /path/to/words http://example.com/
  %(prog)s -W /path/to/words http://example.com:8000/
  %(prog)s -k -W /path/to/words https://example.com:10000/""",
    )
    required = parser.add_argument_group("required arguments")
    optional = parser.add_argument_group("optional global arguments")
    mutating = parser.add_argument_group(
        title="optional mutating arguments",
        description="The following arguments will increase the total number of requests to be"
        + " made by\napplying various mutations and testing each mutation on a separate request.",
    )
    misc = parser.add_argument_group("misc arguments")
    word = required.add_mutually_exclusive_group(required=True)
    word.add_argument(
        "-w", "--word", metavar="str", type=str, help="Word to use.",
    )
    word.add_argument(
        "-W", "--wordlist", metavar="f", type=_args_check_file, help="Path to wordlist to use.",
    )
    optional.add_argument(
        "-n",
        "--new",
        required=False,
        default=False,
        action="store_true",
        help="Use a new connection for every request.\n"
        + "If not specified persistent http connection will be used for all requests.\n"
        + "Note, using a new connection will decrease performance,\n"
        + "but ensure to have a clean state on every request.",
    )
    optional.add_argument(
        "-k",
        "--insecure",
        required=False,
        default=False,
        action="store_true",
        help="Do not verify TLS certificates.",
    )
    optional.add_argument(
        "--code",
        nargs="+",
        metavar="str",
        required=False,
        default=DEFAULT_CODES,
        type=_args_check_codes,
        help="HTTP status code to treat as success.\n"
        + "You can use a '.' (dot) as a wildcard.\n"
        + "Default: "
        + " ".join(DEFAULT_CODES),
    )
    optional.add_argument(
        "--header",
        nargs="+",
        metavar="h",
        default=[],
        type=_args_check_header,
        help="Custom http header string to add to all requests.\n"
        + "Note, multiple values are allowed for multiple headers.\n"
        + "Note, if duplicates are specified, the last one will overwrite.\n"
        + "See --mheaders for mutations.\n"
        + "Format: <key>:<val> [<key>:<val>]",
    )
    optional.add_argument(
        "--cookie",
        nargs="+",
        metavar="c",
        default=[],
        type=_args_check_cookie,
        help="Cookie string to add to all requests.\n" + "Format: <key>=<val> [<key>=<val>]",
    )
    optional.add_argument(
        "--proxy",
        metavar="str",
        required=False,
        default=None,
        type=_args_check_proxy,
        help="Use a proxy for all requests.\n"
        + "Format: http://<host>:<port>\nFormat: http://<user>:<pass>@<host>:<port>\n"
        + "Format: https://<host>:<port>\nFormat: https://<user>:<pass>@<host>:<port>\n"
        + "Format: socks5://<host>:<port>\nFormat: socks5://<user>:<pass>@<host>:<port>",
    )
    auth = optional.add_mutually_exclusive_group(required=False)
    auth.add_argument(
        "--auth-basic",
        metavar="str",
        required=False,
        default=None,
        type=_args_check_auth,
        help="Use basic authentication for all requests.\n" + "Format: <user>:<pass>",
    )
    auth.add_argument(
        "--auth-digest",
        metavar="str",
        default=None,
        type=_args_check_auth,
        help="Use digest authentication for all requests.\n" + "Format: <user>:<pass>",
    )
    optional.add_argument(
        "--timeout",
        metavar="sec",
        required=False,
        default=DEFAULT_TIMEOUT,
        type=float,
        help="Connection timeout in seconds for each request.\nDefault: " + str(DEFAULT_TIMEOUT),
    )
    optional.add_argument(
        "--retry",
        metavar="num",
        required=False,
        default=DEFAULT_RETRIES,
        type=int,
        help="Connection retries per request.\nDefault: " + str(DEFAULT_RETRIES),
    )
    optional.add_argument(
        "--delay",
        metavar="sec",
        required=False,
        type=float,
        help="Delay between requests to not flood the server.",
    )
    optional.add_argument(
        "--output",
        metavar="file",
        required=False,
        type=str,
        help="Output file to write results to.",
    )
    mutating.add_argument(
        "--method",
        nargs="+",
        metavar="m",
        required=False,
        default=[DEFAULT_METHOD],
        type=_args_check_method,
        help="List of HTTP methods to test each request against.\n"
        + "Note, each supplied method will double the number of requests.\n"
        + "Supported methods: "
        + " ".join(SUPPORTED_METHODS)
        + "\n"
        + "Default: "
        + DEFAULT_METHOD,
    )
    mutating.add_argument(
        "--mheader",
        nargs="+",
        metavar="h",
        default=[],
        type=_args_check_header,
        help="Custom http header string to add to mutate all requests.\n"
        + "Note, multiple values are allowed for multiple headers.\n"
        + "Format: <key>:<val> [<key>:<val>]",
    )
    mutating.add_argument(
        "--ext",
        nargs="+",
        metavar="ext",
        default=[""],
        required=False,
        help="List of file extensions to to add to words for testing.\n"
        + "Note, each supplied extension will double the number of requests.\n"
        + "Format: .zip [.pem]\n",
    )
    mutating.add_argument(
        "--slash",
        metavar="str",
        required=False,
        default="no",
        type=_args_check_slash,
        help="Append or omit a trailing slash to URLs to test.\n"
        + "Note, a slash will be added after the extensions if they are specified as well.\n"
        + "Note, using 'both' will double the number of requests.\n"
        + "Options: "
        + ", ".join(SUPPORTED_SLASHES.keys())
        + "\n"
        + "Default: "
        + DEFAULT_SLASH,
    )
    misc.add_argument("-h", "--help", action="help", help="Show this help message and exit")
    misc.add_argument(
        "-v",
        "--version",
        action="version",
        version="%(prog)s " + VERSION + " by cytopia",
        help="Show version information",
    )
    parser.add_argument("BASE_URL", type=str, help="The base URL to scan.")
    return parser.parse_args()


# -------------------------------------------------------------------------------------------------
# PAREMETER GET FUNCTIONS
# -------------------------------------------------------------------------------------------------


def get_words(word, wordlist):
    """Get list of words."""
    if word is not None:
        return [word]
    return read_file(wordlist)


def get_headers(headers):
    """Get dict of HTTP headers."""
    data = requests.utils.default_headers()
    for header in headers:
        key, val = header.split(":")
        key = key.strip()
        val = val.strip()
        if key in data:
            del data[key]
        data[key] = val
    return data


def get_cookies(cookies):
    """Get dict of HTTP cookies."""
    data = {}
    for cookie in cookies:
        key, val = cookie.split("=")
        key = key.strip()
        val = val.strip()
        data[key] = val
    return data


def get_proxies(proxy):
    """Get dict of proxies."""
    if proxy is not None:
        return {
            "http": proxy,
            "https": proxy,
        }
    return None


def get_auth_method(auth_basic, auth_digest):
    """Get authentication object."""
    if auth_basic is not None:
        return requests.auth.HTTPBasicAuth(auth_basic[0], auth_basic[1])
    if auth_digest is not None:
        return requests.auth.HTTPDigestAuth(auth_digest[0], auth_digest[1])
    return None


def get_slash_values(slash):
    """Get list with empty element and or slash element."""
    if slash is not None:
        return SUPPORTED_SLASHES[slash]
    return SUPPORTED_SLASHES[DEFAULT_SLASH]


def merge_headers(headers, mheaders):
    """Get list of header dict mutations."""
    data = {}
    for h in headers:
        data[h] = [headers[h]]

    # {key: [val1, val2]}
    for mh in mheaders:
        key, val = mh.split(":")
        key = key.strip()
        val = val.strip()
        if key in data:
            data[key].append(val)
        else:
            data[key] = [val]

    return data


def mutate_headers(**kwargs):
    """Get mutated list of header dict with all combinations via cartesian product."""
    keys = kwargs.keys()
    vals = kwargs.values()
    data = []
    for instance in itertools.product(*vals):
        data.append(dict(zip(keys, instance)))
    return data


# -------------------------------------------------------------------------------------------------
# MAIN ENTRYPOINT: BANNER
# -------------------------------------------------------------------------------------------------


def print_banner(url, words, args, h_mutations):
    """Print initial banner."""
    # http://www.patorjk.com/software/taag/
    print(
        """
   ██╗   ██╗██████╗ ██╗     ██████╗ ██╗   ██╗███████╗████████╗███████╗██████╗
   ██║   ██║██╔══██╗██║     ██╔══██╗██║   ██║██╔════╝╚══██╔══╝██╔════╝██╔══██╗
   ██║   ██║██████╔╝██║     ██████╔╝██║   ██║███████╗   ██║   █████╗  ██████╔╝
   ██║   ██║██╔══██╗██║     ██╔══██╗██║   ██║╚════██║   ██║   ██╔══╝  ██╔══██╗
   ╚██████╔╝██║  ██║███████╗██████╔╝╚██████╔╝███████║   ██║   ███████╗██║  ██║
    ╚═════╝ ╚═╝  ╚═╝╚══════╝╚═════╝  ╚═════╝ ╚══════╝   ╚═╝   ╚══════╝╚═╝  ╚═╝

                               {version} by cytopia
""".format(
            version=VERSION
        )
    )

    total = (
        len(words)
        * len(h_mutations)
        * len(args.method)
        * (2 if args.slash == "both" else 1)
        * (1 if len(args.ext) == 0 else len(args.ext))
    )

    print("      SETTINGS")
    print("            Base URL:         {url}".format(url=url))
    print(
        "            Connection:       {conn}".format(
            conn="Non-persistent" if args.new else "Persistent"
        )
    )
    print("            Valid codes:      {codes}".format(codes=", ".join(args.code)))
    print("            Timeout:          {timeout}s".format(timeout=args.timeout))
    print("            Retries:          {retries}".format(retries=args.retry))
    print("            Delay:            {delay}".format(delay=args.delay))
    if len(args.cookie) > 0:
        print("            Cookie:           {cookie}".format(cookie="&".join(args.cookie)))
    if args.proxy is not None:
        print("            Proxy:            {proxy}".format(proxy=args.proxy))
    if args.auth_basic is not None:
        print("            Basisc auth:      {auth}".format(auth=args.auth_basic))
    if args.auth_digest is not None:
        print("            Digest auth:      {auth}".format(auth=args.auth_digest))
    print()

    print("      MUTATIONS")
    print("            Mutating headers: {num}".format(num=len(h_mutations)))
    print(
        "            Methods:          {num} ({m})".format(
            num=len(args.method), m=", ".join(args.method)
        )
    )
    print("            Slashes:          {slash}".format(slash=args.slash))
    print(
        "            Extensions:       {ext} ({val})".format(
            ext=len(args.ext),
            val="empty extension"
            if args.ext[0] == "" and len(args.ext) == 1
            else ", ".join('"' + item + '"' for item in args.ext),
        )
    )
    print("            Words:            {num}".format(num=len(words)))
    print()
    print("      TOTAL REQUESTS: {num}".format(num=total))
    print()


def print_round_header(headers):
    """Print round header with http headers."""
    max_length = 80
    key_length = len(max(headers.keys(), key=len))
    print("-" * max_length)
    for key in headers:
        padding = key_length - len(key)
        value = headers[key]
        while len(value) > max_length - len(key) - padding - 2:
            value = value[:-1]
        print("{key}: {pad}{val}".format(key=key, pad=" " * padding, val=value))
    print()


# -------------------------------------------------------------------------------------------------
# MAIN ENTRYPOINT
# -------------------------------------------------------------------------------------------------


def main():
    """Start the program."""
    args = get_args()

    # Get base url
    url = args.BASE_URL
    # optional arguments
    headers_initial = get_headers(args.header)
    cookies = get_cookies(args.cookie)
    proxies = get_proxies(args.proxy)
    auth = get_auth_method(args.auth_basic, args.auth_digest)
    # mutating arguments
    methods = args.method
    extensions = args.ext
    slashes = get_slash_values(args.slash)
    headers_mutated = merge_headers(headers_initial, args.mheader)
    headers_mutated = mutate_headers(**headers_mutated)
    # dictionary words
    words = get_words(args.word, args.wordlist)

    if not args.new:
        sess = get_session(auth, headers_initial, proxies)

    print_banner(url, words, args, headers_mutated)

    curr = 1
    total = (
        len(words)
        * len(headers_mutated)
        * len(args.method)
        * (2 if args.slash == "both" else 1)
        * len(extensions)
    )
    for headers in headers_mutated:
        print_round_header(headers)
        for method in methods:
            for word in words:
                for extension in extensions:
                    for slash in slashes:
                        target = url + word + extension + slash
                        for retry in range(args.retry):
                            s = get_status_text(curr, total, method, target, retry + 1, args.retry)
                            print_status(s)
                            if not args.new:
                                succ, conn = session_request(
                                    sess,
                                    target,
                                    method,
                                    cookies,
                                    headers,
                                    args.timeout,
                                    not args.insecure,
                                )
                            else:
                                succ, conn = request(
                                    target,
                                    method,
                                    auth,
                                    cookies,
                                    headers,
                                    proxies,
                                    args.timeout,
                                    not args.insecure,
                                )
                            clear_status(s)
                            if succ:
                                break
                        if not succ:
                            print_err(
                                "[ERR] [{m}] {target}: {msg}".format(
                                    m=method, target=target, msg=conn["err"]
                                )
                            )
                        else:
                            code = conn.status_code
                            if check_code(code, args.code):
                                print_succ(
                                    "[{code}] [{m}] {target}".format(
                                        code=code, m=method, target=target
                                    )
                                )
                        curr += 1
    print()


if __name__ == "__main__":
    # Catch Ctrl+c and exit without error message
    try:
        main()
    except KeyboardInterrupt:
        print()
        sys.exit(1)
