#! /usr/bin/env python
import logging as log
import argparse
import pyxnat
import pandas as pd
import pydicom
import json
import os.path as op

class readable_dir(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        prospective_dir=values
        if not os.path.isdir(prospective_dir):
            msg = "readable_dir:{0} is not a valid path".format(prospective_dir)
            raise argparse.ArgumentTypeError(msg)
        if os.access(prospective_dir, os.R_OK):
            setattr(namespace,self.dest,prospective_dir)
        else:
            msg = "readable_dir:{0} is not a readable dir".format(prospective_dir)
            raise argparse.ArgumentTypeError(msg)

def check_xnat_item(a, x):
    projects = [e.label() for e in list(x.select.projects())]
    experiments = []
    for p in projects:
        exp = x.array.experiments(project_id=p).data
        experiments.extend([e['ID'] for e in exp])

    if a in projects:
        return 0
    elif a in experiments:
        return 1
    else:
        return -1

def get_scandate(experiment_id, x, t1_scan_label='T1_ALFA1'):

    columns = ['xsiType', 'xnat:imagescandata/type', 'xnat:imagescandata/ID']
    scans = x.array.scans(experiment_id=experiment_id, columns=columns).data
    t1_scans = {e['xnat:imagescandata/id']:e for e in scans \
        if e['xnat:imagescandata/type'] == t1_scan_label}

    if len(t1_scans.items()) == 0:
        msg = 'No T1 found for %s: %s. Trying with all of them.'\
            %(experiment_id, [e['xnat:imagescandata/id'] for e in scans])
        log.warning(msg)
        t1_scans = {e['xnat:imagescandata/id']:e for e in scans \
            if not e['xnat:imagescandata/id'].startswith('OT-')\
            and not e['xnat:imagescandata/id'].startswith('O-')}


    max_nb = sorted(t1_scans.keys())[-1]
    scan = x.select.experiment(experiment_id).scan(max_nb)

    f = list(scan.resource('DICOM').files())[0]
    f.get(dest='/tmp/test.dcm')
    d = pydicom.read_file('/tmp/test.dcm')

    if hasattr(d, 'AcquisitionDate'):
        acquisition_date = d.AcquisitionDate
    else:
        acquisition_date = d.AcquisitionDateTime[:8]
    return acquisition_date

def collect_mrscandates(x, project_id=None, experiment_id=None, max_rows=None):
    if project_id is None and experiment_id is None:
        log.error('project_id and experiment_id cannot be both None')
    elif not project_id is None and not experiment_id is None:
        log.error('project_id and experiment_id cannot be provided both')
    elif not experiment_id is None:
        res = get_scandate(experiment_id, x)
        print(res)

    elif not project_id is None:
        data = []
        columns = ['label', 'subject_ID', 'subject_label']
        for e in x.array.experiments(project_id=project_id, columns=columns).data[:15]:
            try:
                row = [e['ID'], e['label'], e['subject_label']]
                d = get_scandate(e['ID'], x)
                row.append(d)
                data.append(row)
            except Exception as exc:
                log.error('Failed with %s. Skipping it.'%e['ID'])
                raise exc
        df = pd.DataFrame(data, columns=('ID', 'label', 'subject_label', 'scandate'))
        df['scandate'] = pd.to_datetime(df['scandate'])
        df = df.set_index('ID').sort_index()
        return df



def download_spm12(x, project_id=None, experiment_id=None, destdir='/tmp/', max_rows=None):
    if project_id is None and experiment_id is None:
        log.error('project_id and experiment_id cannot be both None')
    elif not project_id is None and not experiment_id is None:
        log.error('project_id and experiment_id cannot be provided both')
    else:
        experiments = []
        if not experiment_id is None:
            experiments = [experiment_id]

        if not project_id is None:
            experiments = []
            for e in x.array.experiments(project_id=project_id, columns=['label']).data[:15]:
                experiments.append(e['ID'])
        for e in experiments:
            r = x.select.experiment(e).resource('SPM12_SEGMENT')
            if not r.exists():
                log.error('%s has no SPM12_SEGMENT resource'%e)
                continue
            r.get(dest_dir=destdir)


def spm12_volumes(x, project_id=None, experiment_id=None, max_rows=None):
    import nibabel as nib
    import numpy as np
    if project_id is None and experiment_id is None:
        log.error('project_id and experiment_id cannot be both None')
    elif not project_id is None and not experiment_id is None:
        log.error('project_id and experiment_id cannot be provided both')
    else:
        experiments = []
        if not experiment_id is None:
            experiments = [experiment_id]

        if not project_id is None:
            experiments = []
            for e in x.array.experiments(project_id=project_id, columns=['label']).data[:15]:
                experiments.append(e['ID'])
        table = []
        for e in experiments:
            r = x.select.experiment(e).resource('SPM12_SEGMENT')
            if not r.exists():
                log.error('%s has no SPM12_SEGMENT resource'%e)
                continue
            vols = [e]
            for kls in ['c1', 'c2', 'c3']:
                f = [each for each in r.files() if each.id().startswith(kls)][0]
                fp = '/tmp/temp.nii.gz'
                f.get(fp)
                d = nib.load(fp)
                size = np.prod(d.header['pixdim'].tolist()[:4])
                v = np.sum(d.dataobj) * size
                vols.append(v)
            table.append(vols)
        df = pd.DataFrame(table, columns=['ID', 'c1', 'c2', 'c3']).set_index('ID').sort_index()
        return df



def parse_args(command, args, x, destdir='/tmp/', test=False):
    commands = ['nifti', 'mrscandates', 'freesurfer', 'spm12']
    if command not in commands:
        msg = '%s not found (valid commands: %s)'%(command, commands)
        log.info(msg)
        raise Exception(msg)

    if command == 'mrscandates':
        if len(args) == 0:
            msg = 'display help message for %s'%command
            print(msg)
        elif len(args) == 1:
            a = args[0] #should be a project or an experiment_id
            t = check_xnat_item(a, x)
            if t == 0:
                if not test:
                    df = collect_mrscandates(x, project_id=a)
                    if destdir == None:
                        destdir = '/tmp'
                    from datetime import datetime
                    dt = datetime.today().strftime('%Y%m%d')
                    fn = 'bx_%s_%s.xls'%(a, dt)
                    fp = op.join(destdir, fn)
                    log.info('Saving it in %s'%fp)
                    df.to_excel(fp)

            elif t == 1:
                if not test:
                    collect_mrscandates(x, experiment_id=a)

    elif command == 'freesurfer':
        if len(args) == 0:
            msg = 'display help message for %s'%command
            print(msg)
        elif len(args) == 1:
            # error: missing arguments (at least a project)
            msg = 'missing argument(s)'
            print(msg)
        elif len(args) == 2:
            subcommand = args[0]
            a = args[1] #should be a project or an experiment_id
            print(a)
            t = check_xnat_item(a, x)
            if subcommand in ['thickness', 'aparc']:
                pass
            elif subcommand == 'aseg':
                pass
            elif subcommand == 'hippoSfVolumes':
                pass


    elif command == 'spm12':
        if len(args) == 0:
            msg = 'display help message for %s'%command
            print(msg)
        elif len(args) == 1:
            # error: missing arguments (at least a project)
            msg = 'missing argument(s)'
            print(msg)
        elif len(args) == 2:
            subcommand = args[0]
            a = args[1] #should be a project or an experiment_id
            print(a)
            t = check_xnat_item(a, x)
            if subcommand == 'maps':
                if t == 0:
                    download_spm12(x, project_id=a, destdir=destdir)
                elif t == 1:
                    download_spm12(x, experiment_id=a, destdir=destdir)
            elif subcommand == 'volumes':
                if t == 0:
                    spm12_volumes(x, project_id=a)
                elif t == 1:
                    spm12_volumes(x, experiment_id=a)

    elif command == 'nifti':
        if len(args) == 0:
            msg = 'display help message for %s'%command
            print(msg)
        elif len(args) == 1:
            # error: missing arguments (at least a project)
            msg = 'missing argument(s)'
            print(msg)
        elif len(args) == 2:
            _type = args[0]
            a = args[1] #should be a project or an experiment_id
            print(a)
            t = check_xnat_item(a, x)



def create_parser():
    import argparse
    parser = argparse.ArgumentParser(description='bx')
    parser.add_argument('command', help='BX command')
    parser.add_argument('args', help='BX command', nargs="*")
    parser.add_argument('--config', help='XNAT configuration file',
        required=False, type=argparse.FileType('r'), default='~/.xnat.cfg')
    parser.add_argument('--verbose', '-V', action='store_true', default=False,
        help='Display verbosal information (optional)', required=False)
    return parser

if __name__=="__main__" :
    parser = create_parser()
    args = parser.parse_args()
    if args.verbose:
        log.basicConfig(level=log.INFO)
    else:
        log.basicConfig(level=log.ERROR)
    dd = json.load(open(args.config.name)).get('destination', None)
    x = pyxnat.Interface(config=args.config.name)
    parse_args(args.command, args.args, x, dd)
    #mrsession_validate(args)
