#!/usr/bin/env python

# requepl -- http repl with sessions
# Copyright (C) 2011  John Krauss
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

try:
    import requests # currently depends on requests library.  eventually curl?
    from requests.exceptions import RequestException
except ImportError:
    print "You must install the requests library to use requepl."
    print "http://docs.python-requests.org/en/latest/user/install/#install"
    exit(0)
from functools import wraps  # wrrapppppidooooooo
import urlparse
import cmd
import sys
import ast
import pprint
import readline

REQUEST_HEADERS = [
    'Accept',
    'Accept-Charset',
    'Accept-Encoding',
    'Accept-Language',
    'Authorization',
    'Cache-Control',
    'Connection',
    'Cookie',
    'Content-Length',
    'Content-MD5',
    'Content-Type',
    'Date',
    'Expect',
    'From',
    'Host',
    'If-Match',
    'If-Modified-Since',
    'If-None-Match',
    'If-Range',
    'If-Unmodified-Since',
    'Max-Forwards',
    'Pragma',
    'Proxy-Authorization',
    'Range',
    'Referer',
    'TE',
    'Upgrade',
    'User-Agent',
    'Via',
    'Warning']

if len(sys.argv) > 1:
    DEFAULT_HOST = sys.argv[1]
else:
    DEFAULT_HOST = 'localhost'

def parse_request_line(line):
    """Parse a request line.
    Returns url and request_data.
    """
    args = line.split(' ')
    url = args.pop(0)
    raw = ' '.join(args)
    if raw:
        try:
            request_data = ast.literal_eval(raw)
        except: # the diaper is OK.
            print "Request data is not a python object, sending as '%s'" % raw
            request_data = raw
    else:
        request_data = raw
    return url, request_data


def safe_request(func):
    """Prevent request error from killing the repl.  Also, add request
    to history and provide some immediate information about it.
    """
    @wraps(func)
    def wrapped(obj, line):
        partial, request_data = parse_request_line(line)

        url = urlparse.urljoin(obj.hosts[obj.session_name], partial)
        try:
            while True: # do/while construct for interactive redirects.
                response = func(obj, url, request_data)
                obj.histories[obj.session_name].append(response)
                print "request for '%s' with data '%s' responded with `status` %s, %s `headers`, and `content` of length %s" \
                    % (url, request_data, response.status_code, len(response.headers), len(response.content))
                if response.status_code >= 300 and response.status_code < 400:
                    if response.headers.has_key('location'):
                        url = response.headers['location']
                        follow = raw_input("Follow redirect to '%s'? Y/n > " % url)
                        if len(follow) and (follow[0] == 'n' or follow[1] == 'N'):
                            break      # break out early
                    else:
                        print "No location header from 3xx response, can't follow."
                        break
                else:
                    break # not a redirect
        except RequestException as e:
            print "Error requesting url '%s' with data '%s'" % (url, request_data)
        except ValueError as e:
            print "Error requesting url '%s' with data '%s'" % (url, request_data)
        except KeyError as e:
            print "Invalid scheme for url '%s'" % url
        except KeyboardInterrupt as e:
            print ""
            print "Interrupted request for url '%s' with data '%s'" % (url, request_data)

    return wrapped

class Requepl(cmd.Cmd):
    """Requepl provides an HTTP repl with multiple sessions.
    """

    sessions = {}
    histories = {}
    hosts = {}
    pp = pprint.PrettyPrinter(indent=2)
    intro = "***An HTTP repl with multiple sessions. Type `help` for"\
            " a list of commands."

    def preloop(self):
        """Initialize default session.
        """
        # OS X Lion whyyyy
        if 'libedit' in readline.__doc__ or sys.platform == 'darwin':
            self.completekey = None # disable cmd.Cmd's default
                                    # complete binding
            self.old_completer = readline.get_completer()
            readline.set_completer(self.complete)
            readline.parse_and_bind("bind ^I rl_complete")

        self.do_session(DEFAULT_HOST)

    def default(self, line):
        """Show help on unknown line.
        """
        print 'Unrecognized command: "%s"' % line
        self.do_help('help')

    def emptyline(self):
        """Show nothing on empty line.
        """
        pass

    @property
    def prompt(self):
        """Generate prompt with the current session.
        """
        return "(%s)%s> " % (self.session_name, self.hosts[self.session_name])

    @property
    def session(self):
        """Get the current session.
        """
        return self.sessions[self.session_name]

    @property
    def history(self):
        """Get the current history.
        """
        return self.histories[self.session_name]

    @property
    def host(self):
        """Get the current host.
        """
        return self.hosts[self.session_name]

    def _complete(self, text, options):
        """Conveniently filter available options against supplied text
        for autocomplete.
        """
        return filter(lambda x: x.find(text) == 0, options)

    # def _http_www_complete(self, text):
    #     """Completions with [http://[www.]].
    #     """
    #     return self._complete(text, ['http://', 'http://www.'])

    def do_status(self, line):
        """
        headers_response

        Show the content of the last request.
        """
        if len(self.history):
            print self.history[len(self.history) - 1].status_code
        else:
            print "No requests made in this session yet."

    def do_headers(self, line):
        """
        headers

        Show response headers from the last request.
        """
        if len(self.history):
            print self.pp.pprint(self.history[len(self.history) - 1].headers)
        else:
            print "No requests made in this session yet."

    def do_content(self, line):
        """
        content

        Show the content of the last request.
        """
        if len(self.history):
            print self.history[len(self.history) - 1].content
        else:
            print "No requests made in this session yet."

    def do_set_header(self, line):
        """
        set_header <key> <value>

        Set a request header for the current session.  Value will be encoded.
        """
        args = line.split(' ')
        key = args.pop(0)
        value = ' '.join(args)
        if key and value:
            self.session.headers[key] = value
        else:
            print 'You must specify a key and value for the header.'
        self.do_request_headers('')

    def complete_set_header(self, text, line, begidx, endidx):
        """Provide standard headers for autocomplete.
        """
        return self._complete(text, REQUEST_HEADERS)

    def do_unset_header(self, line):
        """
        unset_header <key>

        Unset a request header for the current session.
        """
        if(self.session.headers.has_key(line)):
            self.session.headers.pop(line)
        else:
            print "Header '%s' is not set for this session." % line
        self.do_request_headers('')

    def complete_unset_header(self, text, line, begidx, endidx):
        """Provide existing headers for autocomplete.
        """
        return self._complete(text, self.session.headers.keys())

    def do_request_headers(self, line):
        """
        Display the current request headers for this session.
        """
        self.pp.pprint(self.session.headers)

    def _create_and_enter_session(self, host, name):
        """ Create a new session for <host> with <name>. Overwrites existing session.
        """
        self.session_name = name
        self.sessions[name] = requests.session()
        self.hosts[name] = ''
        self.histories[name] = []

        split = urlparse.urlsplit(host)

        # replace the scheme with http no matter what
        if split.scheme != 'http' and split.scheme != '':
            print 'Only HTTP scheme allowed.  Replacing scheme with HTTP.'

        # otherwise urlparse starts with path
        if split.scheme == '':
            netloc = split.path
            path = '/'
        else:
            netloc = split.netloc
            path = split.path

        host = urlparse.urlunsplit((
            'http', netloc, path, split.query, split.fragment))
        self.hosts[name] = host

    def do_session(self, line):
        """
        session <host> [name]
        session <name>

        Enters a new session with <host>, named for the host or with
        the optional [name].  If [name] already exists, replaces it.

        If only one parameter is specified and it matches an existing
        name, that session is reentered.
        """
        args = line.split(' ')

        if len(args) is 2: # force enter a new session with specified
                           # name.  Will overwrite whatever already
                           # has that name.
            self._create_and_enter_session(args[0], args[1])
        elif len(args) is 1 and args[0]: # try to reenter existing session, create otherwise
            name = args[0]
            if self.sessions.has_key(name):
                self.session_name = name
            else:
                self._create_and_enter_session(args[0], args[0])
        else: # fail.
            print "You should specify only a host and an optional name."
            self.do_help('session')

    def complete_session(self, text, line, begidx, endidx):
        """Complete with an available session.
        """
        return self._complete(text, self.sessions.keys())

    def do_sessions(self, line):
        """
        sessions

        List the available sessions.
        """
        for session_name in self.sessions.iterkeys():
            print session_name

    # def do_clear(self, line):
    #     """
    #     clear

    #     Clear the current session's cookies, headers, host, and history.
    #     """
    #     self.sessions[self.session_name] = requests.session()
    #     self.hosts[self.session_name] = ''
    #     self.histories[self.session_name] = []

    #     print "Cleared current session."

    def do_destroy(self, line):
        """
        destroy

        Destroy the current session.  Drops into another session if there are
        more, exits otherwise.
        """

        self.sessions.pop(self.session_name)
        self.hosts.pop(self.session_name)
        self.histories.pop(self.session_name)

        names = self.sessions.keys()
        if len(names) > 0:
            self.do_session(names[len(names)-1])
        else:
            print "No more sessions, exiting."
            return True
            # self.hosts[self.session_name] = ''
            # return 0

    def do_cookies(self, line):
        """
        cookies

        Show the cookies in the current session's cookie jar.
        """
        self.pp.pprint(self.sessions[self.session_name].cookies)

    # def do_host(self, line):
    #     """
    #     host [<url>]

    #     Set a host prefix for all requests in this session.  `host`
    #     alone will remove the prefix.
    #     """
    #     if line.find(' ') != -1:
    #         print 'Host cannot be broken with a space.'
    #     else:
    #         split = urlparse.urlsplit(line)

    #         if not split.scheme:
    #             split = urlparse.urlsplit('//' + line, 'http')
    #         self.hosts[self.session_name] = urlparse.urlunsplit(split)

    # def complete_host(self, text, line, begidx, endidx):
    #     return self._http_www_complete(text)

    @safe_request
    def do_head(self, url, request_data):
        """
        head <url>

        Synchronously HEAD url in relation to the session's host.
        You will be prompted to follow redirects.
        """
        return self.session.head(url, allow_redirects=False)

    # def complete_head(self, text, line, begidx, endidx):
    #     return self._http_www_complete(text)

    @safe_request
    def do_get(self, url, request_data):
        """
        get <url>

        Synchronously GET url in relation to the session's host.
        You will be prompted to follow redirects.
        """
        return self.session.get(url, allow_redirects=False)

    # def complete_get(self, text, line, begidx, endidx):
    #     return self._http_www_complete(text)

    @safe_request
    def do_post(self, url, request_data):
        """
        post <url> [request_data]

        Synchronously POST to url in relation to the session's host.
        If there is [request_data], it is parsed as a Python object
        and then form encoded.  If it cannot be parsed as a Python
        object, it is still URL encoded.

        Sending form encoded:
        post / {"key":"value"}

        Sending JSON:
        post / '{"key":"value"}'

        Sending arbitrary string:
        post / 'arbitrary post data'
        """
        return self.session.post(url, data=request_data)

    # def complete_post(self, text, line, begidx, endidx):
    #     return self._http_www_complete(text)

    @safe_request
    def do_put(self, url, request_data):
        """
        put <url> [request_data]

        Synchronously PUT to url in relation to the session's host.
        If there is [request_data], it is parsed as a Python object
        and then form encoded.  If it cannot be parsed as a Python
        object, it is still URL encoded.

        Sending form encoded:
        put / {"key":"value"}

        Sending JSON:
        put / '{"key":"value"}'

        Sending arbitrary string:
        put / 'arbitrary post data'
        """
        return self.session.put(url, data=request_data)

    # def complete_put(self, text, line, begidx, endidx):
    #     return self._http_www_complete(text)

    @safe_request
    def do_delete(self, url, request_data):
        """
        delete <url> [request_data]

        Synchronously DELETE url in relation to the session's host.
        If there is [request_data], it is parsed as a Python object
        and then form encoded.  If it cannot be parsed as a Python
        object, it is still URL encoded.

        Sending form encoded:
        delete / {"key":"value"}

        Sending JSON:
        delete / '{"key":"value"}'

        Sending arbitrary string:
        delete / 'arbitrary post data'
        """
        return self.session.delete(url, data=request_data)

    # def complete_delete(self, text, line, begidx, endidx):
    #     return self._http_www_complete(text)

    def help_help(self):
        """Take `help` off the list of undocumented commands.
        """
        self.do_help('')

    def do_exit(self, line):
        """
        Exit Requepl.
        """
        return True

    def do_quit(self, line):
        """
        Exit Requepl.
        """
        return True

    def do_EOF(self, line):
        """Exit the repl.
        """
        print ''
        return True

if __name__ == '__main__':
    try:
        Requepl().cmdloop()
    except KeyboardInterrupt:
        print ''
