#!/usr/bin/env python

import os
import sys
# if you pip installed this is wrong but you have silent_tools installed anyways
silent_tools_dir = os.path.dirname(os.path.realpath(__file__))
if silent_tools_dir not in sys.path:
    sys.path.append(silent_tools_dir)

import silent_tools
from silent_tools import eprint
import stat
import math
import shutil

# Don't throw an error when someone uses head
from signal import signal, SIGPIPE, SIG_DFL
signal(SIGPIPE, SIG_DFL) 

if (len(sys.argv) == 1):
    eprint("")
    eprint('silentsketchextract by bcov - extract pdbs WITHOUT rosetta')
    eprint("Usage:")
    eprint("        silentsketchextract myfile.silent")
    eprint("Flags:")
    eprint("        -j cpus")
    sys.exit(1)

all_args = sys.argv[1:]
j_flag = None
for i in range(len(all_args)-1, -1, -1):
    elem = all_args[i]
    if elem[0] != '-':
        continue
    if elem[1] == 'j':
        if len(elem) == 2:
            if i == len(all_args) - 1:
                assert False, 'silentsketchextract: Error, no argument for -j'
            else:
                j_flag = int(all_args[i+1])
                all_args.pop(i+1)
                all_args.pop(i)
        else:
            j_flag = int(elem[2:])
            all_args.pop(i)
        break
    assert False, 'silentsketchextract: Unknown flag ' + elem

if len(all_args) == 0:
    assert False, "silentsketchextract: You passed -j but no silent file??"


silent_file = all_args[0]


if j_flag is not None:
    tags = silent_tools.get_silent_index( silent_file )['tags']

    this_file = os.path.realpath(__file__)
    tmp_dir = 'tmp_' + silent_tools.random_string(13)
    os.mkdir(tmp_dir)
    per_file = int(math.ceil( len(tags) / j_flag ))
    for i in range(j_flag):
        file = os.path.join(tmp_dir, 'tmp_%i.list'%i)
        with open(file, 'w') as f:
            f.write('\n'.join(tags[i*per_file:(i+1)*per_file]))
            f.write('\n')
    stdout, stderr, code = silent_tools.cmd2(f'ls {tmp_dir}/tmp* | parallel -j{j_flag} "cat {{}} | {this_file} {silent_file}"')
    try:
        shutil.rmtree(tmp_dir)
    except:
        pass
    if len(stdout.strip()) > 0:
        print(stdout.strip())
    if len(stderr.strip()) > 0:
        eprint(stderr.strip())
    sys.exit(code)


silent_index = silent_tools.get_silent_index( silent_file )


with open(silent_file, errors='ignore') as sf:
    for tag in silent_index['index']:

        structure = silent_tools.get_silent_structure_file_open( sf, silent_index, tag )

        fname = tag + '.pdb'
        pdb = silent_tools.structure_to_pdb(structure)

        eprint(fname)
        with open(fname, 'w') as f:
            f.write(''.join(pdb))


