#!/usr/bin/env python3 

import argparse
from fastaUtils.fasta import parse_fasta, parse_header, iterate_sequences
from scipy.spatial.distance import cdist
import numpy as np

if __name__=="__main__":
  parser = argparse.ArgumentParser(prog='fst-distance',description="Compute pairwise distances or average pairwise distances between sequences in a MSA",formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  parser.add_argument('infile', nargs='?', default=None, help='Input file in fasta format')
  parser.add_argument('-to', dest='C', default=None, type=str, help='Second input file in fasta format')
  parser.add_argument('-bs', dest='batch_size', type=int,default=1024, help='Batch size')
  parser.add_argument('--aggregate', dest='aggregate', default=False, action='store_true', help='Return average distance in group instead than sequence by sequence')
  parser.add_argument('--sample-weights', '-sw', dest='sample_weights', default=None, type=str, help='Weight the contribution of each sequence')
  parser.add_argument('--seqid', dest='seqid', default=False, action='store_true', help='Compute sequence identity instead of hamming distance')

  args=parser.parse_args()
    
  seqs=parse_fasta(args.infile)
  X=[[ord(ch) for ch in seq.seq] for seq in seqs]
  check=[len(seq) for seq in X]
  if np.min(check)!=np.max(check):
    raise RuntimeError("Not all sequences have the same length")

  X=np.array(X,dtype=int)

  if args.C is None:
    C=X
  else:
    seqs=parse_fasta(args.C)
    C=[[ord(ch) for ch in seq.seq] for seq in seqs]
    C=np.array(C,dtype=int)

  if args.sample_weights is not None:
    sw=np.loadtxt(args.sample_weights)
  else:
    sw=np.ones((X.shape[0],))

  batches=X.shape[0]//args.batch_size
  if batches==0:
    chunks=[X]
  else:
    chunks=np.array_split(X,batches)
  N=C.shape[0]-1 if args.C is None else C.shape[0]
  
  distances=[]
  for chunk in chunks:
    d=cdist(chunk,C,metric='hamming').sum(axis=1)/N
    distances.append(d)

  distances=np.hstack(distances)
  if args.seqid:
    distances=1.-distances

  if args.aggregate:
    print( np.multiply(distances,sw).sum()/np.sum(sw) )
  else:
    for d in distances:
      print(d)