#!/usr/bin/env python
# Copyright European Organization for Nuclear Research (CERN)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#                       http://www.apache.org/licenses/LICENSE-2.0
#
# Authors:
# - Vincent Garonne, <vincent.garonne@cern.ch>, 2012-2013
# - Cedric Serfon, <cedric.serfon@cern.ch>, 2014-2018
# - David Cameron, <david.cameron@cern.ch>, 2015
# - Dimitrios Christidis, <dimitrios.christidis@cern.ch>, 2019
#
# PY3K COMPATIBLE

from __future__ import print_function

import os
import sys
import time

import ldap  # pylint: disable=import-error

from rucio.db.sqla.session import get_session
from rucio.client import Client
from rucio.common.config import config_get
from rucio.common.exception import Duplicate

from VOMSAdmin.VOMSCommands import VOMSAdminProxy  # pylint: disable=import-error


UNKNOWN = 3
CRITICAL = 2
WARNING = 1
OK = 0

LDAP_HOSTS = ['ldaps://xldap.cern.ch']
LDAP_OPTIONS = [
    # default configuration comes from /etc/openldap/ldap.conf
    # (ldap.OPT_X_TLS_CACERTDIR, '/etc/pki/tls/certs'),
    # (ldap.OPT_X_TLS_CACERTFILE, '/etc/pki/tls/certs/ca-bundle.crt'),
    # (ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_NEVER),
    (ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND),
    (ldap.OPT_REFERRALS, 0),
]
LDAP_BASE = 'OU=Users,OU=Organic Units,DC=cern,DC=ch'
LDAP_SCOPE = ldap.SCOPE_SUBTREE
LDAP_FILTER = '(&(objectClass=user)(!(memberOf=CN=cern-accounts-service,OU=e-groups,OU=Workgroups,DC=cern,DC=ch)))'
LDAP_ATTRS = ['cn', 'mail', 'proxyAddresses', 'cernExternalMail']
LDAP_PAGE_SIZE = 1000


def get_accounts_identities():
    session = get_session()
    query = '''select b.identity, a.account, a.email from atlas_rucio.accounts a,  atlas_rucio.account_map b where a.account=b.account and identity_type='X509' and account_type='USER' '''
    dns = {}
    try:
        result = session.execute(query)
        for dn, account, email in result:
            dns[dn] = (account, email)
        return dns
    except Exception as error:
        print(error)


def get_ldap_identities():
    """Get user account info from CERN AD/LDAP"""
    for opt_key, opt_val in LDAP_OPTIONS:
        ldap.set_option(opt_key, opt_val)

    conn = ldap.initialize(",".join(LDAP_HOSTS))
    conn.simple_bind_s()

    paged_serverctrls = []
    old_paged_search = [int(x) for x in ldap.__version__.split('.')] < [2, 4, 0]
    if old_paged_search:
        paged_serverctrls.append(ldap.controls.SimplePagedResultsControl(ldap.LDAP_CONTROL_PAGE_OID, True, (LDAP_PAGE_SIZE, '')))
    else:
        paged_serverctrls.append(ldap.controls.SimplePagedResultsControl(True, size=LDAP_PAGE_SIZE, cookie=''))

    ret = {}
    while True:
        msgid = conn.search_ext(LDAP_BASE, LDAP_SCOPE, filterstr=LDAP_FILTER, attrlist=LDAP_ATTRS, serverctrls=paged_serverctrls)
        rtype, rdata, rmsgid, serverctrls = conn.result3(msgid=msgid)
        for dn, attrs in rdata:
            cn = attrs['cn'][0]
            user = {
                'mails': [],
                # 'x509': [],
            }

            for attr in ['mail', 'cernExternalMail']:
                if attr in attrs:
                    user[attr] = attrs[attr][0]

            for pmail in attrs.get('proxyAddresses', []):
                if pmail.lower().startswith('smtp:'):
                    mail = pmail[len('smtp:'):]
                    if mail.lower() not in [umail.lower() for umail in user['mails']]:
                        user['mails'].append(mail)
            for mail in attrs.get('mail', []):
                if mail.lower() not in [umail.lower() for umail in user['mails']]:
                    user['mails'].append(mail)

            ret[cn] = user

        cookie = None
        for serverctrl in serverctrls:
            if old_paged_search:
                if serverctrl.controlType == ldap.LDAP_CONTROL_PAGE_OID:
                    unused_est, cookie = serverctrl.controlValue
                    if cookie:
                        serverctrl.controlValue = (LDAP_PAGE_SIZE, cookie)
                    break
            else:
                if serverctrl.controlType == ldap.controls.SimplePagedResultsControl.controlType:
                    cookie = serverctrl.cookie
                    if cookie:
                        serverctrl.size = LDAP_PAGE_SIZE
                    break

        if not cookie:
            break

        paged_serverctrls = serverctrls
    return ret


if __name__ == '__main__':
    try:
        PROXY = config_get('nagios', 'proxy')
        os.environ["X509_USER_PROXY"] = PROXY
        CERT, KEY = os.environ['X509_USER_PROXY'], os.environ['X509_USER_PROXY']
    except Exception as error:
        print("Failed to get proxy from rucio.cfg")
        sys.exit(CRITICAL)
    starttime = time.time()
    status = OK
    nbusers = 0
    nonicknames = []
    client = Client()
    admin = VOMSAdminProxy(vo='atlas', host='voms2.cern.ch', port=8443,
                           user_cert=CERT, user_key=KEY)
    res = admin.call_method('list-users')
    accounts_email = {account['account']: account['email'] for account in client.list_accounts()}
    accounts = accounts_email.keys()
    scopes = [_ for _ in client.list_scopes()]
    dns = get_accounts_identities()
    ldap_accounts = get_ldap_identities()
    for account in accounts_email:
        if account in ldap_accounts:
            if accounts_email[account] and not accounts_email[account].lower() in [email.lower() for email in ldap_accounts[account]['mails']]:
                print('Bad email for %s : %s vs %s' % (account, accounts_email[account], ldap_accounts[account]))
                try:
                    client.update_account(account=account, key='email', value=ldap_accounts[account]['mail'])
                except Exception as error:
                    print(error)
        else:
            print('%s might not be in ATLAS anymore' % account)
    voms_identities = {}
    if isinstance(res, list):
        for user in res:
            valid_nickname = True
            dn = user._DN
            scope = None
            if dn in dns:
                # Check if scope exists
                account = dns[dn][0]
                scope = 'user.' + account
            else:
                nbusers += 1
                attempts = 0
                totattemps = 3
                for attempts in range(0, totattemps):
                    if attempts < totattemps - 1:
                        try:
                            dn = user._DN
                            ca = user._CA
                            email = user._mail
                            result = admin.call_method('list-user-attributes', dn, ca)
                            if result is None:
                                print("Failed to list-user-attributes for dn: %s" % dn)
                                continue
                            nickname = None
                            try:
                                nickname = result[0]._value
                            except TypeError as error:
                                print('ERROR : Failed to process DN: %s' % dn)
                            if nickname:
                                if nickname in accounts:
                                    if nickname not in voms_identities:
                                        voms_identities[nickname] = [identity['identity'] for identity in client.list_identities(account=nickname) if identity['type'] == 'X509']
                                    if dn not in voms_identities[nickname]:
                                        try:
                                            client.add_identity(account=nickname, identity=dn, authtype='X509', email=email, default=True)
                                            print('Identity %(dn)s added' % locals())
                                        except Duplicate:
                                            pass
                                    scope = 'user.' + account
                                    break
                                else:
                                    if nickname in ldap_accounts:
                                        if email.lower() not in [mail.lower() for mail in ldap_accounts[nickname]['mails']]:
                                            print('Account %s does not exist. To create it : rucio-admin account add --type USER --email %s %s' % (nickname, email, nickname))
                                            break
                                        account = nickname
                                        if account not in accounts:
                                            try:
                                                client.add_account(account=account, type='USER', email=email)
                                                client.add_identity(account=account, identity=dn, authtype='X509', email=email, default=True)
                                                scope = 'user.' + account
                                                print('Account %(account)s added' % locals())
                                            except Exception:
                                                pass
                            elif user._DN not in nonicknames:
                                nonicknames.append(user._DN)
                        except Exception as error:
                            print(error)
                    else:
                        try:
                            print('ERROR getting info for %s' % (user._DN))
                        except UnicodeEncodeError:
                            print('ERROR getting info for %s' % (repr(user._DN)))
                        status = WARNING
            if scope and scope not in scopes and valid_nickname:
                try:
                    client.add_scope(account, scope)
                    print('Scope %(scope)s added' % locals())
                except Duplicate:
                    pass
    else:
        sys.exit(CRITICAL)
    print('%i users extracted from VOMS' % nbusers)
    if nonicknames != []:
        print('Users with no nickname : %s' % str(nonicknames))

    sys.exit(status)
