#!/usr/bin/env python3

import json
import csv
import os.path

from hivclustering import *
from hivclustering.networkbuild import *


def make_hiv_network():

    network = build_a_network()

    if settings().json:


        network_info = describe_network(network, True, settings().singletons)

        if settings().contaminant_file:
            network_info['Settings'] = {'threshold': settings().threshold,
                                        'edge-filtering': settings().edge_filtering,
                                        'contaminants': settings().contaminants,
                                        'contaminant-ids': list(settings().contaminant_file)
                                        }
        else:
            network_info['Settings'] = {'threshold': settings().threshold,
                                        'edge-filtering': settings().edge_filtering,
                                        'contaminants': settings().contaminants
                                        }

        if settings().singletons:
            network_info['Settings']['singletons'] = True


        nodes = []

        if settings().prior is not None:
            try:
                existing_nodes      = {} # // node ID -> cluster ID
                prior_network_nodes = json.load (settings().prior)["Nodes"]

                max_id = 0
                cluster_sizes = {}

                for n in prior_network_nodes:
                    if n["cluster"] is not None:
                        existing_nodes[n["id"]] = n["cluster"]
                        max_id = max (max_id, n["cluster"])
                        if n["cluster"] in cluster_sizes:
                            cluster_sizes[n["cluster"]] += 1
                        else:
                            cluster_sizes[n["cluster"]] = 1

                cluster_map = {} # // current cluster ID -> set of previous cluster ID

                for n in network.nodes:
                    if not n.cluster_id in cluster_map:
                        cluster_map[n.cluster_id] = set ()
                        
                    if n.id in existing_nodes:
                        cluster_map[n.cluster_id].add (existing_nodes[n.id])
                    else:
                        cluster_map[n.cluster_id].add (None) # tag for a node not in previous clusters

                mapped_to_existing_clusters = set ()
                used_existing_clusters = set ()

                cluster_remap = {}

                for c_id, mapped_id in cluster_map.items():
                    if c_id is None: # singletons
                        if len (mapped_id) > 1 or None not in mapped_id:
                            raise Exception("Incostistent singletons: nodes previously clustered became unclustered")
                    else:
                        if len (mapped_id) == 1:
                            if None in mapped_id: # new cluster
                                pass
                                print ("Cluster %d is a new cluster" % c_id,  file = sys.stderr)
                            else:
                                cluster_remap[c_id] = list(mapped_id)[0]
                                print ("Cluster %d matches a previous cluster %d" % (c_id, cluster_remap[c_id]),  file = sys.stderr)
                                if cluster_remap[c_id] in used_existing_clusters:
                                    raise Exception("Cluster %d from the existing network is mapped to multiple new clusters" % c_id)
                                else:
                                    used_existing_clusters.add (cluster_remap[c_id])
                                
                                     
                                mapped_to_existing_clusters.add (c_id)
                        else:
                            if len (mapped_id) == 2 and None in mapped_id:
                                mapped_id.remove (None)
                                cluster_remap[c_id] = list(mapped_id)[0]
                                print ("Cluster %d extends a previous cluster %d" % (c_id, list(mapped_id)[0]),  file = sys.stderr)
                                mapped_to_existing_clusters.add (c_id)
                                if cluster_remap[c_id] in used_existing_clusters:
                                    raise Exception("Cluster %d from the existing network is mapped to multiple new clusters" % c_id)
                                else:
                                    used_existing_clusters.add (cluster_remap[c_id])
                            else:
                                print ("Cluster %d MERGES %s" % (c_id, ", ".join ([str(k) for k in list(mapped_id)])), file = sys.stderr)
                                if None in mapped_id:
                                    mapped_id.remove (None)
                                if not mapped_id.isdisjoint (used_existing_clusters):
                                    raise Exception("Cluster %d from the existing network is mapped to multiple new clusters" % c_id)
                                else:
                                    used_existing_clusters.update (mapped_id)
                                mapped_to_existing_clusters.add (c_id)
                                mapped_id = sorted ([(k, cluster_sizes[k]) for k in mapped_id], key = lambda x: x[1], reverse = True)
                                mapped_id = sorted ([k for k in mapped_id if k[1] == mapped_id[0][1]],  key = lambda x: x[0])
                                cluster_remap[c_id] = mapped_id[0][0]


                
                #validate = set (['WA00S000662580-1','WA00S000945841-5','WA00S000696923-1','WA00S000572828-4'])
                for n in network.nodes:
                    #if n.cluster_id == 145:
                    #    print ("Node %s in cluster %d " % (n.id, n.cluster_id), file = sys.stderr)
                    if n.cluster_id in cluster_remap:
                        n.cluster_id  = cluster_remap [n.cluster_id]
                    else:
                        if n.cluster_id is not None:
                            n.cluster_id  = -n.cluster_id
                    #if n.cluster_id in [145]:
                    #    print (" => cluster %d " % (n.cluster_id), file = sys.stderr)
                
                #print (mapped_to_existing_clusters, "\n\n", cluster_remap, file = sys.stderr)
                cluster_info = network.sort_clusters(filter = lambda cluster_id ,cluster_data : cluster_id < 0, start_id = max_id + 1)



            except Exception as e:
                print ("Error with prior network processing: %s" % str (e), file = sys.stderr)
                sys.exit (1)

        else:
            cluster_info = network.sort_clusters()

        node_idx = {}
        
        for idx, cluster in cluster_info.items():
            #print (idx, cluster)
            for n in cluster:
                if idx is not None:
                    nodes.append({'id': n.id, 'cluster': idx, 'attributes': list(n.attributes),
                                  'edi': n.get_edi(), 'baseline': n.get_baseline_date(True)})
                node_idx[n] = len(nodes) - 1
                

        edges = []
        for e in network.reduce_edge_set():
            if e.visible:
                edge_source = e.compute_direction()
                if edge_source is not None:
                    src = node_idx[edge_source]
                    rcp = node_idx[e.p2 if edge_source != e.p2 else e.p1]
                    directed = True
                else:
                    src = node_idx[e.p1]
                    rcp = node_idx[e.p2]
                    directed = False
                edges.append({'source': src, 'target': rcp, 'directed': directed, 'length': network.distances[
                             e], 'support': e.edge_reject_p, 'removed': not e.has_support(), 'sequences': e.sequences, 'attributes': list(e.attribute)})


        network_info['Nodes'] = nodes 
        # do NOT sort the nodes, otherwise referencing them by index from edges will be broken.
        network_info['Edges'] = sorted(edges, key=lambda edge: ''.join([sequence for sequence in edge['sequences']])) 
        # OK to sort edges, because their order is not used downstream

        print(json.dumps(network_info, indent=4, sort_keys=True))

    else:
        describe_network(network)

    if settings().dot:
        network.generate_dot(settings().dot)

    if settings().cluster:
        network.write_clusters(settings().cluster)

    if settings().centralities:
        network.write_centralities(settings().centralities)
        
        

    if settings().bridges:
        print ("Finding all non-trivial (not ending in a terminal node) bridges in the network", file = sys.stderr)
        network.find_all_bridges (attr = "bridge") ## this will labels the nodes with "bridge" attribute

        bin_by_node = {}

        for e in network.reduce_edge_set():
            if e.has_attribute ("bridge") and e.p1.degree != 1 and e.p2.degree != 1 :
                print ("%s [degree %d] -- %s [degree %d]" % (e.p1.id, e.p1.degree, e.p2.id, e.p2.degree), file = settings().output)
                for n in [e.p1, e.p2]:
                    if n not in bin_by_node:
                        bin_by_node[n] = 1
                    else:
                        bin_by_node[n] += 1

        if len (bin_by_node):
            for n, d in bin_by_node.items():
                print ("Node %s (degree %d) is involved in %d bridges" % (n.id, n.degree, d), file = settings().output)
        else:
            print ("No non-trivial bridges found", file = sys.stderr)

    return network

if __name__ == '__main__':
    make_hiv_network()

