#!/usr/bin/env python3

import argparse
import logomaker
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

def dataframeFreq(data,aa,transform=lambda x:x):
    df={}
    for a in aa:
        df[a]=[ transform(p[a]) if a in p else 0. for p in data ]
    return pd.DataFrame.from_dict(df)

def loadFreq(filename):
    data=[]
    with open(filename) as infile:
        for line in infile:
            pairs=line.split()
            pos={}
            for pair in pairs:
                k,v=pair.split(':')
                v=float(v)
                pos[k]=v
            data.append(pos)
    return data

def colorschemes():
  data={}
  for idx,row in logomaker.list_color_schemes().iterrows():
    data[idx]=row['color_scheme']
  return data

def entropy(x):
  if x==0.:
    return 0
  else:
    return -x*np.log(x)

if __name__=="__main__":
  parser = argparse.ArgumentParser(prog='fst-logo',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  parser.add_argument('profile',type=str,help='Profile to be used during encoding')
  parser.add_argument('outfile',type=str,help='Output file')
  parser.add_argument('--aalist',type=str,default="GPAVLIMCFYWHKRQNEDST",help='List of symbols in logo')
  parser.add_argument('--colorscheme',type=int,default=9,help='Color scheme: {}'.format(colorschemes()))
  parser.add_argument('--fontname',type=str,default='Arial',help='Font name')
  parser.add_argument('--width',type=float,default=0.8,help='Column width')
  parser.add_argument('--vpad',type=float,default=0.,help='vertical pad')
  parser.add_argument('--fadeprob',default=False,action='store_true',help='Fade')
  parser.add_argument('--entropy',default=False,action='store_true',help='Plot entropy, not frequencies')

  args=parser.parse_args()
  
  color_scheme=colorschemes()[args.colorscheme]

  freq=loadFreq(args.profile)
  if args.entropy:
    crp_df = dataframeFreq(freq,aa=list(args.aalist),transform=entropy)  
  else:
    crp_df = dataframeFreq(freq,aa=list(args.aalist))

  crp_logo = logomaker.Logo(crp_df, fade_probabilities=args.fadeprob, vpad=args.vpad, width=args.width, font_name='Arial Rounded MT Bold',color_scheme=color_scheme, stack_order='small_on_top')

  # style using Logo methods
  crp_logo.style_spines(visible=False)
  crp_logo.style_spines(spines=['left', 'bottom'], visible=True)
  crp_logo.style_xticks(rotation=90, fmt='%d', anchor=0)

  # style using Axes methods
  crp_logo.ax.set_ylabel("Frequency", labelpad=2)
  crp_logo.ax.set_xlabel("Position from end of J domain - Class A", labelpad=5)
  #crp_logo.ax.xaxis.set_ticks_position(list(range(0,len(dataA),10)))
  crp_logo.ax.xaxis.set_tick_params(pad=-1)
  plt.savefig(args.outfile,dpi=600)
  plt.show()
