#!/usr/bin/env python3

from fastaUtils.fasta import parse_fasta, parse_header, generate_header
from functools import partial

def build_regex(regexpr,allowed_subs=0):
  chars=[c for c in regexpr if c!=' ']
  final=""
  for c in chars:
    if c=="_":
      final+="[a-zA-Z]"
    elif c=="@":
      final+="[AILMFWYVailmfwyv]"
    elif c=="#":
      final+="[STNQstnq]"
    elif c=="+":
      final+="[RHKrhk]"
    elif c=="-":
      final+="[DEde]"
    elif c==".":
      final+="[-.]"
    else:
      final+=c
#  final=final.replace('[[','[').replace(']]',']')
  if allowed_subs>0:
    return "({}){{s<{}}}".format(final,allowed_subs+1)
  else:
    return r"({})".format(final)

def grep(regexpr,sequences,invert_match=False,only_matching=False,match_header=False,begin=None,end=None,allowed_subs=0):
  rule=build_regex(regexpr,allowed_subs)
  if allowed_subs>0:
    import regex
    finditer=partial(regex.finditer,partial=False)
    rule = regex.compile(rule)
  else:
    import re
    finditer=partial(re.finditer)
    rule = re.compile(rule)
  
  if not invert_match:
    if match_header:
      for seq in sequences:
        matches=list(finditer(rule,seq.header))
        if len(matches):
          yield seq
    else:
      for seq in sequences:
        matches=list(finditer(rule,seq.seq[begin:end]))
        if len(matches):
          if only_matching:
            for match in matches:
              db,uid,name,descr,os,ox,gn,pe,sv,beg,e=parse_header(seq.header.strip())
              beg=match.span()[0]
              e=match.span()[1]
              seq.header=generate_header(db,uid,name,descr,os,ox,gn,pe,sv,beg,e)
              seq.seq=match.group()
              yield seq
          else:
            yield seq
  else:
    if match_header:
      for seq in sequences:
        matches=list(finditer(rule,seq.header))
        if len(matches)==0:
          yield seq
    else:
      for seq in sequences:
        matches=list(finditer(rule,seq.seq[begin:end]))
        if len(matches)==0:
          yield seq

import argparse
if __name__=="__main__":
  parser = argparse.ArgumentParser(prog='fst-grep',description='Perform regex filtering on fasta file. \
  In addition to the standard aminoacids, the following symbols are defined: \
  "_": match all amino acids, not gaps, \
  ".": match gaps, \
  "@": match hydrophobics, \
  "+"/"-": match charged, \
  "#": match polar, not charged',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  parser.add_argument('-v','--invert-match', action='store_true', default=False, help='Return only non-matching sequences')
  parser.add_argument('-o','--only-matching', action='store_true', default=False, help='Return only the matching regions of each sequence')
  parser.add_argument('-b','--begin', type=int, default=None, help='Do not match anything before this residue (first residue is 0)')
  parser.add_argument('-e','--end', type=int, default=None, help='Do not match anything after this residue (first residue is 0)')
  parser.add_argument('-s','--allowed-subs', type=int, default=0, help='Allow this number of substitutions in the match (fuzzy match)')
  parser.add_argument('--match-header', action='store_true', default=False, help='Try to match header, not sequence. -o, -b and -e are ignored in this case')
  parser.add_argument('-c','--count', action='store_true', default=False, help='Count matches')
  args,unkargs=parser.parse_known_args()
  if len(unkargs)==1:
    unkargs.append(None)
  if len(unkargs)!=2:
    raise RuntimeError("Incorrect number of argument passed. Expected 2, found {}. Run fst-grep for usage instructions.".format(len(unkargs)))
    
  seqs=parse_fasta(unkargs[1])
  if args.count:
    n=0
    for s in grep(unkargs[0],seqs,invert_match=args.invert_match,only_matching=args.only_matching,match_header=args.match_header,begin=args.begin,end=args.end,allowed_subs=args.allowed_subs):
      n+=1
    print(n)
  else:
    for s in grep(unkargs[0],seqs,invert_match=args.invert_match,only_matching=args.only_matching,match_header=args.match_header,begin=args.begin,end=args.end,allowed_subs=args.allowed_subs):
      print(s)
      pass
