#!python

import argparse
import bisect
import code
from collections import Counter, defaultdict
from datetime import datetime
from enum import Enum
from io import IOBase
import operator
import re
import readline
import rlcompleter
import sys

from lxml.etree import iterparse
import phonenumbers as pn


readline.parse_and_bind("tab: complete")
readline.parse_and_bind("set editing-mode vi")

"""To run, do `python3 silbacre.py /path/to/messages.xml`

You will be presented with a Python shell with the following variables defined:

    messages:   A list of dictionaries where each dictionary has the keys
                address, body, etc. as defined in the xml file; in addition,
                date is re-formatted as a datetime object. Note that both list
                and dictionary are actually subclasses with some added benefits.

    types:      A string that when printed tells you the meaning of `type` in a
                message.

    SearchTerm: A class for easy searching. A term can be created initially with
                a callable or a simple tuple: the key of the message dictionary
                to be searched, the expected value, and how to match (either
                Search.lowercase or Search.regex). If a callable, it will be
                given the message object as its only argument and is expected to
                return a boolean.
                The regular expressions are by default case sensitive; to make
                it insensitive, put (?i) at the beginning of the RegExp.
                (`pydoc re` has more information on what flags are allowed.)

                To combine search terms, simply use the & and | operators. For
                example:

                    >>> term = SearchTerm(('body', 'morning', Search.lowercase))
                    >>> combined = term & ('type', '1', Search.lowercase)

                To check a message by a term, use the * operator:

                    >>> does_match = message * term

                though the `messages` object has a convenient `filter` method.
                """


Search = Enum('Search', 'lowercase regex')
Type = Enum('Type', 'sent received')


class SearchTerm:
    def __init__(self, train=None):
        if callable(train):
            self.train = (train,)
        elif train is None or isinstance(train[0], tuple):
            self.train = train
        else:
            self.train = (train,)

    @classmethod
    def evaluate(cls, train, message):
        current = False
        op = operator.or_
        train = iter(train)
        for term in train:
            if callable(term):
                result = term(message)
            elif isinstance(term[0], tuple):
                result = cls.evaluate(term, message)
            else:
                msg, srch = message[term[0]], term[1]
                search_type = term[2] or Search.lowercase
                if search_type == Search.regex:
                    result = bool(re.match(srch, msg))
                elif search_type == Search.lowercase:
                    result = (srch.lower() in msg.lower())
                elif search_type == Search.function:
                    result = srch(msg)
            current = op(current, result)
            try:
                op = next(train)
            except StopIteration:
                return current

    def __or__(self, train):
        if isinstance(train, SearchTerm):
            train = train.train
        if self.train is None:
            return SearchTerm(train)
        return SearchTerm((self.train, operator.or_, train))

    def __and__(self, train):
        if isinstance(train, SearchTerm):
            train = train.train
        if self.train is None:
            return SearchTerm(train)
        return SearchTerm((self.train, operator.and_, train))

    def __repr__(self, train=None):
        if not train:
            formatter = "SearchTerm({})"
            train = self.train
            if not train:
                return "SearchTerm()"
        else:
            formatter = "{}"
        ops = {operator.or_: '|', operator.and_: '&'}
        string = ""
        train = iter(train)
        for term in train:
            if isinstance(term[0], tuple):
                string += self.__repr__(term)
            else:
                string += str(term)
            try:
                op = next(train)
                string += " " + ops[op] + " "
            except StopIteration:
                break
        return formatter.format(string)      


class Messages(list):
    def __init__(self, messages=None, *message_files):
        self.timestamps = set()
        if isinstance(messages, (str, IOBase)):
            list.__init__(self)
            self.import_messages(messages)
        elif messages is not None:
            list.__init__(self)
            for message in messages:
                if message not in self:
                    bisect.insort(self, message)
                    self.timestamps.add(message.timestamp)
        for message_file in message_files:
            self.import_messages(message_file)

    def __contains__(self, message):
        return message.timestamp in self.timestamps

    def __getitem__(self, item):
        if isinstance(item, tuple):
            item = SearchTerm(item)
        if isinstance(item, SearchTerm):
            return self.filter(item)
        result = list.__getitem__(self, item)
        if isinstance(result, list):
            return self.__class__(result)
        return result

    def append(self, item):
        if item not in self:
            bisect.insort(self, item)
            self.timestamps.add(item)

    def filter(self, term=SearchTerm(), **kwargs):
        """Return list of messages that match given search term.

        term: SearchTerm object or callable; can be empty
        kwargs: additional terms to search for; values can be two-tuples or a single            item to be matched with Search.lowercase"""

        if callable(term) or isinstance(term, tuple):
            term = SearchTerm(term)
        for key, value in kwargs.items():
            if isinstance(value, tuple):
                term &= (key, *value)
            else:
                term &= (key, value, Search.lowercase)
        return Messages(message for message in self if message.matches(term))

    def group(self, key, use_dict=True):
        by_key = defaultdict(self.__class__)
        for message in self:
            by_key[message[key]].append(message)
        if use_dict:
            return dict(by_key)
        return tuple(v for _,v in sorted(by_key.items()))

    def import_messages(self, backup):
        for _, elem in iterparse(backup, recover=True, huge_tree=True):
            if elem.tag in ("sms", "mms"):
                try:
                    message = Message(elem)
                except ValueError:
                    continue
                if message not in self:
                    bisect.insort(self, message)
                    self.timestamps.add(message.timestamp)


class Message(dict):
    def __init__(self, message):
        if isinstance(message, dict):
            dict.__init__(self, message)
            self.element = None
        else:
            dict.__init__(self, message.attrib)
            self.element = message
            parser = {'sms': self._parse_sms, 'mms': self._parse_mms}
            parser[message.tag](message)
            self['timestamp'] = self['date']
            self['date'] = datetime.fromtimestamp(int(self['date'][:-3]))
            self['address'] = normalise_phonenumber(self['address'])

    def _parse_sms(self, message):
        self['type'] = (Type.received, Type.sent)[int(self['type'])-1]

    def _parse_mms(self, message):
        parts, addrs = message.getchildren()
        numbers = set()
        for child in addrs.getchildren():
            ctype = child.get('type')
            address = child.get('address')
            if ctype == '137':
                if address == self['address']:
                    self['type'] = Type.received
                else:
                    self['type'] = Type.sent
            numbers.add(address)
        if len(numbers) > 2:
            raise ValueError('Group text')
        for part in parts.getchildren():
            if part.get('ct') == 'text/plain':
                self['body'] = part.get('text')
                break
        else:
            raise ValueError('Not text message')
        


    def __hash__(self):
        if getattr(self, '_hash', None) is None:
            self._hash = hash(frozenset(self.values()))
        return self._hash

    def __lt__(self, other):
        return self.date < other.date

    def __getattr__(self, attr):
        try:
            return self[attr]
        except KeyError:
            pass
        raise AttributeError(attr)

    def __gt__(self, other):
        return self.date > other.date

    def __str__(self):
        return self.body

    def __sub__(self, other):
        return self.date - other.date

    def matches(self, term):
        return term.evaluate(term.train, self)


def normalise_phonenumber(number):
    try:
        return pn.format_number(pn.parse(number, 'US'), pn.PhoneNumberFormat.NATIONAL)
    except pn.phonenumberutil.NumberParseException:
        return number



if __name__ == '__main__':
    import code

    parser = argparse.ArgumentParser(
        description="API for reading Silence backup files",
    )
    
    parser.add_argument(
        'backup-path', type=argparse.FileType('rb'), nargs='+',
        help='Path to messages backup file',
    )
    parser.add_argument(
        '-q', '--quiet', action='store_true',
        help="""Do not output prompts, etc. Can be helpful for scripts, for example `cat script.py | silbacre` will output only what is explicitly printed in script.py. Note that silbacre follows the rules of the Python console when it comes to blank lines: you will need two newlines after every unnested block and avoid blank lines inside blocks.""",
    )

    args = vars(parser.parse_args())
    messages = Messages(*args['backup-path'])
    
    types = "1 -> Received\n2 -> Sent"
    prompt = '' if args['quiet'] else "Types of messages:\n" + types
    readfunc = (lambda _: input()) if args['quiet'] else None
    code.interact(prompt, local=locals(), readfunc=readfunc)

