#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# Copyright © 2017 Adrian Perez <aperez@igalia.com>
#
# Distributed under terms of the GPLv3 license.

__version__ = 1

import attr
import os
import re
import requests
import string

from datetime import datetime, timedelta
from delorean import Delorean
from memoized_property import memoized_property
from urllib.parse import quote as urlquote


def _create_requests_session():
    s = requests.Session()
    s.headers.update({
        "Content-Type": "application/json",
    })
    return s


def _get_bool_from_env(varname):
    def env_getter():
        value = os.environ.get(varname, None)
        if value is None:
            return False
        value = value.strip().lower()
        return value and value not in ("0", "no", "false")
    return env_getter


class ApiError(Exception):
    pass


class PurgeTimeout(Exception):
    pass


_time_unit_map = dict(second="seconds", seconds="seconds",
                      minute="minutes", minutes="minutes",
                      hour="hours", hours="hours",
                      day="days", days="days",
                      week="weeks", weeks="weeks")

def _string_to_timedelta(s):
    parts = s.strip().split()
    if len(parts) == 1:
        return timedelta(days=int(parts[0]))
    if len(parts) != 2:
        raise ValueError(s)
    amount = int(parts[0])
    unit = parts[1].strip().lower()
    if unit in ("month", "months"):
        amount *= 30
        unit = "days"
    if unit in ("year", "years"):
        amount *= 365
        unit = "days"
    unit = _time_unit_map.get(unit, None)
    if unit is None:
        raise ValueError("Invalid unit: {!r}".format(parts[1].strip()))
    return timedelta(**{ unit: amount })


@attr.s(frozen=True)
class Config(object):
    # TODO: Properly validate the URL.
    homeserver = attr.ib(validator=attr.validators.instance_of(str))
    keep = attr.ib(validator=attr.validators.instance_of(timedelta),
                   convert=_string_to_timedelta)
    token = attr.ib(validator=attr.validators.instance_of(str))
    rooms = attr.ib(validator=attr.validators.instance_of(set), hash=False)
    purge_request_timeout = \
            attr.ib(validator=attr.validators.instance_of(timedelta),
                    convert=_string_to_timedelta,
                    default=timedelta(seconds=180))

    @staticmethod
    def load(path):
        from configparser import ConfigParser
        ini = ConfigParser(default_section=None, interpolation=None)
        ini.read(path)
        rooms = set()
        cfg = Config(**ini["synpurge"], rooms=rooms)
        [rooms.add(RoomConfig(**ini[s], name=s, config=cfg))
            for s in ini.sections() if s != "synpurge"]
        return cfg


@attr.s(frozen=True)
class RoomConfig(object):
    _config = attr.ib(validator=attr.validators.instance_of(Config),
                      repr=False)
    name = attr.ib(validator=attr.validators.instance_of(str))
    _keep = attr.ib(validator=attr.validators.optional(attr.validators.instance_of(timedelta)),
                    convert=lambda s: None if s is None else
                    _string_to_timedelta(s),
                    default=None)
    _token = attr.ib(validator=attr.validators.optional(attr.validators.instance_of(str)),
                     default=None)
    pattern = attr.ib(validator=attr.validators.instance_of(bool),
                      default=False, convert=bool)

    @property
    def keep(self):
        return self._config.keep if self._keep is None else self._keep

    @property
    def token(self):
        return self._config.token if self._token is None else self._token

    def build_alias_matcher(self):
        re_match = re.compile(r"^" + self.name + r"$").match
        return lambda s: bool(re_match(s))


@attr.s
class Purger(object):
    url = attr.ib(validator=attr.validators.instance_of(str))

    token = attr.ib(validator=attr.validators.instance_of(str),
                    default=attr.Factory(lambda: os.environ.get("SYNPURGE_ACCESS_TOKEN", None)))

    _session = attr.ib(validator=attr.validators.instance_of(requests.Session),
                       default=attr.Factory(_create_requests_session))

    _pretend = attr.ib(validator=attr.validators.instance_of(bool),
                       default=attr.Factory(_get_bool_from_env("SYNPURGE_PRETEND")),
                       init=False, repr=False)

    _debug_requests = attr.ib(validator=attr.validators.instance_of(bool),
                              default=attr.Factory(_get_bool_from_env("SYNPURGE_DEBUG_REQUESTS")),
                              init=False, repr=False)

    chunk_size = attr.ib(validator=attr.validators.instance_of(int), default=5000)

    _API_BASE = "/_matrix/client/r0"

    def request(self, method, url_pattern, params=None, raw_response=False,
                raw_body=False, timeout=None, **kw):
        if params is None:
            params = {}
        if "access_token" not in params:
            params["access_token"] = self.token
        values = {}
        for key, value in kw.items():
            values[key] = urlquote(value)
        tmpl = string.Template(url_pattern)
        url = self.url + self._API_BASE + tmpl.substitute(**values)
        # TODO: Handle rate-limiting and retries.
        req = self._session.prepare_request(requests.Request(method, url,
                                                             params=params))
        r = self._session.send(req, timeout=timeout)

        if self._debug_requests:
            print("[1;35m *[0;0m[32m", req.url, "[0m →[33m", r, "[0m")

        if raw_response:
            return r
        if r.status_code == 200:
            if raw_body:
                return r.text
            else:
                return r.json()
        else:
            raise ApiError(r.text)

    @memoized_property
    def public_rooms(self):
        all_rooms = {}
        for room in self.__iter_public_rooms():
            all_rooms[room["room_id"]] = room
        return all_rooms

    def __iter_public_rooms(self, params=None):
        next_batch = None
        prev_batch = False
        while next_batch != prev_batch:
            data = self.__get_public_rooms_chunk(next_batch, params)
            for room in data["chunk"]:
                yield room
            prev_batch = next_batch
            next_batch = data["next_batch"]

    def __get_public_rooms_chunk(self, next_batch=None, params=None):
        if params is None:
            params = {}
        if next_batch is not None:
            params["next_batch"] = next_batch
        return self.request("GET", "/publicRooms", params=params)

    def get_room_id(self, room_alias, params=None):
        data = self.request("GET", "/directory/room/${alias}",
                            alias=room_alias,
                            params=params)
        return data["room_id"]

    def get_room_messages(self, room_id, start=None, end=None, limit=None,
                          forward=False, params=None):
        if params is None:
            params = {}
        params["dir"] = "f" if forward else "b"
        if start is not None: params["from"] = start
        if end is not None: params["to"] = end
        if limit is not None: params["limit"] = limit
        return self.request("GET", "/rooms/${room_id}/messages",
                            room_id=room_id, params=params)

    def purge_room(self, room_id_or_alias, when, timeout=None, params=None):
        if room_id_or_alias.startswith("!"):
            room_id = room_id_or_alias
        else:
            print("Resolving room alias:", room_id_or_alias, params=params)
            room_id = self.get_room_id(room_id_or_alias, params=params)
        print("Purging events older than", when.humanize(), "for room", room_id)

        event_id, event_time = self._find_event_id_before(room_id, when, params=params)
        if event_id is None:
            print("No history to trim")
            return False

        print("First event to purge:", event_id, "-", event_time.humanize())
        if self._pretend:
            print(" ** SKIPPING: SYNPURGE_PRETEND ENABLED **")
        else:
            try:
                r = self.request("POST",
                                 "/admin/purge_history/${room_id}/${event_id}",
                                 room_id=room_id,
                                 event_id=event_id,
                                 timeout=timeout,
                                 params=params)
            except requests.exceptions.Timeout:
                raise PurgeTimeout(room_id_or_alias)
        return True

    def _find_event_id_before(self, room_id, when, params=None):
        data = self.get_room_messages(room_id, params=params)
        start, end = data["start"], data["end"]
        while start != end:
            for event in data["chunk"]:
                ts = datetime.fromtimestamp(event["origin_server_ts"] / 1000)
                event_time = Delorean(ts, timezone="UTC")
                if event_time < when:
                    return event["event_id"], event_time
            start, end = data["start"], data["end"]
            data = self.get_room_messages(room_id,
                                          end,
                                          limit=self.chunk_size,
                                          params=params)
        return None


if __name__ == "__main__":
    import sys
    try:
        c = Config.load(sys.argv[1] if len(sys.argv) > 1 else "/etc/synpurge.conf")
        p = Purger(url=c.homeserver, token=c.token)
        now = Delorean()
        for room in c.rooms:
            purge_upto = now - room.keep
            params = dict(access_token=room.token)
            print("Purging:", room.name, "keeping history up to", purge_upto.humanize())
            if room.pattern:
                room_alias_matches = room.build_alias_matcher()
                for room_id, room_info in p.public_rooms.items():
                    for room_alias in room_info.get("aliases", []):
                        if room_alias_matches(room_alias):
                            print(" -", room_alias, "→", room_id)
                            try:
                                p.purge_room(room_id,
                                             purge_upto,
                                             timeout=c.purge_request_timeout.total_seconds(),
                                             params=params)
                            except PurgeTimeout as e:
                                print("Timeout: {!r} (continuing)", e)
            else:
                try:
                    p.purge_room(room.name,
                                 purge_upto,
                                 timeout=c.purge_request_timeout.total_seconds(),
                                 params=params)
                except PurgeTimeout as e:
                    print("Timeout: {!r} (continuing)", e)
    except ApiError as e:
        raise SystemExit("Matrix API error: {!s}".format(e))
