#!/usr/bin/python
from retrying import retry
from ldif import LDIFWriter
import sys
import argparse
import ldap
import os
import hashlib


def make_secret(password):
    """Encodes the given password as a base64 SSHA hash+salt buffer"""
    if password.startswith('{SSHA}'):
        return password

    salt = os.urandom(4)

    # hash the password and append the salt
    sha = hashlib.sha1(password)
    sha.update(salt)

    # create a base64 encoded string of the concatenated digest + salt
    digest_salt = '{}{}'.format(sha.digest(), salt).encode('base64').strip()

    # now tag the digest above with the {SSHA} tag
    tagged_digest_salt = '{{SSHA}}{}'.format(digest_salt)

    return tagged_digest_salt


def user_add(b, username, cn, sn, givenName, password, shell, groups):
    """Add a user to the LDAP"""
    uidNumber = increment_uid(b)

    # first add the group
    dn = 'cn=%s,ou=Group,%s' % (username, b)
    add_record = [
     ('objectclass', ['top', 'posixGroup']),
     ('cn', [username]),
     ('memberuid', [uidNumber]),
     ('gidNumber', [uidNumber])
    ]
    conn.add_s(dn, add_record)

    # now add the user
    dn = 'uid=%s,ou=People,%s' % (username, b)

    if not cn:
        cn = username
    if not sn:
        sn = username

    add_record = [
     ('objectclass', ['top', 'person', 'organizationalPerson',
                      'inetOrgPerson', 'posixAccount', 'shadowAccount']),
     ('uid', [username]),
     ('cn', [cn]),
     ('sn', [sn]),
     ('loginShell', [shell]),
     ('uidNumber', [uidNumber]),
     ('gidNumber', [uidNumber]),
     ('homeDirectory', ['/home/%s' % username])
    ]
    if givenName:
        add_record.append(('givenName', [givenName]))
    if password:
        password = make_secret(password)
        add_record.append(('userPassword', [password]))

    conn.add_s(dn, add_record)

    if groups:
        for group in groups:
            group_addusers(b, group, [username])


def user_delete(b, username):
    """Delete a user from the system"""
    # First delete the user
    try:
        dn = 'uid=%s,ou=People,%s' % (username, b)
        conn.delete_s(dn)
    except Exception as e:
        print(e)

    base_dn = 'ou=Group,%s' % b
    filter = '(memberuid=%s)' % username
    attrs = ['']
    groups = conn.search_s(base_dn, ldap.SCOPE_SUBTREE, filter, attrs)
    dns = [dn for dn, attrs in groups]

    try:
        for dn in dns:
            mod_attrs = [(ldap.MOD_DELETE, 'memberuid', username)]
            conn.modify_s(dn, mod_attrs)
    except Exception as e:
        print(e)

    # Next delete the group
    try:
        dn = 'cn=%s,ou=Group,%s' % (username, b)
        conn.delete_s(dn)
    except Exception as e:
        print(e)


def user_list(b):
    """List users defined in the system"""
    base_dn = 'ou=People,%s' % b
    filter = '(objectclass=person)'
    attrs = ['uid']
    for dn, attrs in conn.search_s(base_dn, ldap.SCOPE_SUBTREE, filter, attrs):
        print(attrs['uid'][0])


def user_modify(b, username, **kwargs):
    """Modify a user"""
    dn = 'uid=%s,ou=People,%s' % (username, b)

    mod_attrs = []
    for key, value in kwargs.iteritems():
        if key == 'shell':
            key = 'loginShell'
        if not value:
            continue
        mod_attrs.append((ldap.MOD_DELETE, key, None))
        mod_attrs.append((ldap.MOD_ADD, key, value))
    conn.modify_s(dn, mod_attrs)


def user_reset(b, username, password):
    """Reset a users password"""
    password = make_secret(password)
    dn = 'uid=%s,ou=People,%s' % (username, b)
    conn.passwd_s(dn, None, password)


def user_show(b, username):
    base_dn = 'ou=People,%s' % b
    filter = '(uid=%s)' % username
    writer = LDIFWriter(sys.stdout)
    for dn, attrs in conn.search_s(base_dn, ldap.SCOPE_SUBTREE, filter):
        writer.unparse(dn, attrs)


def user_uidNumber(b, username):
    """Utility function to get the numeric id from a username"""
    base_dn = 'ou=People,%s' % b
    filter = '(uid=%s)' % username
    attrs = ['uidNumber']
    for dn, attrs in conn.search_s(base_dn, ldap.SCOPE_SUBTREE, filter, attrs):
        return attrs['uidNumber'][0]


def group_list(b):
    """List groups defined in the system"""
    base_dn = 'ou=Group,%s' % b
    filter = '(objectclass=posixGroup)'
    for dn, attrs in conn.search_s(base_dn, ldap.SCOPE_SUBTREE, filter):
        print(attrs['gidNumber'][0], attrs['cn'][0])


def group_delete(b, groupname):
    """Delete a user from the system"""
    # First delete the user
    try:
        dn = 'cn=%s,ou=Group,%s' % (groupname, b)
        conn.delete_s(dn)
    except Exception as e:
        print(e)


def group_add(b, groupname):
    """Add a group to the LDAP"""
    gidNumber = increment_gid(b)

    # first add the group
    dn = 'cn=%s,ou=Group,%s' % (groupname, b)
    add_record = [
     ('objectclass', ['top', 'posixGroup']),
     ('cn', [groupname]),
     ('gidNumber', [gidNumber])
    ]
    conn.add_s(dn, add_record)


def group_addusers(b, groupname, username):
    """Add users to a group"""
    dn = 'cn=%s,ou=Group,%s' % (groupname, b)
    for name in username:
        try:
            mod_attrs = []
            mod_attrs.append((ldap.MOD_ADD, 'memberuid', name))
            conn.modify_s(dn, mod_attrs)
        except Exception as error:
            print("Error adding %s to %s: %s" % (name, groupname, error))


def group_delusers(b, groupname, username):
    """Remove users from a group"""
    dn = 'cn=%s,ou=Group,%s' % (groupname, b)
    mod_attrs = []
    for name in username:
        mod_attrs.append((ldap.MOD_DELETE, 'memberuid', name))
    conn.modify_s(dn, mod_attrs)


def group_show(b, groupname):
    base_dn = 'ou=Group,%s' % b
    filter = '(cn=%s)' % groupname
    writer = LDIFWriter(sys.stdout)
    for dn, attrs in conn.search_s(base_dn, ldap.SCOPE_SUBTREE, filter):
        writer.unparse(dn, attrs)


@retry(stop_max_delay=10000)
def increment_uid(b):
    """Generate a new userid"""
    dn = 'cn=uid,%s' % b
    filter = 'objectclass=*'
    attrs = ['uidNumber']

    result = conn.search_s(dn, ldap.SCOPE_SUBTREE, filter, attrs)
    uidNumber = result[0][1]['uidNumber'][0]

    mod_attrs = [(ldap.MOD_DELETE, 'uidNumber', uidNumber),
                 (ldap.MOD_ADD, 'uidNumber', str(int(uidNumber)+1))]

    conn.modify_s(dn, mod_attrs)
    return uidNumber


@retry(stop_max_delay=10000)
def increment_gid(b):
    """Generate a new groupid"""
    dn = 'cn=gid,%s' % b
    filter = 'objectclass=*'
    attrs = ['uidNumber']

    try:
        result = conn.search_s(dn, ldap.SCOPE_SUBTREE, filter, attrs)
        gidNumber = result[0][1]['uidNumber'][0]

        mod_attrs = [(ldap.MOD_DELETE, 'uidNumber', gidNumber),
                     (ldap.MOD_ADD, 'uidNumber', str(int(gidNumber)+1))]

        conn.modify_s(dn, mod_attrs)
    except Exception as e:
        print(e)
        raise
    return gidNumber


def csep(s):
    "A utility function to split a comma separated string into a list"
    try:
        return s.split(',')
    except:
        raise argparse.ArgumentTypeError("Illegal groups value")


parser = argparse.ArgumentParser(prog='obol',
                                 description='useradd for ldap.')

parser.add_argument('-D', metavar="BIND DN", default='cn=Manager,dc=local')
parser.add_argument('-w', metavar="BIND PASSWORD", required=True)
parser.add_argument('-H', metavar="HOST", default='ldap://localhost')
parser.add_argument('-b', metavar="BASE_DN", default='dc=local')
subparsers = parser.add_subparsers(help='commands', dest='target')

users = subparsers.add_parser('user', help='user commands')
user_commands = users.add_subparsers(dest='command')

groups = subparsers.add_parser('group', help='group commands')
group_commands = groups.add_subparsers(dest='command')

# Do the user commands
command = user_commands.add_parser('add', help='add a user')
command.add_argument('username')
command.add_argument('--password')
command.add_argument('--cn', metavar="COMMON NAME")
command.add_argument('--sn', metavar="SURNAME")
command.add_argument('--givenName')
command.add_argument('--shell', default='/bin/bash')
command.add_argument('--groups', type=csep,
                     help="a comma separated list of groups")

command = user_commands.add_parser('delete', help='delete a user')
command.add_argument('username')

command = user_commands.add_parser('show', help='show user details')
command.add_argument('username')

command = user_commands.add_parser('reset', help='reset user password')
command.add_argument('username')
command.add_argument('--password')

command = user_commands.add_parser('modify', help='modify a user attribute')
command.add_argument('username')
command.add_argument('--cn', metavar="COMMON NAME")
command.add_argument('--sn', metavar="SURNAME")
command.add_argument('--givenName')
command.add_argument('--shell', default='/bin/bash')

command = user_commands.add_parser('list', help='list users')

# Now do the group commands
command = group_commands.add_parser('add', help='add a group')
command.add_argument('groupname')

command = group_commands.add_parser('show', help='show a group')
command.add_argument('groupname')

command = group_commands.add_parser('addusers', help='add users to a group')
command.add_argument('groupname')
command.add_argument('username', nargs='+')

command = group_commands.add_parser('delete', help='delete a group')
command.add_argument('groupname')

command = group_commands.add_parser('delusers',
                                    help='delete users from a group')
command.add_argument('groupname')
command.add_argument('username', nargs='+')

command = group_commands.add_parser('list', help='list groups')


if __name__ == '__main__':
    args = parser.parse_args()
    target = args.target
    command = args.command
    dn = args.D
    password = args.w
    host = args.H
    conn = ldap.initialize(host)
    conn.simple_bind_s(dn, password)
    try:
        ldapuser = sys.modules[__name__]
        fun = getattr(ldapuser, '%s_%s' % (target, command))
        args_d = vars(args)
        del args_d['target']
        del args_d['command']
        del args_d['D']
        del args_d['w']
        del args_d['H']
        fun(**args_d)
    except Exception as e:
        print(e)
        raise
