#!/usr/bin/env python

# Copyright 2016 NeuroData (http://neurodata.io)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# ndmg_bids.py
# Created by Greg Kiar on 2016-07-25.
# Email: gkiar@jhu.edu

from argparse import ArgumentParser
from subprocess import Popen, PIPE
from os.path import expanduser
from ndmg.scripts.ndmg_setup import get_files
import ndmg.utils as mgu
from ndmg.scripts.ndmg_pipeline import ndmg_pipeline
import os.path as op

import os
import sys


# Define organizational variables
subj_tag = 'sub-'
sesh_tag = 'ses-'
anat_tag = 'anat'
dwi_tag = 'dwi'
bval_ext = '.bval'
bvec_ext = '.bvec'
imgs_ext = ('*.nii', '*.nii.gz')

atlas_dir = op.join(expanduser("~"), 'atlases')
atlas = op.join(atlas_dir, 'atlas/MNI152_T1_1mm.nii.gz')
atlas_mask = op.join(atlas_dir, 'atlas/MNI152_T1_1mm_brain_mask.nii.gz')
labels = ['labels/AAL.nii.gz',
          'labels/desikan.nii.gz',
          'labels/HarvardOxford.nii.gz',
          'labels/CPAC200.nii.gz',
          'labels/Talairach.nii.gz',
          'labels/JHU.nii.gz',
          'labels/slab907.nii.gz',
          'labels/slab1068.nii.gz',
          'labels/DS00071.nii.gz',
          'labels/DS00096.nii.gz',
          'labels/DS00108.nii.gz',
          'labels/DS00140.nii.gz',
          'labels/DS00195.nii.gz',
          'labels/DS00278.nii.gz',
          'labels/DS00350.nii.gz',
          'labels/DS00446.nii.gz',
          'labels/DS00583.nii.gz',
          'labels/DS00833.nii.gz',
          'labels/DS01216.nii.gz',
          'labels/DS01876.nii.gz',
          'labels/DS03231.nii.gz',
          'labels/DS06481.nii.gz',
          'labels/DS16784.nii.gz',
          'labels/DS72784.nii.gz']
labels = [op.join(atlas_dir, l) for l in labels]

# Data structure:
# sub-<subject id>/
#   ses-<session id>/
#     anat/
#       sub-<subject id>_ses-<session id>_T1w.nii.gz
#     dwi/
#       sub-<subject id>_ses-<session id>_dwi.nii.gz
#   *   sub-<subject id>_ses-<session id>_dwi.bval
#   *   sub-<subject id>_ses-<session id>_dwi.bvec
#
# *these files can be anywhere up stream of the dwi data, and are inherited.


def driver(inDir, outDir, subj, sesh):
    if inDir is None or outDir is None or subj is None:
        sys.exit("Error: Missing input, output directory or subject id.\
                 \n Try 'ndmg_bids -h' for help")

    subj_dir_contents = next(os.walk(op.join(inDir, subj_tag + subj)))[1]
    loc = '%s/%s%s/' % (inDir, subj_tag, subj)
    if any(sesh_tag in item for item in subj_dir_contents):
        if sesh is None:
            sys.exit("Error: Multiple sessions found, no session provided.\
                     \n Try 'ndmg_bids -h' for help")
        elif (sesh_tag + sesh) not in subj_dir_contents:
            sys.exit("Error: Session ID provided does not match those found.\
                     \nSessions found: " + ", ".join(subj_dir_contents) +\
                     "\n Try 'ndmg_bids -h' for help")
        loc = '%s%s%s/' % (loc, sesh_tag, sesh)

    anat =[x[0] for x in os.walk(loc) if anat_tag in x[0]]
    anat_file = get_files(imgs_ext, anat[0])[0]

    dwi = [x[0] for x in os.walk(loc) if dwi_tag in x[0]]
    dwi_file = get_files(imgs_ext, dwi[0])[0]

    bvec_file = []
    bval_file = []
    dwi_step = dwi[0]
    while not bval_file or not bvec_file:
        bval_file = [op.join(dwi_step, f) for f in os.listdir(dwi_step)
                     if f.endswith(bval_ext)][0]
        bvec_file = [op.join(dwi_step, f) for f in os.listdir(dwi_step)
                     if f.endswith(bvec_ext)][0]
        if dwi_step is op.abspath(op.join(inDir, os.pardir)):
            sys.exit("Error: No b-values or b-vectors found for this subject.\
                     \nPlease review BIDS spec (bids.neuroimaging.io).")
        dwi_step = op.abspath(op.join(dwi_step, os.pardir))

    ope = op.exists
    if any(not ope(l) for l in labels) or not (ope(atlas) and ope(atlas_mask)):
        print("Cannot find atlas information; downloading...")
        mgu().execute_cmd('mkdir -p ' + atlas_dir) 
        cmd = " ".join(['wget -rnH --cut-dirs=3 --no-parent -P ' + atlas_dir,
                        'http://openconnecto.me/mrdata/share/atlases/'])
        mgu().execute_cmd(cmd)

    print("T1 file: " + anat_file)
    print("DWI file: " + dwi_file)
    print("Bval file: " + bval_file)
    print("Bvec file: " + bvec_file)
    print("Atlas file: " + atlas)
    print("Mask file: " + atlas_mask)
    print("Label file(s): " + ", ".join(labels))


    mgu().execute_cmd("mkdir -p " + outDir + " " + outDir + "/tmp")
    ndmg_pipeline(dwi_file, bval_file, bvec_file, anat_file,
                  atlas, atlas_mask, labels, outDir, clean=False)


def main():
    parser = ArgumentParser(description="This is an end-to-end connectome \
                            estimation pipeline from sMRI and DTI images")
    parser.add_argument("-i", "--bids-dir", action="store",
                        help="Base directory for input data")
    parser.add_argument("-o", "--output-dir", action="store",
                        help="Base directory to store derivaties")
    parser.add_argument("-p", "--participant-label", action="store",
                        help="Subject ID to be analyzed")
    parser.add_argument("-s", "--session-label", action="store",
                        help="Session ID to be analyzed (if multiple exist)")
    result = parser.parse_args()

    inDir = result.bids_dir
    outDir = result.output_dir
    subj = result.participant_label
    sesh = result.session_label

    driver(inDir, outDir, subj, sesh)

if __name__ == "__main__":
    main()
