#!/usr/bin/env python

libnames = [('mdtraj', 'md'), ('numpy', 'np'), ('keras', 'krs'), ('argparse', 'arg'), ('datetime', 'dt'), ('sys', 'sys')]

for (name, short) in libnames:
  try:
    lib = __import__(name)
  except ImportError:
    print("Library %s is cannot be loaded, exiting" % name)
    exit(0)
  else:
    globals()[short] = lib

import anncolvar

# Parsing command line arguments
parser = arg.ArgumentParser(description='Artificial neural network learning of collective variables of molecular systems, requires numpy, keras and mdtraj')

parser.add_argument('-i', dest='infile', default='traj.xtc',
help='Input trajectory in pdb, xtc, trr, dcd, netcdf or mdcrd, WARNING: the trajectory must be 1. centered in the PBC box, 2. fitted to a reference structure and 3. must contain only atoms to be analysed!')

parser.add_argument('-p', dest='intop', default='top.pdb',
help='Input topology in pdb, WARNING: the structure must be 1. centered in the PBC box and 2. must contain only atoms to be analysed!')

parser.add_argument('-c', dest='colvar', default='colvar.txt',
help='Input collective variable file in text formate, must contain the same number of lines as frames in the trajectory')

parser.add_argument('-col', dest='col', default=2, type=int,
help='The index of the column containing collective variables in the input collective variable file')

parser.add_argument('-boxx', dest='boxx', default=0.0, type=float,
help='Size of x coordinate of PBC box (from 0 to set value in nm)')

parser.add_argument('-boxy', dest='boxy', default=0.0, type=float,
help='Size of y coordinate of PBC box (from 0 to set value in nm)')

parser.add_argument('-boxz', dest='boxz', default=0.0, type=float,
help='Size of z coordinate of PBC box (from 0 to set value in nm)')

parser.add_argument('-nofit', dest='nofit', default='False',
help='Disable fitting, the trajectory must be properly fited (default False)')

parser.add_argument('-testset', dest='testset', default=0.10, type=float,
help='Size of test set (fraction of the trajectory, default = 0.1)')

parser.add_argument('-shuffle', dest='shuffle', default='True',
help='Shuffle trajectory frames to obtain training and test set (default True)')

parser.add_argument('-layers', dest='layers', default=1, type=int,
help='Number of hidden layers (allowed values 1-3, default = 1)')

parser.add_argument('-layer1', dest='layer1', default=256, type=int,
help='Number of neurons in the first encoding layer (default = 256)')

parser.add_argument('-layer2', dest='layer2', default=256, type=int,
help='Number of neurons in the second encoding layer (default = 256)')

parser.add_argument('-layer3', dest='layer3', default=256, type=int,
help='Number of neurons in the third encoding layer (default = 256)')

parser.add_argument('-actfun1', dest='actfun1', default='sigmoid',
help='Activation function of the first layer (default = sigmoid, for options see keras documentation)')

parser.add_argument('-actfun2', dest='actfun2', default='linear',
help='Activation function of the second layer (default = linear, for options see keras documentation)')

parser.add_argument('-actfun3', dest='actfun3', default='linear',
help='Activation function of the third layer (default = linear, for options see keras documentation)')

parser.add_argument('-optim', dest='optim', default='adam',
help='Optimizer (default = adam, for options see keras documentation)')

parser.add_argument('-loss', dest='loss', default='mean_squared_error',
help='Loss function (default = mean_squared_error, for options see keras documentation)')

parser.add_argument('-epochs', dest='epochs', default=100, type=int,
help='Number of epochs (default = 100, >1000 may be necessary for real life applications)')

parser.add_argument('-batch', dest='batch', default=256, type=int,
help='Batch size (0 = no batches, default = 256)')

parser.add_argument('-o', dest='ofile', default='',
help='Output file with original and approximated collective variables (txt, default = no output)')

parser.add_argument('-model', dest='modelfile', default='',
help='Prefix for output model files (experimental, default = no output)')

parser.add_argument('-plumed', dest='plumedfile', default='plumed.dat',
help='Output file for Plumed (default = plumed.dat)')

parser.add_argument('-plumed2', dest='plumedfile2', default='plumed2.dat',
help='Output file for Plumed with ANN module (default = plumed2.dat)')

args = parser.parse_args()

infilename = args.infile
intopname = args.intop
colvarname = args.colvar
column = args.col
boxx = args.boxx
boxy = args.boxy
boxz = args.boxz
if args.testset < 0.0 or args.testset > 0.5:
  print("ERROR: -testset must be 0.0 - 0.5")
  exit(0)
atestset = float(args.testset)

# Shuffling the trajectory before splitting
if args.shuffle == "True":
  shuffle = 1
elif args.shuffle == "False":
  shuffle = 0
else:
  print("ERROR: -shuffle %s not understood" % args.shuffle)
  exit(0)
if args.nofit == "True":
  nofit = 1
elif args.nofit == "False":
  nofit = 0
else:
  print("ERROR: -nofit %s not understood" % args.nofit)
  exit(0)
if args.layers < 1 or args.layers > 3:
  print("ERROR: -layers must be 1-3, for deeper learning contact authors")
  exit(0)
if args.layer1 > 1024:
  print("WARNING: You plan to use %i neurons in the first layer, could be slow")
if args.layers == 2:
  if args.layer2 > 1024:
    print("WARNING: You plan to use %i neurons in the second layer, could be slow")
if args.layers == 3:
  if args.layer3 > 1024:
    print("WARNING: You plan to use %i neurons in the third layer, could be slow")
if args.actfun1 not in ['softmax','elu','selu','softplus','softsign','relu','tanh','sigmoid','hard_sigmoid','linear']:
  print("ERROR: cannot understand -actfun1 %s" % args.actfun1)
  exit(0)
if args.layers == 2:
  if args.actfun2 not in ['softmax','elu','selu','softplus','softsign','relu','tanh','sigmoid','hard_sigmoid','linear']:
    print("ERROR: cannot understand -actfun2 %s" % args.actfun1)
    exit(0)
if args.layers == 3:
  if args.actfun3 not in ['softmax','elu','selu','softplus','softsign','relu','tanh','sigmoid','hard_sigmoid','linear']:
    print("ERROR: cannot understand -actfun3 %s" % args.actfun3)
    exit(0)
if args.layers == 1 and args.actfun2!='linear':
  print("ERROR: actfun2 must be linear for -layers 1")
  exit(0)
if args.layers == 2 and args.actfun3!='linear':
  print("ERROR: actfun3 must be linear for -layers 2")
  exit(0)
layers = args.layers
layer1 = args.layer1
layer2 = args.layer2
layer3 = args.layer3
actfun1 = args.actfun1
actfun2 = args.actfun2
actfun3 = args.actfun3
epochs = args.epochs
optim = args.optim
batch = args.batch
loss = args.loss
if args.ofile[-4:] == '.txt':
  ofilename = args.ofile
elif len(args.ofile)>0:
  ofilename = args.ofile + '.txt'
else:
  ofilename = ''
modelfile = args.modelfile
plumedfile = args.plumedfile
plumedfile2 = args.plumedfile2
if plumedfile[-4:] != '.dat':
  plumedfile = plumedfile + '.dat'
if plumedfile2[-4:] != '.dat':
  plumedfile2 = plumedfile2 + '.dat'
if plumedfile2 !="":
  if layers == 1:
    if layers1 !="tanh":
      print("ERROR: only tanh and linear functions are currently supported in ANN module of Plumed")
      exit(0)
  elif layers == 2:
    if layers1 !="tanh" or layers2 !="tanh":
      print("ERROR: only tanh and linear functions are currently supported in ANN module of Plumed")
      exit(0)
  elif layers == 3:
    if layers1 !="tanh" or layers2 !="tanh" or layers3 !="tanh":
      print("ERROR: only tanh and linear functions are currently supported in ANN module of Plumed")
      exit(0)
anncolvar.anncollectivevariable(infilename, intopname, colvarname, column,
                                boxx, boxy, boxz, atestset, shuffle, nofit,
                                layers, layer1, layer2, layer3,
                                actfun1, actfun2, actfun3,
                                optim, loss, epochs, batch,
                                ofilename, modelfile, plumedfile, plumedfile2)


