#!/usr/bin/env python

import os
import sys
# if you pip installed this is wrong but you have silent_tools installed anyways
silent_tools_dir = os.path.dirname(os.path.realpath(__file__))
if silent_tools_dir not in sys.path:
    sys.path.append(silent_tools_dir)

import silent_tools
from silent_tools import eprint
import math
import re

# Don't throw an error when someone uses head
from signal import signal, SIGPIPE, SIG_DFL
signal(SIGPIPE, SIG_DFL) 

if (len(sys.argv) != 3):
    eprint("")
    eprint("silentsplittargetlength by bcov - a tool to split a silentfile into parts with the same sequence length. This is crucial for getting good performance when running AF2.")
    eprint("                                    This makes an improvement over silentsplitbylength in checking the target length too (which af2 cares about)")
    eprint("Usage:")
    eprint("        silentsplittargetlength myfile.silent TAGS_PER_SILENT_FILE")
    sys.exit(1)

silent_file = sys.argv[1]

if ( not os.path.exists(silent_file) ):
    eprint("silentsplittargetlength: File not found " + silent_file)
    assert(False)

tags_per = sys.argv[2]
try:
    tags_per = int(tags_per)
    assert(tags_per > 0)
except:
    eprint("silentsplittargetlength: Second argument must be a number greater than 0")
    assert(False)

alpha = "abcdefghijklmnopqrstuvwxyz"
silent_index = silent_tools.get_silent_index( silent_file )

def get_split_str(width, number):
    out_str = ""
    for i in range(width):
        out_str = alpha[number % 26] + out_str
        number //= 26
    return "x" + "z"*(width-2) + out_str

def get_split_str2(number):
    for width in range(2, 10000):
        this_num = 26**(width-1) * 25 
        if ( number < this_num ):
            return get_split_str(width, number)
        number -= this_num

def get_length_str( structure ):
    return '_'.join('%i'%len(x) for x in silent_tools.get_sequence_chunks(structure))

def newfile( seqlen, idx ):
    f = open( f"len{seqlen}{get_split_str2(idx)}.silent", "w" )
    f.write( silent_tools.silent_header(silent_index) )
    return f

def open_new_file( file_dict, seqlen ):
    if seqlen in file_dict:
        # close the previous file and open a new one
        currfile, _, idx = file_dict[seqlen]
        currfile.close()
        idx += 1
    else:
        idx = 0

    currfile = newfile(seqlen,idx)
    file_dict[seqlen] = [ currfile, 0, idx ]

def add2length_file( structure, seqlen, file_dict ):
    if seqlen in file_dict:
        # assert we have not reached limit for this file
        if file_dict[ seqlen ][1] >= tags_per:
            open_new_file( file_dict, seqlen )
            file_dict[ seqlen ][0].write(structure)
            file_dict[ seqlen ][1] += 1 # increment the num_structs counter
            return

    else: # We have not reached this seqlen yet
       open_new_file( file_dict, seqlen )

    file_dict[ seqlen ][0].write(structure)
    file_dict[ seqlen ][1] += 1


open_silent_file = open( silent_file )

total_structures = len(silent_index['tags'])

seek_size = 1000
cur_struct = 0
file_dict = {}

while cur_struct < total_structures:

    if ( cur_struct + seek_size > total_structures ):
        seek_size = total_structures - cur_struct

    structures = silent_tools.get_silent_structures_true_slice( open_silent_file, 
                        silent_index, cur_struct, cur_struct+seek_size, True )

    for structure in structures:
        lenstr = get_length_str( structure.split("\n") )
        add2length_file( structure, lenstr, file_dict )

    cur_struct += seek_size

open_silent_file.close()

for f in file_dict:
    file_dict[f][0].close()




