#!/usr/bin/env python
'''
Collect trajectory data from a tios simulation.
'''
from __future__ import print_function
import sys
import os
import datetime
import curses
import re
import argparse
from tios import utilities, receiver
from tios._version import __version__
from mdio import xtcio, dcdio, ncio
from mdio.base import Frame
def set_erase_to_end_of_line():
    # see stackoverflow.com/questions/5419389
    result = ""
    if sys.stdout.isatty():
        curses.setupterm(fd=sys.stdout.fileno())
        result += re.sub(r'\$<\d+>[/*]?', '', curses.tigetstr('el').decode() or '')
    return result
    
el = set_erase_to_end_of_line()

parser  = argparse.ArgumentParser(description='Connect to a *tios* stream and download data')
parser.add_argument("id", help='tios id of the job to collect data from')
parser.add_argument('-o', '--trajfile', help='name of trajectory file')
parser.add_argument('-l', '--logfile', help='name of log file')
parser.add_argument('-p', '--pdb', 
                    help='write a .pdb file to complement the trajectory')
parser.add_argument('-v', '--verbose', action='store_true',
                    help='display info as collection progresses')
parser.add_argument('-t', '--time_period', default=0.0, 
                    help='time period over which to collect data (in ps or ns)')
parser.add_argument('-i', '--interval', default=0.0,
                    help='time interval between snapshots (in ps or ns)')
parser.add_argument('-n', '--firstn', default=None,
                    help='include only these first atoms in the trajectory file')
parser.add_argument('-V', '--version', action='version', version=__version__)
args = parser.parse_args()

tr = receiver.TiosReceiver(args.id, firstn=args.firstn)

if isinstance(args.time_period, str):
    if 'ps' in args.time_period:
        time_period = float(args.time_period[:-2])
    elif 'ns' in args.time_period:
        time_period = float(args.time_period[:-2]) * 1000
    else:
        time_period = float(args.time_period)
else:
    time_period = float(args.time_period)

if isinstance(args.interval, str):
    if 'ps' in args.interval:
        interval = float(args.interval[:-2])
    elif 'ns' in args.interval:
        interval = float(args.interval[:-2]) * 1000
    else:
        interval = float(args.interval)
else:
    interval = float(args.interval)

sample_interval = float(tr._te.trate)
if interval == 0.0:
    interval = sample_interval
if int(interval) % int(sample_interval) != 0:
    print('Error: your sampling interval must be equal to,', end='')
    print(' or a multiple of, {}ps'.format(sample_interval))

nskip = int(interval / sample_interval)
if args.pdb is not None:
    with open(args.pdb, 'w') as f:
        f.write(tr.pdb)

if args.trajfile is None:
    exit(0)

ext = os.path.splitext(args.trajfile)[1]
if not ext in ['.xtc', '.dcd']:
    print('Error = only .xtc and .dcd formats are supported')
    exit(1)
if ext == '.xtc':
    writer = xtcio.xtc_open
else:
    writer = dcdio.dcd_open
f = writer(args.trajfile, 'w')
if ext == '.xtc':
    frame = Frame(tr.xyz, box=tr.box, time=tr.timepoint)
else:
    # Dcd writer: does not convert nm to angstroms
    tr.xyz = tr.xyz * 10.0
    if tr.box is not None:
        tr.box = tr.box * 10.0
    frame = Frame(tr.xyz, box=tr.box, time=tr.timepoint)
f.write_frame(frame)
status = tr.status
if status == 'Running':
    delta_t = datetime.datetime.utcnow() - tr._te.last_update
    if delta_t.total_seconds() > 10 * 60 / tr._te.frame_rate:
        status = '*Stalled*'
timepoint = tr.timepoint
n_frames = 1
unlimited = time_period == 0
final_timepoint = timepoint + time_period
out = 'Frames so far: {} Time point (ps): {:.2f} status: {}'
if args.logfile is not None:
    with open(args.logfile, 'w') as l:
        l.write(out.format(n_frames, timepoint, status) + '\n')
if args.verbose:
    sys.stdout.write(out.format(n_frames, timepoint, status))
    sys.stdout.flush()
killer = utilities.GracefulKiller()
while not killer.kill_now:
    if status != 'Running':
        tr.step(killer=killer)
    else:
        tr.step(wait=False, killer=killer)
    if killer.kill_now:
        break
    timepoint = tr.timepoint    
    status = tr.status
    if status == 'Running':
        delta_t = datetime.datetime.utcnow() - tr._te.last_update
        if delta_t.total_seconds() > 10 * 60 / tr._te.frame_rate:
            status = '*Stalled*'
    n_frames += 1

    if not unlimited and timepoint > final_timepoint:
        break
    if n_frames % nskip == 0:
        if ext == '.xtc':
            frame = Frame(tr.xyz, box=tr.box, time=tr.timepoint)
        else:
            # Dcd writer: does not convert nm to angstroms
            tr.xyz = tr.xyz * 10.0
            if tr.box is not None:
                tr.box = tr.box * 10.0
            frame = Frame(tr.xyz, box=tr.box, time=tr.timepoint)
        f.write_frame(frame)
        if args.logfile is not None:
            with open(args.logfile, 'w') as l:
                l.write(out.format(n_frames / nskip, timepoint, status) + '\n')
        if args.verbose:
            sys.stdout.write('\r' + el + out.format(n_frames / nskip,
                                                    timepoint, status))
            sys.stdout.flush()
f.close()
