#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# 2020-11-11 Timo Sturm <timo.sturm@netknights.it>
# 2018-02-21 Cornelius Kölbel <cornelius.koelbel@netknights.it>
# 2017-11-21 Cornelius Kölbel <corenlius.koelbel@netknights.it>
# 2017-10-18 Cornelius Kölbel <cornelius.koelbel@netknights.it>
# 2017-05-02 Friedrich Weber <friedrich.weber@netknights.it>
# 2017-04-25 Cornelius Kölbel <cornelius.koelbel@netknights.it>
#
# Copyright (c) 2017, Cornelius Kölbel
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
import re
import sys

from dateutil import parser
from dateutil.tz import tzlocal, tzutc
from yaml import safe_dump as yaml_safe_dump
from yaml import safe_load as yaml_safe_load

from edumfa.app import create_app
from edumfa.lib.importotp import export_pskc
from edumfa.lib.policy import ACTION
from edumfa.lib.realm import get_default_realm
from edumfa.lib.token import (
    enable_token,
    get_tokens,
    get_tokens_paginated_generator,
    import_token,
    remove_token,
    unassign_token,
)
from edumfa.lib.user import User, split_user
from edumfa.lib.utils import parse_legacy_time, to_unicode
from edumfa.models import Token

__version__ = "0.1"

ALLOWED_ACTIONS = [
    "disable",
    "delete",
    "unassign",
    "mark",
    "export",
    "listuser",
    "tokenrealms",
]

__doc__ = """
This script can be used to clean up the token database.

It can list, disable, delete or mark tokens based on

Conditions:

* last_auth
* orphaned tokens
* any tokeninfo
* unassigned token / token has no user
* tokentype

Actions:

* list
* unassign
* mark
* disable
* delete


    edumfa-token-janitor find
        --last_auth=10h|7d|2y
        --tokeninfo-key=<key>
        --has-not-tokeninfo-key<key>
        --has-tokeninfo-key<key>
        --tokeninfo-value=<value>
        --tokeninfo-value-greater-than=<value>
        --tokeninfo-value-less-than=<value>
        --tokeninfo-value-after=<value>
        --tokeninfo-value-before=<value>
        --assigned false
        --orphaned true
        --tokentype=<type>
        --username=<username@realm>
        --serial=<regexp>
        --description=<regexp>
        
        --action={0!s}
        
        --set-description="new description"
        --set-tokeninfo-key=<key>
        --set-tokeninfo-value=<value>    
        --limit=<value>
    
.. note:: If you fail to redirect the output of this command at the commandline
   to e.g. a file with a UnicodeEncodeError, you need to set the environment
   variable like this:
       export PYTHONIOENCODING=UTF-8
   Here is a good read about it:
   https://stackoverflow.com/questions/4545661/unicodedecodeerror-when-redirecting-to-file

""".format("|".join(ALLOWED_ACTIONS))

import click
from flask import current_app as app
from flask.cli import FlaskGroup

CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"])


def create_prod_app():
    return create_app("production", silent=True, script=True)


@click.group(
    cls=FlaskGroup,
    add_default_commands=False,
    create_app=create_prod_app,
    context_settings=CONTEXT_SETTINGS,
    epilog="Check out our docs at https://edumfa.readthedocs.io/ for more details",
)
def cli():
    pass


def _try_convert_to_integer(given_value_string):
    try:
        return int(given_value_string)
    except ValueError:
        raise click.UsageError("Not an integer: {}".format(given_value_string))


def _compare_greater_than(_key, given_value_string):
    """
    :return: a function which returns True if its parameter (converted to an integer)
             is greater than *given_value_string* (converted to an integer).
    """
    given_value = _try_convert_to_integer(given_value_string)

    def comparator(value):
        try:
            return int(value) > given_value
        except ValueError:
            return False

    return comparator


def _compare_less_than(_key, given_value_string):
    """
    :return: a function which returns True if its parameter (converted to an integer)
             is less than *given_value_string* (converted to an integer).
    """
    given_value = _try_convert_to_integer(given_value_string)

    def comparator(value):
        try:
            return int(value) < given_value
        except ValueError:
            return False

    return comparator


def _parse_datetime(key, value):
    if key == ACTION.LASTAUTH:
        # Special case for last_auth: Legacy values are given in UTC time!
        last_auth = parser.parse(value)
        if not last_auth.tzinfo:
            last_auth = parser.parse(value, tzinfos=tzutc)
        return last_auth
    else:
        # Other values are given in local time
        return parser.parse(parse_legacy_time(value))


def _try_convert_to_datetime(given_value_string):
    try:
        parsed = parser.parse(given_value_string, dayfirst=False)
        if not parsed.tzinfo:
            # If not timezone is given we assume the timestamp is given in local time
            parsed = parsed.replace(tzinfo=tzlocal())
        return parsed
    except ValueError:
        raise click.UsageError(
            "Not a valid datetime format: {}".format(given_value_string)
        )


def _compare_after(key, given_value_string):
    """
    :return: a function which returns True if its parameter (converted to a datetime) occurs after
             *given_value_string* (converted to a datetime).
    """
    given_value = _try_convert_to_datetime(given_value_string)

    def comparator(value):
        try:
            return _parse_datetime(key, value) > given_value
        except ValueError:
            return False

    return comparator


def _compare_before(key, given_value_string):
    """
    :return: a function which returns True if its parameter (converted to a datetime) occurs before
             *given_value_string* (converted to a datetime).
    """
    given_value = _try_convert_to_datetime(given_value_string)

    def comparator(value):
        try:
            return _parse_datetime(key, value) < given_value
        except ValueError:
            return False

    return comparator


def _compare_regex_or_equal(_key, given_regex):
    def comparator(value):
        if type(value) in (int, bool):
            # If the value from the database is an integer, we compare "euqals integer"
            given_value = _try_convert_to_integer(given_regex)
            return given_value == value
        else:
            # if the value from the database is a string, we compare regex
            return re.search(given_regex, value)

    return comparator


def build_tokenvalue_filter(
    key,
    tokeninfo_value,
    tokeninfo_value_greater_than,
    tokeninfo_value_less_than,
    tokeninfo_value_after,
    tokeninfo_value_before,
):
    """
    Build and return a token value filter, which is a list of comparator functions.
    Each comparator function takes a tokeninfo value and returns True if the
    user-defined criterion matches.
    The filter matches a record if *all* comparator functions return True, i.e.
    if the conjunction of all comparators returns True.

    :param key: user-supplied --tokeninfo-key= value
    :param has_tokeninfo_key: user-supplied --has_tokeninfo_key= value
    :param has_not_tokeninfo_key: user-supplied --has_not_tokeninfo_key= value
    :param tokeninfo_value: user-supplied --tokeninfo-value= value
    :param tokeninfo_value_greater_than: user-supplied --tokeninfo-value-greater-than= value
    :param tokeninfo_value_less_than: user-supplied --tokeninfo-value-less-than= value
    :param tokeninfo_value_after: user-supplied --tokeninfo-value-after= value
    :param tokeninfo_value_before: user-supplied --tokeninfo-value-before= value
    :return: list of functions
    """
    tvfilter = []
    if tokeninfo_value:
        tvfilter.append(_compare_regex_or_equal(key, tokeninfo_value))
    if tokeninfo_value_greater_than:
        tvfilter.append(_compare_greater_than(key, tokeninfo_value_greater_than))
    if tokeninfo_value_less_than:
        tvfilter.append(_compare_less_than(key, tokeninfo_value_less_than))
    if tokeninfo_value_after:
        tvfilter.append(_compare_after(key, tokeninfo_value_after))
    if tokeninfo_value_before:
        tvfilter.append(_compare_before(key, tokeninfo_value_before))
    return tvfilter


def _get_tokenlist(
    last_auth,
    assigned,
    active,
    tokeninfo_key,
    tokeninfo_value_filter,
    tokenattribute,
    tokenattribute_filter,
    orphaned,
    orphan_exception_value,
    tokentype,
    serial,
    description,
    username,
    chunksize,
    has_not_tokeninfo_key,
    has_tokeninfo_key,
):
    filter_active = None
    filter_assigned = None
    user_obj = None
    orphaned = orphaned or ""

    if assigned is not None:
        filter_assigned = assigned.lower() == "true"
    if active is not None:
        filter_active = active.lower() == "true"
    if username is not None:
        user_login, user_realm = split_user(username)
        if not user_realm:
            user_realm = get_default_realm()
        user_obj = User(login=user_login, realm=user_realm)
        if not user_obj.exist():
            sys.exit(
                "User {0!s} does not exist. Please check username and realm.".format(
                    user_obj
                )
            )

    if chunksize is not None:
        iterable = get_tokens_paginated_generator(
            tokentype=tokentype,
            active=filter_active,
            assigned=filter_assigned,
            psize=chunksize,
            user=user_obj,
        )
    else:
        iterable = [
            get_tokens(
                tokentype=tokentype,
                active=filter_active,
                assigned=filter_assigned,
                user=user_obj,
            )
        ]
    for tokenobj_list in iterable:
        filtered_list = []
        sys.stderr.write("++ Creating token object list.\n")
        tok_count = 0
        tok_found = 0
        for token_obj in tokenobj_list:
            sys.stderr.write(
                "{0} Tokens processed / {1} Tokens found\r".format(tok_count, tok_found)
            )
            sys.stderr.flush()
            tok_count += 1
            if last_auth and token_obj.check_last_auth_newer(last_auth):
                continue
            if serial and not re.search(serial, token_obj.token.serial):
                continue
            if description and not re.search(description, token_obj.token.description):
                continue
            if has_not_tokeninfo_key:
                if has_not_tokeninfo_key in token_obj.get_tokeninfo():
                    continue
            if has_tokeninfo_key:
                if has_tokeninfo_key not in token_obj.get_tokeninfo():
                    continue
            if tokeninfo_value_filter and tokeninfo_key:
                value = token_obj.get_tokeninfo(tokeninfo_key)
                # if the tokeninfo key is not even set, it does not match the filter
                if value is None:
                    continue
                # suppose not all comparator functions return True
                # => at least one comparator function returns False
                # => at least one user-supplied criterion does not match
                # => the token object does not match the user-supplied criteria
                if not all(comparator(value) for comparator in tokeninfo_value_filter):
                    continue
            if tokenattribute_filter and tokenattribute:
                value = token_obj.token.get(tokenattribute)
                if value is None:
                    continue
                if not all(comparator(value) for comparator in tokenattribute_filter):
                    continue
            if orphaned.upper() in ["1", "TRUE"] and not token_obj.is_orphaned(
                orphan_exception_value
            ):
                continue
            if orphaned.upper() in ["0", "FALSE"] and token_obj.is_orphaned(
                orphan_exception_value
            ):
                continue

            tok_found += 1
            # if everything matched, we append the token object
            filtered_list.append(token_obj)

        sys.stderr.write(
            "{0} Tokens processed / {1} Tokens found\r\n".format(tok_count, tok_found)
        )
        sys.stderr.write("++ Token object list created.\n")
        sys.stderr.flush()
        yield filtered_list


def export_token_data(token_list, attributes=None):
    """
    Returns a list of tokens. Each token again is a simple list of data

    :param token_list:
    :param attributes: display additional user attributes
    :return:
    """
    tokens = []
    for token_obj in token_list:
        token_data = [f"{token_obj.token.serial!s}", f"{token_obj.token.tokentype!s}"]
        try:
            user = token_obj.user
            if user:
                token_data.append(f"{user.info.get('username', '')!s}")
                token_data.append(f"{user.info.get('givenname', '')!s}")
                token_data.append(f"{user.info.get('surname', '')!s}")
                token_data.append(f"{user.uid!s}")
                token_data.append(f"{user.resolver!s}")
                token_data.append(f"{user.realm!s}")

            if attributes:
                for att in attributes.split(","):
                    token_data.append(f"{user.info.get(att, '')!s}")
        except Exception:
            sys.stderr.write(
                f"Failed to determine user for token {token_obj.token.serial!s}.\n"
            )
            token_data.append("**failed to resolve user**")
        tokens.append(token_data)
    return tokens


def export_user_data(token_list, attributes=None):
    """
    Returns a list of users with the information how many tokens this user has assigned

    :param token_list:
    :param attributes: display additional user attributes
    :return:
    """
    users = {}
    for token_obj in token_list:
        try:
            user = token_obj.user
        except Exception:
            sys.stderr.write(
                "Failed to determine user for token {0!s}.\n".format(
                    token_obj.token.serial
                )
            )
        if user:
            uid = f"'{user.info.get('username', '')!s}','{user.info.get('givenname', '')!s}','{user.info.get('surname', '')!s}','{user.uid!s}','{user.resolver!s}','{user.realm!s}'"
            if attributes:
                for att in attributes.split(","):
                    uid += f",'{user.info.get(att, '')!s}'"
        else:
            uid = "N/A" + ", " * 5

        if uid in users.keys():
            users[uid].append(token_obj.token.serial)
        else:
            users[uid] = [token_obj.token.serial]

    return users


@cli.command()
@click.option("--set-description", help="set a new description")
@click.option("--set-tokeninfo-key", help="set a new tokeninfo-key")
@click.option("--set-tokeninfo-value", help="set a new tokeninfo-value")
@click.option(
    "--tokeninfo-value-before",
    metavar="DATETIME",
    help="Interpret tokeninfo values as datetimes, match only if they occur before the given ISO 8601 datetime",
)
@click.option(
    "--tokeninfo-value-after",
    metavar="DATETIME",
    help="Interpret tokeninfo values as datetimes, match only if they occur after the given ISO 8601 datetime",
)
@click.option(
    "--tokeninfo-value-greater-than",
    metavar="INTEGER",
    help="Interpret tokeninfo values as integers and match only if they are "
    "greater than the given integer",
)
@click.option(
    "--tokeninfo-value-less-than",
    metavar="INTEGER",
    help="Interpret tokeninfo values as integers and match only if they are "
    "smaller than the given integer",
)
@click.option(
    "--tokeninfo-value", metavar="REGEX|INTEGER", help="The tokeninfo value to match"
)
@click.option(
    "--has-not-tokeninfo-key",
    help="filters for tokens that have not given the specified tokeninfo-key",
)
@click.option(
    "--has-tokeninfo-key",
    help="filters for tokens that have given the specified tokeninfo-key.",
)
@click.option("--tokeninfo-key", help="The tokeninfo key to match. ")
@click.option(
    "--tokenattribute-value-greater-than",
    metavar="INTEGER",
    help="Match if the value of the token attribute is greater than the given value.",
)
@click.option(
    "--tokenattribute-value-less-than",
    metavar="INTEGER",
    help="Match if the value of the token attribute is less than the given value.",
)
@click.option(
    "--tokenattribute-value",
    metavar="REGEX|INTEGER",
    help="The value of the token attribute which should match.",
)
@click.option(
    "--tokenattribute", help="Match for a certain token attribute from the database."
)
@click.option("--tokentype", help="The tokentype to search.")
@click.option("--serial", help="A regular expression on the serial")
@click.option("--description", help="A regular expression on the description")
@click.option(
    "--username",
    help='An exact username including the userrealm concatenated by "@". If no realm is given, '
    "the default realm will be used.",
)
@click.option(
    "--attributes",
    help='Extends the "listuser" function to display additional user attributes',
)
@click.option(
    "--tokenrealms",
    help="A comma separated list of realms, to which the tokens should be assigned.",
)
@click.option(
    "--sum",
    "sum_tokens",
    help='In case of the action "listuser", this switch specifies if '
    "the output should only contain the number of tokens owned "
    "by the user.",
    is_flag=True,
)
@click.option(
    "--action",
    help="Which action should be performed on the "
    "found tokens. {0!s}. Exporting to PSKC only supports "
    "HOTP, TOTP and PW tokens!".format("|".join(ALLOWED_ACTIONS)),
)
@click.option(
    "--csv",
    is_flag=True,
    help="In case of a simple find, the output is written as CSV instead of the "
    "formatted output.",
)
@click.option(
    "--yaml",
    is_flag=True,
    help="In case of a simple find, the output is written as YAML instead of the "
    "formatted output.",
)
@click.option("--last_auth", help="Can be something like 10h, 7d, or 2y")
@click.option("--assigned", help="True|False|None")
@click.option("--active", help="True|False|None")
@click.option("--orphaned", help="Whether the token is an orphaned token. Set to 1")
@click.option(
    "--orphan_exception_value",
    default=False,
    type=bool,
    help="If an exception occures during orphan check use this value.",
)
@click.option(
    "--b32",
    is_flag=True,
    help="In case of exporting found tokens to CSV the seed is written base32 encoded instead of hex.",
)
@click.option(
    "--chunksize",
    default=None,
    help="Read tokens from the database in smaller chunks to perform operations.",
)
@click.option(
    "--limit",
    default=None,
    help="If an action which writes changes to any token is given, then only apply these changes if "
    "less than or equal LIMIT tokens are affected.",
)
def find(
    last_auth,
    assigned,
    active,
    tokeninfo_key,
    tokeninfo_value,
    tokeninfo_value_greater_than,
    tokeninfo_value_less_than,
    tokeninfo_value_after,
    tokeninfo_value_before,
    orphaned,
    orphan_exception_value,
    tokentype,
    serial,
    description,
    username,
    action,
    set_description,
    set_tokeninfo_key,
    set_tokeninfo_value,
    sum_tokens,
    tokenrealms,
    csv,
    chunksize,
    attributes,
    b32,
    has_not_tokeninfo_key,
    has_tokeninfo_key,
    tokenattribute,
    tokenattribute_value,
    tokenattribute_value_greater_than,
    tokenattribute_value_less_than,
    yaml,
    limit,
):
    """
    finds all tokens which match the conditions
    """
    tokenattributes = [col.key for col in Token.__table__.columns]
    if tokenattribute and tokenattribute not in tokenattributes:
        sys.stderr.write(
            "Unknown token attribute. Allowed attributes are {0!s}\n".format(
                ", ".join(tokenattributes)
            )
        )
        sys.exit(1)
    if action and action not in ALLOWED_ACTIONS:
        sys.stderr.write(
            "Unknown action. Allowed actions are {0!s}\n".format(
                ", ".join(["'{0!s}'".format(x) for x in ALLOWED_ACTIONS])
            )
        )
        sys.exit(1)
    if chunksize is not None:
        chunksize = int(chunksize)
    # filter for tokeninfo values
    tvfilter = build_tokenvalue_filter(
        tokeninfo_key,
        tokeninfo_value,
        tokeninfo_value_greater_than,
        tokeninfo_value_less_than,
        tokeninfo_value_after,
        tokeninfo_value_before,
    )
    # filter for token attribute
    tafilter = build_tokenvalue_filter(
        tokenattribute,
        tokenattribute_value,
        tokenattribute_value_greater_than,
        tokenattribute_value_less_than,
        None,
        None,
    )
    generator = _get_tokenlist(
        last_auth,
        assigned,
        active,
        tokeninfo_key,
        tvfilter,
        tokenattribute,
        tafilter,
        orphaned,
        orphan_exception_value,
        tokentype,
        serial,
        description,
        username,
        chunksize,
        has_not_tokeninfo_key,
        has_tokeninfo_key,
    )
    if chunksize is not None:
        sys.stderr.write(
            "+ Reading tokens from database in chunks of {}...\n".format(chunksize)
        )
    else:
        sys.stderr.write("+ Reading tokens from database...\n")
    for tlist in generator:
        sys.stderr.write("+ Tokens read. Starting action.\n")
        if not action:
            if not csv:
                click.echo("Token serial\tTokeninfo")
                click.echo("=" * 42)
                for token_obj in tlist:
                    click.echo(
                        f"{token_obj.token.serial!s} ({token_obj.token.tokentype!s})\n\t\t{token_obj.token.description!s}\n\t\t{token_obj.get_tokeninfo()!s}"
                    )
            else:
                for token_obj in tlist:
                    click.echo(
                        f"'{token_obj.token.serial!s}','{token_obj.token.tokentype!s}','{token_obj.token.description!s}','{token_obj.get_tokeninfo()!s}'"
                    )

        elif action == "listuser":
            if not sum_tokens:
                tokens = export_token_data(tlist, attributes)
                for token in tokens:
                    click.echo(",".join(["'{0!s}'".format(x) for x in token]))
            else:
                users = export_user_data(tlist, attributes)
                for user, tokens in users.items():
                    click.echo(f"{user!s},{len(tokens)!s}")

        elif action == "export":
            if csv:
                for tokenobj in tlist:
                    if tokenobj.type.lower() not in ["totp", "hotp"]:
                        continue
                    token_dict = tokenobj._to_dict(b32=b32)
                    owner = (
                        f"{tokenobj.user.login!s}@{tokenobj.user.realm!s}"
                        if tokenobj.user
                        else "n/a"
                    )
                    if type == "totp":
                        click.echo(
                            f"{owner!s}, {token_dict.get('serial')!s}, {token_dict.get('otpkey')!s}, {token_dict.get('type')!s}, {token_dict.get('otplen')!s}, {token_dict.get('timestep')!s}"
                        )
                    else:
                        click.echo(
                            f"{owner!s}, {token_dict.get('serial')!s}, {token_dict.get('otpkey')!s}, {token_dict.get('type')!s}, {token_dict.get('otplen')!s}"
                        )
            elif yaml:
                token_list = []
                for tokenobj in tlist:
                    try:
                        token_dict = tokenobj._to_dict(b32=b32)
                        token_dict["owner"] = (
                            "{!s}@{!s}".format(tokenobj.user.login, tokenobj.user.realm)
                            if tokenobj.user
                            else "n/a"
                        )
                        token_list.append(token_dict)
                    except Exception as e:
                        sys.stderr.write(
                            f"\nFailed to export token {token_dict.get('serial')!s}.\n"
                        )
                print(yaml_safe_dump(token_list))
            else:
                key, token_num, soup = export_pskc(tlist)
                sys.stderr.write("\n{0!s} tokens exported.\n".format(token_num))
                sys.stderr.write(
                    "\nThis is the AES encryption key of the token seeds.\n"
                    "You need this key to import the "
                    "tokens again:\n\n\t{0!s}\n\n".format(key)
                )
                click.echo("{0!s}".format(soup))
        else:
            if limit is not None and int(limit) < len(tlist):
                click.echo(
                    f'{len(tlist)!s} tokens would be affected of given action "{action!s}", but limit was {limit!s}. Aborting the action!'
                )
            else:
                for token_obj in tlist:
                    try:
                        if action == "tokenrealms":
                            trealms = [r.strip() for r in tokenrealms.split(",") if r]
                            token_obj.set_realms(trealms)
                            click.echo(
                                f"Setting realms of token {token_obj.token.serial!s} to {trealms!s}."
                            )
                        if action == "disable":
                            enable_token(serial=token_obj.token.serial, enable=False)
                            click.echo(f"Disabling token {token_obj.token.serial!s}")
                        elif action == "delete":
                            remove_token(serial=token_obj.token.serial)
                            click.echo(f"Deleting token {token_obj.token.serial!s}")
                        elif action == "unassign":
                            unassign_token(serial=token_obj.token.serial)
                            click.echo(f"Unassigning token {token_obj.token.serial!s}")
                        elif action == "mark":
                            if set_description:
                                click.echo(
                                    f"Setting description for token {token_obj.token.serial!s}: {set_description!s}"
                                )
                                token_obj.set_description(set_description)
                                token_obj.save()
                            if set_tokeninfo_value and set_tokeninfo_key:
                                click.echo(
                                    f"Setting tokeninfo for token {token_obj.token.serial!s}: {set_tokeninfo_key!s}={set_tokeninfo_value!s}"
                                )
                                token_obj.add_tokeninfo(
                                    set_tokeninfo_key, set_tokeninfo_value
                                )
                                token_obj.save()
                    except Exception as exx:
                        click.echo(f"Failed to process token {token_obj.token.serial}.")
                        click.echo("{0}".format(exx))
        del tlist


@cli.command("updatetokens")
@click.option(
    "--yaml", help="Specify the YAML file with the previously exported tokens."
)
def updatetokens(yaml):
    """
    This can update existing tokens in the eduMFA system. You can specify a yaml file with the tokendata.
    Can be used to reencrypt data, when changing the encryption key.
    """
    print("Loading YAML data. This may take a while.")
    token_list = yaml_safe_load(open(yaml, "r").read())
    for tok in token_list:
        del tok["owner"]
        tok_objects = get_tokens(serial=tok.get("serial"))
        if len(tok_objects) == 0:
            sys.stderr.write(
                "\nCan not find token {0!s}. Not updating.\n".format(tok.get("serial"))
            )
        else:
            print("Updating token {0!s}.".format(tok.get("serial")))
            try:
                tok_objects[0].update(tok)
            except Exception as e:
                sys.stderr.write(
                    "\nFailed to update token {0!s}.".format(tok.get("serial"))
                )


@cli.command("loadtokens")
@click.option("--pskc", help="Import this PSKC file.")
@click.option("--preshared_key_hex", help="The AES encryption key.")
@click.option(
    "--validate_mac",
    default="check_fail_hard",
    help="How the file should be validated.\n"
    "'no_check' : Every token is parsed, ignoring HMAC\n"
    "'check_fail_soft' : Skip tokens with invalid HMAC\n"
    "'check_fail_hard' : Only import tokens if all HMAC are valid.",
)
def loadtokens(pskc, preshared_key_hex, validate_mac):
    """
    Loads token data from a PSKC file.
    """
    from edumfa.lib.importotp import parsePSKCdata

    with open(pskc, "r") as pskcfile:
        file_contents = pskcfile.read()

    tokens, not_parsed_tokens = parsePSKCdata(
        file_contents, preshared_key_hex=preshared_key_hex, validate_mac=validate_mac
    )
    success = 0
    failed = 0
    failed_tokens = []
    for serial in tokens:
        try:
            print("Importing token {0!s}".format(serial))
            import_token(serial, tokens[serial])
            success = success + 1
        except Exception as e:
            failed = failed + 1
            failed_tokens.append(serial)
            print("--- Failed to import token. {0!s}".format(e))

    if not_parsed_tokens:
        print(
            "The following tokens were not read from the PSKC file"
            " because they could not be validated: {0!s}".format(not_parsed_tokens)
        )
    print("Successfully imported {0!s} tokens.".format(success))
    print("Failed to import {0!s} tokens: {1!s}".format(failed, failed_tokens))


if __name__ == "__main__":
    cli()
