#!/usr/bin/env python3

import argparse
from fastaUtils.fasta import parse_fasta
from collections import defaultdict
import math
import string

list21aa="-ARNDCEQGHILKMFPSTWYV"
alphabet20aa={a:0 for a in list21aa[1:]}
alphabet20aaplusgap={a:0 for a in list21aa}

if __name__=="__main__":
  parser = argparse.ArgumentParser(prog='fst-profile',description="Generate a profile/logo from a fasta file",formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  parser.add_argument('infile', nargs='?', default=None, help='Input file in fasta format')
  parser.add_argument('--alphabet', default='adaptive',choices=['adaptive','20aa','20aa+gap'], type=str, help='Select alphabet')
  parser.add_argument('--exclude', default=[], nargs='+', type=str, help='Exclude some symbols from adaptive alphabet')
  parser.add_argument('--entropy', default=False, action='store_true', help='Print entropy profile instead of frequencies')
  args=parser.parse_args()
  
#  size=[len(seq.seq) for seq in parse_fasta(args.infile)]
  # process sequences
  sequences=parse_fasta(args.infile)
  seq0=next(sequences)

  Ncols=len(seq0.seq)*200
  size=len(seq0.seq)
#  Ncols=max(size)
  if args.alphabet=='adaptive':
    profile=[defaultdict(lambda:0.,{}) for i in range(Ncols)]
    excludelist=frozenset(args.exclude)
  elif args.alphabet=='20aa':
    profile=[defaultdict(lambda:0.,alphabet20aa) for i in range(Ncols)]
    excludelist=frozenset(list(string.ascii_lowercase)+[s for s in list(string.ascii_uppercase) if s not in alphabet20aa])
  elif args.alphabet=='20aa+gap':
    profile=[defaultdict(lambda:0.,alphabet20aaplusgap) for i in range(Ncols)]
    excludelist=frozenset(list(string.ascii_lowercase)+[s for s in list(string.ascii_uppercase) if s not in alphabet20aaplusgap])
  
  for col,c in enumerate(seq0.seq):
    if c not in excludelist:
      profile[col][c]+=1
      profile[col]['Nseqs']+=1
    
  for seq in sequences:
    size=max(size,len(seq.seq))
    for col,c in enumerate(seq.seq):
      if c not in excludelist:
        profile[col][c]+=1
        profile[col]['Nseqs']+=1
  
  if args.entropy:
    for c in range(size):
      col=profile[c]
      Nseqs=col['Nseqs']
      s=0.
      for key,val in col.items():
        if key != 'Nseqs':
          p=float(val)/Nseqs
          if(p>0):
            s-=p*math.log(p)
      print(s) 
  else:
    for c in range(size):
      col=profile[c]
      print( " ".join(["{}:{}".format(key,value/col['Nseqs']) for key,value in sorted(col.items()) if key!='Nseqs']) )
