amachine.am_vis
1from pathlib import Path 2import graphviz 3from matplotlib import colormaps 4from matplotlib.colors import to_hex 5import matplotlib.pyplot as plt 6from IPython.display import Image, display 7import matplotlib.image as mpimg 8from io import BytesIO 9 10def create_digraph(engine="dot"): 11 12 engine_configs = { 13 "dot": { 14 'rankdir': 'TB', 15 'ranksep': '1.0', 16 'nodesep': '0.4', 17 'splines': 'spline', 18 'constraint' : 'true', 19 'concentrate': 'false' 20 }, 21 "neato": { 22 'overlap': 'scale', 23 'overlap_scaling': '-4', 24 'esep': '+2.5', 25 'sep': '+1.75', 26 'model': 'shortpath', 27 'damping': '0.85', 28 'epsilon': '0.00001', 29 'maxiter': '1000000', 30 'start': '5', 31 }, 32 "fdp": { 33 'overlap': 'prism', 34 'sep': '+1.5', 35 'K': '1.0', 36 'splines': 'true', 37 'len' : '3.0', 38 'maxiter': '5000' 39 } 40 } 41 42 graph_attr = engine_configs.get(engine, {}) 43 44 node_attr = { 45 'shape': 'box', 46 'style': 'rounded, filled', 47 'fillcolor': 'lightblue', 48 'fontname': 'Helvetica' 49 } 50 51 edge_attr = { 52 'penwidth': '1.2', 53 'color': 'gray40' 54 } 55 56 return graphviz.Digraph( 57 engine=engine, 58 graph_attr=graph_attr, 59 node_attr=node_attr, 60 edge_attr=edge_attr 61 ) 62 63 64def draw_graph( 65 aM, 66 output_dir : Path | None = None, 67 title="am_graph", 68 view=True, 69 subgraphs=None, 70 engine="dot", 71 symbols_only=False ): 72 73 GV = create_digraph(engine=engine) 74 75 cmap = colormaps['Set3'] 76 77 if subgraphs : 78 graph_colors = [to_hex(cmap(i % 12)) for i in range(len(subgraphs))] 79 80 for node in range( len(aM.states) ) : 81 82 if subgraphs : 83 84 node_subgraph = -1 85 for i, sg in enumerate( subgraphs ) : 86 if node in sg : 87 node_subgraph = i 88 89 node_color = graph_colors[ node_subgraph ] if node_subgraph >= 0 else 'red' 90 91 else : 92 node_color = to_hex(cmap(8)) 93 94 node_tex = aM.states[ node ].name 95 if not node_tex : 96 node_tex = str(node) 97 98 GV.node( 99 str(node), 100 label=node_tex, 101 shape='circle', 102 style='bold,filled', 103 fillcolor=node_color, 104 color='black', 105 width='0.5' ) 106 107 108 for tr in aM.transitions : 109 110 u = tr.origin_state_idx 111 v = tr.target_state_idx 112 113 pr_str = str( tr.pq ) if aM.is_q_weighted else str( round( tr.prob, 4 ) ) 114 115 fontsize= '30' if symbols_only else '12' 116 117 if symbols_only : 118 label_text = f"{aM.alphabet[tr.symbol_idx]}" 119 else : 120 label_text = f"{aM.alphabet[tr.symbol_idx]}({pr_str})" 121 122 html_label = ( 123 f'<<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="2">' 124 f'<TR><TD BGCOLOR="#FFFFFF">{label_text}</TD></TR>' 125 f'</TABLE>>' 126 ) 127 128 GV.edge( 129 str(u), str(v), 130 label=html_label, 131 fontsize=fontsize, 132 fontname='Times-Italic', 133 labelfloat="true", 134 fontcolor="#0042ad" 135 ) 136 137 if view == True : 138 139 GV.graph_attr.update(dpi='300') 140 png_bytes = GV.pipe(format="png") 141 buf = BytesIO( png_bytes ) 142 img = mpimg.imread( buf ) 143 144 plt.figure(figsize=(8, 6)) 145 plt.imshow(img) 146 plt.axis('off') 147 plt.tight_layout() 148 plt.show() 149 150 if output_dir is not None : 151 152 GV.attr(bgcolor='transparent') 153 GV.render(title, directory=output_dir, view=False, format='png', cleanup=True)
def
create_digraph(engine='dot'):
12def create_digraph(engine="dot"): 13 14 engine_configs = { 15 "dot": { 16 'rankdir': 'TB', 17 'ranksep': '1.0', 18 'nodesep': '0.4', 19 'splines': 'spline', 20 'constraint' : 'true', 21 'concentrate': 'false' 22 }, 23 "neato": { 24 'overlap': 'scale', 25 'overlap_scaling': '-4', 26 'esep': '+2.5', 27 'sep': '+1.75', 28 'model': 'shortpath', 29 'damping': '0.85', 30 'epsilon': '0.00001', 31 'maxiter': '1000000', 32 'start': '5', 33 }, 34 "fdp": { 35 'overlap': 'prism', 36 'sep': '+1.5', 37 'K': '1.0', 38 'splines': 'true', 39 'len' : '3.0', 40 'maxiter': '5000' 41 } 42 } 43 44 graph_attr = engine_configs.get(engine, {}) 45 46 node_attr = { 47 'shape': 'box', 48 'style': 'rounded, filled', 49 'fillcolor': 'lightblue', 50 'fontname': 'Helvetica' 51 } 52 53 edge_attr = { 54 'penwidth': '1.2', 55 'color': 'gray40' 56 } 57 58 return graphviz.Digraph( 59 engine=engine, 60 graph_attr=graph_attr, 61 node_attr=node_attr, 62 edge_attr=edge_attr 63 )
def
draw_graph( aM, output_dir: pathlib.Path | None = None, title='am_graph', view=True, subgraphs=None, engine='dot', symbols_only=False):
66def draw_graph( 67 aM, 68 output_dir : Path | None = None, 69 title="am_graph", 70 view=True, 71 subgraphs=None, 72 engine="dot", 73 symbols_only=False ): 74 75 GV = create_digraph(engine=engine) 76 77 cmap = colormaps['Set3'] 78 79 if subgraphs : 80 graph_colors = [to_hex(cmap(i % 12)) for i in range(len(subgraphs))] 81 82 for node in range( len(aM.states) ) : 83 84 if subgraphs : 85 86 node_subgraph = -1 87 for i, sg in enumerate( subgraphs ) : 88 if node in sg : 89 node_subgraph = i 90 91 node_color = graph_colors[ node_subgraph ] if node_subgraph >= 0 else 'red' 92 93 else : 94 node_color = to_hex(cmap(8)) 95 96 node_tex = aM.states[ node ].name 97 if not node_tex : 98 node_tex = str(node) 99 100 GV.node( 101 str(node), 102 label=node_tex, 103 shape='circle', 104 style='bold,filled', 105 fillcolor=node_color, 106 color='black', 107 width='0.5' ) 108 109 110 for tr in aM.transitions : 111 112 u = tr.origin_state_idx 113 v = tr.target_state_idx 114 115 pr_str = str( tr.pq ) if aM.is_q_weighted else str( round( tr.prob, 4 ) ) 116 117 fontsize= '30' if symbols_only else '12' 118 119 if symbols_only : 120 label_text = f"{aM.alphabet[tr.symbol_idx]}" 121 else : 122 label_text = f"{aM.alphabet[tr.symbol_idx]}({pr_str})" 123 124 html_label = ( 125 f'<<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="2">' 126 f'<TR><TD BGCOLOR="#FFFFFF">{label_text}</TD></TR>' 127 f'</TABLE>>' 128 ) 129 130 GV.edge( 131 str(u), str(v), 132 label=html_label, 133 fontsize=fontsize, 134 fontname='Times-Italic', 135 labelfloat="true", 136 fontcolor="#0042ad" 137 ) 138 139 if view == True : 140 141 GV.graph_attr.update(dpi='300') 142 png_bytes = GV.pipe(format="png") 143 buf = BytesIO( png_bytes ) 144 img = mpimg.imread( buf ) 145 146 plt.figure(figsize=(8, 6)) 147 plt.imshow(img) 148 plt.axis('off') 149 plt.tight_layout() 150 plt.show() 151 152 if output_dir is not None : 153 154 GV.attr(bgcolor='transparent') 155 GV.render(title, directory=output_dir, view=False, format='png', cleanup=True)