#!/usr/bin/env python3
"""
DICOM to json converter for sequence types in common use at CRIC Bristol.
The purpose is to provide plain text acces to DICOM tags that are relevent in
fmri, dti, asl etc analysis to facilitate scripting in bash rather than python.

R. Hartley-Davies 2016

As dcm2niix now produces a BIDS style json sidecar file, we need to be
consistent with that.
R. Hartley-Davies 2018
"""

from pydicom import dcmread

from dcmextras.siemenscsa import csa
from dcmextras.siemensphoenix import phoenix

__version__ = '0.2'

#
# TODO: Need to get lots of example images including ones with wip parameters.
#


# Access routines for phoenix fields including 'vectors'
def get_phoenix_int(protocol, tag):
    # NB: missing field implies could be zero (WTF)
    if tag in protocol:
        return int(protocol[tag])
    else:
        return 0


def get_phoenix_bool(protocol, tag):
    # NB: missing field implies could be False (WTF)
    if tag in protocol:
        return bool(protocol[tag])
    else:
        return False


def get_phoenix_float(protocol, tag):
    return float(protocol[tag])


def get_phoenix_str(protocol, tag):
    return str(protocol[tag])


def get_phoenix_string_vector(protocol, pattern):
    # hard limit to 256 for now
    tags = [pattern.replace('*', '%d' % i, 1) for i in range(256)]
    # Each ascii character is a string of the form "0xNN   # C" (WTF)
    chars = [
        chr(int(protocol[tag].split(' ')[0], 16))
        for tag in tags if tag in protocol
    ]
    return ''.join(chars).strip()


def get_phoenix_numeric_vector(protocol, pattern):
    tags = [pattern.replace('*', '%d' % i, 1) for i in range(256)]
    return [float(protocol[tag]) for tag in tags if tag in protocol]


class Baseline(object):
    __standard_tags = {
        'PatientID':                     str,
        'StudyID':                       str,
        'SeriesNumber':                  int,
        'SequenceName':                  str,
        'SequenceVariant':
            lambda item_:
                list(map(str, item_)) if hasattr(item_, '__iter__') else str(item_),
        'ScanningSequence':              str,
        'ImageType':                     lambda list_: list(map(str, list_)),
        'Manufacturer':                  str,
        'RepetitionTime':                float,
        'EchoTime':                      float,
        'NumberOfPhaseEncodingSteps':    int,
        'EchoTrainLength':               int,
        'PercentSampling':               float,
        'PercentPhaseFieldOfView':       float,
        'AcquisitionMatrix':             lambda list_: list(map(int, list_)),
        'InPlanePhaseEncodingDirection': str,
        'SliceLocation':                 float,
        'Rows':                          int,
        'Columns':                       int,
    }

    __siemens_tags = {
        'SliceMeasurementDuration':     [(0x019, 0x0b), float],
        'RealDwellTime':                [(0x019, 0x18), int],
    }

    __csa_image_tags = {
        'EchoLinePosition':               ('EchoLinePosition', int),
        'EchoColumnPosition':             ('EchoColumnPosition', int),
        'EchoPartitionPosition':          ('EchoPartitionPosition', int),
        'PhaseEncodingDirectionPositive': ('PhaseEncodingDirectionPositive', int),
    }

    __csa_series_tags = {}
    __phoenix_tags = {
        'AccelerationPE': ('sPat.lAccelFactPE', get_phoenix_int),
        'Acceleration3D': ('sPat.lAccelFact3D', get_phoenix_int),
    }
    __wip_tags = {}

    def __init__(self):
        super(Baseline, self).__init__()
        self.standard_tags   = {}
        self.siemens_tags    = {}
        self.csa_image_tags  = {}
        self.csa_series_tags = {}
        self.phoenix_tags    = {}
        self.wip_tags        = {}

        # TODO: Ideally, this block would be a method in the base class called
        # from the __init__ of the derived classes but unfortunately if we do
        # that we get the baseclass __standard_tags etc rather than the
        # derived class ones.
        self.standard_tags.update(self.__standard_tags)
        self.siemens_tags.update(self.__siemens_tags)
        self.csa_image_tags.update(self.__csa_image_tags)
        self.csa_series_tags.update(self.__csa_series_tags)
        self.phoenix_tags.update(self.__phoenix_tags)
        self.wip_tags.update(self.__wip_tags)


class EPI(Baseline):
    __standard_tags = {}

    __siemens_tags = {
        'NumberOfImagesInMosaic':       [(0x019, 0x0a), int],
        'SliceMeasurementDuration':     [(0x019, 0x0b), float],
        'RealDwellTime':                [(0x019, 0x18), int],
        'BandwidthPerPixelPhaseEncode': [(0x019, 0x28), float],
        'MosaicRefAcqTimes':            [(0x019, 0x29), lambda list_: list(map(float, list_))],
    }

    __csa_image_tags = {}

    __csa_series_tags = {
        'RFEchoTrainLength': ('RFEchoTrainLength', int),
        'GradientEchoTrainLength': ('GradientEchoTrainLength', int),
    }
    __phoenix_tags = {}
    __wip_tags = {}

    def __init__(self):
        super(EPI, self).__init__()
        self.standard_tags.update(self.__standard_tags)
        self.siemens_tags.update(self.__siemens_tags)
        self.csa_image_tags.update(self.__csa_image_tags)
        self.csa_series_tags.update(self.__csa_series_tags)
        self.phoenix_tags.update(self.__phoenix_tags)
        self.wip_tags.update(self.__wip_tags)


class DTI(EPI):
    __standard_tags = {}

    __siemens_tags = {
        'B_value':                      [(0x019, 0x0c), int],
        'DiffusionDirectionality':      [(0x019, 0x0d), str],
        'DiffusionGradientDirection':   [(0x019, 0x0e), lambda list_: list(map(float, list_))],
    }

    __csa_image_tags = {}
    __csa_series_tags = {}
    __phoenix_tags = {
        'DiffusionWeightings': ('sDiffusion.lDiffWeightings', get_phoenix_int),
        'FreeDiffusionFile': ('sDiffusion.sFreeDiffusionData.sComment.*', get_phoenix_string_vector),
        'FreeDiffusionNumberOfDirections': ('sDiffusion.sFreeDiffusionData.lDiffDirections', get_phoenix_int),
        'FreeDiffusionCoordinateSystem': ('sDiffusion.sFreeDiffusionData.ulCoordinateSystem', get_phoenix_int),
        'FreeDiffusionNormalization': ('sDiffusion.sFreeDiffusionData.ulNormalization', get_phoenix_int),
        'FreeDiffusionSagittalComponents': ('sDiffusion.sFreeDiffusionData.asDiffDirVector[*].dSag', get_phoenix_numeric_vector),
        'FreeDiffusionTransverseComponents': ('sDiffusion.sFreeDiffusionData.asDiffDirVector[*].dTra', get_phoenix_numeric_vector),
        'FreeDiffusionCoronalComponents': ('sDiffusion.sFreeDiffusionData.asDiffDirVector[*].dCor', get_phoenix_numeric_vector),
    }
    __wip_tags = {}

    def __init__(self):
        super(DTI, self).__init__()
        self.standard_tags.update(self.__standard_tags)
        self.siemens_tags.update(self.__siemens_tags)
        self.csa_image_tags.update(self.__csa_image_tags)
        self.csa_series_tags.update(self.__csa_series_tags)
        self.phoenix_tags.update(self.__phoenix_tags)
        self.wip_tags.update(self.__wip_tags)


class BOLD(EPI):
    __standard_tags = {}
    __siemens_tags = {}
    __csa_image_tags = {}
    __csa_series_tags = {}
    __phoenix_tags = {}
    __wip_tags = {}

    def __init__(self):
        super(BOLD, self).__init__()
        self.standard_tags.update(self.__standard_tags)
        self.siemens_tags.update(self.__siemens_tags)
        self.csa_image_tags.update(self.__csa_image_tags)
        self.csa_series_tags.update(self.__csa_series_tags)
        self.phoenix_tags.update(self.__phoenix_tags)
        self.wip_tags.update(self.__wip_tags)


class MB(Baseline):
    # need to work out where multiband parameters are stuffed
    __standard_tags = {}
    __siemens_tags = {}
    __csa_image_tags = {}
    __csa_series_tags = {}
    __phoenix_tags = {}
    __wip_tags = {
        'VersionString': (None,  str)
    }

    def __init__(self):
        super(MB, self).__init__()
        self.standard_tags.update(self.__standard_tags)
        self.siemens_tags.update(self.__siemens_tags)
        self.csa_image_tags.update(self.__csa_image_tags)
        self.csa_series_tags.update(self.__csa_series_tags)
        self.phoenix_tags.update(self.__phoenix_tags)
        self.wip_tags.update(self.__wip_tags)


class BOLDMB(BOLD, MB):
    __standard_tags = {}
    __siemens_tags = {}
    __csa_image_tags = {}
    __csa_series_tags = {}
    __phoenix_tags = {}
    __wip_tags = {}

    def __init__(self):
        super(BOLDMB, self).__init__()
        self.standard_tags.update(self.__standard_tags)
        self.siemens_tags.update(self.__siemens_tags)
        self.csa_image_tags.update(self.__csa_image_tags)
        self.csa_series_tags.update(self.__csa_series_tags)
        self.phoenix_tags.update(self.__phoenix_tags)
        self.wip_tags.update(self.__wip_tags)


class DTIMB(DTI, MB):
    __standard_tags = {}
    __siemens_tags = {}
    __csa_image_tags = {}
    __csa_series_tags = {}
    __phoenix_tags = {}
    __wip_tags = {}

    def __init__(self):
        super(DTIMB, self).__init__()
        self.standard_tags.update(self.__standard_tags)
        self.siemens_tags.update(self.__siemens_tags)
        self.csa_image_tags.update(self.__csa_image_tags)
        self.csa_series_tags.update(self.__csa_series_tags)
        self.phoenix_tags.update(self.__phoenix_tags)
        self.wip_tags.update(self.__wip_tags)


class PASL2D(BOLD):
    __standard_tags = {}
    __siemens_tags = {}
    __csa_series_tags = {}

    __csa_image_tags = {
        'AslTI1': ('QCData', lambda list_: float(list_[3])),
        'AslTI2': ('QCData', lambda list_: float(list_[4]))
    }

    __phoenix_tags = {
        'AslFlowLimit': ('sAsl.fFlowLimit', get_phoenix_float)
    }
    __wip_tags = {}

    def __init__(self):
        super(PASL2D, self).__init__()
        self.standard_tags.update(self.__standard_tags)
        self.siemens_tags.update(self.__siemens_tags)
        self.csa_image_tags.update(self.__csa_image_tags)
        self.csa_series_tags.update(self.__csa_series_tags)
        self.phoenix_tags.update(self.__phoenix_tags)
        self.wip_tags.update(self.__wip_tags)


class CASL2D(BOLD):
    __standard_tags = {}
    __siemens_tags = {}
    __csa_image_tags = {}
    __csa_series_tags = {}
    __phoenix_tags = {}
    __wip_tags = {
        'PerformMppcASL': (0,  bool),
        'NumPhases':      (1,  int),
        'RunSpoil':       (2,  bool),
        'RFFlip':         (4,  int),
        'RFDur':          (5,  int),
        'RFSep':          (6,  int),
        'MeanGrad':       (7,  float),
        'TagGrad':        (8,  float),
        'TagDur':         (9,  float),
        'PLD':            (10, float),
        'T1Opt':          (11, float),
        'PreSat':         (12, bool),
        'DInv':           (13, bool)
    }

    def __init__(self):
        super(CASL2D, self).__init__()
        self.standard_tags.update(self.__standard_tags)
        self.siemens_tags.update(self.__siemens_tags)
        self.csa_image_tags.update(self.__csa_image_tags)
        self.csa_series_tags.update(self.__csa_series_tags)
        self.phoenix_tags.update(self.__phoenix_tags)
        self.wip_tags.update(self.__wip_tags)


class ZShim(BOLD):
    __standard_tags = {}
    __siemens_tags = {}
    __csa_image_tags = {}
    __csa_series_tags = {}
    __phoenix_tags = {}
    __wip_tags = {
        'ZSHimEnabled':     (0,  bool),
        #
        # TODO: This is difficult as it will be an array of integer values
        # with zero values deleted. May need to generalise WIP param access fns
        # the wip memblock positions are probably not stable either - we need
        # to get versioning on the sequence
        # 'ZShimDelayed':     (1,  bool),
        # 'ZShimCalibration': (2,  bool),
        # 'ZShimNumber':      (3,  int),
        # 'ZShimMinimum':     (0,  float),
        # 'ZShimMaximum':     (1,  float),
        # 'ZShimValues':      (4,  int),
    }

    def __init__(self):
        super(ZShim, self).__init__()
        self.standard_tags.update(self.__standard_tags)
        self.siemens_tags.update(self.__siemens_tags)
        self.csa_image_tags.update(self.__csa_image_tags)
        self.csa_series_tags.update(self.__csa_series_tags)
        self.phoenix_tags.update(self.__phoenix_tags)
        self.wip_tags.update(self.__wip_tags)


class MPRAGE(Baseline):
    __standard_tags = {
        ('InversionTime', float)
    }
    __siemens_tags = {}
    __csa_image_tags = {}
    __csa_series_tags = {}
    __phoenix_tags = {}
    __wip_tags = {}

    def __init__(self):
        super(MPRAGE, self).__init__()
        self.standard_tags.update(self.__standard_tags)
        self.siemens_tags.update(self.__siemens_tags)
        self.csa_image_tags.update(self.__csa_image_tags)
        self.csa_series_tags.update(self.__csa_series_tags)
        self.phoenix_tags.update(self.__phoenix_tags)
        self.wip_tags.update(self.__wip_tags)


class PCFlow(Baseline):
    __standard_tags = {}
    __siemens_tags = {}
    __csa_image_tags = {
        'FlowVenc': ('FlowVenc', float),
        'FlowEncodingDirection': ('FlowEncodingDirection', int)
    }
    __csa_series_tags = {}
    __phoenix_tags = {}
    __wip_tags = {}

    def __init__(self):
        super(PCFlow, self).__init__()
        self.standard_tags.update(self.__standard_tags)
        self.siemens_tags.update(self.__siemens_tags)
        self.csa_image_tags.update(self.__csa_image_tags)
        self.csa_series_tags.update(self.__csa_series_tags)
        self.phoenix_tags.update(self.__phoenix_tags)
        self.wip_tags.update(self.__wip_tags)


class SPACE(Baseline):
    __standard_tags = {
        'InversionTime': float,
    }
    __siemens_tags = {}
    __csa_image_tags = {}
    __csa_series_tags = {}
    __phoenix_tags = {}
    __wip_tags = {}

    def __init__(self):
        super(SPACE, self).__init__()
        self.standard_tags.update(self.__standard_tags)
        self.siemens_tags.update(self.__siemens_tags)
        self.csa_image_tags.update(self.__csa_image_tags)
        self.csa_series_tags.update(self.__csa_series_tags)
        self.phoenix_tags.update(self.__phoenix_tags)
        self.wip_tags.update(self.__wip_tags)


class Fieldmap(Baseline):
    __standard_tags = {}
    __siemens_tags = {}
    __csa_image_tags = {}
    __csa_series_tags = {}
    __phoenix_tags = {}
    __wip_tags = {}

    def __init__(self):
        super(Fieldmap, self).__init__()
        self.standard_tags.update(self.__standard_tags)
        self.siemens_tags.update(self.__siemens_tags)
        self.csa_image_tags.update(self.__csa_image_tags)
        self.csa_series_tags.update(self.__csa_series_tags)
        self.phoenix_tags.update(self.__phoenix_tags)
        self.wip_tags.update(self.__wip_tags)


class Localizer(Baseline):
    __standard_tags = {}
    __siemens_tags = {}
    __csa_image_tags = {}
    __csa_series_tags = {}
    __phoenix_tags = {}
    __wip_tags = {}

    def __init__(self):
        super(Localizer, self).__init__()
        self.standard_tags.update(self.__standard_tags)
        self.siemens_tags.update(self.__siemens_tags)
        self.csa_image_tags.update(self.__csa_image_tags)
        self.csa_series_tags.update(self.__csa_series_tags)
        self.phoenix_tags.update(self.__phoenix_tags)
        self.wip_tags.update(self.__wip_tags)


class TOF3D(Baseline):
    __standard_tags = {}
    __siemens_tags = {}
    __csa_image_tags = {}
    __csa_series_tags = {}
    __phoenix_tags = {}
    __wip_tags = {}

    def __init__(self):
        super(TOF3D, self).__init__()
        self.standard_tags.update(self.__standard_tags)
        self.siemens_tags.update(self.__siemens_tags)
        self.csa_image_tags.update(self.__csa_image_tags)
        self.csa_series_tags.update(self.__csa_series_tags)
        self.phoenix_tags.update(self.__phoenix_tags)
        self.wip_tags.update(self.__wip_tags)


# Ordered from specific sequence types to more general ones,
# last item is a catchall. First match takes priority.
SEQTYPES = {[
    (MPRAGE,    lambda d: 'tfl3d'   in d.SequenceName and 'tfl' in phoenix(dobj)['tSequenceFileName']),
    (ZShim,     lambda d: 'epfid2d' in d.SequenceName and 'zshim' in phoenix(dobj)['tSequenceFileName']),
    (CASL2D,    lambda d: 'epfid2d' in d.SequenceName and 'ep2d_casl' in phoenix(dobj)['tSequenceFileName']),
    (PASL2D,    lambda d: 'epfid2d' in d.SequenceName and 'ep2d_pasl' in phoenix(dobj)['tSequenceFileName'] and 'ASL' in d.ImageType),
    (DTIMB,     lambda d: 'ep_b'    in d.SequenceName and 'cmrr_mbep2d_diff' in phoenix(dobj)['tSequenceFileName']),
    (DTI,       lambda d: ('ep_b' in d.SequenceName or 'ez_b' in d.SequenceName) and 'ep2d_diff' in phoenix(dobj)['tSequenceFileName']),
    (BOLDMB,    lambda d: 'epfid2d' in d.SequenceName and 'cmrr_mbep2d_bold' in phoenix(dobj)['tSequenceFileName']),
    (BOLD,      lambda d: 'ep2d'    in d.SequenceName),  # TODO: and something else
    (EPI,       lambda d: 'ep2d'    in d.SequenceName),
    (PCFlow,    lambda d: 'fl2d'    in d.SequenceName and 'P' in d.ImageType),
    (SPACE,     lambda d: 'spcir'   in d.SequenceName and 'tse_vfl' in phoenix(dobj)['tSequenceFileName']),
    (Fieldmap,  lambda d: 'fm2d'    in d.SequenceName and 'gre_field_mapping' in phoenix(dobj)['tSequenceFileName']),
    (Localizer, lambda d: 'fl2d'    in d.SequenceName and 'gre' in phoenix(dobj)['tSequenceFileName']),
    (TOF3D,     lambda d: 'fl3d'    in d.SequenceName and 'fl_tof' in phoenix(dobj)['tSequenceFileName']),
    (Baseline,  lambda d: True),
]}


def private_creators(dobj, group):
    '''List of all the private creator offsets for a given private group
    '''
    return [
        t.element << 8
        for t in dobj.keys()
        if t.group == group and (t.element & 0xff00) == 0
    ]


def get_sequence_type(dobj):
    '''Classify the sequence, returning the appropriate class.
       NB Assumes SEQTYPES is a dict ordered from specific to generic.
    '''
    if 'SequenceName' in dobj and 'ImageType' in dobj:
        for seqt, fn in SEQTYPES.items():
            if fn(dobj):
                return seqt
    return Baseline


def get_dcm_info(dobj):
    '''Extract meta information based on an identification of the sequence type
    '''
    seq = get_sequence_type(dobj)()
    standard_tags   = seq.standard_tags
    siemens_tags    = seq.siemens_tags
    csa_image_tags  = seq.csa_image_tags
    csa_series_tags = seq.csa_series_tags
    phoenix_tags    = seq.phoenix_tags
    wip_tags        = seq.wip_tags

    info = {}
    missing = set()

    # Standard Tags
    for k, fn in standard_tags.items():
        try:
            info[k] = fn(getattr(dobj, k))
        except AttributeError:
            missing.add(k)

    # Siemens Private Tags
    for k, ((group, element), fn) in siemens_tags.items():
        for offset in private_creators(dobj, group):
            element += offset
            try:
                info[k] = fn(dobj[(group, element)].value)
            except KeyError:
                missing.add(k)

    # Image CSA
    image_csa = csa(dobj, 'image')
    for k, (tag, fn) in csa_image_tags.items():
        try:
            info[k] = fn(image_csa[tag])
        except KeyError:
            missing.add(k)

    # Series CSA
    series_csa = csa(dobj, 'series')
    for k, (tag, fn) in csa_series_tags.items():
        try:
            info[k] = fn(series_csa[tag])
        except KeyError:
            missing.add(k)

    # Phoenix Protocol
    protocol = phoenix(dobj)
    for k, (tag, fn) in phoenix_tags.items():
        try:
            info[k] = fn(protocol, tag)
        except KeyError:
            missing.add(k)

    # Derived Tags
    # WIP parameters
    phoenix_dict = phoenix(dobj)
    for k, (wip_pos, fn) in wip_tags.items():
        # Bizarre: if a boolean is False it just disappears
        assert fn in (bool, int, float, str)
        try:
            if fn is float:
                info[k] = fn(phoenix_dict['sWipMemBlock.adFree[%d]' % wip_pos])
            elif fn is str:
                info[k] = fn(phoenix_dict['sWipMemBlock.tFree'])
            else:
                info[k] = fn(phoenix_dict['sWipMemBlock.alFree[%d]' % wip_pos])
        except KeyError:
            if fn is bool:
                info[k] = False
            elif fn is int:
                info[k] = 0
            else:
                missing.add(k)

    return info, missing


if __name__ == "__main__":
    import sys
    from argparse import ArgumentParser

    parser = ArgumentParser(description='Extract DICOM tags to json')
    parser.add_argument('-y', '--yaml',  action="store_true", default=False, help='write yaml instead of json')
    parser.add_argument('--warn',  action="store_true", default=False, help='warn about missing tags')
    parser.add_argument('--seqtype',  action="store_true", default=False, help='identify sequence type only')
    parser.add_argument('-v', '--version', action='version', version=__version__)
    parser.add_argument('infile',  action="store",      default='-',   help='input dicom file',         nargs='?')
    parser.add_argument('outfile', action="store",      default='-',   help='output json or yaml file', nargs='?')

    args = parser.parse_args()

    # DICOM to python dict
    if args.infile == '-':
        # Rather messy as pydicom requires the stream to be seekable
        from tempfile import mkstemp
        import os
        _, tmp_dicom = mkstemp()
        with open(tmp_dicom, 'wb') as f:
            f.write(sys.stdin.read())
        dobj = dcmread(tmp_dicom, 'rb')
        os.unlink(tmp_dicom)
    else:
        with open(args.infile, 'rb') as f:
            dobj = dcmread(f)

    if args.seqtype:
        print('Sequence Type:', get_sequence_type(dobj).__name__, file=sys.stderr)
        sys.exit(0)

    info, missing = get_dcm_info(dobj)

    # Python dict to json or yaml
    if args.yaml:
        import yaml
        if args.outfile == '-':
            yaml.dump(info, sys.stdout, indent=4, encoding=None, default_flow_style=False)
            print()
        else:
            with open(args.outfile, 'w') as f:
                yaml.dump(info, f, indent=4, encoding=None, default_flow_style=False)
                print(file=f)
    else:
        import json
        if args.outfile == '-':
            json.dump(info, sys.stdout, sort_keys=True, indent=4, separators=(',', ': '))
            print()
        else:
            with open(args.outfile, 'w') as f:
                json.dump(info, f, sort_keys=True, indent=4, separators=(',', ': '))
                print(file=f)

    if missing and args.warn:
        print('The following tags were not found:', missing, file=sys.stderr)
        sys.exit(1)
    else:
        sys.exit(0)
