#!/usr/bin/env python3

from DIRAC.Interfaces.API.Dirac import Dirac
from DIRAC.Interfaces.API.Job   import Job
from DIRAC                      import initialize

import os
import argparse

from logzero import logger as log
from tqdm    import trange

#---------------------------------------
class data:
    njobs = None
    nfits = None
    mode  = None
#---------------------------------------
def get_banned_sites():
    l_site = [
            'LCG.ECDF.uk',
            'LCG.Manchester.uk', 
            'LCG.NIPNE-07.ro', 
            'LCG.Krakow.pl', 
            'LCG.PNPI.ru', 
            'LCG.MIT.us', 
            'LCG.UKI-LT2-IC-HEP.uk', 
            'LCG.USC.es', 
            'LCG.JINR.ru'
            ]

    return l_site
#---------------------------------------
def get_job(jobid):
    seeds_file = f'{os.getcwd()}/seeds/{jobid}.sd'
    if not os.path.isfile(seeds_file):
        log.error(f'Cannot find: {seeds_file}')
        raise FileNotFoundError

    j = Job()
    j.setCPUTime(36000)
    j.setBannedSites(get_banned_sites())

    #For tests
    #j.setExecutable('/usr/bin/touch for_tests.out')
    #j.setOutputSandbox(['for_tests.out'])

    shell_path = os.path.join(sys.prefix, 'rk_extractor', 'run_toys.sh')
    pytho_path = os.path.join(sys.prefix, 'rk_extractor', 'run_toys.py')

    #For real jobs
    j.setExecutable(shell_path)
    j.setInputSandbox([seeds_file, pytho_path])
    j.setOutputSandbox(['output.json', 'fits.tar'])

    j.setName(f'job_{jobid:03}')

    return j
#---------------------------------------
def make_seeds():
    log.info(f'Making seeds')
    os.makedirs('seeds', exist_ok=True)
    for ijob in range(data.njobs):
        log.debug(f'Writting seeds/{ijob}.sd')
        ofile = open(f'seeds/{ijob}.sd', 'w')
        for ifit in range(1000 * ijob , data.nfits + 1000 * ijob):
            ofile.write(f'{ifit}\n')
        ofile.close()
#---------------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to send toy fit jobs to the grid')
    parser.add_argument('-j', '--njobs' , type=int, help='Number of grid jobs', required=True)
    parser.add_argument('-f', '--nfits' , type=int, help='Number of fits per job', required=True)
    parser.add_argument('-m', '--mode'  , type=str, help='Run locally or in the grid', choices=['local', 'wms'], required=True)
    args = parser.parse_args()

    data.njobs = args.njobs
    data.nfits = args.nfits
    data.mode  = args.mode
#---------------------------------------
def main():
    initialize()
    dirac = Dirac()

    get_args()
    make_seeds()

    l_jobid = []
    for jobid in trange(data.njobs):
        job    = get_job(jobid)
        d_info = dirac.submitJob(job, mode=data.mode)
        jobid  = d_info['JobID']
        l_jobid.append(jobid)

    with open('jobids.out', 'w') as ofile:
        for jobid in l_jobid:
            ofile.write(f'{jobid}\n')
#---------------------------------------
if __name__ == '__main__':
    main()

