amachine.am_vis
1from pathlib import Path 2import graphviz 3from matplotlib import colormaps 4from matplotlib.colors import to_hex 5import matplotlib.pyplot as plt 6from IPython.display import Image, display 7import matplotlib.image as mpimg 8from io import BytesIO 9 10def create_digraph(engine="dot"): 11 12 engine_configs = { 13 "dot": { 14 'rankdir': 'LR', 15 'ranksep': '1.0', 16 'nodesep': '0.4', 17 'splines': 'spline', 18 'constraint' : 'true', 19 'concentrate': 'false' 20 }, 21 "neato": { 22 'overlap': 'scale', 23 'overlap_scaling': '-4', 24 'esep': '+2.5', 25 'sep': '+1.75', 26 'model': 'shortpath', 27 'damping': '0.85', 28 'epsilon': '0.00001', 29 'maxiter': '1000000', 30 'start': '5', 31 }, 32 "fdp": { 33 'overlap': 'prism', 34 'sep': '+1.5', 35 'K': '1.0', 36 'splines': 'true', 37 'len' : '3.0', 38 'maxiter': '5000' 39 } 40 } 41 42 graph_attr = engine_configs.get(engine, {}) 43 44 node_attr = { 45 'shape': 'box', 46 'style': 'rounded, filled', 47 'fillcolor': 'lightblue', 48 'fontname': 'Helvetica' 49 } 50 51 edge_attr = { 52 'penwidth': '1.2', 53 'color': 'gray40' 54 } 55 56 return graphviz.Digraph( 57 engine=engine, 58 graph_attr=graph_attr, 59 node_attr=node_attr, 60 edge_attr=edge_attr 61 ) 62 63 64def draw_graph( 65 aM, 66 output_dir : Path | None = None, 67 title="am_graph", 68 view=True, 69 subgraphs=None, 70 engine="dot" ): 71 72 GV = create_digraph(engine=engine) 73 74 cmap = colormaps['Set3'] 75 76 if subgraphs : 77 graph_colors = [to_hex(cmap(i % 12)) for i in range(len(subgraphs))] 78 79 for node in range( len(aM.states) ) : 80 81 if subgraphs : 82 83 node_subgraph = -1 84 for i, sg in enumerate( subgraphs ) : 85 if node in sg : 86 node_subgraph = i 87 88 node_color = graph_colors[ node_subgraph ] if node_subgraph >= 0 else 'red' 89 90 else : 91 node_color = to_hex(cmap(8)) 92 93 node_tex = aM.states[ node ].name 94 if not node_tex : 95 node_tex = str(node) 96 97 GV.node( 98 str(node), 99 label=node_tex, 100 shape='circle', 101 style='bold,filled', 102 fillcolor=node_color, 103 color='black', 104 width='0.5' ) 105 106 107 for tr in aM.transitions : 108 109 u = tr.origin_state_idx 110 v = tr.target_state_idx 111 112 pr_str = str( tr.pq ) if aM.is_q_weighted else str( round( tr.prob, 4 ) ) 113 label_text = f"{aM.alphabet[tr.symbol_idx]}({pr_str})" 114 115 html_label = ( 116 f'<<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="2">' 117 f'<TR><TD BGCOLOR="#FFFFFF">{label_text}</TD></TR>' 118 f'</TABLE>>' 119 ) 120 121 GV.edge( 122 str(u), str(v), 123 label=html_label, 124 fontsize='10', 125 fontname='Times-Italic', 126 labelfloat="true", 127 ) 128 129 if view == True : 130 131 GV.graph_attr.update(dpi='300') 132 png_bytes = GV.pipe(format="png") 133 buf = BytesIO( png_bytes ) 134 img = mpimg.imread( buf ) 135 136 plt.figure(figsize=(8, 6)) 137 plt.imshow(img) 138 plt.axis('off') 139 plt.tight_layout() 140 plt.show() 141 142 if output_dir is not None : 143 144 GV.attr(bgcolor='transparent') 145 GV.render(title, directory=output_dir, view=False, format='png', cleanup=True)
def
create_digraph(engine='dot'):
12def create_digraph(engine="dot"): 13 14 engine_configs = { 15 "dot": { 16 'rankdir': 'LR', 17 'ranksep': '1.0', 18 'nodesep': '0.4', 19 'splines': 'spline', 20 'constraint' : 'true', 21 'concentrate': 'false' 22 }, 23 "neato": { 24 'overlap': 'scale', 25 'overlap_scaling': '-4', 26 'esep': '+2.5', 27 'sep': '+1.75', 28 'model': 'shortpath', 29 'damping': '0.85', 30 'epsilon': '0.00001', 31 'maxiter': '1000000', 32 'start': '5', 33 }, 34 "fdp": { 35 'overlap': 'prism', 36 'sep': '+1.5', 37 'K': '1.0', 38 'splines': 'true', 39 'len' : '3.0', 40 'maxiter': '5000' 41 } 42 } 43 44 graph_attr = engine_configs.get(engine, {}) 45 46 node_attr = { 47 'shape': 'box', 48 'style': 'rounded, filled', 49 'fillcolor': 'lightblue', 50 'fontname': 'Helvetica' 51 } 52 53 edge_attr = { 54 'penwidth': '1.2', 55 'color': 'gray40' 56 } 57 58 return graphviz.Digraph( 59 engine=engine, 60 graph_attr=graph_attr, 61 node_attr=node_attr, 62 edge_attr=edge_attr 63 )
def
draw_graph( aM, output_dir: pathlib.Path | None = None, title='am_graph', view=True, subgraphs=None, engine='dot'):
66def draw_graph( 67 aM, 68 output_dir : Path | None = None, 69 title="am_graph", 70 view=True, 71 subgraphs=None, 72 engine="dot" ): 73 74 GV = create_digraph(engine=engine) 75 76 cmap = colormaps['Set3'] 77 78 if subgraphs : 79 graph_colors = [to_hex(cmap(i % 12)) for i in range(len(subgraphs))] 80 81 for node in range( len(aM.states) ) : 82 83 if subgraphs : 84 85 node_subgraph = -1 86 for i, sg in enumerate( subgraphs ) : 87 if node in sg : 88 node_subgraph = i 89 90 node_color = graph_colors[ node_subgraph ] if node_subgraph >= 0 else 'red' 91 92 else : 93 node_color = to_hex(cmap(8)) 94 95 node_tex = aM.states[ node ].name 96 if not node_tex : 97 node_tex = str(node) 98 99 GV.node( 100 str(node), 101 label=node_tex, 102 shape='circle', 103 style='bold,filled', 104 fillcolor=node_color, 105 color='black', 106 width='0.5' ) 107 108 109 for tr in aM.transitions : 110 111 u = tr.origin_state_idx 112 v = tr.target_state_idx 113 114 pr_str = str( tr.pq ) if aM.is_q_weighted else str( round( tr.prob, 4 ) ) 115 label_text = f"{aM.alphabet[tr.symbol_idx]}({pr_str})" 116 117 html_label = ( 118 f'<<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="2">' 119 f'<TR><TD BGCOLOR="#FFFFFF">{label_text}</TD></TR>' 120 f'</TABLE>>' 121 ) 122 123 GV.edge( 124 str(u), str(v), 125 label=html_label, 126 fontsize='10', 127 fontname='Times-Italic', 128 labelfloat="true", 129 ) 130 131 if view == True : 132 133 GV.graph_attr.update(dpi='300') 134 png_bytes = GV.pipe(format="png") 135 buf = BytesIO( png_bytes ) 136 img = mpimg.imread( buf ) 137 138 plt.figure(figsize=(8, 6)) 139 plt.imshow(img) 140 plt.axis('off') 141 plt.tight_layout() 142 plt.show() 143 144 if output_dir is not None : 145 146 GV.attr(bgcolor='transparent') 147 GV.render(title, directory=output_dir, view=False, format='png', cleanup=True)