GitLab Repo

amachine.am_fast

  1from __future__ import annotations
  2
  3import sys
  4from pathlib import Path
  5import json
  6
  7import importlib.util
  8import os
  9
 10import pyarrow as pa
 11import pyarrow.parquet as pq
 12import pprint
 13import numpy as np
 14import numpy.typing as npt
 15
 16import networkx as nx
 17
 18from automata.fa.dfa import DFA
 19
 20
 21_build_dir = Path(__file__).resolve().parent / "build"
 22
 23if str(_build_dir) not in sys.path:
 24    sys.path.insert(0, str(_build_dir))
 25
 26try:
 27    from . import _am_fast
 28    from ._am_fast import *
 29except ImportError as e:
 30    print(f"\n--- DEBUG START ---")
 31    print(f"Error: {e}")
 32    print(f"Current Directory: {os.getcwd()}")
 33    print(f"File Location: {__file__}")
 34    # This helps see if the .so file is actually in the folder
 35    print(f"Contents of this folder: {os.listdir(os.path.dirname(__file__))}")
 36    print(f"--- DEBUG END ---\n")
 37    raise e
 38
 39# from ._am_fast import generate_cpp, block_entropy_convergence_cpp, strongly_connected_components_cpp, minify_dfa_cpp
 40from .json_utils import save_json
 41
 42def block_entropy_convergence(
 43    h_mu: float,
 44    n_states: int,
 45    n_symbols : int,
 46    convergence_tol: float,
 47    precision: float,
 48    eps: float,
 49    branches: list[tuple[float, list[float]]],
 50    trans: list[list[tuple[int, float, int]]],
 51    max_branches: int = 30_000_000
 52) -> any :
 53    return block_entropy_convergence_cpp(
 54        h_mu            = float( h_mu ),
 55        n_states        = n_states,
 56        n_symbols       = n_symbols,
 57        convergence_tol = convergence_tol,
 58        precision       = float(precision),
 59        eps             = eps,
 60        branches        = branches,
 61        trans           = trans,
 62        max_branches    = max_branches
 63    )
 64
 65def strongly_connected_components( T ) :
 66    return strongly_connected_components_cpp( T )
 67
 68def generate_data(
 69    n_gen        : int,
 70    start_state  : int,
 71    transitions  : list[list[tuple[int, float, int]]],
 72    alphabet     : list[str],
 73    include_states: bool = False,
 74    random_seed : int = 42 ) -> dict[str, npt.NDArray]:
 75
 76    n_states  = len(transitions)
 77    n_symbols = len(alphabet)
 78
 79    if n_states > 65536:
 80        raise ValueError(f"Max states supported is 65536, got {n_states}")
 81    if n_symbols > 65536:
 82        raise ValueError(f"Max alphabet size supported is 65536, got {n_symbols}")
 83
 84    symbol_indices, state_indices = generate_cpp(
 85        n_gen             = n_gen,
 86        start_state_index = start_state,
 87        transitions       = transitions,
 88        include_states    = include_states,
 89        random_seed       = random_seed
 90    )
 91
 92    res: dict[str, npt.NDArray] = {
 93        "symbol_index": np.from_dlpack(symbol_indices)
 94    }
 95    if include_states:
 96        res["state_index"] = np.from_dlpack(state_indices)
 97
 98    return res
 99
100def generate_data(
101    n_gen          : int,
102    start_state    : int,
103    transitions    : list[list[tuple[int, float, int]]],
104    alphabet       : list[str],
105    include_states : bool = False,
106    random_seed    : int  = 42,
107) -> dict[str, npt.NDArray]:
108
109    n_states  = len(transitions)
110    n_symbols = len(alphabet)
111
112    if n_states > 65536:
113        raise ValueError(f"Max states supported is 65536, got {n_states}")
114    if n_symbols > 65536:
115        raise ValueError(f"Max alphabet size supported is 65536, got {n_symbols}")
116
117    symbol_indices, state_indices = generate_cpp(
118        n_gen             = n_gen,
119        start_state_index = start_state,
120        transitions       = transitions,
121        include_states    = include_states,
122        random_seed       = random_seed,
123    )
124
125    res: dict[str, npt.NDArray] = {
126        "symbol_index": np.from_dlpack(symbol_indices)
127    }
128    if include_states:
129        res["state_index"] = np.from_dlpack(state_indices)
130
131    return res
132
133
134def save_data(
135    data             : dict[str, any],
136    file_prefix      : str,
137    alphabet         : list[str],
138    n_states         : int,
139    start_state      : int,
140    random_seed      : int,
141    machine_metadata : dict[str, any] | None = None,
142) -> None:
143
144    def _uint_type(n: int) -> pa.DataType:
145        return pa.uint8() if n <= 256 else pa.uint16()
146
147    sym_type   = _uint_type(len(alphabet))
148    state_type = _uint_type(n_states)
149
150    columns: dict[str, pa.Array] = {
151        "symbol_index": pa.array(data["symbol_index"]).cast(sym_type)
152    }
153
154    if "state_index" in data:
155        state_data  = data["state_index"]
156        final_state = int(state_data[-1])
157        columns["state_index"] = pa.array(state_data[:-1]).cast(state_type)
158
159    iso_shifts = data.get("isomorphic_shifts", {})
160    iso_final_states: dict[int, int] = {}
161
162    for shift, shifted in iso_shifts.items():
163        columns[f"symbol_index_isoshift_{shift}"] = pa.array(shifted["symbol_index"]).cast(sym_type)
164        if "state_index" in shifted:
165            iso_state_data                             = shifted["state_index"]
166            iso_final_states[f"{shift}"]                    = int(iso_state_data[-1])
167            columns[f"state_index_isoshift_{shift}"]   = pa.array(iso_state_data[:-1]).cast(state_type)
168
169    parquet_meta: dict[str, any] = {
170        "alphabet"          : alphabet,
171        "machine_metadata"  : machine_metadata or {},
172        "start_state"       : start_state,
173        "random_seed"       : random_seed,
174        "isomorphic_shifts" : sorted(iso_shifts.keys()),
175    }
176
177    if "state_index" in data:
178        parquet_meta["final_state"] = final_state
179    if iso_final_states:
180        parquet_meta["isomorphic_final_states"] = iso_final_states
181
182    table = pa.table(columns)
183    table = table.replace_schema_metadata({
184        **(table.schema.metadata or {}),
185        "am_metadata": json.dumps(parquet_meta),
186    })
187
188    pq.write_table(table, f"{file_prefix}.parquet")
189    save_json(parquet_meta, f"{file_prefix}.json")
190
191
192def minify_cpp(dfa: DFA, retain_names: bool = True) -> DFA:
193    state_list = list(dfa.states)
194    state_idx = {s: i for i, s in enumerate(state_list)}
195    symbol_list = sorted(dfa.input_symbols)
196    symbol_idx = {sym: i for i, sym in enumerate(symbol_list)}
197
198    adj = [[] for _ in range(len(state_list))]
199    for state, paths in dfa.transitions.items():
200        u = state_idx[state]
201        for sym, nxt in paths.items():
202            v = state_idx.get(nxt, -1)
203            if v != -1:
204                adj[u].append((symbol_idx[sym], v))
205
206    is_final = [s in dfa.final_states for s in state_list]
207    init_idx = state_idx[dfa.initial_state]
208
209    # result names in res now match the C++ struct fields exactly
210    res = minify_dfa_cpp(len(state_list), adj, init_idx, is_final)
211
212    if res.is_empty_language:
213        return dfa.__class__.empty_language(dfa.input_symbols)
214
215    if retain_names:
216        class_members = {}
217        for old_idx, nc in enumerate(res.eq_class):
218            if nc >= 0:
219                class_members.setdefault(nc, []).append(state_list[old_idx])
220        class_map = {nc: frozenset(m) for nc, m in class_members.items()}
221    else:
222        class_map = {nc: nc for nc in range(res.n_classes)}
223
224    new_states = set(class_map.values())
225    new_initial = class_map[res.new_initial]
226    new_final = {class_map[c] for c, f in enumerate(res.class_is_final) if f}
227
228    new_trans = {}
229    for nc in range(res.n_classes):
230        new_trans[class_map[nc]] = {
231            symbol_list[trans[0]]: class_map[trans[1]]
232            for trans in res.class_trans[nc]
233        }
234
235    return dfa.__class__(
236        states=new_states,
237        input_symbols=dfa.input_symbols,
238        transitions=new_trans,
239        initial_state=new_initial,
240        final_states=new_final,
241        allow_partial=any(len(t) < len(symbol_list) for t in new_trans.values())
242    )
def block_entropy_convergence( h_mu: float, n_states: int, n_symbols: int, convergence_tol: float, precision: float, eps: float, branches: list[tuple[float, list[float]]], trans: list[list[tuple[int, float, int]]], max_branches: int = 30000000) -> <built-in function any>:
43def block_entropy_convergence(
44    h_mu: float,
45    n_states: int,
46    n_symbols : int,
47    convergence_tol: float,
48    precision: float,
49    eps: float,
50    branches: list[tuple[float, list[float]]],
51    trans: list[list[tuple[int, float, int]]],
52    max_branches: int = 30_000_000
53) -> any :
54    return block_entropy_convergence_cpp(
55        h_mu            = float( h_mu ),
56        n_states        = n_states,
57        n_symbols       = n_symbols,
58        convergence_tol = convergence_tol,
59        precision       = float(precision),
60        eps             = eps,
61        branches        = branches,
62        trans           = trans,
63        max_branches    = max_branches
64    )
def strongly_connected_components(T):
66def strongly_connected_components( T ) :
67    return strongly_connected_components_cpp( T )
def generate_data( n_gen: int, start_state: int, transitions: list[list[tuple[int, float, int]]], alphabet: list[str], include_states: bool = False, random_seed: int = 42) -> dict[str, numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[~_ScalarT]]]:
101def generate_data(
102    n_gen          : int,
103    start_state    : int,
104    transitions    : list[list[tuple[int, float, int]]],
105    alphabet       : list[str],
106    include_states : bool = False,
107    random_seed    : int  = 42,
108) -> dict[str, npt.NDArray]:
109
110    n_states  = len(transitions)
111    n_symbols = len(alphabet)
112
113    if n_states > 65536:
114        raise ValueError(f"Max states supported is 65536, got {n_states}")
115    if n_symbols > 65536:
116        raise ValueError(f"Max alphabet size supported is 65536, got {n_symbols}")
117
118    symbol_indices, state_indices = generate_cpp(
119        n_gen             = n_gen,
120        start_state_index = start_state,
121        transitions       = transitions,
122        include_states    = include_states,
123        random_seed       = random_seed,
124    )
125
126    res: dict[str, npt.NDArray] = {
127        "symbol_index": np.from_dlpack(symbol_indices)
128    }
129    if include_states:
130        res["state_index"] = np.from_dlpack(state_indices)
131
132    return res
def save_data( data: dict[str, any], file_prefix: str, alphabet: list[str], n_states: int, start_state: int, random_seed: int, machine_metadata: dict[str, any] | None = None) -> None:
135def save_data(
136    data             : dict[str, any],
137    file_prefix      : str,
138    alphabet         : list[str],
139    n_states         : int,
140    start_state      : int,
141    random_seed      : int,
142    machine_metadata : dict[str, any] | None = None,
143) -> None:
144
145    def _uint_type(n: int) -> pa.DataType:
146        return pa.uint8() if n <= 256 else pa.uint16()
147
148    sym_type   = _uint_type(len(alphabet))
149    state_type = _uint_type(n_states)
150
151    columns: dict[str, pa.Array] = {
152        "symbol_index": pa.array(data["symbol_index"]).cast(sym_type)
153    }
154
155    if "state_index" in data:
156        state_data  = data["state_index"]
157        final_state = int(state_data[-1])
158        columns["state_index"] = pa.array(state_data[:-1]).cast(state_type)
159
160    iso_shifts = data.get("isomorphic_shifts", {})
161    iso_final_states: dict[int, int] = {}
162
163    for shift, shifted in iso_shifts.items():
164        columns[f"symbol_index_isoshift_{shift}"] = pa.array(shifted["symbol_index"]).cast(sym_type)
165        if "state_index" in shifted:
166            iso_state_data                             = shifted["state_index"]
167            iso_final_states[f"{shift}"]                    = int(iso_state_data[-1])
168            columns[f"state_index_isoshift_{shift}"]   = pa.array(iso_state_data[:-1]).cast(state_type)
169
170    parquet_meta: dict[str, any] = {
171        "alphabet"          : alphabet,
172        "machine_metadata"  : machine_metadata or {},
173        "start_state"       : start_state,
174        "random_seed"       : random_seed,
175        "isomorphic_shifts" : sorted(iso_shifts.keys()),
176    }
177
178    if "state_index" in data:
179        parquet_meta["final_state"] = final_state
180    if iso_final_states:
181        parquet_meta["isomorphic_final_states"] = iso_final_states
182
183    table = pa.table(columns)
184    table = table.replace_schema_metadata({
185        **(table.schema.metadata or {}),
186        "am_metadata": json.dumps(parquet_meta),
187    })
188
189    pq.write_table(table, f"{file_prefix}.parquet")
190    save_json(parquet_meta, f"{file_prefix}.json")
def minify_cpp( dfa: automata.fa.dfa.DFA, retain_names: bool = True) -> automata.fa.dfa.DFA:
193def minify_cpp(dfa: DFA, retain_names: bool = True) -> DFA:
194    state_list = list(dfa.states)
195    state_idx = {s: i for i, s in enumerate(state_list)}
196    symbol_list = sorted(dfa.input_symbols)
197    symbol_idx = {sym: i for i, sym in enumerate(symbol_list)}
198
199    adj = [[] for _ in range(len(state_list))]
200    for state, paths in dfa.transitions.items():
201        u = state_idx[state]
202        for sym, nxt in paths.items():
203            v = state_idx.get(nxt, -1)
204            if v != -1:
205                adj[u].append((symbol_idx[sym], v))
206
207    is_final = [s in dfa.final_states for s in state_list]
208    init_idx = state_idx[dfa.initial_state]
209
210    # result names in res now match the C++ struct fields exactly
211    res = minify_dfa_cpp(len(state_list), adj, init_idx, is_final)
212
213    if res.is_empty_language:
214        return dfa.__class__.empty_language(dfa.input_symbols)
215
216    if retain_names:
217        class_members = {}
218        for old_idx, nc in enumerate(res.eq_class):
219            if nc >= 0:
220                class_members.setdefault(nc, []).append(state_list[old_idx])
221        class_map = {nc: frozenset(m) for nc, m in class_members.items()}
222    else:
223        class_map = {nc: nc for nc in range(res.n_classes)}
224
225    new_states = set(class_map.values())
226    new_initial = class_map[res.new_initial]
227    new_final = {class_map[c] for c, f in enumerate(res.class_is_final) if f}
228
229    new_trans = {}
230    for nc in range(res.n_classes):
231        new_trans[class_map[nc]] = {
232            symbol_list[trans[0]]: class_map[trans[1]]
233            for trans in res.class_trans[nc]
234        }
235
236    return dfa.__class__(
237        states=new_states,
238        input_symbols=dfa.input_symbols,
239        transitions=new_trans,
240        initial_state=new_initial,
241        final_states=new_final,
242        allow_partial=any(len(t) < len(symbol_list) for t in new_trans.values())
243    )