#!/usr/bin/env python3

import argparse
from fastaUtils.fasta import parse_fasta
from fastaUtils.profiles import parse_profile_line
from itertools import groupby
from operator import itemgetter

def rangetype(arg):
  if not '-' in arg:
    return (int(arg)-1,int(arg))
  else:
    tok=arg.split('-')
    if len(tok[1])==0:
      return (int(tok[0])-1,None)
    return (int(tok[0])-1,int(tok[1]))


if __name__=="__main__":
  parser = argparse.ArgumentParser(prog='fst-cut',description="Remove columns from msa",formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  parser.add_argument('infile', nargs='?', default=None, help='Input file in fasta format')
  parser.add_argument('-c', dest="columns", default=[], nargs='+', type=rangetype, help='Columns to be deleted. Syntax: `int` or `int-int` or `int-`. Extrema are part of the range. First column is 1')
  parser.add_argument('-r', dest="reverse", action="store_true", default=False, help='Reverse mode: cut all other columns')
  parser.add_argument('-R',dest="rule",default=None,type=str,help='python expression returning True for columns that need to be removed. Available variables: `freq[AA]` and `col`')
  parser.add_argument('-F',dest="freqfile",default=None,type=str,help='file containing occurrence frequencies for every column')
  args=parser.parse_args()
  
  todel=[]
  # add all columns matched by rule
  if args.rule is not None and args.freqfile is not None:
    with open(args.freqfile,'r') as freqfile:
      for col,line in enumerate(freqfile):
        freq=parse_profile_line(line.strip())
        if eval(args.rule):
          todel.append(col)

  # extend with column lists. If the user provided one or more right-open lists, store them in lastrange
  lastrange=[]
  for r in args.columns:
    if r[1] is None:
      lastrange.append(r)
    else:
      todel.extend([c for c in range(r[0],r[-1])])

  todel=list(sorted(set(todel)))
  # compute an unique right-open range that does not have elements in common with other ranges
  if len(lastrange)>1:
    f=min([r[0] for r in lastrange])
    lastrange=[(max(todel[-1],f),None)]

  # compute all closed ranges
  ranges=[]
  for k,g in groupby(enumerate(todel),lambda x:x[0]-x[1]):
    group = (map(itemgetter(1),g))
    group = list(map(int,group))
    ranges.append((group[0],group[-1]+1))
  # process sequences
  sequences=parse_fasta(args.infile)
  if args.reverse==False:
    ranges=sorted(ranges,reverse=True)
    for seq in sequences:
      if len(lastrange):
        seq.seq=seq.seq[:lastrange[0][0]]
      for r in ranges:
        seq.seq=seq.seq[:r[0]]+seq.seq[r[-1]:]
      print(seq)
  else:
    for seq in sequences:
      tail=""
      if len(lastrange):
        tail=seq.seq[lastrange[0][0]:]
      seq.seq="".join([seq.seq[r[0]:r[1]] for r in ranges])
      seq.seq+=tail
      print(seq)
