#!/usr/bin/env python3

import os
import sys
import argparse
import getpass
import pkg_resources
import re

import ldap3
from ldap3.core.exceptions import LDAPSocketOpenError, LDAPInvalidDnError, \
    LDAPBindError, LDAPInvalidFilterError
import yaml
import ezldap
from ezldap.terminal import fmt


def main():
    parser = argparse.ArgumentParser(
        description='ezldap CLI - Perform various options on an LDAP directory.',
        # not sure why the formatter width can't be resized...
        formatter_class=lambda prog: argparse.HelpFormatter(prog, indent_increment=1))
    parser.add_argument('-v', '--version', action='version', version=
        '%(prog)s version {}'.format(ezldap.__version__))
    subparsers = parser.add_subparsers(title='Valid commands', metavar='')

    def help_msg(_):
        parser.print_help()
        print()

    parser.set_defaults(func=help_msg)

    # subparsers follow
    config_parser = subparsers.add_parser('config',
        help='Configure ezldap (configs are stored in ~/.ezldap/).',
        description='Create a set of configs for ezldap scripts under ~/.ezldap/. '
        'To delete your ezldap config, just delete the ~/.ezldap folder.')
    config_parser.add_argument('-f', '--force', default=False, const=True, action='store_const',
        help='Whether to force overwriting existing config files in ~/.ezldap '
        '(namely LDIF templates).')
    config_parser.set_defaults(func=config)

    search_parser = subparsers.add_parser('search',
        help='Search for entities by LDAP filter.',
        description='Search for objects in a directory using LDAP filters. '
        'Example: "(objectClass=*)".')
    search_parser.add_argument('filter', nargs=1, type=str, default='(objectClass=*)',
        help='LDAP filter to search by, for example: (objectClass=*)')
    search_parser.set_defaults(func=search)

    search_dn_parser = subparsers.add_parser('search_dn',
        help='Search for and print DNs in a directory that match a keyword.',
        description='Search LDAP tree for a DN keyword and a list of matching DNs.')
    search_dn_parser.add_argument('keyword', nargs='?', type=str, default='',
        help='Keyword to search DNs by (case-sensitive).')
    search_dn_parser.set_defaults(func=search_dn)

    add_user_parser = subparsers.add_parser('add_user', help='Add a user.',
        description="Creates an LDAP user and add to a group. "
        "Creates a same-named LDAP group as well if no group specified. "
        "Prints the randomly generated password to stdout.")
    add_user_parser.add_argument('username', nargs=1, type=str,
        help="Username to create.")
    add_user_parser.add_argument('groupname', nargs='?', type=str, default=None,
        help="Groupname to add user to. If omitted, creates same-named LDAP group. "
        "If the group is a local group, a new LDAP group will not be created "
        "and an ldap add group operation will not be perfomed.")
    add_user_parser.add_argument('--ldif-user', nargs=1, type=str,
        default=['~/.ezldap/add_user.ldif'],
        help='Path of LDIF template to use when adding a user.')
    add_user_parser.add_argument('--ldif-group', nargs=1, type=str,
        default=['~/.ezldap/add_group.ldif'],
        help='Path of LDIF template to use when adding a group.')
    add_user_parser.add_argument('--ldif-add-to-group', nargs=1, type=str,
        default=['~/.ezldap/add_to_group.ldif'],
        help='Path of LDIF template to use when adding a group.')
    add_user_parser.set_defaults(func=add_user)

    add_group_parser = subparsers.add_parser('add_group', help='Add a group.',
        description="Creates an LDAP group using a presupplied LDIF template.")
    add_group_parser.add_argument('groupname', nargs=1, type=str,
        help='LDAP group name to create.')
    add_group_parser.add_argument('gid', nargs='?', type=str, default=None,
        help='GID number of group to create. '
        'If not supplied, uses next available GID number.')
    add_group_parser.add_argument('--ldif', nargs=1, type=str,
        default=['~/.ezldap/add_group.ldif'],
        help='Path of LDIF template to use when performing this operation.')
    add_group_parser.set_defaults(func=add_group)

    add_to_group_parser = subparsers.add_parser('add_to_group',
        help='Add a user to a group.',
        description='Add a user to an LDAP group using a presupplied LDIF template.')
    add_to_group_parser.add_argument('username', nargs=1, type=str,
        help='An LDAP username to add to a group.')
    add_to_group_parser.add_argument('groupname', nargs=1, type=str,
        help='An LDAP groupname to add to a group.')
    add_to_group_parser.add_argument('--ldif', nargs=1, type=str,
        default=['~/.ezldap/add_to_group.ldif'],
        help='Path of LDIF template to use when performing this operation.')
    add_to_group_parser.set_defaults(func=add_to_group)

    add_host_desc = 'Add a host.'
    add_host_parser = subparsers.add_parser('add_host',
        help=add_host_desc, description=add_host_desc)
    add_host_parser.add_argument('hostname', nargs=1, type=str,
        help='Hostname of machine to add. Can be either the fully-qualified '
        'or short hostname. If the short hostname (hostname -s) is used, '
        'a fully-qualified alias will be created based on this directory\'s '
        'suffix ("host1" would become "host1.ezldap.io", for a naming context '
        'of dc=ezldap,dc=io).')
    add_host_parser.add_argument('ip_address', nargs=1, type=str,
        help='IP address of machine to add.')
    add_host_parser.add_argument('--ldif', nargs=1, type=str,
        default=['~/.ezldap/add_host.ldif'],
        help='Path of LDIF template to use when performing this operation.')
    add_host_parser.set_defaults(func=add_host)

    add_ldif_desc = 'Add a generic LDIF template to a directory.'
    add_ldif_parser = subparsers.add_parser('add_ldif',
        description=add_ldif_desc, help=add_ldif_desc)
    add_ldif_parser.add_argument('ldif', nargs=1, type=str,
        help='LDIF file to use as template.')
    add_ldif_parser.set_defaults(func=add_ldif)

    modify_desc = 'Add, replace, or delete an attribute from an entity.'
    modify_parser = subparsers.add_parser('modify',
        help=modify_desc, description=modify_desc)
    modify_parser.add_argument('dn', nargs=1, type=str,
        help='Distinguished Name (DN) of object to modify.')
    modify_parser.add_argument('operation', nargs=1, type=str,
        choices=['add', 'replace', 'delete'],
        help='Type of operation to perform. Can be one of: add, replace, delete.')
    modify_parser.add_argument('attribute', nargs=1, type=str,
        help='Attribute to modify.')
    modify_parser.add_argument('value', nargs=1, type=str,
        help='Value to add, replace, or delete. '
        'When performing a delete operation, passing "-" will delete all values for that attribute.')
    modify_parser.add_argument('replace_with', nargs='?', type=str, default=None,
        help='Value to replace an attribute with when performing a replace operation.')
    modify_parser.set_defaults(func=modify)

    modify_ldif_desc = 'Modify an entry using an LDIF template.'
    modify_ldif_parser = subparsers.add_parser('modify_ldif',
        description=modify_ldif_desc, help=modify_ldif_desc)
    modify_ldif_parser.add_argument('ldif', nargs=1, type=str,
        help='LDIF file to use as template.')
    modify_ldif_parser.set_defaults(func=modify_ldif)

    modify_dn_desc = 'Rename the DN of and/or move an entry.'
    modify_dn_parser = subparsers.add_parser('modify_dn', help=modify_dn_desc,
        description=modify_dn_desc + ' Note that this tool allows you to rename '
        'and move an entry in a single command.')
    modify_dn_parser.add_argument('dn', nargs=1, type=str, help='DN of entry to be moved.')
    modify_dn_parser.add_argument('new_dn', nargs=1, type=str, help='New DN of entry.')
    modify_dn_parser.set_defaults(func=modify_dn)

    delete_desc = 'Delete an entry from an LDAP directory.'
    delete_parser = subparsers.add_parser('delete', help=delete_desc,
        description=delete_desc+' Will print the entry and prompt for confirmation.')
    delete_parser.add_argument('-f', '--force', default=False, const=True, action='store_const',
        help='Do not print entry and do not prompt for confirmation.')
    delete_parser.add_argument('dn', nargs=1, type=str,
        help='Distinguished Name (DN) of object to delete.')
    delete_parser.set_defaults(func=delete)

    ch_home_desc="Change a user's home directory."
    change_home_parser = subparsers.add_parser('change_home',
        help=ch_home_desc, description=ch_home_desc)
    change_home_parser.add_argument('username', nargs=1, type=str,
        help='User to change homeDirectory for.')
    change_home_parser.add_argument('homedir', nargs=1, type=str,
        help='New home directory.')
    change_home_parser.set_defaults(func=change_home)

    ch_shell_desc = "Change a user's default shell."
    change_shell_parser = subparsers.add_parser('change_shell',
        help=ch_shell_desc, description=ch_shell_desc)
    change_shell_parser.add_argument('username', nargs=1, type=str,
        help='User to change loginShell for.')
    change_shell_parser.add_argument('shell', nargs=1, type=str,
        help='Desired shell for user.')
    change_shell_parser.set_defaults(func=change_shell)

    ch_pw_desc="Change or reset a user's password."
    change_pw_parser = subparsers.add_parser('change_pw',
        help=ch_pw_desc, description=ch_pw_desc)
    change_pw_parser.add_argument('username', nargs=1, type=str,
        help='User to reset password for.')
    change_pw_parser.add_argument('-s', '--specify-password',
        default=False, const=True, action='store_const',
        help='Specify a password instead of randomizing it.')
    change_pw_parser.set_defaults(func=change_pw)

    check_pw_parser = subparsers.add_parser('check_pw',
        help="Check a user's password.",
        description='Verify that an LDAP password is correct. '
        'Currently only works with SSHA-encrypted passwords.')
    check_pw_parser.add_argument('username', nargs=1, type=str,
        help='Username to check password for')
    check_pw_parser.set_defaults(func=check_pw)

    bind_info_desc = "Print info about ezldap's connection to your server."
    bind_info_parser = subparsers.add_parser('bind_info', help=bind_info_desc,
        description=bind_info_desc + ' Note that "cleartext" refers to whether '
        'a bind was performed using SSL (using an ldaps:// URI). '
        'TLS will be started before binding if your server supports it.')
    bind_info_parser.add_argument('-a', '--anonymous-bind',
        default=False, const=True, action='store_const',
        help='Show connection info for anonymous binds.')
    bind_info_parser.set_defaults(func=bind_info)

    server_info_desc = 'Print information about the LDAP server you are using.'
    server_info_parser = subparsers.add_parser('server_info',
        help=server_info_desc, description=server_info_desc)
    server_info_parser.set_defaults(func=server_info)

    class_info_desc = 'Print information about a specific LDAP objectClass.'
    class_info_parser = subparsers.add_parser('class_info',
        help=class_info_desc, description=class_info_desc)
    class_info_parser.add_argument('objectClass', nargs=1, type=str,
        help='objectClass to investigate.')
    class_info_parser.add_argument('-n', '--no-superior',
        default=False, const=True, action='store_const',
        help='Do not display superior class info from what this objectClass is '
        'derived from.')
    class_info_parser.set_defaults(func=class_info)

    # allow using extra "--key=value"-style args not explicitly defined, and
    # pack them into argv.replacements as a string dict
    argv, argv_undefined = parser.parse_known_args()
    argv.replacements = {}
    for extra in argv_undefined:
        if extra[:2] == '--' and '=' in extra:
            # handle a fully defined '--key=value' argument
            unpacked = extra[2:].split('=')
            argv.replacements[unpacked[0]] = '='.join(unpacked[1:])
        elif extra[:2] == '--':
            # handle key of a '--key value' argument
            fail('argument "--{}" has no value. Use the format "--key=value" to '
                'specify additional arugments to ezldap.'.format(extra[2:]))
        else:
            # Anything not falling into '--key=value' or '--key value' format
            # is an invalid argument and triggers program exit.
            fail('Unrecognized argument: "{}"'.format(extra))

    # force users to configure package
    if argv.func not in [config, help_msg] and 'EZLDAP_CONFIG' not in os.environ:
        assert_config_exists()

    # bind to directory and perform subparser function
    try:
        argv.func(argv)
    except LDAPSocketOpenError:
        fail('Could not reach LDAP server at {}'.format(ezldap.config()['host']))
    except LDAPBindError:
        fail('Bind failed: invalid credentials.')
    except ezldap.LDIFTemplateError as e:
        fail(e.args[0])


def exists(path):
    return os.path.exists(os.path.expanduser(path))


def assert_config_exists():
    try:
        assert exists('~/.ezldap/config.yml')
        assert exists('~/.ezldap/add_user.ldif')
        assert exists('~/.ezldap/add_group.ldif')
        assert exists('~/.ezldap/add_to_group.ldif')
        assert exists('~/.ezldap/add_host.ldif')
    except AssertionError:
        fail('You must run "ezldap config" before using this script.')


def add_group(argv):
    gid = argv.gid
    with ezldap.auto_bind(server_info=False) as con:
        if gid is None:
            gid = con.next_gidn()

        res = con.add_group(argv.groupname[0], gid=gid, ldif_path=argv.ldif[0],
            **argv.replacements)
        op_summary_ldif_add(res)


def add_to_group(argv):
    with ezldap.auto_bind(server_info=False) as con:
        res = con.add_to_group(argv.username[0], argv.groupname[0],
            ldif_path=argv.ldif[0], **argv.replacements)
        op_summary_ldif_add(res)


def add_user(argv):
    user = argv.username[0]
    group = argv.groupname

    with ezldap.auto_bind(server_info=False) as con:
        if group is None:
            # No group specified, perform a second check for groups named after
            # the user, then create it.
            group = user
            if con.get_group(group) is None:
                print('Creating LDAP group {}... '.format(group), end='')
                res = con.add_group(group, ldif_path=argv.ldif_group[0],
                    **argv.replacements)
                op_summary(res[0]['result'] == 0, res[0])
        elif con.get_group(group) is None:
            fail('Group does not exist.')

        print('Creating user {}... '.format(user), end='')
        passwd = ezldap.random_passwd()
        res = con.add_user(user, group, passwd, ldif_path=argv.ldif_user[0],
            **argv.replacements)
        op_summary_ldif_add(res)

        print('Adding {} to LDAP group {}... '.format(user, group), end='')
        res = con.add_to_group(user, group, ldif_path=argv.ldif_add_to_group[0],
            **argv.replacements)
        op_summary_ldif_add(res)

        print('Password: {}'.format(passwd))


def add_host(argv):
    hostname = argv.hostname[0]
    with ezldap.auto_bind(server_info=False) as con:
        if '.' not in hostname:
            # we're working with the short hostname
            short_name = hostname
            fq_name = short_name + '.' + ezldap.dn_address(con.base_dn())
        else:
            # we're working with the fully qualified hostname
            short_name = re.findall(r'^[^.]+', hostname)[0]
            fq_name = hostname

        res = con.add_host(short_name, argv.ip_address[0], hostname_fq=fq_name,
            ldif_path=argv.ldif[0], **argv.replacements)
        op_summary_ldif_add(res)


def add_ldif(argv):
    with ezldap.auto_bind(server_info=False) as con:
        replacements = con.conf
        replacements.update(argv.replacements)
        ldif = ezldap.ldif_read(argv.ldif[0], replacements)
        res = con.ldif_add(ldif)
        op_summary_ldif_add(res)


def modify(argv):
    op = argv.operation[0]
    dn = argv.dn[0]
    attrib = argv.attribute[0]
    value = argv.value[0]
    with ezldap.auto_bind(server_info=False) as con:
        assert_dn_exists(con, dn)

        if op == 'add':
            res = con.modify_add(dn, attrib, value)
        elif op == 'replace':
            res = con.modify_replace(dn, attrib, value, replace_with=argv.replace_with)
        elif op == 'delete':
            if value == '-':
                # delete all values
                value = None

            res = con.modify_delete(dn, attrib, value)
        else:
            # should never get here, just being safe.
            fail('Unrecognized command.')

        op_summary_ezldap(res)


def modify_ldif(argv):
    with ezldap.auto_bind(server_info=False) as con:
        replacements = con.conf
        replacements.update(argv.replacements)
        ldif = ezldap.ldif_read(argv.ldif[0], replacements)
        res = con.ldif_modify(ldif)
        op_summary_ldif_add(res)


def separate_dn(dn):
    chunks = dn.split(',')
    return (chunks[0], ','.join(chunks[1:]))


def modify_dn(argv):
    dn = argv.dn[0]
    relative_old, superior_old = separate_dn(dn)

    new_dn = argv.new_dn[0]
    relative_new, superior_new = separate_dn(new_dn)

    if dn == new_dn:
        fail('DNs are the same.')

    with ezldap.auto_bind(server_info=False) as con:
        assert_dn_exists(con, dn)

        if relative_old != relative_new:
            res = con.modify_dn(dn, relative_new)
            dn = relative_new + ',' + superior_old
            relative_old = relative_new
            if not res:
                op_summary(res, con.result)

        if superior_old != superior_new:
            res = con.modify_dn(dn, relative_old, new_superior=superior_new)

        op_summary(res, con.result)


def delete(argv):
    dn = argv.dn[0]
    with ezldap.auto_bind(server_info=False) as con:
        assert_dn_exists(con, dn)
        if not argv.force:
            query = con.search_list(search_base=dn, search_scope=ldap3.BASE)
            ezldap.ldif_print(query)

            confirm = input('Delete object? (y/N) ')
            if confirm[0] != 'y':
                sys.exit('Operation aborted.')

        success = con.delete(dn)
        op_summary(success, con.result)


def change_home(argv):
    with ezldap.auto_bind(server_info=False) as con:
        try:
            dn = con.get_user(argv.username[0])['dn'][0]
        except TypeError:
            fail('User "{}" not found.'.format(argv.username[0]))

        res = con.modify_replace(dn, 'homeDirectory', argv.homedir[0])
        op_summary_ezldap(res)


def change_shell(argv):
    with ezldap.auto_bind(server_info=False) as con:
        try:
            dn = con.get_user(argv.username[0])['dn'][0]
        except TypeError:
            fail('User "{}" not found.'.format(argv.username[0]))

        res = con.modify_replace(dn, 'loginShell', argv.shell[0])
        op_summary_ezldap(res)


def change_pw(argv):
    user = argv.username[0]
    if argv.specify_password:
        passwd = getpass.getpass('New password for {}: '.format(user))
        confirm = getpass.getpass('Confirm password: '.format(user))
        if passwd != confirm:
            fail('Passwords do not match.')
    else:
        passwd = ezldap.random_passwd()

    with ezldap.auto_bind(server_info=False) as con:
        try:
            dn = con.get_user(user)['dn'][0]
        except TypeError:
            fail('User "{}" does not exist.'.format(user))

        ssha = ezldap.ssha_passwd(passwd)
        res = con.modify_replace(dn, 'userPassword', ssha)
        op_summary_ezldap(res)
        if not argv.specify_password:
            print('New password for {}: {}'.format(user, passwd))


def check_pw(argv):
    with ezldap.auto_bind(server_info=False) as con:
        query = con.get_user(argv.username[0])
        if query is None:
            fail('User not found.')

        ssha = query['userPassword'][0]

    print('Enter password to verify...')
    passwd = getpass.getpass()

    if ezldap.ssha_check(ssha, passwd):
        print(fmt('Passwords match!', 'green'))
    else:
        fail("Passwords do not match.")


def config(argv):
    if os.path.exists(os.path.expanduser('~/.ezldap/config.yml')):
        conf = ezldap.config()
    else:
        conf = ezldap.guess_config()

    prompts = {
        'host': 'LDAP host',
        'binddn': 'Bind DN (leave blank for anonymous bind)',
        'bindpw': 'Bind password (leave blank to prompt for password)',
        'peopledn': 'User base dn',
        'groupdn': 'Group base dn',
        'hostsdn': 'Host base dn',
        'homedir': 'Default home directory for new users'
    }

    print('Configuring ezldap...')
    print('Default values are in [brackets] - to accept, press Enter.\n')
    for k, v in prompts.items():
        conf[k] = interactive_conf(v, conf[k])

    print(fmt('\nWriting configs to ~/.ezldap/', 'green'))
    os.makedirs(os.path.expanduser('~/.ezldap'), mode=0o700, exist_ok=True)
    yaml.dump(conf,
        stream=open(os.path.expanduser('~/.ezldap/config.yml'), 'w'),
        default_flow_style=False)

    for template in pkg_resources.resource_listdir('ezldap', 'templates'):
        if argv.force or not os.path.exists(os.path.expanduser('~/.ezldap/' + template)):
            content = pkg_resources.resource_string('ezldap', 'templates/' + template).decode()
            with open(os.path.expanduser('~/.ezldap/' + template), 'w') as fout:
                fout.write(content)

    print(fmt("Edit config.yaml and the LDIF templates in ~/.ezldap/ to "
        "configure ezldap's behavior.", 'green'))


def interactive_conf(prompt, default=None):
    '''
    Prompt the user for a config value, press ENTER for default val.
    '''
    if default is None:
        val = input(fmt('{}: '.format(prompt), bold=True))
    else:
        if 'None' in default:
            default = re.sub(r'None', 'dc=example,dc=com', default)

        val = input('{} [{}]: '.format(fmt(prompt, bold=True), default))

    if val == '':
        return default
    else:
        return val


def search(argv):
    # Make command always wrap searches in parentheses for convenience
    # i.e. cn=someuser is 4 less characters than '(cn=someuser)',
    # and more closely mimics ldapsearch's behavior
    search_filter = argv.filter[0].strip()
    if search_filter[0] != '(':
        search_filter = '(' + search_filter

    if search_filter[-1] != ')':
        search_filter = search_filter + ')'

    conf = ezldap.config()
    with ezldap.Connection(conf['host']) as con:
        try:
            ezldap.ldif_print(con.search_list(search_filter=search_filter))
        except LDAPInvalidFilterError:
            fail('Invalid LDAP filter.')


def search_dn(argv):
    conf = ezldap.config()
    with ezldap.Connection(conf['host']) as con:
        # You cannot use dn as a search filter, so we dump all dns and then
        # search through those ourselves.
        query = con.search_list_t(attributes=None)

    for dn in query['dn']:
        if argv.keyword in dn:
            print(dn)


def bind_info(argv):
    if argv.anonymous_bind:
        with ezldap.Connection(ezldap.config()['host']) as con:
            print(con)
    else:
        with ezldap.auto_bind(server_info=False) as con:
            print(con)


def server_info(argv):
    conf = ezldap.config()
    with ezldap.Connection(conf['host']) as con:
        print(con.server.info, end='')


def class_info(argv):
    conf = ezldap.config()
    with ezldap.Connection(conf['host']) as con:
        try:
            object_class = con.server.schema.object_classes[argv.objectClass[0]]
            class_list = [object_class]
            if not argv.no_superior:
                # fetch all superior classes in the chain and add them too
                while object_class.superior is not None:
                    object_class = con.server.schema.object_classes[object_class.superior[0]]
                    class_list.append(object_class)

            for class_ in class_list:
                print(class_)

        except KeyError:
            fail('objectClass "{}" not found.'.format(argv.objectClass[0]))


def assert_dn_exists(con, dn):
    try:
        if not con.exists(dn):
            fail('DN "{}" not found.'.format(dn))
    except LDAPInvalidDnError:
        fail('DN "{}" not found.'.format(dn))


def abort(msg):
    '''
    Like a failure, but yellow text.
    '''
    if msg.strip == '':
        msg = 'Operation aborted.'

    sys.exit(fmt(msg, 'yellow'))


def fail(msg):
    '''
    Print a failure message in big red text, and stop the program.
    '''
    if msg.strip() == '':
        msg = 'Operation failed.'

    sys.exit(fmt(msg, 'red'))


def op_summary(success, result):
    '''
    Summarize the failure of an operation if one occured.
    '''
    if success:
        print(fmt('Success!', 'green'))
    else:
        if result['message'] != '':
            # any operation that fails *should* exit.
            fail(result['message'])
        else:
            # fallback for cases with a description, but no error message
            fail(result['description'])


def op_summary_ezldap(result):
    op_summary(result['result'] == 0, result)


def op_summary_ldif_add(result):
    op_summary(result[0]['result'] == 0, result[0])


if __name__ == '__main__':
    main()
