#!/usr/bin/env python3

"""
NanoVar

This is the main executable file of the program NanoVar.

Copyright (C) 2019 Tham Cheng Yong


NanoVar is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

NanoVar is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with NanoVar.  If not, see <https://www.gnu.org/licenses/>.
"""

__author__ = 'CY Tham'
import os
import sys
import time
import logging
import nanovar
import random
import threading
from datetime import datetime
from nanovar import __version__, input_parser, gzip_check, fastx_valid, bed_valid, check_exe
from progress.spinner import Spinner


def main():
    # Parse arguments
    args = input_parser()
    read_path = args.reads
    ref_path = args.ref
    ref_name = os.path.basename(ref_path).rsplit('.', 1)[0]
    wk_dir = args.dir
    genome_filter = args.filter_bed
    minlen = args.minlen
    splitpct = args.splitpct
    minalign = args.minalign
    buff = args.buffer
    score_threshold = args.score
    threads = args.threads
    true_threads = threads
    quiet = args.quiet
    force = args.force
    model_path = os.path.join(os.path.dirname(nanovar.__file__), 'model', 'ANN.E63B400L2N12-5D0.4-0.3SGDsee31_het_v1.h5')
    filter_bed_dir = os.path.join(os.path.dirname(nanovar.__file__), 'gaps')
    mdb = args.mdb
    wmk = args.wmk
    hsb = args.hsb

    # Check for required executables
    mdb = check_exe(mdb, 'makeblastdb')
    wmk = check_exe(wmk, 'windowmasker')
    hsb = check_exe(hsb, 'hs-blastn')

    # Observe verbosity
    if quiet:
        sys.stdout = open(os.devnull, 'w')
    else:
        # Print initiation message
        now = datetime.now()
        now_str = now.strftime("[%d/%m/%Y %H:%M:%S]")
        print(now_str, "- NanoVar started")
        if threads > 1:

            def sleep():
                t = 0.15
                t += t * random.uniform(-0.2, 0.2)
                time.sleep(t)

            # Spinner progress
            def run_spinner(spin):
                while True:
                    spin.next()
                    sleep()

            spinner = Spinner('Mapping reads and calling SVs - ')
            thread_spin = threading.Thread(target=run_spinner, args=(spinner,))
            thread_spin.setDaemon(True)
            thread_spin.start()
            threads -= 1
        else:
            print('Mapping reads and calling SVs -')

    # Setup working directory
    if not os.path.exists(wk_dir):
        os.makedirs(wk_dir)
    if not os.path.exists(os.path.join(wk_dir, 'fig')):
        os.makedirs(os.path.join(wk_dir, 'fig'))

    # Setup up logging
    log_file = os.path.join(wk_dir, 'NanoVar-{:%d%m%y-%H%M}.log'.format(datetime.now()))
    logging.basicConfig(filename=log_file, level=logging.DEBUG, format='[%(asctime)s] - %(levelname)s - %(message)s',
                        datefmt='%d/%m/%Y %H:%M:%S')
    logging.info('Initialize NanoVar log file')
    logging.info('Version: NanoVar-%s' % __version__)
    logging.info('Command: %s' % ' '.join(sys.argv))

    # Test gzip compression and validates read file
    if gzip_check(read_path):
        read_para = "<(zcat " + read_path + ")"
        fastx_check, rlen_dict = fastx_valid(read_path, "gz")
        read_name = os.path.basename(read_path).rsplit('.fa', 1)[0]
    else:
        read_para = "<(cat " + read_path + ")"
        fastx_check, rlen_dict = fastx_valid(read_path, "txt")
        read_name = os.path.basename(read_path).rsplit('.fa', 1)[0]
    if fastx_check[0] == "Fail":
        logging.critical("Error: Input FASTQ/FASTA file is corrupted around line %s +/- 4" % str(fastx_check[1]))
        raise Exception("Error: Input FASTQ/FASTA file is corrupted around line %s +/- 4" % str(fastx_check[1]))
    else:
        logging.debug("Input FASTQ/FASTA file passed")
    if gzip_check(ref_path):
        logging.critical("Error: Input reference file is gzipped, please unpack it")
        raise Exception("Error: Input reference file is gzipped, please unpack it")
    logging.info('Reads: %s' % read_path)
    logging.info('Reference genome: %s' % ref_path)
    logging.info('Working directory: %s' % wk_dir)
    logging.info('Filter file: %s' % genome_filter)
    logging.info('Minimum SV len: %s' % str(minlen))
    logging.info('Mapping percent for split-read: %s' % str(splitpct))
    logging.info('Length buffer for clustering: %s' % str(buff))
    logging.info('Score threshold: %s' % str(score_threshold))
    logging.info('Number of threads: %s\n' % str(threads))
    logging.info('Total number of reads in FASTQ/FASTA: %s\n' % str(fastx_check[1]))
    logging.info('NanoVar started')

    from Bio import SeqIO
    from collections import OrderedDict
    # Process reference genome
    contig_len_dict = OrderedDict()
    total_gsize = 0
    for seq_record in SeqIO.parse(ref_path, "fasta"):
        contig_len_dict[seq_record.id] = len(seq_record)
        total_gsize += len(seq_record)

    # Validate filter BED file
    if genome_filter is not None:
        if genome_filter in ('hg38', 'hg19', 'mm10'):
            filter_path = os.path.join(filter_bed_dir, genome_filter + '_filter.bed')
        else:
            filter_path = genome_filter
        if os.path.isfile(filter_path):
            if bed_valid(filter_path, contig_len_dict):
                logging.debug("Genome filter BED passed")
        else:
            logging.critical("Error: Genome filter BED %s is not found" % filter_path)
            raise Exception("Error: Genome filter BED %s is not found" % filter_path)
    else:
        filter_path = genome_filter

    # Indexing and alignment
    from nanovar import master_align
    blast_cmd_tab = master_align(ref_path, wk_dir, ref_name, read_para, read_name, threads, force, mdb, wmk, hsb)

    # Parse and detect SVs
    logging.info('Parsing and detecting SVs')
    logging.getLogger("matplotlib").setLevel(logging.WARNING)
    from nanovar.nv_characterize import VariantDetect
    run = VariantDetect(wk_dir, blast_cmd_tab[1], rlen_dict, splitpct, minalign, filter_path, minlen, buff, model_path,
                        total_gsize, contig_len_dict, score_threshold, read_path, read_name, ref_path, ref_name, blast_cmd_tab[0])
    run.parse_detect()
    run.cluster_nn()
    run.vcf_report()
    logging.info('NanoVar ended')
    now = datetime.now()
    if true_threads > 1:
        now_str = now.strftime("\n[%d/%m/%Y %H:%M:%S]")
    else:
        now_str = now.strftime("[%d/%m/%Y %H:%M:%S]")
    print(now_str, "- NanoVar ended")


if __name__ == "__main__":
    main()
