#!python
import argparse as ap
import nmrpeaklists as npl

description = """
Read peak clusters from a file and add them to an NMRPipe .tab file

Cluster file format
-------------------
Each cluster is a comma separated list of peaks. Each peak is a list of
spin names. Each spin name should contain three parts: a one-letter code
for the amino acid type or '+' for un-assigned spin systems; an integer
residue number or system number; and finally a dash followed by an atom
name. Any of the three parts can be replaced by a '?' if the part is
unavailable. Two names can be combined into one if the spins form a spin
anchor (i.e. they are a directly attached hydrogen/heavy atom pair). In
such a case, list the residue type and residue number once and combine
the atom names with a slash. If spin anchor names are combined for one
peak, they must be combined for all peaks. You can add or remove
clusters to be added with the '#' symbol.

The nmrpeaklists script print_tab_clusters can be used with the option
'-c' to print a list of clusters compatible with this script.

2D Examples
-----------
   I29-H  I29-N,   K9-H     K9-N
 # K31-H  K31-N,  W37-HE1  W37-NE1

or

   I29-H/N,    K9-H/N
 # K31-H/N,   W37-HE1/NE1
  E102-H/N,  +117-H/N,     +118-H/N,    +143-H/N
   E17-H/N,    ??-?

3D Examples
-----------
   W37-HA  W37-CA  L87-?,  I29-HG2  I29-CG2  Y74-HE

or

   W37-HA/CA  L87-?,  I29-HG2/CG2  Y74-HE
 # S?-H/N  D58-C,  R64-H/N  T63-C
"""
parser = ap.ArgumentParser(description=description,
                           formatter_class=ap.RawDescriptionHelpFormatter)
parser.add_argument(dest='tab_file', metavar='peaklist.tab', type=str,
                    help='.tab file')
parser.add_argument(dest='cluster_file', metavar='clusters', type=str,
                    help='cluster file')
parser.add_argument(dest='out_file', metavar='cluster.tab', type=str,
                    default='cluster.tab', nargs='?',
                    help='output file (default cluster.tab)')
args = parser.parse_args()

# Read the peak list
pipe_file = npl.PipeFile()
peaklist = pipe_file.read_peaklist(args.tab_file)

# Read the cluster file
# Split each cluster into a list of peaks and each peak into a list of names
clusters = []
with open(args.cluster_file, 'r') as clst:
    cluster_file = clst.readlines()
for line in cluster_file:
    line = line.strip()
    if not line or line.startswith('#'):
        continue
    cluster = []
    for peak in line.split(','):
        peak = peak.strip()
        names = peak.split()
        cluster.append(names)
    clusters.append(cluster)

# Determine the number of names in each peak
names_per_peak = (len(peak) for cluster in clusters for peak in cluster)
names_per_peak = set(names_per_peak)
if len(names_per_peak) != 1:
    err = ('cluster file format error, '
           'each peak must have the same number of names')
    raise ValueError(err)
names_per_peak = names_per_peak.pop()

# Create the npl.Column types for reading the names
num_dims = peaklist.dims
indices = [(0, 1)] if names_per_peak < num_dims else [0]
start = 2 if names_per_peak < num_dims else 1
indices += [i for i in range(start, num_dims)]
columns = npl.PipeNameGroup().generate_columns(indices)

# Read each cluster into its own peak list
# Create a list of clusters zipped with their corresponding cluster ids
# For now, initialize the cluster ids to None
cluster_list = []
for list_of_peaks in clusters:
    cluster = npl.PeakList()
    for list_of_names in list_of_peaks:
        peak = npl.Peak()
        peak[:] = [npl.Spin() for _ in range(num_dims)]
        for column, name in zip(columns, list_of_names):
            column.set_string(peak, name)
        cluster.append(peak)
    cluster_list.append((cluster, None))
clusters = cluster_list

# For each peak, determine if it is in one of the clusters
# If so, use the existing cluster id
# If not, or if the cluster doesn't yet have a cluster id, assign a new one
for default_id, peak in enumerate(peaklist, 1):
    found_in = [index for index, (cluster, cluster_id) in enumerate(clusters)
                if peak in cluster]
    if not found_in:
        peak.cluster_id = default_id
        peak.cluster_size = 1
    elif len(found_in) == 1:
        index, = found_in
        cluster, cluster_id = clusters[index]
        if cluster_id is None:
            cluster_id = default_id
            clusters[index] = (cluster, cluster_id)
        peak.cluster_id = cluster_id
        peak.cluster_size = len(cluster)
    else:
        names = ' '.join(spin.name for spin in peak)
        raise ValueError('{} found in more than one cluster'.format(names))

# If the cluster columns are not in the template, add them.
# Then write the peaklist
column_names = [column.name for column in pipe_file.template]
columns_to_add = []
if 'CLUSTID' not in column_names:
    columns_to_add.append(npl.PeakAttrColumn('CLUSTID', '%4d', 'cluster_id'))
if 'MEMCNT' not in column_names:
    columns_to_add.append(npl.PeakAttrColumn('MEMCNT', '%2d', 'cluster_size'))
pipe_file.template.insert_default(columns_to_add)
pipe_file.write_peaklist(peaklist, args.out_file)
