#!/usr/bin/env python

import os
import sys
# if you pip installed this is wrong but you have silent_tools installed anyways
silent_tools_dir = os.path.dirname(os.path.realpath(__file__))
if silent_tools_dir not in sys.path:
    sys.path.append(silent_tools_dir)

import silent_tools
from silent_tools import eprint
import stat
import math
import shutil

# Don't throw an error when someone uses head
from signal import signal, SIGPIPE, SIG_DFL
signal(SIGPIPE, SIG_DFL) 


all_args = sys.argv[1:]
j_flag = None
for i in range(len(all_args)-1, -1, -1):
    elem = all_args[i]
    if elem[0] != '-':
        continue
    if elem[1] == 'j':
        if len(elem) == 2:
            if i == len(all_args) - 1:
                assert False, 'silentsketchextractspecific: Error, no argument for -j'
            else:
                j_flag = int(all_args[i+1])
                all_args.pop(i+1)
                all_args.pop(i)
        else:
            j_flag = int(elem[2:])
            all_args.pop(i)
        break
    assert False, 'silentsketchextractspecific: Unknown flag ' + elem


if (len(all_args) == 0):
    eprint("")
    eprint('silentsketchextractspecific by bcov - extract specific pdbs WITHOUT rosetta')
    eprint("Usage:")
    eprint("        cat list_of_tags.list | silentsketchextractspecific myfile.silent")
    eprint("                             or")
    eprint("        silentsketchextractspecific myfile.silent tag1 tag2 tag3 > new.silent")
    eprint("Flags:")
    eprint("        -j cpus")
    sys.exit(1)

silent_file = all_args[0]


tag_buffers = []
if ( stat.S_ISFIFO(os.stat("/dev/stdin").st_mode) ):
    tag_buffers += sys.stdin.readlines()
tag_buffers += all_args[1:]

tags = []
for line in tag_buffers:
    line = line.strip()
    if (len(line) == 0):
        continue
    sp = line.split(" ")
    for tag in sp:
        tags.append(tag)

if (len(tags) != len(set(tags))):
    eprint("silentsketchextractspecific: Warning: duplicate tags specified")



if j_flag is not None:
    this_file = os.path.realpath(__file__)
    tmp_dir = 'tmp_' + silent_tools.random_string(13)
    os.mkdir(tmp_dir)
    per_file = int(math.ceil( len(tags) / j_flag ))
    for i in range(j_flag):
        file = os.path.join(tmp_dir, 'tmp_%i.list'%i)
        with open(file, 'w') as f:
            f.write('\n'.join(tags[i*per_file:(i+1)*per_file]))
            f.write('\n')
    stdout, stderr, code = silent_tools.cmd2(f'ls {tmp_dir}/tmp* | parallel -j{j_flag} "cat {{}} | {this_file} {silent_file}"')
    try:
        shutil.rmtree(tmp_dir)
    except:
        pass
    if len(stdout.strip()) > 0:
        print(stdout.strip())
    if len(stderr.strip()) > 0:
        eprint(stderr.strip())
    sys.exit(code)




silent_index = silent_tools.get_silent_index( silent_file )

existing_tags = []
for tag in tags:
    if ( tag not in silent_index['index'] ):
        eprint("silentsketchextractspecific: Unable to find tag: %s"%tag)
        continue
    else:
        existing_tags.append(tag)


with open(silent_file, errors='ignore') as sf:
    for tag in tags:

        assert( tag in silent_index['index'] )

        structure = silent_tools.get_silent_structure_file_open( sf, silent_index, tag )

        fname = tag + '.pdb'
        pdb = silent_tools.structure_to_pdb(structure)

        eprint(fname)
        with open(fname, 'w') as f:
            f.write(''.join(pdb))


