#!/usr/bin/env python
# Copyright European Organization for Nuclear Research (CERN) 2015
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Authors:
# - Tomas Javurek, Cedric Serfon, 2015

import os
import sys
import smtplib
import time
import requests
import calendar
from email.mime.text import MIMEText
from email.MIMEMultipart import MIMEMultipart
from email.MIMEBase import MIMEBase
from email import Encoders
# from rucio.common.config import config_get
from datetime import datetime
from datetime import date

from rucio.db.sqla.session import get_session
# from rucio.core import monitor

# Exit statuses
OK, WARNING, CRITICAL, UNKNOWN = 0, 1, 2, 3

users = True
groups = True
gdp = True
testmode = False
working_days = ['Wednesday']

timestamp = datetime.today().strftime('%Y-%m-%d')
log_dir = '/var/log/rucio/lost_files/logs/'
log_path = log_dir + timestamp + '.log'
tmpdata_dir = '/var/log/rucio/lost_files/tmp/'
tmpdata_path = tmpdata_dir + 'rse-lost-files.txt'
reports_dir = '/var/log/rucio/lost_files/reports/'


# protection against running this script every day
def run_judger(working_days):

    flog = open(log_path, 'a')
    today_date = date.today()
    today_day = calendar.day_name[today_date.weekday()]
    if today_day in working_days:
        flog.write('Today is %s.\n' % today_day)
        flog.write('I might try to work today.\n')
        return True
    else:
        flog.write('Today is %s. This is NOT my working day! I am working only on:\n' % today_day)
        flog.write(str(working_days) + '\n')
        return False


def merge_dicts(d1, d2):

    dm = d1.copy()
    for a in d2.keys():
        if a not in dm.keys():
            dm[a] = d2[a]
        else:
            dm[a] = list(set(dm[a] + d2[a]))
    return dm


# extracting mails of users from Rucio DB
def find_mails_users(account, session):

    mails = []
    try:
        query = ''' select distinct a.email from atlas_rucio.identities a, atlas_rucio.account_map b where
a.identity=b.identity and b.account='%s'  ''' % account
        result = session.execute(query)
        for row in result:
            for col in row:
                mails.append(str(col))
    except Exception, e:
        flog = open(log_path, 'a')
        flog.wirte('find_mails_users\n')
        flog.write(str(e) + '\n')
        sys.exit(CRITICAL)
    if account == 'ddmadmin' or account == 'root':
        mails = ['atlas-adc-ddm-support@cern.ch']
    if 'tomas.javurek@cern.ch' not in mails:
        mails.append('tomas.javurek@cern.ch')

    return mails


# hardcoded, TODO
def find_mails_gdp():

    mails = ['atlas-adc-dpa@cern.ch', 'tomas.javurek@cern.ch']
    return mails


# extracting mails of physgroups from Rucio DB
def find_mails_groups(rse, session):

    mails = []
    try:
        query = ''' select distinct email from atlas_rucio.identities where identity in
 (select identity from atlas_rucio.account_map where account in
 (select value from atlas_rucio.rse_attr_map where key = 'physgroup' and rse_id = atlas_rucio.rse2id('%s'))) ''' % rse
        result = session.execute(query)
        for row in result:
            for col in row:
                mails.append(str(col))
    except Exception, e:
        flog = open(log_path, 'a')
        flog.write('find_mails_groups\n')
        flog.write(str(e) + '\n')
        sys.exit(CRITICAL)

    if 'tomas.javurek@cern.ch' not in mails:
        mails.append('tomas.javurek@cern.ch')

    return mails


# find account for rule on given did
def get_rule_owners(scope, name, session):

    rule_owners = []
    try:
        query = ''' select distinct(account) from atlas_rucio.rules where scope='%s' and name='%s'  ''' % (scope, name)
        result = session.execute(query)
        for row in result:
            for col in row:
                rule_owners.append(str(col))
    except Exception, e:
        flog = open(log_path, 'a')
        flog.write('get_rule_owners:')
        flog.write(str(e) + '\n')

    if testmode:
        print 'DEBUG: ', scope, name
        print 'DEBUG: rule owners', rule_owners

    return rule_owners


# collects reports for given email
def report_collector(rse, account, session):

    mails_reports = {}
    mail_list = []
    report_path = ''
    if groups and rse != '':
        mail_list = find_mails_groups(rse, session)
        report_path = reports_dir + 'report_' + rse
    if users and account != '' and account != 'gdp':
        mail_list = find_mails_users(account, session)
        report_path = reports_dir + 'report_' + account
    if gdp and account == 'gdp':
        mail_list = find_mails_gdp()
        report_path = reports_dir + 'report_' + account
    if mail_list == [] or report_path == '' or report_path == 'report_':
        return

    for mail in mail_list:
        if mail not in mails_reports:
            mails_reports[mail] = [report_path]
        else:
            mails_reports[mail].append(report_path)
    return mails_reports


# mailing agent
def send_report(mail, report_paths):

    if testmode:
        print "DEBUG: mailing agent is accessed."
        print "DEBUG: ", mail, report_paths

    # defining mailing list
    me = 'atlas-adc-ddm-support@cern.ch'
    recepients = []

    if testmode:
        print 'DEBUG: notification would be send to:', mail
        recepients = ['tomas.javurek@cern.ch']
    else:
        recepients = [mail]

    msg = MIMEMultipart()
    msg['Subject'] = 'DDMops: completely lost files that may affect you - last 7 days'
    msg['From'] = me
    msg['To'] = ", ".join(recepients)

    # email body
    msg.attach(MIMEText('Please check the attached list of files that have been lost and can not be recovered. These files may affect you. In case of questions contact DDMops.' + "\n\n"))

    # if report is short,lost files are reported in email body as well
    lines = []
    for report_path in report_paths:
        fr = open(report_path, 'r')
        for lost_file in fr.readlines():
            lines.append(lost_file)
        fr.close()
        if len(lines) > 20:
            break
    if len(lines) < 21:
        for l in lines:
            msg.attach(MIMEText(str(l)))

    # attachments
    for report_path in report_paths:
        fr = open(report_path, 'rb')
        part = MIMEBase('application', "octet-stream")
        part.set_payload(fr.read())
        Encoders.encode_base64(part)
        part.add_header('Content-Disposition', 'attachment; filename="%s"' % report_path)
        msg.attach(part)

    # sending email, s=server
    flog = open(log_path, 'a')
    flog.write('Reports were sent to:\n')
    flog.write(mail)
    flog.write(str(report_paths))
    flog.write('\n\n')
    s = smtplib.SMTP('localhost')
    s.sendmail(me, recepients, msg.as_string())
    s.quit()


# create report for gdp
# call mailing agent
def report_gdp():

    # INIT
    if testmode:
        print "DEBUG: making report for GDP"
        print "||||||||||||||||||||||||||||"
    if not os.path.isfile(tmpdata_path):
        print "ERROR: lost files not downloaded"
        sys.exit(CRITICAL)

    cmd = 'cp %s %s' % (tmpdata_path, reports_dir + '/report_gdp')
    os.system(cmd)

    return ['gdp']


# make report by user
# call the mailing agent
def report_by_account(session):

    # INIT
    if testmode:
        print "DEBUG: making report by account"
        print "|||||||||||||||||||||||||||||||"
    if not os.path.isfile(tmpdata_path):
        print "ERROR: lost files not downloaded"
        sys.exit(CRITICAL)
    fi = open(tmpdata_path, 'r')
    data_per_account = {}
    accs = []

    # loop over lost files from get_bad_files()
    for line in fi.readlines():
        scope = line.split(' ')[0]
        data_name = line.split(' ')[1]
        dataset = line.split(' ')[3]
        rse_name = line.split(' ')[4]
        account = line.split(' ')[5]
        updated_at = line.split(' ')[6]
        accounts = []

        # find owners of rule, they are contacted as well
        if testmode:
            print "DEBUG: get rule owners"
        rule_owners = get_rule_owners(scope, dataset, session)
        # did_woners = get_did_owner TO BE DEVELOPED
        for own in rule_owners:
            if own not in accounts:
                accounts.append(own)
        if testmode:
            print 'INFO:', rse_name, account, dataset, data_name
            if accounts == []:
                print "DEBUG: there is no account to be notified."
            else:
                print "DEBUG: rule owners found:", accounts
            print '======================='

        for acc in accounts:
            if acc not in data_per_account.keys():
                data_per_account[acc] = [{'scope': scope, 'name': data_name, 'dataset': dataset, 'rse': rse_name, 'time': updated_at}]
            else:
                data_per_account[acc].append({'scope': scope, 'name': data_name, 'dataset': dataset, 'rse': rse_name, 'time': updated_at})

    if testmode:
        print "DEBUG: creating reports and sending."

    # create report per account
    for account in data_per_account.keys():
        fo = open(reports_dir + 'report_' + account, 'w')
        for bad_file in data_per_account[account]:
            fo.write("%s %s %s %s\n" % (bad_file['scope'], bad_file['dataset'], bad_file['name'], bad_file['time']))

    # send report by mail
    for account in data_per_account.keys():
        if testmode:
            print "DEBUG: going to send the report."
        accs.append(account)

    if testmode:
        if data_per_account == {}:
            print "DEBUG: nothing to send."

    if testmode:
        print "DEBUG: report by accnounts done."
    fi.close()
    return accs


# make report for each rse
# call the mailing agent
def report_by_rses(session):

    rses = []
    # INIT
    if not os.path.isfile(tmpdata_path):
        print "ERROR: lost files not downloaded"
        sys.exit(CRITICAL)
    fi = open(tmpdata_path, 'r')
    data_per_rse = {}

    # loop over lost files from get_bad_files()
    for line in fi.readlines():
        scope = line.split(' ')[0]
        data_name = line.split(' ')[1]
        dataset = line.split(' ')[3]
        rse_name = line.split(' ')[4]
        account = line.split(' ')[5]
        updated_at = line.split(' ')[6]

        if rse_name not in data_per_rse.keys():
            data_per_rse[rse_name] = [{'scope': scope, 'name': data_name, 'dataset': dataset, 'account': account, 'time': updated_at, 'rse': rse_name}]
        else:
            data_per_rse[rse_name].append({'scope': scope, 'name': data_name, 'dataset': dataset, 'account': account, 'time': updated_at, 'rse': rse_name})

    # create report per rse
    for rse in data_per_rse.keys():
        fo = open(reports_dir + 'report_' + rse, 'w')
        for bad_file in data_per_rse[rse]:
            fo.write("%s %s %s %s\n" % (bad_file['scope'], bad_file['dataset'], bad_file['name'], bad_file['time']))

    # send report by mail
    for rse in data_per_rse.keys():
        rses.append(rse)

    fi.close()
    return rses


# the input
def get_bad_files(session):

    f = open(tmpdata_path, 'w')
    try:
        query = ''' select a.scope, a.name, b.scope, b.name, atlas_rucio.id2rse(a.rse_id), a.account, a.updated_at from atlas_rucio.bad_replicas a, atlas_rucio.contents_history b
 where a.state='L' and a.updated_at>sysdate-7 and b.did_type='D'and a.scope=b.child_scope and a.name=b.child_name '''

        result = session.execute(query)
        processed_files = []
        for row in result:
            if row[3].startswith('panda.'):
                continue
            if '_sub' in row[3]:
                continue
            f_did = row[0] + ':' + row[1]
            if f_did in processed_files:
                print 'double counted:', f_did
                continue
            else:
                processed_files.append(f_did)
            for col in row:
                f.write('%s ' % col)
            f.write('\n')

    except Exception, e:
        flog = open(log_path, 'a')
        flog.write('get_bad_files\n')
        flog.write(str(e) + "\n")
        return False

    return True


def get_bad_files_from_dump(session):

    flog = open(log_path, 'a')
    url = 'https://rucio-hadoop.cern.ch/lost_files'
    dump7 = requests.get(url, verify=False)
    if dump7.status_code == 404:
        flog.write('ERROR: dump of bad files not reachable on hadoop')
        return False

    f = open(tmpdata_path, 'w')
    line_counter = 0
    processed_files = []
    for l in dump7.text.split('\n'):
        line_counter += 1
        data = l.split('\t')
        if len(data) < 7:
            flog.write('WARNING: line %i in dump does not contain full info \n' % line_counter)
            continue
        if data[3].startswith('panda.'):
            continue
        if '_sub' in data[3]:
            continue
        f_did = data[0] + ':' + data[1]
        if f_did in processed_files:
            print 'double counted:', f_did
            continue
        else:
            processed_files.append(f_did)
        updated_at = time.strftime('%Y-%m-%d', time.localtime(float(data[6])))
        f.write('%s %s %s %s %s %s %s\n' % (data[0], data[1], data[2], data[3], data[4], data[5], updated_at))
    flog.write('INFO: dump contains %i lines\n' % line_counter)

    return True


def main():

    run_flag = run_judger(working_days)
    if not run_flag:
        sys.exit(OK)

    session = get_session()

    # check folder hierarchy
    if not os.path.exists(log_dir):
        sys.exit(CRITICAL)
    if not os.path.exists(tmpdata_dir):
        sys.exit(CRITICAL)
    if not os.path.exists(reports_dir):
        sys.exit(CRITICAL)

    mails = {}
    # get input
    get_input = get_bad_files_from_dump(session)
    if not get_input:
        print 'WARNING: the dump is not accessible'
        get_input = get_bad_files(session)
        if not get_input:
            sys.exit(CRITICAL)
    # make and sent report to groups
    if groups:
        l_rses = report_by_rses(session)
        for rse in l_rses:
            reps = report_collector(rse, '', session)
            mails = merge_dicts(mails, reps)
    # make and sent report to users
    if users:
        l_acc = report_by_account(session)
        for acc in l_acc:
            reps = report_collector('', acc, session)
            mails = merge_dicts(mails, reps)
    if gdp:
        if testmode:
            print "DEBUG: summary report to gdp"
        l_acc = report_gdp()
        for acc in l_acc:
            reps = report_collector('', acc, session)
            mails = merge_dicts(mails, reps)

    if len(list(set(mails.keys()))) != len(mails.keys()):
        print "ERROR: list of emails is not distinct"
        sys.exit('ERROR: list of emails is not distinct')

    if testmode:
        send_report('tomas.javurek@cern.ch', mails['tomas.javurek@cern.ch'])
        for m in mails.keys():
            flog = open(log_path, 'a')
            flog.write(m)
            flog.write(str(mails[m]))
            flog.write('\n')
    else:
        for m in mails.keys():
            send_report(m, mails[m])

    # clean tmp
    cmd = 'rm ' + tmpdata_path
    os.system(cmd)

    sys.exit(OK)


if __name__ == '__main__':

    main()
