#!python
import argparse as ap
import nmrpeaklists as npl

def argparser():
    description = """
    Convert a CARA anchor or strip peak list to NMRPipe .tab file

    Assignment
    ----------
    CARA uses the Xeasy format for its peak lists. This format does not
    contain assignment data directly. Instead, it includes a spin ID
    number for each chemical shift in the peak list, where each spin ID
    number corresponds to a spin within CARA.

    Assignment data must be added to STRIP peak lists using a spin ID file.
    Spin ID files are generated using the CARA Lua script included in the
    nmrpeaklists library. Use the --spinID option to provide this spin ID
    file.

    If no spin ID file is provided, it will be assumed that the peak list is
    an ANCHOR peak list. CARA anchor peak lists include the assignment data
    for each data line as a comment immediately following it. In this case,
    the assignments can be extracted directly from the peak list. No spin ID
    file is required.

    Correspondence to an NMRPipe spectrum
    -------------------------------------
    If you plan to use the NMRPipe .tab file with nlinLS to fit data, then
    the peak list will correspond to a particular NMRPipe spectrum. The
    script needs access to the header of that spectrum to correctly
    calculate the location of each peak in units of points. Use the --ft
    option to provide the spectrum. It can be either a monolithic file or
    the first plane from a series.

    Beware, the order of the columns in the peak list may not match the
    order of dimensions in the spectrum.  Use the option --order to specify
    which column of the peak list (from left to right) corresponds to which
    dimension in the NMRPipe spectrum (XYZA).

    Use the --cal option to enter the calibration for each dimension in PPM.
    The order here should match the order of dimensions in the NMRPipe
    spectrum. These values should be the same as the calibration values in
    CARA.

    nlinLS Columns
    --------------
    If you plan to use the .tab file for fitting data with nlinLS, then you
    will need to add the appropriate fitting columns. Columns can be added
    in pre-set groups based on common experiments, or they can be added
    individually.

    Use the '-e' option to add all the fitting parameters needed for one of
    the experiment types listed. Use the '-c' option to specify individual
    parameter columns. Columns specified with the '-c' option will
    overwrite corresponding columns set with '-e'.

    You must provide three values when specifying a column with '-c'. The
    first two are the values that appear in the VARS and FORMAT lines of
    the NMRPipe .tab file, respectively. The final value will be the default
    value used for each peak in the column.

    In some cases, several columns are related to one another. These fall
    into two categories, either there is one column for each dimension
    (e.g. XW, YW, ZW) or there are multiple columns for a single dimension
    (e.g. Z_A0, Z_A1 ...). In these cases, use '%s' or '%d' in the VARS
    string to indicate that the string should be expanded with dimension
    labels (X, Y, Z ...) or with list indices (0, 1, 2 ...), respectively.
    The default values for these columns must be specified as a Python list.
    List multiplication and addition syntax is supported here
    (e.g. [1]+[0.5]*64), but the list MUST NOT include any spaces.

    Supported Experiments
    ---------------------
    The following pre-made experiment types are available. Use the given
    experiment name with the '-e' option to add the corresponding nlinLS
    fitting columns to the .tab file. Some experiments take a second
    argument to the '-o' option. See the notes for an explanation.

    The columns HEIGHT and XW, YW, etc. are added for all experiments.

        Experiment    Columns         Notes
        ----------    -------         -----
        Volume        VOLUME
        R1            Z_A
        R2            Z_A
        Het-NOE       Z_A0, Z_A1
        CEST N        Z_A0 - Z_A(N)   N = profile points+1 (for reference)
        RD N          Z_A0 - Z_A(N)   N = profile points+1 (for reference)

    The default value for each column added is provided in the table below.
    These values can be overridden with the '-c' option.

        Column                        Default
        ------                        -------
        HEIGHT                          5e7
        XW, YW, ZW, AW                   2
        VOLUME                           0
        Z_A (R1)                        -1
        Z_A (R2)                       -10
        Z_A0 (Het-NOE, CEST, RD)         1.0
        Z_A1 (Het-NOE)                   0.75
        Z_A1, Z_A2, ... (CEST, RD)       0.7

    Examples
    --------
    cara2tab --in strip.peaks --spinID cara.spins --ft NOESY.ft3
             --order ZXY --cal 0.01 -0.05 0.11 -e Volume

    cara2tab --in anchor.peaks --ft test.ft2 -e R1

    cara2tab --in anchor.peaks --ft test.ft2 -e R1 -c HEIGHT %12.5e 1e8

    cara2tab --in anchor.peaks --ft test.ft2 -e R2 -c %sW %6.3f [8,5]

    cara2tab --in anchor.peaks --ft test.ft2 -e CEST 33
             -c Z_A%d %8.5f [1]+[0.6]*32
    """
    parser = ap.ArgumentParser(description=description,
                               formatter_class=ap.RawDescriptionHelpFormatter)
    required = parser.add_argument_group('required arguments')
    required.add_argument('--in', dest='in_file', metavar='cara.peaks',
                          type=str, required=True, help='CARA peak list')
    parser.add_argument('--spinID', dest='spin_id_file', metavar='cara.spins',
                        type=str, default=None, help='CARA spin ID file')
    parser.add_argument('--ft', dest='ft_file', metavar='test.ft2',
                        type=str, help='NMRPipe spectrum')
    parser.add_argument('--order', dest='dim_order', metavar='ORDER', type=str,
                        help='correspondence between CARA peak list columns ' +
                        '(left to right) and spectrum dimensions (XYZA), ' +
                        'default XYZA')
    parser.add_argument('--cal', dest='cal', metavar='C', nargs='*',
                        type=float, help='calibration for each dimension ' +
                        'of the .ft file (in PPM)')
    parser.add_argument('-e', metavar=('EXP', 'N'), dest='experiment',
                        action=ValidateExp, nargs='+',
                        help='Pre-made experiments: ' +
                        'Volume R1 R2 Het-NOE RD CEST')
    parser.add_argument('-c', dest='custom', action='append', nargs=3,
                        metavar=('VAR', 'FMT', 'DEFAULT'),
                        help='Create custom column')
    parser.add_argument('--out', dest='out_file', metavar='cara.tab',
                        type=str, default='cara.tab',
                        help='NMRPipe .tab file (default cara.tab)')
    return parser

def main():
    parser = argparser()
    args = parser.parse_args()

    # Read peaklist
    peaklist = npl.XeasyFile().read_peaklist(args.in_file)

    # Assign the peak list
    if args.spin_id_file is not None:
        assignments = npl.CaraSpinsFile().read_file(args.spin_id_file)
    else:
        assignments = npl.CaraAnchorFile().read_file(args.in_file)
    peaklist = assignments.assign_peaklist(peaklist)
    peaklist = npl.sort_by_assignments(peaklist)
    peaklist = npl.renumber_peaklist(peaklist)

    # Permutate the dimensions to match the .ft file
    if args.dim_order is not None:
        new_indices = ['XYZA'.index(d) for d in args.dim_order]
        peaklist = npl.reorder_dims(peaklist, new_indices=new_indices)

    # Calibrate the peak list
    if args.cal is not None:
        peaklist = npl.calibrate_peaklist(peaklist, args.cal)

    # Calculate the chemical shifts in points
    if args.ft_file is not None:
        ft_file = npl.PipeSpectrumHeader().read_file(args.ft_file)
        peaklist = ft_file.calc_shift_pts(peaklist)

    # Get the columns and default values for the selected experiment
    if args.experiment is not None:
        columns, defaults = experiment_columns(args.experiment, peaklist.dims)
    else:
        columns = []
        defaults = []

    # Process custom options
    if args.custom is not None:
        columns, defaults = add_custom(args.custom, columns, defaults)

    # Add the new columns and set the default values
    for column, default in zip(columns, defaults):
        for peak in peaklist:
            column.set_value(peak, default)

    # Create new file object, insert new columns in its template, and write
    pipe_file = npl.PipeFile()
    pipe_file.template.insert_default(columns)
    pipe_file.write_peaklist(peaklist, args.out_file)


class ValidateExp(ap.Action):
    """ Custom argparse action for parsing the experiment type"""
    def __call__(self, parser, args, values, option_string=None):
        if not values or len(values) > 2:
            parser.error('--exp takes exactly one or two arguments')
        exp = values[0]
        if exp not in ['Volume', 'R1', 'R2', 'Het-NOE', 'RD', 'CEST']:
            parser.error('invalid experiment: {!r}'.format(exp))
        if len(values) == 1:
            if exp in ['RD', 'CEST']:
                parser.error('%s requires a second argument' % exp)
        else:
            values = [exp, int(values[1])]
            if exp in ['Volume', 'R1', 'R2', 'Het-NOE']:
                parser.error('%s does not take a second argument' % exp)
        setattr(args, self.dest, values)


def experiment_columns(experiment, dims):
    columns = []
    defaults = []
    # Add HEIGHT and XW, YW, etc. columns (common to all experiments)
    columns.append(npl.PeakAttrColumn('HEIGHT', '%11.4e', 'height'))
    defaults.append(5e7)
    width_names = ['{}W'.format(d) for d in 'XYZA'[:dims]]
    width_group = npl.SpinAttrGroup(width_names, '%5.2f', 'width')
    columns.extend(width_group.generate_columns())
    defaults.extend([2]*dims)
    # Choose experiment
    if len(experiment) == 2:
        exp, profile_length = experiment
    else:
        exp, = experiment
    # Add the remaining, experiment-specific columns
    if exp == 'Volume':
        new_columns = [npl.PeakAttrColumn('VOL', '%11.4e', 'volume')]
        new_defaults = [0]
    elif exp == 'R1':
        new_columns = [npl.PeakAttrColumn('Z_A', '%6.2f', 'R1')]
        new_defaults = [-1]
    elif exp == 'R2':
        new_columns = [npl.PeakAttrColumn('Z_A', '%6.2f', 'R2')]
        new_defaults = [-10]
    elif exp == 'Het-NOE':
        group = npl.PeakAttrListGroup(['Z_A0', 'Z_A1'], '%8.5f', 'het_noe')
        new_columns = group.generate_columns()
        new_defaults = [1, 0.75]
    elif exp == 'CPMG' or exp == 'CEST':
        names = ['Z_A%d' % i for i in range(profile_length)]
        group = npl.PeakAttrListGroup(names, '%8.5f', exp.lower())
        new_columns = group.generate_columns()
        new_defaults = [1] + [0.7]*(profile_length - 1)
    columns.extend(new_columns)
    defaults.extend(new_defaults)
    return columns, defaults


def add_custom(custom, columns, defaults):
    for var, fmt, default in custom:
        # Parse the custom options and make the Columns
        if '%s' in var:
            new_defaults = npl.parse_list_literal(default)
            num_spins = len(new_defaults)
            names = [var % d for d in 'XYZA'[:num_spins]]
            group = npl.SpinAttrGroup(names, fmt, var)
            new_columns = group.generate_columns()
        elif '%d' in var:
            new_defaults = npl.parse_list_literal(default)
            names = [var % i for i in range(len(new_defaults))]
            group = npl.PeakAttrListGroup(names, fmt, var)
            new_columns = group.generate_columns()
        else:
            if fmt.endswith(('e', 'f')):
                default = float(default)
            elif fmt.endswith('d'):
                default = int(default)
            new_columns = [npl.PeakAttrColumn(var, fmt, var)]
            new_defaults = [default]
        # For each new Column, overwrite if it already exists
        column_names = [column.name for column in columns]
        for new_col, new_def in zip(new_columns, new_defaults):
            try:
                index = column_names.index(new_col.name)
            except ValueError:
                columns.append(new_col)
                defaults.append(new_def)
            else:
                columns[index] = new_col
                defaults[index] = new_def
    return columns, defaults


if __name__ == '__main__':
    main()
