#!/usr/bin/env python3

import argparse
import json
import os


DESCRIPTION = """firm-convertminer - convert MINER input files to a FIRM input directory"""

if __name__ == '__main__':
   parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                    description=DESCRIPTION)
   parser.add_argument('regulons', help='regulons file (JSON format)')
   parser.add_argument('mappings', help='mappings file')
   parser.add_argument('outdir', help='output directory')
   args = parser.parse_args()

   if not os.path.exists(args.regulons):
      raise Exception('Regulons file does not exist')
   if not os.path.exists(args.mappings):
      raise Exception('Mappings file does not exist')
   if not os.path.exists(args.outdir):
      os.makedirs(args.outdir)

   ens2entrez = {}
   altens = {}
   with open(args.mappings) as infile:
      infile.readline()  # skip header
      for line in infile:
         name1, name2, source = line.strip().split('\t')
         if source == 'Entrez Gene ID':
            ens2entrez[name1] = name2
         if source == 'Ensembl Gene ID':
            altens[name2] = name1

   skipped = set()
   allgenes = set()
   with open(os.path.join(args.outdir, 'regulons.sgn'), 'w') as outfile:
      outfile.write('Gene\tGroup\n')
      with open(args.regulons) as infile:
         regulons = json.load(infile)
         for reg_id, genes in regulons.items():
            for gene in genes:
               allgenes.add(gene)
               try:
                  outfile.write('%s\t%s\n' % (ens2entrez[gene], reg_id))
               except KeyError:
                  # try alternative EnsEMBL gene id
                  try:
                     outfile.write('%s\t%s\n' % (ens2entrez[altens[gene]], reg_id))
                  except:
                     skipped.add(gene)

   print("Conversion done and result placed into '%s'" % args.outdir)
   print("The following %d genes (out of %d) could not be mapped: " % (len(skipped), len(allgenes)))
   for g in skipped:
      print(g)
