#!python

import sys
import argparse
import json
import math

import numpy as np

from sklearn.cluster import DBSCAN
from distance import hamming

def process_arguments(args):

    parser = argparse.ArgumentParser(prog="{}".format(args[0]),
        description='Clusters a collection of mementos based on Simhash and slice')

    parser.add_argument('-i', '--input', dest='input_filename',
        required=True,
        help="A JSON file produced by the detect_off_topic command."
    )

    parser.add_argument('-s', '--slice-file', 
        dest='slice_filename', required=True,
        help="A file containing the URI-Ms and their slices,\n"
            "as produced by the slice_by_datetime command."
    )

    parser.add_argument('-o', '--output', dest='output_filename',
        required=True,
        help="The output file listing the URI-Ms of all mementos"
            " with their slices and clusters."
    )

    args = parser.parse_args()

    return args

def shdist(a, b, **oo):
    return hamming(a, b)

if __name__ == '__main__':

    args = process_arguments(sys.argv)

    with open(args.input_filename) as f:
        jsondata = json.load(f)

    considered_urims = []
    slice_numbers = {}
    slices = {}
    clusters = {}

    with open(args.slice_filename) as f:
        for line in f:
            line = line.strip()
            slice_number, urim = line.split('\t')
            slice_numbers[urim] = slice_number
            considered_urims.append(urim)
            slices.setdefault(slice_number, []).append(urim)

    simhashes = {}

    for urit in jsondata:

        for urim in jsondata[urit]:
            
            if urim in considered_urims:

                simhashes[urim] = jsondata[urit][urim]["raw memento simhash value"]

    for slice_number in slices:

        simhash_list = []
        
        for urim in slices[slice_number]:

            simhash_list.append(simhashes[urim])

        X = np.matrix(simhash_list)

        db = DBSCAN(eps=0.3, min_samples=2, metric=shdist).fit(X.T)

        for index, label in enumerate(db.labels_):
            urim = slices[slice_number][index]
            clusters[urim] = label

    with open(args.output_filename, 'w') as f:
        for urim in considered_urims:
            myslice = slice_numbers[urim]
            mycluster = clusters[urim]

            f.write('{}\t{}\t{}\n'.format(
                myslice, mycluster, urim
            ))

