GitLab Repo

amachine.am_vis

  1from pathlib import Path
  2import graphviz
  3from matplotlib import colormaps
  4from matplotlib.colors import to_hex
  5import matplotlib.pyplot as plt
  6from IPython.display import Image, display
  7import matplotlib.image as mpimg
  8from io import BytesIO
  9
 10def create_digraph(engine="dot"):
 11
 12    engine_configs = {
 13        "dot": {
 14            'rankdir': 'LR',
 15            'ranksep': '1.0',
 16            'nodesep': '0.4',
 17            'splines': 'spline',
 18			'constraint' : 'true',
 19            'concentrate': 'false'
 20        },
 21		"neato": {
 22			'overlap': 'scale',
 23			'overlap_scaling': '-4',  
 24			'esep': '+2.5',           
 25			'sep': '+1.75',           
 26			'model': 'shortpath',
 27			'damping': '0.85',
 28			'epsilon': '0.00001',
 29			'maxiter': '1000000',
 30			'start': '5',
 31		},
 32        "fdp": {
 33            'overlap': 'prism',
 34            'sep': '+1.5',
 35            'K': '1.0',
 36            'splines': 'true',
 37			'len' : '3.0',
 38            'maxiter': '5000'
 39        }
 40    }
 41
 42    graph_attr = engine_configs.get(engine, {})
 43
 44    node_attr = {
 45        'shape': 'box',
 46        'style': 'rounded, filled',
 47        'fillcolor': 'lightblue',
 48        'fontname': 'Helvetica'
 49    }
 50    
 51    edge_attr = {
 52        'penwidth': '1.2',
 53        'color': 'gray40'
 54    }
 55
 56    return graphviz.Digraph(
 57        engine=engine,
 58        graph_attr=graph_attr,
 59        node_attr=node_attr,
 60        edge_attr=edge_attr
 61    )
 62
 63
 64def draw_graph( 
 65	aM, 
 66	output_dir : Path | None = None,
 67	title="am_graph", 
 68	view=True, 
 69	subgraphs=None,
 70	engine="dot" ):
 71
 72	GV = create_digraph(engine=engine)
 73
 74	cmap = colormaps['Set3']
 75
 76	if subgraphs :
 77		graph_colors = [to_hex(cmap(i % 12)) for i in range(len(subgraphs))]
 78
 79	for node in range( len(aM.states) ) :
 80		
 81		if subgraphs :
 82			
 83			node_subgraph = -1
 84			for i, sg in enumerate( subgraphs ) :
 85				if node in sg :
 86					node_subgraph = i
 87
 88			node_color = graph_colors[ node_subgraph ] if node_subgraph >= 0 else 'red'
 89
 90		else :
 91			node_color = to_hex(cmap(8))
 92
 93		node_tex = aM.states[ node ].name
 94		if not node_tex :
 95			node_tex = str(node)
 96
 97		GV.node(
 98			str(node), 
 99			label=node_tex, 
100			shape='circle',
101			style='bold,filled',
102        	fillcolor=node_color,
103        	color='black',
104			width='0.5' )
105
106
107	for tr in aM.transitions :
108
109		u = tr.origin_state_idx
110		v = tr.target_state_idx
111
112		pr_str = str( tr.pq ) if aM.is_q_weighted else str( round( tr.prob, 4 ) )
113		label_text = f"{aM.alphabet[tr.symbol_idx]}({pr_str})"
114
115		html_label = (
116            f'<<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="2">'
117            f'<TR><TD BGCOLOR="#FFFFFF">{label_text}</TD></TR>'
118            f'</TABLE>>'
119        )
120
121		GV.edge(
122			str(u), str(v), 
123			label=html_label,
124			fontsize='10',
125			fontname='Times-Italic',
126			labelfloat="true",
127		)
128
129	if view == True :
130
131		GV.graph_attr.update(dpi='300')
132		png_bytes = GV.pipe(format="png")
133		buf = BytesIO( png_bytes )
134		img = mpimg.imread( buf )
135
136		plt.figure(figsize=(8, 6))
137		plt.imshow(img)
138		plt.axis('off')
139		plt.tight_layout()
140		plt.show()
141
142	if output_dir is not None :
143		
144		GV.attr(bgcolor='transparent')
145		GV.render(title, directory=output_dir, view=False, format='png', cleanup=True)
def create_digraph(engine='dot'):
12def create_digraph(engine="dot"):
13
14    engine_configs = {
15        "dot": {
16            'rankdir': 'LR',
17            'ranksep': '1.0',
18            'nodesep': '0.4',
19            'splines': 'spline',
20			'constraint' : 'true',
21            'concentrate': 'false'
22        },
23		"neato": {
24			'overlap': 'scale',
25			'overlap_scaling': '-4',  
26			'esep': '+2.5',           
27			'sep': '+1.75',           
28			'model': 'shortpath',
29			'damping': '0.85',
30			'epsilon': '0.00001',
31			'maxiter': '1000000',
32			'start': '5',
33		},
34        "fdp": {
35            'overlap': 'prism',
36            'sep': '+1.5',
37            'K': '1.0',
38            'splines': 'true',
39			'len' : '3.0',
40            'maxiter': '5000'
41        }
42    }
43
44    graph_attr = engine_configs.get(engine, {})
45
46    node_attr = {
47        'shape': 'box',
48        'style': 'rounded, filled',
49        'fillcolor': 'lightblue',
50        'fontname': 'Helvetica'
51    }
52    
53    edge_attr = {
54        'penwidth': '1.2',
55        'color': 'gray40'
56    }
57
58    return graphviz.Digraph(
59        engine=engine,
60        graph_attr=graph_attr,
61        node_attr=node_attr,
62        edge_attr=edge_attr
63    )
def draw_graph( aM, output_dir: pathlib.Path | None = None, title='am_graph', view=True, subgraphs=None, engine='dot'):
 66def draw_graph( 
 67	aM, 
 68	output_dir : Path | None = None,
 69	title="am_graph", 
 70	view=True, 
 71	subgraphs=None,
 72	engine="dot" ):
 73
 74	GV = create_digraph(engine=engine)
 75
 76	cmap = colormaps['Set3']
 77
 78	if subgraphs :
 79		graph_colors = [to_hex(cmap(i % 12)) for i in range(len(subgraphs))]
 80
 81	for node in range( len(aM.states) ) :
 82		
 83		if subgraphs :
 84			
 85			node_subgraph = -1
 86			for i, sg in enumerate( subgraphs ) :
 87				if node in sg :
 88					node_subgraph = i
 89
 90			node_color = graph_colors[ node_subgraph ] if node_subgraph >= 0 else 'red'
 91
 92		else :
 93			node_color = to_hex(cmap(8))
 94
 95		node_tex = aM.states[ node ].name
 96		if not node_tex :
 97			node_tex = str(node)
 98
 99		GV.node(
100			str(node), 
101			label=node_tex, 
102			shape='circle',
103			style='bold,filled',
104        	fillcolor=node_color,
105        	color='black',
106			width='0.5' )
107
108
109	for tr in aM.transitions :
110
111		u = tr.origin_state_idx
112		v = tr.target_state_idx
113
114		pr_str = str( tr.pq ) if aM.is_q_weighted else str( round( tr.prob, 4 ) )
115		label_text = f"{aM.alphabet[tr.symbol_idx]}({pr_str})"
116
117		html_label = (
118            f'<<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="2">'
119            f'<TR><TD BGCOLOR="#FFFFFF">{label_text}</TD></TR>'
120            f'</TABLE>>'
121        )
122
123		GV.edge(
124			str(u), str(v), 
125			label=html_label,
126			fontsize='10',
127			fontname='Times-Italic',
128			labelfloat="true",
129		)
130
131	if view == True :
132
133		GV.graph_attr.update(dpi='300')
134		png_bytes = GV.pipe(format="png")
135		buf = BytesIO( png_bytes )
136		img = mpimg.imread( buf )
137
138		plt.figure(figsize=(8, 6))
139		plt.imshow(img)
140		plt.axis('off')
141		plt.tight_layout()
142		plt.show()
143
144	if output_dir is not None :
145		
146		GV.attr(bgcolor='transparent')
147		GV.render(title, directory=output_dir, view=False, format='png', cleanup=True)