#!/usr/bin/env python
"""
Build a fasta database from a metapeptide database. Either nucleotide or amino acid.
"""

import argparse
import logging
from datetime import datetime
from sixgill import metapeptides
from Bio import bgzf
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
from Bio.Alphabet import IUPAC
from Bio.Seq import Seq
from sixgill import __version__

__author__ = "Damon May"
__copyright__ = "Copyright (c) 2016 Damon May"
__license__ = "Apache 2.0"

logger = logging.getLogger(__name__)

# options for the type of output fasta file
FASTA_TYPE_AA = 'aa'  # amino acid
FASTA_TYPE_PEPTIDE = 'peptide'  # peptide, with specified missed cleavages. This mode requires holding
                                # all peptides in the database in memory, as each is only written once

FASTA_TYPES = [FASTA_TYPE_AA, FASTA_TYPE_PEPTIDE]

DEFAULT_MISSED_CLEAVAGES = 0
DEFAULT_MIN_PEPTIDE_LENGTH = 7


def declare_gather_args():
    """
    Declare all arguments, parse them, and return the args dict.
    Does no validation beyond the implicit validation done by argparse.
    return: a dict mapping arg names to values
    """

    # declare args
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('metapeptidedbfile', type=argparse.FileType('r'),
                        help='input metapeptide database file')
    parser.add_argument('--out', required=True, type=argparse.FileType('w'),
                        help='output file')
    parser.add_argument('--type', required=True, choices=FASTA_TYPES,
                        help='database type')
    parser.add_argument('--missedcleavages', type=int, default=DEFAULT_MISSED_CLEAVAGES,
                        help='missed cleavages (for type peptide only)')
    parser.add_argument('--minpeptidelength', type=int, default=DEFAULT_MIN_PEPTIDE_LENGTH,
                        help='minimum peptide length (for type peptide only)')

    parser.add_argument('--debug', action="store_true", help='Enable debug logging')
    parser.add_argument('--version', action='version', version='%(prog)s {version}'.format(version=__version__))
    return parser.parse_args()


def main():
    args = declare_gather_args()
    # logging
    logging.basicConfig(level=logging.INFO,
                        format="%(asctime)s %(levelname)s: %(message)s")
    if args.debug:
        logger.setLevel(logging.DEBUG)

    script_start_time = datetime.now()
    logger.debug("Start time: %s" % script_start_time)

    print("Reading metapeptide database %s..." % args.metapeptidedbfile.name)

    # reader for the Bgzipped metapeptide database
    bgzf_reader = bgzf.BgzfReader(args.metapeptidedbfile.name)

    n_entries = 0
    if args.type == FASTA_TYPE_AA:
        for metapeptide in metapeptides.read_metapeptides(bgzf_reader):
            n_entries += 1
            entry_name = metapeptide.sequence
            write_protein_name_seq_to_fasta(entry_name, metapeptide.sequence, args.out)
        print("Done. Wrote protein fasta with %d entries to %s" % (n_entries, args.out.name))
    elif args.type == FASTA_TYPE_PEPTIDE:
        peptides_written = set()
        for metapeptide in metapeptides.read_metapeptides(bgzf_reader):
            tryptic_peptides = metapeptides.TRYPTIC_CLEAVAGE_REGEX.findall(metapeptide.sequence)
            peps_to_write = set()
            for n_missed in xrange(0, args.missedcleavages + 1):
                for i in xrange(0, len(tryptic_peptides) - n_missed):
                    pep = ''.join(tryptic_peptides[i:i + n_missed + 1])
                    if len(pep) >= args.minpeptidelength:
                        peps_to_write.add(pep)
            for pep in peps_to_write:
                if pep not in peptides_written:
                    write_protein_name_seq_to_fasta(metapeptide.sequence + '_' + pep, pep, args.out)
                    n_entries += 1
                    peptides_written.add(pep)
        args.out.close()
        print("Done. Wrote peptide fasta with %s entries to %s" % (n_entries, args.out.name))

    else:
        quit("Invalid fasta type. Quitting.")

    logger.debug("End time: %s. Elapsed time: %s" % (datetime.now(), datetime.now() - script_start_time))


def write_dna_name_seq_to_fasta(name, seq, fasta_file, description=None):
    """
    create a single record for a name and a sequence; write it to a fasta file.
    Should be able to call this multiple times on a single file
    :param name:
    :param seq:
    :param fasta_file:
    :return:
    """
    write_record_to_fasta(make_dna_seq_record(name, description, seq), fasta_file)


def write_protein_name_seq_to_fasta(name, seq, fasta_file, description=""):
    """
    create a single record for a name and a sequence; write it to a fasta file.
    Should be able to call this multiple times on a single file
    :param name:
    :param seq:
    :param fasta_file:
    :return:
    """
    write_record_to_fasta(make_protein_seq_record(name, description, seq), fasta_file)


def write_record_to_fasta(record, fasta_file):
    """write a single record to a fasta file. Should be able to call this
    multiple times on a single file
    """
    SeqIO.write([record], fasta_file, "fasta")


def make_dna_seq_record(id, description, seq):
    return make_seq_record(id, description, seq, IUPAC.ambiguous_dna)


def make_protein_seq_record(id, description, seq):
    return make_seq_record(id, description, seq, IUPAC.protein)


def make_seq_record(id, description, seq, alphabet):
    return SeqRecord(Seq(seq, alphabet), id=id, description=description)

main()
