#!/usr/bin/env python

import os
import sys
# if you pip installed this is wrong but you have silent_tools installed anyways
silent_tools_dir = os.path.dirname(os.path.realpath(__file__))
if silent_tools_dir not in sys.path:
    sys.path.append(silent_tools_dir)

import silent_tools
from silent_tools import eprint
import math
import re

# Don't throw an error when someone uses head
from signal import signal, SIGPIPE, SIG_DFL
signal(SIGPIPE, SIG_DFL) 

if (len(sys.argv) != 3):
    eprint("")
    eprint("silentsplitbylength by Nate - a tool to split a silentfile into parts with the same sequence length. This is crucial for getting good performance when running AF2.")
    eprint("Usage:")
    eprint("        silentsplitbylength myfile.silent TAGS_PER_SILENT_FILE")
    sys.exit(1)

silent_file = sys.argv[1]

if ( not os.path.exists(silent_file) ):
    eprint("silentsplitbylength: File not found " + silent_file)
    assert(False)

tags_per = sys.argv[2]
try:
    tags_per = int(tags_per)
    assert(tags_per > 0)
except:
    eprint("silentsplitbylength: Second argument must be a number greater than 0")
    assert(False)

alpha = "abcdefghijklmnopqrstuvwxyz"
silent_index = silent_tools.get_silent_index( silent_file )

def get_split_str(width, number):
    out_str = ""
    for i in range(width):
        out_str = alpha[number % 26] + out_str
        number //= 26
    return "x" + "z"*(width-2) + out_str

def get_split_str2(number):
    for width in range(2, 10000):
        this_num = 26**(width-1) * 25 
        if ( number < this_num ):
            return get_split_str(width, number)
        number -= this_num

def get_sequence_length( structure ):
    for line in structure.split('\n'):
        if not line.startswith( "ANNOTATED_SEQUENCE" ): continue
        return len( re.sub( "\[.*?\]", "", line.split(' ')[1] ) )
    raise Exception( "No ANNOTATED_SEQUENCE line found in silent file. Are you sure the file is correctly formatted?" )

def newfile( seqlen, idx ):
    f = open( f"len{seqlen}{get_split_str2(idx)}.silent", "w" )
    f.write( silent_tools.silent_header(silent_index) )
    return f

def open_new_file( file_dict, seqlen ):
    if seqlen in file_dict:
        # close the previous file and open a new one
        currfile, _, idx = file_dict[seqlen]
        currfile.close()
        idx += 1
    else:
        idx = 0

    currfile = newfile(seqlen,idx)
    file_dict[seqlen] = [ currfile, 0, idx ]

def add2length_file( structure, seqlen, file_dict ):
    if seqlen in file_dict:
        # assert we have not reached limit for this file
        if file_dict[ seqlen ][1] >= tags_per:
            open_new_file( file_dict, seqlen )
            file_dict[ seqlen ][0].write(structure)
            file_dict[ seqlen ][1] += 1 # increment the num_structs counter
            return

    else: # We have not reached this seqlen yet
       open_new_file( file_dict, seqlen )

    file_dict[ seqlen ][0].write(structure)
    file_dict[ seqlen ][1] += 1


open_silent_file = open( silent_file )

total_structures = len(silent_index['tags'])

seek_size = 1000
cur_struct = 0
file_dict = {}

while cur_struct < total_structures:

    if ( cur_struct + seek_size > total_structures ):
        seek_size = total_structures - cur_struct

    structures = silent_tools.get_silent_structures_true_slice( open_silent_file, 
                        silent_index, cur_struct, cur_struct+seek_size, True )

    for structure in structures:
        seqlen = get_sequence_length( structure )
        add2length_file( structure, seqlen, file_dict )

    cur_struct += seek_size

open_silent_file.close()

for f in file_dict:
    file_dict[f][0].close()




