#!python
import sys
from argparse import ArgumentParser

from prody import parsePDB
import numpy as np
from scipy.spatial.distance import cdist

parser = ArgumentParser()
parser.add_argument("--rec", default=None,
                    help="Receptor file if using interface mode.")
parser.add_argument("--only-CA", action="store_const", const=True,
                    default=False,
                    help="Only use C alpha atoms.")
parser.add_argument("--only-backbone", action="store_true", default=False)
parser.add_argument("--interface-only", action='store_const', const=True,
                    default=False,
                    help="Only use inteface atoms. Requires --rec.")
parser.add_argument("--interface_radius", type=float,
                    default=10.0,
                    help="Radius around receptor to consider.")
parser.add_argument("pdb_files", nargs=2, metavar="pdb_file",
                    help="PDB files to calculate RMSD for.")
args = parser.parse_args()

if args.interface_only and args.rec is None:
    print("--only-inteface requires --rec")
    sys.exit(1)

pdb1, pdb2 = (parsePDB(f) for f in args.pdb_files)

if pdb1 is None or pdb2 is None:
    print("Error parsing pdb files")
    sys.exit(1)

if args.only_CA:
    pdb1, pdb2 = pdb1.calpha, pdb2.calpha
elif args.only_backbone:
    pdb1, pdb2 = pdb1.backbone, pdb2.backbone

coords1, coords2 = pdb1.getCoords(), pdb2.getCoords()

if args.interface_only:
    rec = parsePDB(args.rec)
    rec_coords = rec.getCoords()
    sq_radius = args.interface_radius*args.interface_radius

    dists = cdist(rec_coords, coords1, 'sqeuclidean')
    indices1 = np.any(dists < (sq_radius), axis=0).nonzero()[0]
    dists = cdist(rec_coords, coords2, 'sqeuclidean')
    indices2 = np.any(dists < (sq_radius), axis=0).nonzero()[0]

    interface = np.union1d(indices1, indices2)
    coords1, coords2 = coords1[interface], coords2[interface]


if len(coords1) != len(coords2):
    print("Unequal number of atoms selected for RMSD calculation.")
    sys.exit(1)

N = len(coords1)
delta = coords1 - coords2
np.multiply(delta, delta, delta)
rmsd = np.sqrt(np.sum(delta)/N)
print("{:.4f}".format(rmsd))
