#!/usr/bin/env python
"""
Downloads or uploads XNAT files.
"""
import sys
import os
import re
import argparse
from collections import defaultdict
import qixnat
from qixnat import command
from qixnat.helpers import path_hierarchy


class ArgumentError(Exception):
    pass


class SessionNotFoundError(Exception):
    pass


def main(argv=sys.argv):
    # Parse the command line arguments.
    paths, opts = _parse_arguments()
    # The XNAT configuration.
    config = opts.pop('config', None)
    # Configure the logger.
    command.configure_log(**opts)

    # Determine whether the copy is an upload or download.
    xnat_prefix = 'xnat:'
    prefixed = [path.startswith(xnat_prefix) for path in paths]
    if all(prefixed) or not prefixed[-1]:
        direction = 'down'
    else:
        # There must be a source.
        if len(paths) == 1:
            raise ArgumentError("There must be at least one XNAT download"
                                " source.")
        direction = 'up'
    
    # Determine the sources and destination.
    if direction == 'up':
        # Upload:
        # Validate that only the target has a xnat: prefix.
        if any(prefixed[:-1]):
            raise ArgumentError("Upload sources cannot have a xnat: prefix")
        # The last path argument is the XNAT destination.
        # The remaining path arguments are the file sources. 
        sources = paths[:-1]
        # The target XNAT path is the last path argument, with the
        # leading xnat: prefix removed and the trailing slash removed,
        # if necessary.
        dest = paths[-1][len(xnat_prefix):].rstrip('/')
        # Remove the XNAT path '/files' destination suffix, if necessary.
        if dest.endswith('/files'):
            dest = dest[:-len('/files')]
    else:
        # Download:
        # Validate that only the sources have a xnat: prefix.
        if not all(prefixed[:-1]):
            raise ArgumentError("Download sources must have a xnat: prefix")
        # The last argument can have a xnat: prefix, in which case the
        # target location is the current directory. In that case, the
        # paths with a xnat: prefix are all path arguments. Otherwise,
        # the prefixed paths are all but the last path arguments.
        prefixed_paths = paths if prefixed[-1] else paths[:-1]
        # Remove each source leading prefix and trailing slash.
        source_args = [path[len(xnat_prefix):].rstrip('/')
                   for path in prefixed_paths]
        # Append a files suffix to the XNAT sources, if necessary.
        fix = lambda s: s if re.search(r"/file(s|/\w+)$", s) else s + '/files'
        # The XNAT sources.
        sources = [fix(src) for src in source_args]
        # If the last path argument is an XNAT directory, then
        # the destination is the current directory, otherwise
        # the destination is the last path argument.
        dest = '.' if prefixed[-1] else paths[-1]

    # Copy the files.
    with qixnat.connect(config) as xnat:
        if direction is 'up':
            # The path starts with /project/subject/experiment.
            prefix_types = ['project', 'subject', 'experiment']
            # Infer the XNAT hierarchy from the target XNAT path.
            hierarchy = path_hierarchy(dest)
            # The /project/subject/experiment prefix.
            prefix = hierarchy[:3]
            # Validate that each prefix is of the correct type.
            if [pair[0] for pair in prefix] != prefix_types:
                raise ArgumentError("Upload XNAT target must start with a"
                                    " /project/subject/session: %s" % dest)
            # Extract the prefix names.
            prj, sbj, sess = ((pair[1] for pair in prefix))
            # Get the XNAT session.
            sess_obj = xnat.get_session(prj, sbj, sess)
            if not opts.get('modality') and not xnat.exists(sess_obj):
                raise SessionNotFoundError("The --modality option is required"
                                           " to create the XNAT session: %s %s"
                                           " %s" % (prj, sbj, sess))
            # The remaining (type, value) pairs.
            upload_opts = dict(hierarchy[3:])
            # Add in the general command options, e.g. force.
            upload_opts.update(opts)
            # Upload the files.
            xnat.upload(prj, sbj, sess, *sources, **upload_opts)
        else:
            # The download source XNAT file objects.
            src_obj_lists = [xnat.expand_path(src) for src in sources]
            src_objs = reduce(lambda x,y: x + y, src_obj_lists)
            # Download the files.
            for src_obj in src_objs:
                xnat.download_file(src_obj, dest, **opts)

    return 0


def _parse_arguments():
    """
    Parses the command line arguments.
    
    :return the (source, destination, options) tuple
    """
    parser = argparse.ArgumentParser()
    # The common XNAT options.
    command.add_options(parser)

    # The --force and --skip-existing options are exclusive.
    existing_opts = parser.add_mutually_exclusive_group()
    existing_opts.add_argument('-f', '--force', action='store_true',
                               help='overwrite existing target file')
    existing_opts.add_argument('-s', '--skip-existing', action='store_true',
                               help="don't copy if the target file exists")
    
    # The scan modality option.
    parser.add_argument('-m', '--modality', help="the scan modality, e.g. MR")

    # The source file(s) or XNAT hierarchy path.
    parser.add_argument('paths', nargs='+', metavar="PATH",
                        help='the file(s) or xnat:/project/subject/... object path(s)')

    # Parse all arguments.
    args = vars(parser.parse_args())
    # Filter out the empty arguments.
    nonempty_args = dict((k, v) for k, v in args.iteritems() if v != None)
    # The source(s) and destination.
    paths = nonempty_args.pop('paths')

    # Return the paths and options.
    return paths, nonempty_args


if __name__ == '__main__':
    sys.exit(main())
