#!/usr/bin/env python3

import sys, getpass
from ldaptor.protocols.ldap import (
    ldapclient,
    distinguishedname,
    ldapconnector,
    ldapsyntax,
)
from ldaptor import usage, generate_password
from twisted.internet import defer, reactor, protocol


class PasswdClient(ldapclient.LDAPClient):
    def connectionMade(self):
        if self.factory.binddn:
            pwd = self.factory.bindPassword
            if pwd is None:
                pwd = getpass.getpass("Password for %s: " % self.factory.binddn)
            d = self.bind(self.factory.binddn, pwd)
        else:
            d = self.bind()
        d.addCallbacks(self._handle_bind_success, self._handle_bind_fail)

    def _report_ldap_error(self, fail):
        print("fail:", fail.getErrorMessage(), file=sys.stderr)
        global exitStatus
        exitStatus = 1
        return fail

    def getPassword(self, dn):
        if not self.factory.generatePasswords:
            pwd = getpass.getpass("NEW Password for %s: " % dn)
            return defer.succeed((pwd,))
        else:
            return generate_password.generate(reactor)

    def _handle_bind_success(self, x):
        l = []
        for dn in self.factory.dnlist:
            d = self.getPassword(dn)
            d.addCallbacks(
                callback=self._got_password,
                callbackArgs=(dn,),
                errback=self._report_pwgen_error,
            )
            l.append(d)
        dl = defer.DeferredList(l)
        dl.addBoth(lambda x, f=self.transport.loseConnection: f())
        dl.addErrback(defer.logError)
        dl.addBoth(lambda x, f=reactor.stop: f())
        return x

    def _handle_bind_fail(self, fail):
        self._report_ldap_error(fail)
        self.transport.loseConnection()
        reactor.stop()

    def _report_new_password(self, dummy, dn, password):
        if self.factory.generatePasswords:
            print(dn, password)

    def _got_password(self, password, dn):
        assert len(password) == 1
        password = password[0]
        o = ldapsyntax.LDAPEntry(client=self, dn=dn)
        d = o.setPassword(newPasswd=password)
        d.addCallbacks(
            callback=self._report_new_password,
            callbackArgs=(dn, password),
            errback=self._report_ldap_error,
        )
        return d

    def _report_pwgen_error(self, fail):
        fail.trap(generate_password.PwgenException)
        print("pwgen:", fail.getErrorMessage(), file=sys.stderr)
        return fail


class PasswdClientFactory(protocol.ClientFactory):
    protocol = PasswdClient

    def __init__(self, binddn, bindPassword=None, dnlist=(), generatePasswords=0):
        self.binddn = binddn
        self.bindPassword = bindPassword
        self.dnlist = dnlist
        self.generatePasswords = generatePasswords


exitStatus = 0


class MyOptions(
    usage.Options, usage.Options_service_location, usage.Options_bind_mandatory
):
    """LDAPtor command line password change utility"""

    synopsis = "Usage: %s --binddn=DN [OPTION..] [DN..]" % sys.argv[0]
    optFlags = [
        ("generate", None, "Generate random passwords"),
    ]

    def parseArgs(self, *dnlist):
        if not dnlist:
            dnlist = (self.opts["binddn"],)
        self.opts["dnlist"] = dnlist


if __name__ == "__main__":
    import sys, os
    from twisted.python import log

    log.startLogging(sys.stderr, setStdout=0)

    try:
        config = MyOptions()
        config.parseOptions()
    except usage.UsageError as ue:
        print("%s:" % sys.argv[0], ue, file=sys.stderr)
        sys.exit(1)

    bindPassword = None
    if config.opts["bind-auth-fd"]:
        f = os.fdopen(config.opts["bind-auth-fd"])
        bindPassword = f.readline()
        assert bindPassword[-1] == "\n"
        bindPassword = bindPassword[:-1]
        f.close()

    s = PasswdClientFactory(
        dnlist=config.opts["dnlist"],
        binddn=config.opts["binddn"],
        bindPassword=bindPassword,
        generatePasswords=config.opts["generate"],
    )
    dn = distinguishedname.DistinguishedName(stringValue=config.opts["binddn"])
    c = ldapconnector.LDAPConnector(
        reactor, dn, s, overrides=config.opts["service-location"]
    )
    c.connect()
    reactor.run()
    sys.exit(exitStatus)
