#! /usr/bin/env python

"""\
Plot events from an HepMC ASCII file as Graphviz edge-node graphs.

TODO:
 * Use pypdt if available for particle labels
 * Colour parton lines
"""

import optparse, os
op = optparse.OptionParser()
op.add_option("-n", dest="EVTNUMS", default="1", help="event numbers to graph")
opts, args = op.parse_args()

assert len(args) == 1
INFILE = args[0]
BASENAME = os.path.splitext(os.path.basename(INFILE))[0]

## Get a list of event numbers (unique, via a set comprehension)
opts.EVTNUMS = list({int(nstr) for nstr in opts.EVTNUMS.split(",")})


import hepmcio
reader = hepmcio.HepMCReader(INFILE)

def nodelabel(x):
    t = "?"
    if type(x) is hepmcio.Particle:
        t = "P"
    elif type(x) is hepmcio.Vertex:
        t = "V"
    n = abs(x.barcode) if x.barcode is not None else 0
    return "{}{}".format(t, n)

def ist(p):
    return str(abs(p.pid)) == "6"

def isb(p):
    return str(abs(p.pid)).startswith("5")

def isc(p):
    return str(abs(p.pid)).startswith("4")

def iss(p):
    return str(abs(p.pid)).startswith("3")


evtnum = 0
while True:
    if not opts.EVTNUMS:
        break

    evtnum += 1
    print "Reading event", evtnum
    evt = reader.next()
    if not evt:
        print "No event {}, exiting...".format(evtnum)
        break

    if evtnum not in opts.EVTNUMS:
        continue
    opts.EVTNUMS.remove(evtnum)

    # evt.use_units(hepmc.GEV, hepmc.MM)
    # if reader.rdstate() != 0:
    #     break

    import pydot
    g = pydot.Dot()

    bbarcodes = [b for b, p in evt.particles.items() if isb(p)]

    ## Add all "real" nodes
    for iv, v in evt.vertices.items():
        nodeId = nodelabel(v)
        if g.get_node(nodeId) is not None:
            n = pydot.Node(nodeId)
            n.set_color("dimgray")
            n.set_fontcolor("gray")
            n.set_style("filled")
            # n.set_width("0.01")
            # n.set_shape("circle")
            n.set_shape("point")
            n.set_label(nodeId)
            g.add_node(n)

    ## Groups for in/out nodes
    NUM_V_IN = 0
    NUM_V_OUT = 0
    V_IN_GROUP = pydot.Subgraph("IN")
    # V_IN_GROUP.set_color("blue")
    V_OUT_GROUP = pydot.Subgraph("OUT")
    # V_OUT_GROUP.set_color("red")
    # V_OUT_GROUP.set_shape("point")
    g.add_subgraph(V_IN_GROUP)
    g.add_subgraph(V_OUT_GROUP)

    for ip, p in evt.particles.items():

        ## Production vertices
        vstart = p.vtx_start()
        if vstart:
            startNodeId = nodelabel(vstart)
            startNode = g.get_node(startNodeId)
        else:
            startNodeId = "IN_" + str(NUM_V_IN)
            startNode = pydot.Node(startNodeId)
            startNode.set_color("blue")
            startNode.set_fontcolor("blue")
            startNode.set_style("filled")
            #startNode.set_shape("point")
            startNode.set_width("0.05")
            #g.add_node(startNode)
            NUM_V_IN += 1
            V_IN_GROUP.add_node(startNode)

        ## End vertices
        vend = p.vtx_end()
        if vend:
            endNodeId = nodelabel(vend)
            endNode = g.get_node(endNodeId)
        else:
            endNodeId = "OUT_" + str(NUM_V_OUT)
            endNode = pydot.Node(endNodeId)
            endNode.set_color("red")
            endNode.set_fontcolor("red")
            endNode.set_style("filled")
            endNode.set_shape("point")
            #g.add_node(endNode)
            NUM_V_OUT += 1
            V_OUT_GROUP.add_node(endNode)


        ## Particles
        e = pydot.Edge(startNodeId, endNodeId)
        # e.set_label("%d,%d" % (p.pid, p.status))
        e.set_label(str(p.pid))
        e.set_penwidth("2")
        #
        if abs(p.pid) in [12, 14, 16,1000022]:
            e.set_style("dashed")
        #
        if p.status == 1:
            if p.pid == 22:
                e.set_color("lightblue")
                e.set_fontcolor("lightblue")
            elif abs(p.pid) in [11,12]:
                e.set_color("gold")
                e.set_fontcolor("gold")
            elif abs(p.pid) in [13,14]:
                e.set_color("orange")
                e.set_fontcolor("orange")
            elif abs(p.pid) in [15,16]:
                e.set_color("olivedrab")
                e.set_fontcolor("olivedrab")
            else:
                e.set_color("black")
        elif p.status == 2:
            if abs(p.pid) in [11,12]:
                e.set_color("goldenrod")
                e.set_fontcolor("goldenrod")
            elif abs(p.pid) in [13,14]:
                e.set_color("orange3")
                e.set_fontcolor("orange3")
            elif abs(p.pid) in [15,16]:
                e.set_color("darkolivegreen")
                e.set_fontcolor("darkolivegreen")
            elif isb(p):
                e.set_color("purple")
                e.set_fontcolor("purple")
            elif isc(p):
                e.set_color("slateblue")
                e.set_fontcolor("slateblue")
            elif iss(p):
                e.set_color("steelblue")
                e.set_fontcolor("steelblue")
            else:
                e.set_color("gray")
                e.set_fontcolor("dimgray")
        elif p.status == 3:
            e.set_color("red")
            e.set_fontcolor("red")
            e.set_style("dotted")
        elif p.status == 4:
            e.set_color("blue")
            e.set_fontcolor("blue")
        else:
            #if abs(p.pid) in [1,2,3,4,5,6,21]:
            e.set_style("dotted")
            #
            if ist(p):
                e.set_color("purple4")
                e.set_fontcolor("purple4")
            elif isb(p):
                e.set_color("purple1")
                e.set_fontcolor("purple1")
            elif isc(p):
                e.set_color("slateblue1")
                e.set_fontcolor("slateblue1")
            elif iss(p):
                e.set_color("steelblue1")
                e.set_fontcolor("steelblue1")
            elif p.pid == 21:
                e.set_color("seagreen")
                e.set_fontcolor("seagreen")
            else:
                e.set_color("gray")
                e.set_fontcolor("dimgray")
        g.add_edge(e)


    # V_IN_GROUP = pydot.Subgraph("IN")
    # g.add_subgraph(V_IN_GROUP)

    # V_OUT_GROUP = pydot.Subgraph("OUT")
    # g.add_subgraph(V_OUT_GROUP)

    # n1 = pydot.Node("n1")
    # n1.set_color("blue")
    # n1.set_style("filled")
    # V_IN_GROUP.add_node(n1)

    # n2 = pydot.Node("n2")
    # n2.set_color("grey")
    # n2.set_fontcolor("white")
    # n2.set_style("filled")
    # V_IN_GROUP.add_node(n2)

    # n3 = pydot.Node("n3")
    # n1.set_color("purple")
    # n3.set_style("filled")
    # V_OUT_GROUP.add_node(n3)

    # e1 = pydot.Edge(n1, n2)
    # e1.set_label("foo")
    # e1.set_color("red")
    # g.add_edge(e1)

    # e2 = pydot.Edge(n2, n3)
    # e2.set_label("bar")
    # e2.set_color("green")
    # g.add_edge(e2)

    OUTNAME = "%s-%04d" % (BASENAME, evtnum)
    print "Writing", OUTNAME+".dot", "and", OUTNAME+".pdf"
    g.write(OUTNAME+".dot", prog="dot")
    g.write(OUTNAME+".pdf", format="pdf", prog="dot")
