amachine.am_fast
1from __future__ import annotations 2 3import sys 4from pathlib import Path 5import json 6 7import importlib.util 8import os 9 10import pyarrow as pa 11import pyarrow.parquet as pq 12import pprint 13import numpy as np 14import numpy.typing as npt 15 16import networkx as nx 17 18from automata.fa.dfa import DFA 19 20 21_build_dir = Path(__file__).resolve().parent / "build" 22 23if str(_build_dir) not in sys.path: 24 sys.path.insert(0, str(_build_dir)) 25 26try: 27 from . import _am_fast 28 from ._am_fast import * 29except ImportError as e: 30 print(f"\n--- DEBUG START ---") 31 print(f"Error: {e}") 32 print(f"Current Directory: {os.getcwd()}") 33 print(f"File Location: {__file__}") 34 # This helps see if the .so file is actually in the folder 35 print(f"Contents of this folder: {os.listdir(os.path.dirname(__file__))}") 36 print(f"--- DEBUG END ---\n") 37 raise e 38 39# from ._am_fast import generate_cpp, block_entropy_convergence_cpp, strongly_connected_components_cpp, minify_dfa_cpp 40from .json_utils import save_json 41 42def block_entropy_convergence( 43 h_mu: float, 44 n_states: int, 45 n_symbols : int, 46 convergence_tol: float, 47 precision: float, 48 eps: float, 49 branches: list[tuple[float, list[float]]], 50 trans: list[list[tuple[int, float, int]]], 51 max_branches: int = 30_000_000 52) -> any : 53 return block_entropy_convergence_cpp( 54 h_mu = float( h_mu ), 55 n_states = n_states, 56 n_symbols = n_symbols, 57 convergence_tol = convergence_tol, 58 precision = float(precision), 59 eps = eps, 60 branches = branches, 61 trans = trans, 62 max_branches = max_branches 63 ) 64 65def strongly_connected_components( T ) : 66 return strongly_connected_components_cpp( T ) 67 68def generate_data( 69 n_gen : int, 70 start_state : int, 71 transitions : list[list[tuple[int, float, int]]], 72 alphabet : list[str], 73 include_states: bool = False, 74 random_seed : int = 42 ) -> dict[str, npt.NDArray]: 75 76 n_states = len(transitions) 77 n_symbols = len(alphabet) 78 79 if n_states > 65536: 80 raise ValueError(f"Max states supported is 65536, got {n_states}") 81 if n_symbols > 65536: 82 raise ValueError(f"Max alphabet size supported is 65536, got {n_symbols}") 83 84 symbol_indices, state_indices = generate_cpp( 85 n_gen = n_gen, 86 start_state_index = start_state, 87 transitions = transitions, 88 include_states = include_states, 89 random_seed = random_seed 90 ) 91 92 res: dict[str, npt.NDArray] = { 93 "symbol_index": np.from_dlpack(symbol_indices) 94 } 95 if include_states: 96 res["state_index"] = np.from_dlpack(state_indices) 97 98 return res 99 100def generate_data( 101 n_gen : int, 102 start_state : int, 103 transitions : list[list[tuple[int, float, int]]], 104 alphabet : list[str], 105 include_states : bool = False, 106 random_seed : int = 42, 107) -> dict[str, npt.NDArray]: 108 109 n_states = len(transitions) 110 n_symbols = len(alphabet) 111 112 if n_states > 65536: 113 raise ValueError(f"Max states supported is 65536, got {n_states}") 114 if n_symbols > 65536: 115 raise ValueError(f"Max alphabet size supported is 65536, got {n_symbols}") 116 117 symbol_indices, state_indices = generate_cpp( 118 n_gen = n_gen, 119 start_state_index = start_state, 120 transitions = transitions, 121 include_states = include_states, 122 random_seed = random_seed, 123 ) 124 125 res: dict[str, npt.NDArray] = { 126 "symbol_index": np.from_dlpack(symbol_indices) 127 } 128 if include_states: 129 res["state_index"] = np.from_dlpack(state_indices) 130 131 return res 132 133 134def save_data( 135 data : dict[str, any], 136 file_prefix : str, 137 alphabet : list[str], 138 n_states : int, 139 start_state : int, 140 random_seed : int, 141 machine_metadata : dict[str, any] | None = None, 142) -> None: 143 144 def _uint_type(n: int) -> pa.DataType: 145 return pa.uint8() if n <= 256 else pa.uint16() 146 147 sym_type = _uint_type(len(alphabet)) 148 state_type = _uint_type(n_states) 149 150 columns: dict[str, pa.Array] = { 151 "symbol_index": pa.array(data["symbol_index"]).cast(sym_type) 152 } 153 154 if "state_index" in data: 155 state_data = data["state_index"] 156 final_state = int(state_data[-1]) 157 columns["state_index"] = pa.array(state_data[:-1]).cast(state_type) 158 159 iso_shifts = data.get("isomorphic_shifts", {}) 160 iso_final_states: dict[int, int] = {} 161 162 for shift, shifted in iso_shifts.items(): 163 columns[f"symbol_index_isoshift_{shift}"] = pa.array(shifted["symbol_index"]).cast(sym_type) 164 if "state_index" in shifted: 165 iso_state_data = shifted["state_index"] 166 iso_final_states[f"{shift}"] = int(iso_state_data[-1]) 167 columns[f"state_index_isoshift_{shift}"] = pa.array(iso_state_data[:-1]).cast(state_type) 168 169 parquet_meta: dict[str, any] = { 170 "alphabet" : alphabet, 171 "machine_metadata" : machine_metadata or {}, 172 "start_state" : start_state, 173 "random_seed" : random_seed, 174 "isomorphic_shifts" : sorted(iso_shifts.keys()), 175 } 176 177 if "state_index" in data: 178 parquet_meta["final_state"] = final_state 179 if iso_final_states: 180 parquet_meta["isomorphic_final_states"] = iso_final_states 181 182 table = pa.table(columns) 183 table = table.replace_schema_metadata({ 184 **(table.schema.metadata or {}), 185 "am_metadata": json.dumps(parquet_meta), 186 }) 187 188 pq.write_table(table, f"{file_prefix}.parquet") 189 save_json(parquet_meta, f"{file_prefix}.json") 190 191 192def minify_cpp(dfa: DFA, retain_names: bool = True) -> DFA: 193 state_list = list(dfa.states) 194 state_idx = {s: i for i, s in enumerate(state_list)} 195 symbol_list = sorted(dfa.input_symbols) 196 symbol_idx = {sym: i for i, sym in enumerate(symbol_list)} 197 198 adj = [[] for _ in range(len(state_list))] 199 for state, paths in dfa.transitions.items(): 200 u = state_idx[state] 201 for sym, nxt in paths.items(): 202 v = state_idx.get(nxt, -1) 203 if v != -1: 204 adj[u].append((symbol_idx[sym], v)) 205 206 is_final = [s in dfa.final_states for s in state_list] 207 init_idx = state_idx[dfa.initial_state] 208 209 # result names in res now match the C++ struct fields exactly 210 res = minify_dfa_cpp(len(state_list), adj, init_idx, is_final) 211 212 if res.is_empty_language: 213 return dfa.__class__.empty_language(dfa.input_symbols) 214 215 if retain_names: 216 class_members = {} 217 for old_idx, nc in enumerate(res.eq_class): 218 if nc >= 0: 219 class_members.setdefault(nc, []).append(state_list[old_idx]) 220 class_map = {nc: frozenset(m) for nc, m in class_members.items()} 221 else: 222 class_map = {nc: nc for nc in range(res.n_classes)} 223 224 new_states = set(class_map.values()) 225 new_initial = class_map[res.new_initial] 226 new_final = {class_map[c] for c, f in enumerate(res.class_is_final) if f} 227 228 new_trans = {} 229 for nc in range(res.n_classes): 230 new_trans[class_map[nc]] = { 231 symbol_list[trans[0]]: class_map[trans[1]] 232 for trans in res.class_trans[nc] 233 } 234 235 return dfa.__class__( 236 states=new_states, 237 input_symbols=dfa.input_symbols, 238 transitions=new_trans, 239 initial_state=new_initial, 240 final_states=new_final, 241 allow_partial=any(len(t) < len(symbol_list) for t in new_trans.values()) 242 )
def
block_entropy_convergence( h_mu: float, n_states: int, n_symbols: int, convergence_tol: float, precision: float, eps: float, branches: list[tuple[float, list[float]]], trans: list[list[tuple[int, float, int]]], max_branches: int = 30000000) -> <built-in function any>:
43def block_entropy_convergence( 44 h_mu: float, 45 n_states: int, 46 n_symbols : int, 47 convergence_tol: float, 48 precision: float, 49 eps: float, 50 branches: list[tuple[float, list[float]]], 51 trans: list[list[tuple[int, float, int]]], 52 max_branches: int = 30_000_000 53) -> any : 54 return block_entropy_convergence_cpp( 55 h_mu = float( h_mu ), 56 n_states = n_states, 57 n_symbols = n_symbols, 58 convergence_tol = convergence_tol, 59 precision = float(precision), 60 eps = eps, 61 branches = branches, 62 trans = trans, 63 max_branches = max_branches 64 )
def
strongly_connected_components(T):
def
generate_data( n_gen: int, start_state: int, transitions: list[list[tuple[int, float, int]]], alphabet: list[str], include_states: bool = False, random_seed: int = 42) -> dict[str, numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[~_ScalarT]]]:
101def generate_data( 102 n_gen : int, 103 start_state : int, 104 transitions : list[list[tuple[int, float, int]]], 105 alphabet : list[str], 106 include_states : bool = False, 107 random_seed : int = 42, 108) -> dict[str, npt.NDArray]: 109 110 n_states = len(transitions) 111 n_symbols = len(alphabet) 112 113 if n_states > 65536: 114 raise ValueError(f"Max states supported is 65536, got {n_states}") 115 if n_symbols > 65536: 116 raise ValueError(f"Max alphabet size supported is 65536, got {n_symbols}") 117 118 symbol_indices, state_indices = generate_cpp( 119 n_gen = n_gen, 120 start_state_index = start_state, 121 transitions = transitions, 122 include_states = include_states, 123 random_seed = random_seed, 124 ) 125 126 res: dict[str, npt.NDArray] = { 127 "symbol_index": np.from_dlpack(symbol_indices) 128 } 129 if include_states: 130 res["state_index"] = np.from_dlpack(state_indices) 131 132 return res
def
save_data( data: dict[str, any], file_prefix: str, alphabet: list[str], n_states: int, start_state: int, random_seed: int, machine_metadata: dict[str, any] | None = None) -> None:
135def save_data( 136 data : dict[str, any], 137 file_prefix : str, 138 alphabet : list[str], 139 n_states : int, 140 start_state : int, 141 random_seed : int, 142 machine_metadata : dict[str, any] | None = None, 143) -> None: 144 145 def _uint_type(n: int) -> pa.DataType: 146 return pa.uint8() if n <= 256 else pa.uint16() 147 148 sym_type = _uint_type(len(alphabet)) 149 state_type = _uint_type(n_states) 150 151 columns: dict[str, pa.Array] = { 152 "symbol_index": pa.array(data["symbol_index"]).cast(sym_type) 153 } 154 155 if "state_index" in data: 156 state_data = data["state_index"] 157 final_state = int(state_data[-1]) 158 columns["state_index"] = pa.array(state_data[:-1]).cast(state_type) 159 160 iso_shifts = data.get("isomorphic_shifts", {}) 161 iso_final_states: dict[int, int] = {} 162 163 for shift, shifted in iso_shifts.items(): 164 columns[f"symbol_index_isoshift_{shift}"] = pa.array(shifted["symbol_index"]).cast(sym_type) 165 if "state_index" in shifted: 166 iso_state_data = shifted["state_index"] 167 iso_final_states[f"{shift}"] = int(iso_state_data[-1]) 168 columns[f"state_index_isoshift_{shift}"] = pa.array(iso_state_data[:-1]).cast(state_type) 169 170 parquet_meta: dict[str, any] = { 171 "alphabet" : alphabet, 172 "machine_metadata" : machine_metadata or {}, 173 "start_state" : start_state, 174 "random_seed" : random_seed, 175 "isomorphic_shifts" : sorted(iso_shifts.keys()), 176 } 177 178 if "state_index" in data: 179 parquet_meta["final_state"] = final_state 180 if iso_final_states: 181 parquet_meta["isomorphic_final_states"] = iso_final_states 182 183 table = pa.table(columns) 184 table = table.replace_schema_metadata({ 185 **(table.schema.metadata or {}), 186 "am_metadata": json.dumps(parquet_meta), 187 }) 188 189 pq.write_table(table, f"{file_prefix}.parquet") 190 save_json(parquet_meta, f"{file_prefix}.json")
def
minify_cpp( dfa: automata.fa.dfa.DFA, retain_names: bool = True) -> automata.fa.dfa.DFA:
193def minify_cpp(dfa: DFA, retain_names: bool = True) -> DFA: 194 state_list = list(dfa.states) 195 state_idx = {s: i for i, s in enumerate(state_list)} 196 symbol_list = sorted(dfa.input_symbols) 197 symbol_idx = {sym: i for i, sym in enumerate(symbol_list)} 198 199 adj = [[] for _ in range(len(state_list))] 200 for state, paths in dfa.transitions.items(): 201 u = state_idx[state] 202 for sym, nxt in paths.items(): 203 v = state_idx.get(nxt, -1) 204 if v != -1: 205 adj[u].append((symbol_idx[sym], v)) 206 207 is_final = [s in dfa.final_states for s in state_list] 208 init_idx = state_idx[dfa.initial_state] 209 210 # result names in res now match the C++ struct fields exactly 211 res = minify_dfa_cpp(len(state_list), adj, init_idx, is_final) 212 213 if res.is_empty_language: 214 return dfa.__class__.empty_language(dfa.input_symbols) 215 216 if retain_names: 217 class_members = {} 218 for old_idx, nc in enumerate(res.eq_class): 219 if nc >= 0: 220 class_members.setdefault(nc, []).append(state_list[old_idx]) 221 class_map = {nc: frozenset(m) for nc, m in class_members.items()} 222 else: 223 class_map = {nc: nc for nc in range(res.n_classes)} 224 225 new_states = set(class_map.values()) 226 new_initial = class_map[res.new_initial] 227 new_final = {class_map[c] for c, f in enumerate(res.class_is_final) if f} 228 229 new_trans = {} 230 for nc in range(res.n_classes): 231 new_trans[class_map[nc]] = { 232 symbol_list[trans[0]]: class_map[trans[1]] 233 for trans in res.class_trans[nc] 234 } 235 236 return dfa.__class__( 237 states=new_states, 238 input_symbols=dfa.input_symbols, 239 transitions=new_trans, 240 initial_state=new_initial, 241 final_states=new_final, 242 allow_partial=any(len(t) < len(symbol_list) for t in new_trans.values()) 243 )