amachine.am_hmm
1from __future__ import annotations 2 3from abc import abstractmethod 4from typing import override 5import gc 6import copy 7from collections import deque 8from collections import defaultdict 9import json 10from fractions import Fraction 11from pathlib import Path 12import warnings 13import time 14from dataclasses import dataclass, asdict, field 15 16import networkx as nx 17 18from automata.fa.dfa import DFA 19 20import sympy 21import numpy as np 22 23from .am_machine import Machine 24from .am_symbol import Symbol 25 26from .am_msp import MSP, compute_msp, compute_msp_exact 27from .am_solve import solve_for_pi, solve_for_pi_fractional 28from . import am_vis 29from . import am_fast 30 31from .am_causal_state import CausalState 32from .am_transition import Transition 33 34from .am_fast.distance import jensenshannondivergence_gpu as af_jensenshannondivergence 35from .am_fast.distance import jensenshannondivergence_cpu as af_jensenshannondivergence_cpu 36 37class HMM(Machine) : 38 39 """Hidden Markov model implementing epsilon machines, mixed state presentations, 40 complexity measures, and data generation. 41 42 Args: 43 states (list[CausalState] | None): A list of causal states. 44 transitions (list[Transition] | None): A list of transitions between states. 45 start_state (int): Index of the start state. 46 alphabet (list[str]): List of symbols making up the alphabet. 47 name (str): Name of the model. 48 description (str): Description of the model. 49 """ 50 51 def __init__( 52 self, 53 states : list[CausalState] | None = None, 54 transitions : list[Transition] | None = None, 55 start_state : int = 0, 56 alphabet : list[str] | None = None, 57 name : str = "", 58 description : str = "" ) : 59 60 self.alphabet : list[str] | None = alphabet or [] 61 self.states : list[CausalState] | None= states or [] 62 self.transitions : list[Transition] | None= transitions or [] 63 64 self.set_alphabet( self.alphabet ) 65 self.set_states( self.states ) 66 self.set_transitions( self.transitions ) 67 68 self.start_state : int = start_state 69 70 self.name : str = name 71 self.description : str = description 72 73 # To be depreciated 74 self.isoclass : str = None 75 76 # --- derived -------- 77 78 self.complexity : dict[str, any] = {} 79 80 self.symbol_idx_map : dict[str,int] = {} 81 self.state_idx_map : dict[str,int] = {} 82 83 self.pi_fractional = None 84 self.pi : np.ndarray | None = None 85 self.T : np.ndarray | None = None 86 self.msp : MSP | None = None 87 self.reverse_am : HMM | None = None 88 self.is_q_weighted : bool = False 89 self.is_minimal : bool = False 90 91 # --- const ---------- 92 93 self.EPS : float = 1e-12 94 95 #-------------------------------------------------------# 96 # Overrides / Getters # 97 #-------------------------------------------------------# 98 99 @override 100 def get_states(self) -> list[CausalState] : 101 return self.states 102 103 @override 104 def get_transitions(self) -> list[Transition] : 105 return self.transitions 106 107 @override 108 def get_alphabet(self) -> list[Symbol]: 109 return [ Symbol(a) for a in self.alphabet ] 110 111 #-------------------------------------------------------# 112 # Serialization # 113 #-------------------------------------------------------# 114 115 116 def to_dict(self) -> dict[str,any]: 117 118 """Create a dict representing the HMM configuration. 119 120 Returns: 121 122 dict[str,any]: Dictionary containing, name, description, states, transitions, alphabet, and isoclass. 123 """ 124 125 return { 126 "name" : self.name, 127 "description" : self.description, 128 "states" : [ asdict(state) for state in self.states ], 129 "transitions" : [ asdict(transition) for transition in self.transitions ], 130 "alphabet" : self.alphabet, 131 "isoclass" : self.isoclass 132 } 133 134 def from_dict( self, config : dict[str,any] ) : 135 136 """Configure the HMM configuration from a dictionary. 137 138 Args: 139 config (dict[str,any]): The HMM configuration. 140 """ 141 142 self.name = config.name 143 self.description = config.description 144 145 self.set_states( states=[ 146 CausalState( 147 name=state[ "name" ], 148 classes=set( state[ "classes" ] ) 149 ) 150 for state in config.states 151 ] ) 152 153 self.set_transitions( transitions=[ 154 Transition( 155 origin_state_idx=tr[ "origin_state_idx" ], 156 target_state_idx=tr[ "target_state_idx" ], 157 prob=tr[ "prob" ], 158 symbol_idx=tr[ "symbol_idx" ] 159 ) 160 for tr in config.transitions 161 ] ) 162 163 self.set_alphabet( alphabet=config.alphabet ) 164 165 def save_config( 166 self, 167 path : Path, 168 with_complexity : bool = False, 169 with_block_convergence : bool = False, 170 with_structural_properties : bool = False, 171 with_causal_properties : bool = False ) : 172 173 config = self.to_dict() 174 175 if with_complexity : 176 177 complexity = self.get_complexities( 178 with_block_convergence=with_block_convergence 179 ) 180 181 config[ "complexity" ] = complexity 182 183 config[ "structural_properties" ] = { 184 "unifilar" : self.is_unifilar(), 185 "row_stochastic" : self.is_row_stochastic(), 186 "strongly_connected" : self.is_strongly_connected(), 187 "aperiodic" : self.is_aperiodic(), 188 "minimal" : self._is_minimal_as_dfa( topological_only=False ), 189 "topologically_minimal" : self._is_minimal_as_dfa( topological_only=True ), 190 "is_epsilon_machine" : self.is_epsilon_machine() 191 } 192 193 with open( path / "am_config.json", "w", encoding="utf-8" ) as f : 194 json.dump( config, f, ensure_ascii=False, indent=2, default=list ) 195 196 def from_file( self, path : Path ) : 197 with open( Path / "am_config.json", "r" ) as f: 198 config = json.load(f) 199 self.from_dict() 200 201 #-------------------------------------------------------# 202 # Setters and State Management # 203 #-------------------------------------------------------# 204 205 def _invalidate(self) : 206 207 """ 208 Reset all derived properties so they will be recomputed when requested later. 209 """ 210 211 self.complexity = {} 212 self.T = None 213 self.T_x = None 214 self.pi = None 215 self.pi_fractional = None 216 self.msp = None 217 self.reverse_am = None 218 self.is_q_weighted = False 219 self.is_minimal = False 220 221 def set_states( self, states : list[CausalState] ) : 222 self._invalidate() 223 self.states = states.copy() 224 self.state_idx_map = {} 225 for idx, state in enumerate( self.states ) : 226 self.state_idx_map[ state.name ] = idx 227 228 def set_alphabet( self, alphabet : list[str] ) : 229 230 self._invalidate() 231 232 old_alphabet = self.alphabet.copy() 233 234 aSet = set() 235 aSet.update( alphabet ) 236 237 self.alphabet = sorted(list(aSet)) 238 239 self.symbol_idx_map = {} 240 for idx, symbol in enumerate( self.alphabet ) : 241 self.symbol_idx_map[ symbol ] = idx 242 243 for i, tr in enumerate( self.transitions ) : 244 symbol = old_alphabet[ tr.symbol_idx ] 245 self.transitions[ i ] = Transition( 246 origin_state_idx=tr.origin_state_idx, 247 target_state_idx=tr.target_state_idx, 248 prob=tr.prob, 249 symbol_idx=self.symbol_idx_map[ symbol ] 250 ) 251 252 def set_transitions( self, transitions : list[Transition] ) : 253 self._invalidate() 254 self.transitions = transitions.copy() 255 256 def extend_states( self, states : list[CausalState] ) : 257 self.set_states( self.states + states ) 258 259 def extend_alphabet( self, alphabet : list[str] ) : 260 self.set_alphabet( self.alphabet + alphabet ) 261 262 def extend_transitions( self, transitions : list[Transition] ) : 263 self.set_transitions( self.transitions + transitions ) 264 265 def get_complexity_measure_if_exists(self, measure ) : 266 m = self.complexity.get( measure, None ) 267 return m 268 269 def set_complexity_measure(self, measure, value ) : 270 self.complexity[ measure ] = value 271 272 #-------------------------------------------------------# 273 # Get and Compute # 274 #-------------------------------------------------------# 275 276 def get_complexities( 277 self, 278 with_block_convergence=False ) : 279 280 directly_calculable = [ 281 self.C_mu, 282 self.h_mu, 283 self.H_1, 284 self.rho_mu 285 ] 286 287 requires_block_convergence = [ 288 self.E, 289 self.T_inf, 290 self.S, 291 self.chi 292 ] 293 294 complexities = { m.__name__ : m() for m in directly_calculable } 295 296 if with_block_convergence : 297 298 complexities |= { m.__name__ : m() for m in requires_block_convergence } 299 300 for key in [ 'H_L', 'T_L', 'h_mu_L', 'H_sync' ] : 301 if key in self.complexity : 302 complexities[ key ] = self.complexity[ key ] 303 304 return complexities 305 306 #-------------------------------------------------------# 307 308 def get_metadata(self) : 309 return { 310 "name" : self.name, 311 'complexity' : self.complexity, 312 "description" : self.description 313 } 314 315 def get_transition_matrix(self) : 316 317 if self.T is not None : 318 return self.T 319 320 n_states = len( self.states ) 321 T = np.zeros((n_states, n_states)) 322 323 for tr in self.transitions : 324 T[ tr.origin_state_idx, tr.target_state_idx ] = tr.prob 325 326 self.T = T 327 328 return self.T 329 330 #-------------------------------------------------------# 331 332 def get_T_X(self) : 333 334 if self.T_x is not None : 335 return self.T_x 336 337 n_states = len( self.states ) 338 n_symbols = len( self.alphabet ) 339 340 T_x = [ np.zeros((n_states, n_states)) for _ in range( n_symbols ) ] 341 342 for tr in self.transitions : 343 T_x[ tr.symbol_idx ][tr.origin_state_idx, tr.target_state_idx] = tr.prob 344 345 self.T_x = T_x 346 return self.T_x 347 348 #-------------------------------------------------------# 349 350 def get_msp_qw( 351 self, 352 exact_state_cap: int = 1000, 353 verbose: bool = True, 354 ): 355 if self.msp is not None: 356 return self.msp 357 358 try : 359 360 print( "\nTrying to Compute Mixed State Presentation using Exact Fractions\n" ) 361 362 self.msp = compute_msp_exact( 363 T_x=self.get_Tx_fractional(), 364 pi=self.get_fractional_stationary_distribution(), 365 n_states=len(self.states), 366 alphabet=self.alphabet, 367 exact_state_cap=1000, 368 bool = True 369 ) 370 371 return self.msp 372 373 except RuntimeError as e : 374 warnings.warn( f"Exact msp failed: {e} Falling back to msp approximation." ) 375 376 return self.get_msp() 377 378 def get_msp( 379 self, 380 exact_state_cap: int = 175_000, 381 jsd_eps: float = 1e-7, 382 k_ann: int = 50, 383 verbose = True, 384 ) : 385 386 if self.msp is not None: 387 return self.msp 388 389 T_x = self.get_T_X() 390 pi = self.get_stationary_distribution() 391 392 T_stacked = np.stack(T_x) 393 n_symbols = len(self.alphabet) 394 n_input_states = T_stacked.shape[1] 395 396 print( "\nComputing Mixed State Presentation..." ) 397 398 self.msp = compute_msp( 399 T_x=T_x, 400 pi=pi, 401 n_states=len(self.states), 402 alphabet=self.alphabet, 403 exact_state_cap=exact_state_cap, 404 verbose=verbose 405 ) 406 407 return self.msp 408 409 def get_reverse_am(self) : 410 411 if self.reverse_am is not None: 412 return self.reverse_am 413 414 pi = self.get_stationary_distribution() 415 self.reverse_am = copy.deepcopy(self) 416 417 new_transitions = [] 418 for tr in self.transitions: 419 i = tr.target_state_idx 420 j = tr.origin_state_idx 421 422 p_reversed = (pi[j] * tr.prob) / pi[i] 423 424 new_transitions.append( 425 Transition( 426 origin_state_idx=i, 427 target_state_idx=j, 428 prob=p_reversed, 429 symbol_idx=tr.symbol_idx 430 ) 431 ) 432 433 self.reverse_am.set_transitions(new_transitions) 434 435 if self.reverse_am.is_epsilon_machine(): 436 return self.reverse_am 437 438 rmsp = self.reverse_am.get_msp_qw( exact_state_cap=len(self.states)*4 ) 439 440 self.reverse_am.set_states( rmsp.states ) 441 self.reverse_am.set_transitions( rmsp.transitions ) 442 self.reverse_am.msp = rmsp 443 self.reverse_am.start_state = 0 444 445 self.reverse_am.collapse_to_largest_strongly_connected_subgraph() 446 self.reverse_am.minimize() 447 448 return self.reverse_am 449 450 #-------------------------------------------------------# 451 452 def get_Tx_fractional(self) -> list[ list[ list[ Fraction ] ] ] : 453 454 self.to_q_weighted() 455 456 n_states = len( self.states ) 457 n_symbols = len( self.alphabet ) 458 459 T_x = [] 460 461 for x in range( n_symbols ) : 462 T_x.append( [] ) 463 for i in range( n_states ) : 464 T_x[ x ].append( [ 0 for _ in range( n_states ) ] ) 465 466 for tr in self.transitions : 467 T_x[ tr.symbol_idx ][ tr.origin_state_idx ][ tr.target_state_idx ] = tr.pq 468 469 return T_x 470 471 def get_T_sympy( self ) : 472 473 self.to_q_weighted() 474 475 n = len( self.states ) 476 T = sympy.zeros( n, n ) 477 478 for tr in self.transitions : 479 T[ tr.origin_state_idx, tr.target_state_idx ] = tr.pq 480 481 return T 482 483 def get_fractional_stationary_distribution(self) : 484 485 T = self.get_T_sympy() 486 487 if self.pi_fractional is not None : 488 return self.pi_fractional 489 490 G = self.as_digraph() 491 492 if not nx.is_strongly_connected(G): 493 raise ValueError( "Single stationary distribution requires strongly connected HMM." ) 494 495 self.pi_fractional = solve_for_pi_fractional( T ) 496 497 return self.pi_fractional 498 499 def get_stationary_distribution(self): 500 501 if self.pi is not None : 502 return self.pi 503 504 G = self.as_digraph() 505 506 if not nx.is_strongly_connected(G): 507 raise ValueError( "Single stationary distribution requires strongly connected HMM." ) 508 509 T = self.get_transition_matrix() 510 return solve_for_pi( T ) 511 512 #-------------------------------------------------------# 513 # Complexity Measures # 514 #-------------------------------------------------------# 515 516 def C_mu( self ) : 517 518 """The *statistical complexity* (aka *forecasting complexity*) : 519 520 .. math:: 521 522 C_{\\mu} = - \\sum_{\\sigma \\in \\mathcal{S}} \\Pr(\\sigma) \\log_2 \\Pr(\\sigma), 523 524 where :math:`\\mathcal{S}` are the machine's states [^crutchfield_exact_2016], p.2. 525 526 .. note:: 527 528 **Interpretations** 529 530 * The amount of historical information a process stores. 531 * The amount of structure in a process. 532 533 Returns: 534 535 float: :math:`C_{\\mu}`. 536 537 [^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral 538 Decomposition of Intrinsic Computation*, 2016. 539 <https://arxiv.org/abs/1309.3792> 540 """ 541 542 m = self.get_complexity_measure_if_exists( "C_mu" ) 543 544 if m is not None : 545 return m 546 547 pi = self.get_stationary_distribution() 548 549 h = 0 550 for i, pr in enumerate( pi ) : 551 552 if pr < self.EPS : 553 continue 554 555 h += -pr * np.log2( pr ) 556 557 self.set_complexity_measure( "C_mu", h ) 558 559 return h 560 561 #-------------------------------------------------------# 562 563 def h_mu( self ) : 564 565 """The *entropy rate* : 566 567 .. math:: 568 569 h_{\\mu}(\\boldsymbol{\\mathcal{S}}) = - \\sum_{\\sigma \\in \\mathcal{S}} \\Pr(\\sigma) \\sum_{x \\in \\mathcal{A}} \\Pr(x|\\sigma) \\log_2 \\Pr(x|\\sigma), 570 571 where :math:`\\mathcal{A}` is the alphabet and :math:`\\mathcal{S}` are the machine's states [^crutchfield_exact_2016], p.2. 572 573 .. note:: 574 575 **Interpretations** 576 577 * The lower bound on achievable loss in bits. 578 * The irreducable randomness in the process. 579 * The intrinsic Randomness in the process. 580 581 Returns: 582 583 float: :math:`h_{\\mu}`. 584 585 [^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral 586 Decomposition of Intrinsic Computation*, 2016. 587 <https://arxiv.org/abs/1309.3792> 588 """ 589 590 m = self.get_complexity_measure_if_exists( "h_mu" ) 591 592 if m is not None : 593 return m 594 595 T = self.get_transition_matrix() 596 pi = self.get_stationary_distribution() 597 598 n_states = pi.size 599 600 h = 0 601 for i, pr in enumerate( pi ) : 602 603 if pr < self.EPS : 604 continue 605 606 row_entropy = 0 607 for j in range( len( pi ) ) : 608 609 if T[ i, j ] < self.EPS : 610 continue 611 612 row_entropy -= T[ i, j ] * np.log2( T[ i, j ] ) 613 614 h += pr * row_entropy 615 616 self.set_complexity_measure( "h_mu", h ) 617 618 return h 619 620 #-------------------------------------------------------# 621 622 def H_1(self) -> float : 623 624 """The *single symbol uncertainty*: 625 626 .. math:: 627 628 H(1)=-\\sum_{x\\in\\mathcal{A}} \\Pr(x) \\log_2{\\Pr(x)}, 629 630 where :math:`\\mathcal{A}` is the alphabet [^James_2018], p.2. 631 632 .. note:: 633 634 **Interpretations** 635 636 * How uncertain you are on average about a single measurement with no context. 637 638 Returns: 639 640 float: :math:`H(1)`. 641 642 [^James_2018]: James et al., Anatomy of a Bit: Information in a Time Series Observation, 2018. 643 <https://arxiv.org/abs/1105.2988> 644 """ 645 646 m = self.get_complexity_measure_if_exists("H_1") 647 if m is not None: 648 return m 649 650 pi = self.get_stationary_distribution() 651 T_X = self.get_T_X() # dict: symbol -> matrix 652 653 h = 0.0 654 for T_x in T_X: 655 # Pr(x) = sum_i pi[i] * sum_j T^(x)[i,j] 656 p_sym = 0.0 657 for i, pr in enumerate(pi): 658 if pr < self.EPS: 659 continue 660 p_sym += pr * T_x[i, :].sum() 661 662 if p_sym < self.EPS: 663 continue 664 h -= p_sym * np.log2(p_sym) 665 666 self.set_complexity_measure("H_1", h) 667 return h 668 669 #-------------------------------------------------------# 670 671 def rho_mu(self) -> float : 672 673 """The *anticipated information* [^James_2018], p.3.: 674 675 .. math:: 676 677 \\rho_{\\mu}= H(1) - h_{\\mu} 678 679 Returns: 680 681 float: :math:`\\rho_{\\mu}` 682 683 [^James_2018]: James et al., Anatomy of a Bit: Information in a Time Series Observation, 2018. 684 <https://arxiv.org/abs/1105.2988> 685 """ 686 687 m = self.get_complexity_measure_if_exists("rho_mu") 688 689 if m is not None: 690 return m 691 692 rho = self.H_1() - self.h_mu() 693 694 self.set_complexity_measure("rho_mu", rho) 695 696 return rho 697 698 #-------------------------------------------------------# 699 700 def block_convergence( self ) : 701 702 """ 703 "Run block entropy convergence and return a dict with $\\mathbf{E}, \\mathbf{S}, $\\mathbf{T},\\mathbf{T}(L), \\mathcal{H}(L),$ and $h_{\\mu}(L)$. 704 """ 705 706 trs = [ [] for _ in range( len( self.states ) ) ] 707 for tr in self.transitions : 708 trs[ tr.origin_state_idx ].append( ( 709 tr.symbol_idx, 710 float( tr.prob ), 711 tr.target_state_idx ) ) 712 713 pi = self.get_stationary_distribution() 714 715 state_dist = [ float( pi[ i ] ) for i in range( len( self.states ) ) ] 716 branches = [(1.0, list(state_dist))] 717 718 print( "\nComputing Block Entropy\n" ) 719 720 C = am_fast.block_entropy_convergence( 721 h_mu = self.h_mu(), 722 n_states = len( self.states ), 723 n_symbols = len( self.alphabet ), 724 convergence_tol = 1e-6, 725 precision = 10, 726 eps = 1e-25, 727 branches = branches, 728 trans = trs, 729 max_branches = 30_000_000 730 ) 731 732 print( "Done\n" ) 733 734 self.set_complexity_measure( f"E", C.E ) 735 self.set_complexity_measure( f"S", C.S ) 736 self.set_complexity_measure( f"T_inf", C.T ) 737 self.set_complexity_measure( f"T_L", C.T_L.tolist() ) 738 self.set_complexity_measure( f"H_L", C.H_L.tolist() ) 739 self.set_complexity_measure( f"h_mu_L", C.h_mu_L.tolist() ) 740 self.set_complexity_measure( f"H_sync", C.H_sync.tolist() ) 741 742 return C 743 744 #-------------------------------------------------------# 745 746 def E( self ) -> float : 747 748 """The *excess entropy* [^crutchfield_exact_2016], p.4: 749 750 .. math:: 751 752 \\mathbf{E} \\equiv \\sum_{L=1}^{\\infty} I[X_{-\\infty:0}; X_{0:\\infty}] 753 754 Computed via :meth:`get_msp` and :meth:`amachine.am_msp.MSP.get_E_S_T`, or :meth:`amachine.am_fast.block_entropy_convergence` 755 756 .. note:: 757 758 **Interpretations** 759 760 * The information from the past that reduces uncertainty in the future [^crutchfield_exact_2016]. 761 * How much information an observer must extract to synchronize to the process. 762 * Measures how long the process appears more complex than it asymptotically is. 763 * Vanishes for immediately synchronizable processes. 764 765 Returns: 766 767 float: :math:`\\mathbf{E}` 768 769 [^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral 770 Decomposition of Intrinsic Computation*, 2016. 771 <https://arxiv.org/abs/1309.3792> 772 """ 773 774 m = self.get_complexity_measure_if_exists( "E" ) 775 776 if m is not None : 777 return m 778 779 try : 780 msp = self.get_msp() 781 E, S, T = msp.get_E_S_T() 782 self.set_complexity_measure( "E", E ) 783 self.set_complexity_measure( "S", S ) 784 self.set_complexity_measure( "T_inf", T ) 785 786 except Exception as e : 787 788 print( f"MSP failed {e}" ) 789 790 C = self.block_convergence() 791 E = C.E 792 self.set_complexity_measure( "E", E ) 793 794 return E 795 796 #-------------------------------------------------------# 797 798 def S( self ) -> float : 799 800 """The *synchronization* information: 801 802 .. math:: 803 804 \\mathbf{S} \\equiv \\sum_{L=1}^{\\infty} \\mathcal{H}(L), 805 806 where :math:`\\mathcal{H}(L)` is the average state uncertainty having seen all length-L words [^crutchfield_exact_2016], p.4. 807 808 .. note:: 809 810 **Interpretations** 811 812 * The total amount of state information that an observer must extract to become synchronized [^crutchfield_exact_2016]. 813 814 Computed via :meth:`get_msp` and :meth:`amachine.am_msp.MSP.get_E_S_T`, or :meth:`amachine.am_fast.block_entropy_convergence` 815 816 Returns: 817 818 float: :math:`\\mathbf{S}` 819 820 [^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral 821 Decomposition of Intrinsic Computation*, 2016. 822 <https://arxiv.org/abs/1309.3792> 823 """ 824 825 m = self.get_complexity_measure_if_exists( "S" ) 826 827 if m is not None : 828 return m 829 830 try : 831 msp = self.get_msp() 832 E, S, T = msp.get_E_S_T() 833 self.set_complexity_measure( "E", E ) 834 self.set_complexity_measure( "S", S ) 835 self.set_complexity_measure( "T_inf", T ) 836 837 except Exception as e : 838 print( f"{e} \nFalling back to iterative estimation.") 839 exit() 840 841 C = self.block_convergence() 842 S = C.S 843 self.set_complexity_measure( "S", S ) 844 845 return S 846 847 #-------------------------------------------------------# 848 849 def T_inf( self ) -> float : 850 851 """The *transient information*[^crutchfield_exact_2016], p.4: 852 853 .. math:: 854 855 \\mathbf{T} \\equiv \\sum_{L=1}^{\\infty} L \\left[ h_{\\mu}(L) - h_{\\mu} \\right] 856 857 Computed via :meth:`get_msp` and :meth:`amachine.am_msp.MSP.get_E_S_T`, or :meth:`amachine.am_fast.block_entropy_convergence` 858 859 .. note:: 860 861 **Interpretations** 862 863 * The amount of information one must extract from observations so that the block entropy converges to its linear asymptote[^crutchfield_exact_2016]. 864 865 Returns: 866 867 float: :math:`\\mathbf{T}` 868 869 [^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral 870 Decomposition of Intrinsic Computation*, 2016. 871 <https://arxiv.org/abs/1309.3792> 872 """ 873 874 m = self.get_complexity_measure_if_exists( "T_inf" ) 875 876 if m is not None : 877 return m 878 879 try : 880 msp = self.get_msp() 881 E, S, T = msp.get_E_S_T() 882 self.set_complexity_measure( "E", E ) 883 self.set_complexity_measure( "S", S ) 884 self.set_complexity_measure( "T_inf", T ) 885 886 except Exception as e : 887 print( f"{e} \nFalling back to iterative estimation.") 888 C = self.block_convergence() 889 T_inf = C.T 890 self.set_complexity_measure( "T_inf", T_inf ) 891 892 return T_inf 893 894 #-------------------------------------------------------# 895 896 def chi( self ) -> float : 897 898 """Computes the foward crypticity[^crutchfield_crypticity_2009][^Mahoney_crypticity_2021], p.2: 899 900 .. math:: 901 902 \\chi = C_{\\mu} - \\mathbf{E} 903 904 :math:`C_{\\mu}` is trivially computed from the stationary distribution in :meth:`C_mu` and :math:`\\mathbf{E}` in :meth:`E`. 905 906 .. note:: 907 908 **Interpretations** 909 910 * Difference between internal stored information and apparent information to an observer. 911 * How muching information is hiding in the system. 912 913 Returns: 914 915 float: :math:`\\chi` 916 917 [^crutchfield_crypticity_2009]: Crutchfield et al., Time’s barbed arrow: Irreversibility, crypticity, and stored information, 2009. 918 <https://arxiv.org/abs/0902.1209> 919 920 [^Mahoney_crypticity_2021]: Mahoney et al., Information Accessibility and Cryptic Processes, 2021. 921 <https://arxiv.org/abs/0905.4787> 922 """ 923 924 m = self.get_complexity_measure_if_exists( "chi" ) 925 926 if m is not None : 927 return m 928 929 chi = self.C_mu() - self.E() 930 931 if chi < 0 : 932 933 # if chi is 0, accumulated floating point error can result in small negative values 934 if chi < -1e-5: 935 warnings.warn(f"Crypticity is negative ({chi:.6e}).") 936 937 chi = np.clamp( chi, 0 ) 938 939 self.set_complexity_measure( "chi", chi ) 940 941 return chi 942 943 #-------------------------------------------------------# 944 # Properties # 945 #-------------------------------------------------------# 946 947 def is_row_stochastic(self) : 948 949 """ 950 Check that all states have outgoing transition probabilities that sum to 1. 951 """ 952 953 sums = np.zeros( len( self.states ) ) 954 for tr in self.transitions : 955 sums[ tr.origin_state_idx ] += tr.prob 956 return np.allclose( sums, 1.0 ) 957 958 #-------------------------------------------------------# 959 960 def is_unifilar(self) : 961 962 """ 963 Check that no state emits the same symbol on transitions to different states. 964 """ 965 966 symbol_trs = np.full( ( len( self.states ), len( self.alphabet) ), -1 ) 967 968 for tr in self.transitions : 969 970 if symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] == -1 : 971 symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] = tr.target_state_idx 972 973 elif symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] != tr.target_state_idx : 974 return False 975 976 return True 977 978 #-------------------------------------------------------# 979 980 def is_strongly_connected(self) : 981 982 """ 983 Check if every state is reachable from every other state. Relies on [nx.is_strongly_connected](https://networkx.org/documentation/latest/reference/algorithms/generated/networkx.algorithms.components.is_strongly_connected.html). 984 """ 985 986 return nx.is_strongly_connected( self.as_digraph() ) 987 988 #-------------------------------------------------------# 989 990 def is_aperiodic(self) : 991 992 """ 993 Checks if machine is periodic. Relies on [nx.is_aperiodic](https://networkx.org/documentation/latest/reference/algorithms/generated/networkx.algorithms.dag.is_aperiodic.html), "A strongly connected directed graph is aperiodic if there is no integer k > 1 that divides the length of every cycle in the graph." 994 """ 995 996 return nx.is_aperiodic( self.as_digraph() ) 997 998 #-------------------------------------------------------# 999 1000 def _is_minimal_as_dfa( self, topological_only : bool, verbose=True ) : 1001 1002 with_probs = not topological_only 1003 1004 # Construct the DFA 1005 dfa = self.as_dfa( with_probs=with_probs ) 1006 1007 # Minimize the DFA 1008 #dfa = dfa.minify(retain_names=True) 1009 dfa = am_fast.minify_cpp( dfa, retain_names=True ) 1010 1011 # check we have minimal number of states 1012 if len( dfa.states ) != len( self.states ) : 1013 if verbose : 1014 print( f"Not minimal reduces from {len( self.states )} to {len( dfa.states )} states" ) 1015 return False 1016 1017 return True 1018 1019 def is_topological_epsilon_machine( self, verbose=True ) : 1020 1021 """ 1022 Checks if the HMM is a topological $\\epsilon$-machine [^1]. 1023 1024 [^1]: Johnson et al, Enumerating Finitary Processes, 2024. 1025 <https://arxiv.org/abs/1011.0036> 1026 """ 1027 1028 if not ( self.is_unifilar() and self.is_strongly_connected() ) : 1029 if verbose : 1030 print( f"Either non unifilar or not strongly connected" ) 1031 return False 1032 else : 1033 return self._is_minimal_as_dfa( topological_only=True, verbose=verbose ) 1034 1035 def is_epsilon_machine( self, verbose=True ) : 1036 1037 if not ( self.is_unifilar() and self.is_strongly_connected() ) : 1038 if verbose : 1039 print( f"Either non unifilar or not strongly connected" ) 1040 return False 1041 else : 1042 return self._is_minimal_as_dfa( topological_only=False, verbose=verbose ) 1043 1044 #-------------------------------------------------------# 1045 # Properties # 1046 #-------------------------------------------------------# 1047 1048 def minimize(self, retain_names: bool = True, verbose=False): 1049 1050 """ 1051 Minimizes the HMM, resulting in an :math:`\\epsilon-`machine if the HMM 1052 is unifilar and strongly connected. Converts the HMM to a DFA with symbols 1053 labeled jointly with symbols and probabilities, and uses Myhill-Nerode 1054 equivalence for minimization. Relies on `automata_lib` and uses 1055 `automata.fa.dfa.DFA.minify` with `allow_partial=True`, and all states 1056 final. 1057 1058 Args: 1059 retain_names (bool): If `True`, the merged states will be named by their union, e.g. `{s_0, s_1}`, and other states will retain their origion names. Otherwise, they will be relabled `{ '0', '1', ..., 'n-1' }`. 1060 1061 Returns: 1062 1063 automata.fa.dfa.DFA : the resulting DFA. 1064 """ 1065 1066 if self.is_minimal : 1067 return 1068 1069 start = time.perf_counter() 1070 1071 if not self.is_unifilar(): 1072 raise ValueError( 1073 "DFA minimization is not valid for non-unifilar HMMs" 1074 ) 1075 1076 was_strongly_connected = self.is_strongly_connected() 1077 1078 was_row_stochastic = self.is_row_stochastic() 1079 n_states_before = len(self.states) 1080 1081 dfa = self.as_dfa(with_probs=True) 1082 1083 #min_dfa = self.as_dfa(with_probs=True).minify(retain_names=True) 1084 min_dfa = am_fast.minify_cpp( dfa, retain_names=True ) 1085 1086 # Build lookup from original state index -> CausalState object 1087 orig_state = {i: s for i, s in enumerate(self.states)} 1088 eq_list = list(min_dfa.states) 1089 1090 start_eq = min_dfa.initial_state 1091 1092 # Separate the start state, then sort the rest by the 1093 # smallest original state index inside each equivalence class. 1094 other_eqs = [eq for eq in eq_list if eq != start_eq] 1095 other_eqs.sort(key=lambda eq: min(eq)) 1096 1097 # Recombine so start eq comes first, followed by the sorted remaining classes 1098 eq_list = [start_eq] + other_eqs 1099 # ---------------------------------------------------------- 1100 1101 # Recompute eq_to_idx with the new ordering 1102 eq_to_idx = {eq: i for i, eq in enumerate(eq_list)} 1103 1104 # new_start is now guaranteed to be 0 1105 new_start = 0 1106 1107 # Map each original state index -> its equivalence class 1108 # Guard: minify() silently drops unreachable states 1109 orig_to_eq = {s: eq for eq in min_dfa.states for s in eq} 1110 1111 # Build lookup from original state index -> its transitions 1112 orig_trs = defaultdict(list) 1113 for t in self.transitions: 1114 orig_trs[t.origin_state_idx].append(t) 1115 1116 new_trs = [] 1117 for eq in min_dfa.states: 1118 rep = next(iter(eq)) 1119 origin_idx = eq_to_idx[eq] 1120 for t in orig_trs[rep]: 1121 target_eq = orig_to_eq[t.target_state_idx] 1122 target_idx = eq_to_idx[target_eq] 1123 new_trs.append(Transition( 1124 origin_state_idx = origin_idx, 1125 target_state_idx = target_idx, 1126 prob = t.prob, 1127 symbol_idx = t.symbol_idx, 1128 )) 1129 1130 members_list = [[orig_state[i] for i in sorted(eq)] for eq in eq_list] # sorted for determinism 1131 1132 # Compute new names 1133 if retain_names: 1134 new_names = [ 1135 "{" + ",".join(str(m.name) for m in members) + "}" if len(members) > 1 1136 else members[0].name 1137 for members in members_list 1138 ] 1139 else: 1140 new_names = [str(j) for j in range(len(eq_list))] 1141 1142 old_name_to_new_name = { 1143 m.name: new_names[j] 1144 for j, members in enumerate(members_list) 1145 for m in members 1146 } 1147 1148 # Build the new states, preserving classes and isomorphs regardless of naming 1149 new_states = [] 1150 for j, (eq, members, name) in enumerate(zip(eq_list, members_list, new_names)): 1151 classes = set().union(*(m.classes for m in members)) 1152 isomorphs = { 1153 old_name_to_new_name.get(iso, iso) 1154 for m in members 1155 for iso in m.isomorphs 1156 if old_name_to_new_name.get(iso, iso) != name 1157 } 1158 new_states.append(CausalState( 1159 name = name, 1160 classes = classes, 1161 isomorphs = isomorphs, 1162 )) 1163 1164 self.set_states(new_states) 1165 self.set_transitions(new_trs) 1166 self.start_state = new_start 1167 1168 if n_states_before == len(new_states) and verbose : 1169 print( f"{n_states_before} state HMM was already minimal.\n" ) 1170 elif verbose : 1171 print( f"Minimized from {n_states_before} to {len(new_states)}\n" ) 1172 1173 if not ( was_strongly_connected == self.is_strongly_connected() ) : 1174 raise RuntimeError( 1175 f"Minimization broke strongly connected" 1176 ) 1177 1178 if not ( was_row_stochastic == self.is_row_stochastic() ) : 1179 raise RuntimeError( 1180 f"Minimization broke row stochasticity" 1181 ) 1182 1183 self.is_minimal = True 1184 1185 #-------------------------------------------------------# 1186 # Modifiers # 1187 #-------------------------------------------------------# 1188 1189 1190 def collapse_to_largest_strongly_connected_subgraph( self, rename_states=True ) : 1191 1192 # get equivalent networkx graph 1193 G = self.as_digraph() 1194 1195 # if already strongly connected, nothing to do 1196 if not nx.is_strongly_connected( G ) : 1197 1198 start = time.perf_counter() 1199 subgraph_nodes = list( nx.strongly_connected_components( G ) ) 1200 1201 # decompose into strongly connected components and sort by length 1202 # subgraph_nodes = list(nx.strongly_connected_components( G )) 1203 subgraph_nodes.sort(key=len) 1204 component_state_set = subgraph_nodes[-1] 1205 1206 # Take the largest strongly connected component (as list of state names) 1207 component_states = sorted( list( component_state_set ) ) 1208 1209 # make temporary copies of the old transitions and states 1210 old_transitions = [ 1211 Transition( 1212 origin_state_idx=tr.origin_state_idx, 1213 target_state_idx=tr.target_state_idx, 1214 prob=tr.prob, 1215 symbol_idx=tr.symbol_idx 1216 ) 1217 1218 for tr in self.transitions 1219 ] 1220 1221 old_states = [ 1222 CausalState( 1223 name=s.name, 1224 classes=s.classes, 1225 isomorphs=s.isomorphs 1226 ) 1227 for s in self.states 1228 ] 1229 1230 self.set_states( 1231 states=[ 1232 state 1233 for i, state in enumerate( old_states ) if i in component_state_set 1234 ] 1235 ) 1236 1237 # we will build new transition list based on those belonging to the component 1238 self.set_transitions( transitions= [] ) 1239 1240 # for tracking which new transitions leave each state 1241 transitions_from_state = { state : set() for state in component_states } 1242 new_transitions = [] 1243 1244 for tr in old_transitions : 1245 1246 origin_state_name = old_states[ tr.origin_state_idx ].name 1247 target_state_name = old_states[ tr.target_state_idx ].name 1248 1249 # skip transitions that connect separate strongly connected components 1250 if not ( tr.origin_state_idx in component_state_set and tr.target_state_idx in component_state_set ) : 1251 continue 1252 1253 # track transitions (by index in new transitions list) that leave this state 1254 transitions_from_state[ tr.origin_state_idx ].add( len( new_transitions ) ) 1255 1256 my_origin_state_idx = self.state_idx_map[ origin_state_name ] 1257 my_target_state_idx = self.state_idx_map[ target_state_name ] 1258 1259 new_transitions.append( 1260 Transition( 1261 origin_state_idx=my_origin_state_idx, 1262 target_state_idx=my_target_state_idx, 1263 prob=tr.prob, 1264 symbol_idx=tr.symbol_idx 1265 ) ) 1266 1267 self.set_transitions( transitions=new_transitions ) 1268 1269 # if we removed an outgoing transition from a state, we need to distribute its probability 1270 # among the remaining outgoing transitions from the state 1271 for state in component_states : 1272 1273 # get the set of transitions leaving this state 1274 state_trs = transitions_from_state[ state ] 1275 1276 # sum the probabilities of the outgoing transitions from the state 1277 p_sum = np.sum( [ self.transitions[ i ].prob for i in state_trs ] ) 1278 1279 # how much probability is missing 1280 diff = 1.0 - p_sum 1281 1282 # if significant difference 1283 if abs( diff ) > self.EPS : 1284 1285 # calculate how much of the difference each transition gets 1286 adjustment = diff / len( state_trs ) 1287 1288 # update the transitions 1289 for i in state_trs : 1290 1291 # transition with origional probability 1292 tr = self.transitions[ i ] 1293 1294 # adjusted probability 1295 self.transitions[ i ] = Transition( 1296 origin_state_idx=tr.origin_state_idx, 1297 target_state_idx=tr.target_state_idx, 1298 prob=tr.prob + adjustment, 1299 symbol_idx=tr.symbol_idx ) 1300 1301 if rename_states : 1302 self.set_states( [ 1303 CausalState( name=f"{i}" ) 1304 for i, s in enumerate( self.states ) 1305 ] ) 1306 1307 1308 def to_q_weighted( self, denominator_limit=1000 ) : 1309 1310 """ 1311 Approximates the existing transition probabilities with exact fractions, stores 1312 the fractional probabilities as Fraction in Transition.pq, and sets the floating 1313 point probabilty to `float(pq)`. If `denominator_limit` is too small for a sane 1314 conversion, the function recurses with `denominator_limit=denominator_limit*10`. 1315 1316 Args: 1317 denominator_limit (int): The initial input to :meth:`Fraction.limit_denominator` in 1318 the conversion. 1319 """ 1320 1321 if self.is_q_weighted : 1322 return 1323 1324 if not self.is_row_stochastic() : 1325 raise ValueError( "Cannot convert to q-weighted because not row stochastic" ) 1326 1327 t_from = [[] for _ in range(len(self.states))] 1328 1329 for i, tr in enumerate( self.transitions ) : 1330 t_from[ tr.origin_state_idx ].append( i ) 1331 1332 new_transitions = [] 1333 for t_list in t_from : 1334 1335 if not t_list : 1336 continue 1337 1338 p_q_sum = Fraction(0,1) 1339 p_qs = [] 1340 1341 for t_idx in t_list : 1342 1343 p_q = Fraction( self.transitions[ t_idx ].prob ).limit_denominator( denominator_limit ) 1344 p_q_sum += p_q 1345 p_qs.append( p_q ) 1346 1347 if p_q_sum != Fraction(1,1) : 1348 1349 max_pq_i = np.argmax( p_qs ) 1350 max_oq = p_qs[ max_pq_i ] 1351 1352 diff = p_q_sum - Fraction(1,1) 1353 1354 # If recurse with higher resolution 1355 if diff > max_oq : 1356 return self.to_q_weighted( denominator_limit*10 ) 1357 else : 1358 p_qs[ max_pq_i ] -= diff 1359 1360 for i, t_idx in enumerate( t_list ) : 1361 new_transitions.append( 1362 Transition( 1363 origin_state_idx=self.transitions[ t_idx ].origin_state_idx, 1364 target_state_idx=self.transitions[ t_idx ].target_state_idx, 1365 prob=float(p_qs[ i ]), 1366 symbol_idx=self.transitions[ t_idx ].symbol_idx, 1367 pq=p_qs[ i ] 1368 ) 1369 ) 1370 1371 self.set_transitions( new_transitions ) 1372 self.is_q_weighted = True 1373 1374 #-------------------------------------------------------# 1375 # Data Generation # 1376 #-------------------------------------------------------# 1377 1378 def isomorphic_shift( 1379 self, 1380 input_symbol_indices: np.ndarray, 1381 input_state_indices: np.ndarray, 1382 shift : int = 1 1383 ) -> dict[str, np.ndarray]: 1384 1385 """ 1386 Generates a new sequence of of symbols that are permuted with the symbols emitted by 1387 isomorphic states, if they exists. 1388 1389 :math:`\\sigma_o = \\mathcal{S}\\left[\\texttt{input\\_state\\_indices}[i]\\right]`<br> 1390 :math:`\\sigma_t = \\mathcal{S}\\left[\\texttt{input\\_state\\_indices}[i+1]\\right]` 1391 1392 :math:`\\mathcal{I}(\\sigma_o) = \\\\{\\sigma^0_o,\\, \\sigma^1_o,\\, \\dots,\\, \\sigma^{n-1}_o \\\\}`<br> 1393 :math:`\\mathcal{I}(\\sigma_t) = \\\\{\\sigma^0_t,\\, \\sigma^1_t,\\, \\dots,\\, \\sigma^{n-1}_t \\\\}` 1394 1395 :math:`k = \\bigl(\\mathcal{I}(\\sigma_o).\\texttt{index}(\\sigma_o) + \\texttt{shift}\\bigr) \\bmod n` 1396 1397 :math:`\\texttt{output\\_symbol\\_indices}[i] := T(\\sigma_o^k,\\, \\sigma_t^k).\\text{symbol\\_index}`<br> 1398 :math:`\\texttt{output\\_state\\_indices}[i] := \\mathcal{S}.\\texttt{index}(\\sigma_o^k)`<br> 1399 :math:`\\texttt{output\\_state\\_indices}[i+1] := \\mathcal{S}.\\texttt{index}(\\sigma_t^k)` 1400 1401 Where :math:`\\mathcal{I}(\\sigma)`` is the ordered set of states isomorphic to :math:`\\sigma` including :math:`\\sigma` itself. 1402 1403 Args: 1404 input_symbol_indices (np.ndarray): The sequence of generated symbols. 1405 input_state_indices (np.ndarray): The sequence of states that generated symbols with the final state at the end. 1406 shift : int: How much to shift the symbols across the isomorphic states. 1407 """ 1408 1409 if not any( state.isomorphs for state in self.states ): 1410 raise ValueError("HMM has no states with isomorphs") 1411 1412 inputs = np.asarray(input_symbol_indices) 1413 states = np.asarray(input_state_indices) 1414 1415 n_states = len(self.states) 1416 1417 tr_sym_table = np.full((n_states, n_states), -1, dtype=np.int32) 1418 1419 for tr in self.transitions: 1420 tr_sym_table[ tr.origin_state_idx, tr.target_state_idx ] = tr.symbol_idx 1421 1422 # Build isomorph remapping: identity by default, overridden where isomorphs exist 1423 iso_table = np.arange(n_states, dtype=np.int32) 1424 1425 for i, state in enumerate(self.states): 1426 if state.isomorphs: 1427 isomorph = sorted(state.isomorphs)[0] 1428 iso_table[ i ] = self.state_idx_map[isomorph] 1429 1430 for i, state in enumerate(self.states): 1431 if state.isomorphs: 1432 1433 # extend the isomorph list to include the identity 1434 isormorphs_with_identity = sorted( [ i ] + [ self.state_idx_map[iso] for iso in state.isomorphs ] ) 1435 1436 # find the identity index 1437 pos = isormorphs_with_identity.index( i ) 1438 1439 # cyclical shift 1440 iso_table[ i ] = isormorphs_with_identity[ ( pos + shift ) % len( isormorphs_with_identity ) ] 1441 1442 origins = states[:-1] 1443 targets = states[1:] 1444 1445 out_origins = iso_table[origins] 1446 out_targets = iso_table[targets] 1447 1448 inv_sym = tr_sym_table[ out_origins, out_targets ] 1449 inv_sts = np.empty(states.size, dtype=states.dtype) 1450 1451 inv_sts[:-1] = out_origins 1452 inv_sts[-1] = out_targets[-1] 1453 1454 return { 1455 "symbol_index": inv_sym.astype(inputs.dtype), 1456 "state_index": inv_sts, 1457 } 1458 1459 def generate_data( 1460 self, 1461 file_prefix: str, 1462 n_gen: int, 1463 include_states: bool, 1464 isomorphic_shifts : set[int]=None, 1465 random_seed : int=42 ) -> dict[any] : 1466 1467 trs = [ [] for _ in range( len( self.states ) ) ] 1468 for tr in self.transitions : 1469 trs[ tr.origin_state_idx ].append( ( 1470 tr.symbol_idx, 1471 float( tr.prob ), 1472 tr.target_state_idx ) ) 1473 1474 data = am_fast.generate_data( 1475 n_gen=n_gen, 1476 start_state=self.start_state, 1477 transitions=trs, 1478 alphabet=sorted(list(self.alphabet)), 1479 include_states=include_states, 1480 random_seed=random_seed 1481 ) 1482 1483 if isomorphic_shifts is not None : 1484 1485 if not include_states : 1486 raise ValueError( "Isomorphic inversion requires include_states=True" ) 1487 1488 data[ "isomorphic_shifts" ] = {} 1489 1490 for shift in isomorphic_shifts : 1491 1492 try : 1493 1494 shifted = self.isomorphic_shift( 1495 input_symbol_indices=data[ "symbol_index" ], 1496 input_state_indices=data[ "state_index" ], 1497 shift=shift 1498 ) 1499 1500 data[ "isomorphic_shifts" ][ shift ] = { 1501 "symbol_index" : shifted[ "symbol_index" ], 1502 "state_index" : shifted[ "state_index" ] 1503 } 1504 1505 except Exception as e : 1506 print( f"Exception {e}" ) 1507 1508 am_fast.save_data( 1509 data=data, 1510 file_prefix=file_prefix, 1511 alphabet=sorted(list(self.alphabet)), 1512 n_states=len( self.states ), 1513 start_state=self.start_state, 1514 random_seed=random_seed, 1515 machine_metadata=self.get_metadata() ) 1516 1517 return data 1518 1519 #-------------------------------------------------------# 1520 # Basic Visualization # 1521 #-------------------------------------------------------# 1522 1523 def draw_graph( 1524 self, 1525 engine : str = 'dot', 1526 output_dir : Path = Path('.'), 1527 show : bool = True 1528 ) -> None : 1529 1530 """ 1531 Draws the machine using [pygraphiviz](https://pygraphviz.github.io/documentation/stable/) and saves it. 1532 1533 Returns: 1534 1535 networkx.DiGraph : the resulting graph. 1536 """ 1537 1538 G = self.as_digraph() 1539 1540 subgraphs = None if nx.is_strongly_connected( G ) else list( nx.strongly_connected_components( G ) ) 1541 1542 am_vis.draw_graph( 1543 self, 1544 output_dir=output_dir, 1545 title="am_graph", 1546 view=show, 1547 subgraphs=subgraphs, 1548 engine=engine ) 1549 1550 #-------------------------------------------------------# 1551 # Alternate Forms # 1552 #-------------------------------------------------------# 1553 1554 1555 def as_digraph( self ) -> nx.DiGraph : 1556 1557 """ 1558 Builds a [networkx.DiGraph](https://networkx.org/documentation/stable/reference/classes/digraph.html) constructed from the machine's transitions with no edge symbols or weights. 1559 1560 Returns: 1561 1562 networkx.DiGraph : the resulting graph. 1563 """ 1564 1565 G = nx.DiGraph() 1566 G.add_nodes_from( [ i for i, s in enumerate( self.states ) ] ) 1567 1568 for tr in self.transitions : 1569 G.add_edge( tr.origin_state_idx, tr.target_state_idx ) 1570 1571 return G 1572 1573 def as_dfa( self, with_probs : bool ) : 1574 1575 """ 1576 Builds an [automata.fa.dfa.DFA](https://caleb531.github.io/automata/api/fa/class-dfa/) constructed from the machine's transitions. 1577 1578 Args: 1579 with_probs (bool): If true the DFA transitions are labeled based on 1580 the symbol of the machines transition concatenated with its 1581 probability, othwise, the only the symbols. 1582 1583 Returns: 1584 1585 automata.fa.dfa.DFA : the resulting DFA. 1586 """ 1587 1588 precision=8 1589 1590 def edge_label( symb, prob ) : 1591 return f"({symb},{round(prob, precision)})" 1592 1593 # Build states, symbols, and transitions 1594 dfa_states = { i for i, _ in enumerate( self.states ) } 1595 1596 if not with_probs : 1597 dfa_symbols = set( { t.symbol_idx for t in self.transitions } ) 1598 else : 1599 dfa_symbols = set( { edge_label( t.symbol_idx, t.prob ) for t in self.transitions } ) 1600 1601 dfa_transitions = defaultdict(dict) 1602 1603 if not with_probs : 1604 for t in self.transitions : 1605 dfa_transitions[ t.origin_state_idx ][ t.symbol_idx ] = t.target_state_idx 1606 else : 1607 for t in self.transitions : 1608 dfa_transitions[ t.origin_state_idx ][ edge_label( t.symbol_idx, t.prob ) ] = t.target_state_idx 1609 1610 # Construct the DFA 1611 return DFA( 1612 states=dfa_states, 1613 input_symbols=dfa_symbols, 1614 transitions=dfa_transitions, 1615 initial_state=self.start_state, 1616 allow_partial=True, 1617 final_states={ 1618 s for s in dfa_states 1619 } 1620 )
38class HMM(Machine) : 39 40 """Hidden Markov model implementing epsilon machines, mixed state presentations, 41 complexity measures, and data generation. 42 43 Args: 44 states (list[CausalState] | None): A list of causal states. 45 transitions (list[Transition] | None): A list of transitions between states. 46 start_state (int): Index of the start state. 47 alphabet (list[str]): List of symbols making up the alphabet. 48 name (str): Name of the model. 49 description (str): Description of the model. 50 """ 51 52 def __init__( 53 self, 54 states : list[CausalState] | None = None, 55 transitions : list[Transition] | None = None, 56 start_state : int = 0, 57 alphabet : list[str] | None = None, 58 name : str = "", 59 description : str = "" ) : 60 61 self.alphabet : list[str] | None = alphabet or [] 62 self.states : list[CausalState] | None= states or [] 63 self.transitions : list[Transition] | None= transitions or [] 64 65 self.set_alphabet( self.alphabet ) 66 self.set_states( self.states ) 67 self.set_transitions( self.transitions ) 68 69 self.start_state : int = start_state 70 71 self.name : str = name 72 self.description : str = description 73 74 # To be depreciated 75 self.isoclass : str = None 76 77 # --- derived -------- 78 79 self.complexity : dict[str, any] = {} 80 81 self.symbol_idx_map : dict[str,int] = {} 82 self.state_idx_map : dict[str,int] = {} 83 84 self.pi_fractional = None 85 self.pi : np.ndarray | None = None 86 self.T : np.ndarray | None = None 87 self.msp : MSP | None = None 88 self.reverse_am : HMM | None = None 89 self.is_q_weighted : bool = False 90 self.is_minimal : bool = False 91 92 # --- const ---------- 93 94 self.EPS : float = 1e-12 95 96 #-------------------------------------------------------# 97 # Overrides / Getters # 98 #-------------------------------------------------------# 99 100 @override 101 def get_states(self) -> list[CausalState] : 102 return self.states 103 104 @override 105 def get_transitions(self) -> list[Transition] : 106 return self.transitions 107 108 @override 109 def get_alphabet(self) -> list[Symbol]: 110 return [ Symbol(a) for a in self.alphabet ] 111 112 #-------------------------------------------------------# 113 # Serialization # 114 #-------------------------------------------------------# 115 116 117 def to_dict(self) -> dict[str,any]: 118 119 """Create a dict representing the HMM configuration. 120 121 Returns: 122 123 dict[str,any]: Dictionary containing, name, description, states, transitions, alphabet, and isoclass. 124 """ 125 126 return { 127 "name" : self.name, 128 "description" : self.description, 129 "states" : [ asdict(state) for state in self.states ], 130 "transitions" : [ asdict(transition) for transition in self.transitions ], 131 "alphabet" : self.alphabet, 132 "isoclass" : self.isoclass 133 } 134 135 def from_dict( self, config : dict[str,any] ) : 136 137 """Configure the HMM configuration from a dictionary. 138 139 Args: 140 config (dict[str,any]): The HMM configuration. 141 """ 142 143 self.name = config.name 144 self.description = config.description 145 146 self.set_states( states=[ 147 CausalState( 148 name=state[ "name" ], 149 classes=set( state[ "classes" ] ) 150 ) 151 for state in config.states 152 ] ) 153 154 self.set_transitions( transitions=[ 155 Transition( 156 origin_state_idx=tr[ "origin_state_idx" ], 157 target_state_idx=tr[ "target_state_idx" ], 158 prob=tr[ "prob" ], 159 symbol_idx=tr[ "symbol_idx" ] 160 ) 161 for tr in config.transitions 162 ] ) 163 164 self.set_alphabet( alphabet=config.alphabet ) 165 166 def save_config( 167 self, 168 path : Path, 169 with_complexity : bool = False, 170 with_block_convergence : bool = False, 171 with_structural_properties : bool = False, 172 with_causal_properties : bool = False ) : 173 174 config = self.to_dict() 175 176 if with_complexity : 177 178 complexity = self.get_complexities( 179 with_block_convergence=with_block_convergence 180 ) 181 182 config[ "complexity" ] = complexity 183 184 config[ "structural_properties" ] = { 185 "unifilar" : self.is_unifilar(), 186 "row_stochastic" : self.is_row_stochastic(), 187 "strongly_connected" : self.is_strongly_connected(), 188 "aperiodic" : self.is_aperiodic(), 189 "minimal" : self._is_minimal_as_dfa( topological_only=False ), 190 "topologically_minimal" : self._is_minimal_as_dfa( topological_only=True ), 191 "is_epsilon_machine" : self.is_epsilon_machine() 192 } 193 194 with open( path / "am_config.json", "w", encoding="utf-8" ) as f : 195 json.dump( config, f, ensure_ascii=False, indent=2, default=list ) 196 197 def from_file( self, path : Path ) : 198 with open( Path / "am_config.json", "r" ) as f: 199 config = json.load(f) 200 self.from_dict() 201 202 #-------------------------------------------------------# 203 # Setters and State Management # 204 #-------------------------------------------------------# 205 206 def _invalidate(self) : 207 208 """ 209 Reset all derived properties so they will be recomputed when requested later. 210 """ 211 212 self.complexity = {} 213 self.T = None 214 self.T_x = None 215 self.pi = None 216 self.pi_fractional = None 217 self.msp = None 218 self.reverse_am = None 219 self.is_q_weighted = False 220 self.is_minimal = False 221 222 def set_states( self, states : list[CausalState] ) : 223 self._invalidate() 224 self.states = states.copy() 225 self.state_idx_map = {} 226 for idx, state in enumerate( self.states ) : 227 self.state_idx_map[ state.name ] = idx 228 229 def set_alphabet( self, alphabet : list[str] ) : 230 231 self._invalidate() 232 233 old_alphabet = self.alphabet.copy() 234 235 aSet = set() 236 aSet.update( alphabet ) 237 238 self.alphabet = sorted(list(aSet)) 239 240 self.symbol_idx_map = {} 241 for idx, symbol in enumerate( self.alphabet ) : 242 self.symbol_idx_map[ symbol ] = idx 243 244 for i, tr in enumerate( self.transitions ) : 245 symbol = old_alphabet[ tr.symbol_idx ] 246 self.transitions[ i ] = Transition( 247 origin_state_idx=tr.origin_state_idx, 248 target_state_idx=tr.target_state_idx, 249 prob=tr.prob, 250 symbol_idx=self.symbol_idx_map[ symbol ] 251 ) 252 253 def set_transitions( self, transitions : list[Transition] ) : 254 self._invalidate() 255 self.transitions = transitions.copy() 256 257 def extend_states( self, states : list[CausalState] ) : 258 self.set_states( self.states + states ) 259 260 def extend_alphabet( self, alphabet : list[str] ) : 261 self.set_alphabet( self.alphabet + alphabet ) 262 263 def extend_transitions( self, transitions : list[Transition] ) : 264 self.set_transitions( self.transitions + transitions ) 265 266 def get_complexity_measure_if_exists(self, measure ) : 267 m = self.complexity.get( measure, None ) 268 return m 269 270 def set_complexity_measure(self, measure, value ) : 271 self.complexity[ measure ] = value 272 273 #-------------------------------------------------------# 274 # Get and Compute # 275 #-------------------------------------------------------# 276 277 def get_complexities( 278 self, 279 with_block_convergence=False ) : 280 281 directly_calculable = [ 282 self.C_mu, 283 self.h_mu, 284 self.H_1, 285 self.rho_mu 286 ] 287 288 requires_block_convergence = [ 289 self.E, 290 self.T_inf, 291 self.S, 292 self.chi 293 ] 294 295 complexities = { m.__name__ : m() for m in directly_calculable } 296 297 if with_block_convergence : 298 299 complexities |= { m.__name__ : m() for m in requires_block_convergence } 300 301 for key in [ 'H_L', 'T_L', 'h_mu_L', 'H_sync' ] : 302 if key in self.complexity : 303 complexities[ key ] = self.complexity[ key ] 304 305 return complexities 306 307 #-------------------------------------------------------# 308 309 def get_metadata(self) : 310 return { 311 "name" : self.name, 312 'complexity' : self.complexity, 313 "description" : self.description 314 } 315 316 def get_transition_matrix(self) : 317 318 if self.T is not None : 319 return self.T 320 321 n_states = len( self.states ) 322 T = np.zeros((n_states, n_states)) 323 324 for tr in self.transitions : 325 T[ tr.origin_state_idx, tr.target_state_idx ] = tr.prob 326 327 self.T = T 328 329 return self.T 330 331 #-------------------------------------------------------# 332 333 def get_T_X(self) : 334 335 if self.T_x is not None : 336 return self.T_x 337 338 n_states = len( self.states ) 339 n_symbols = len( self.alphabet ) 340 341 T_x = [ np.zeros((n_states, n_states)) for _ in range( n_symbols ) ] 342 343 for tr in self.transitions : 344 T_x[ tr.symbol_idx ][tr.origin_state_idx, tr.target_state_idx] = tr.prob 345 346 self.T_x = T_x 347 return self.T_x 348 349 #-------------------------------------------------------# 350 351 def get_msp_qw( 352 self, 353 exact_state_cap: int = 1000, 354 verbose: bool = True, 355 ): 356 if self.msp is not None: 357 return self.msp 358 359 try : 360 361 print( "\nTrying to Compute Mixed State Presentation using Exact Fractions\n" ) 362 363 self.msp = compute_msp_exact( 364 T_x=self.get_Tx_fractional(), 365 pi=self.get_fractional_stationary_distribution(), 366 n_states=len(self.states), 367 alphabet=self.alphabet, 368 exact_state_cap=1000, 369 bool = True 370 ) 371 372 return self.msp 373 374 except RuntimeError as e : 375 warnings.warn( f"Exact msp failed: {e} Falling back to msp approximation." ) 376 377 return self.get_msp() 378 379 def get_msp( 380 self, 381 exact_state_cap: int = 175_000, 382 jsd_eps: float = 1e-7, 383 k_ann: int = 50, 384 verbose = True, 385 ) : 386 387 if self.msp is not None: 388 return self.msp 389 390 T_x = self.get_T_X() 391 pi = self.get_stationary_distribution() 392 393 T_stacked = np.stack(T_x) 394 n_symbols = len(self.alphabet) 395 n_input_states = T_stacked.shape[1] 396 397 print( "\nComputing Mixed State Presentation..." ) 398 399 self.msp = compute_msp( 400 T_x=T_x, 401 pi=pi, 402 n_states=len(self.states), 403 alphabet=self.alphabet, 404 exact_state_cap=exact_state_cap, 405 verbose=verbose 406 ) 407 408 return self.msp 409 410 def get_reverse_am(self) : 411 412 if self.reverse_am is not None: 413 return self.reverse_am 414 415 pi = self.get_stationary_distribution() 416 self.reverse_am = copy.deepcopy(self) 417 418 new_transitions = [] 419 for tr in self.transitions: 420 i = tr.target_state_idx 421 j = tr.origin_state_idx 422 423 p_reversed = (pi[j] * tr.prob) / pi[i] 424 425 new_transitions.append( 426 Transition( 427 origin_state_idx=i, 428 target_state_idx=j, 429 prob=p_reversed, 430 symbol_idx=tr.symbol_idx 431 ) 432 ) 433 434 self.reverse_am.set_transitions(new_transitions) 435 436 if self.reverse_am.is_epsilon_machine(): 437 return self.reverse_am 438 439 rmsp = self.reverse_am.get_msp_qw( exact_state_cap=len(self.states)*4 ) 440 441 self.reverse_am.set_states( rmsp.states ) 442 self.reverse_am.set_transitions( rmsp.transitions ) 443 self.reverse_am.msp = rmsp 444 self.reverse_am.start_state = 0 445 446 self.reverse_am.collapse_to_largest_strongly_connected_subgraph() 447 self.reverse_am.minimize() 448 449 return self.reverse_am 450 451 #-------------------------------------------------------# 452 453 def get_Tx_fractional(self) -> list[ list[ list[ Fraction ] ] ] : 454 455 self.to_q_weighted() 456 457 n_states = len( self.states ) 458 n_symbols = len( self.alphabet ) 459 460 T_x = [] 461 462 for x in range( n_symbols ) : 463 T_x.append( [] ) 464 for i in range( n_states ) : 465 T_x[ x ].append( [ 0 for _ in range( n_states ) ] ) 466 467 for tr in self.transitions : 468 T_x[ tr.symbol_idx ][ tr.origin_state_idx ][ tr.target_state_idx ] = tr.pq 469 470 return T_x 471 472 def get_T_sympy( self ) : 473 474 self.to_q_weighted() 475 476 n = len( self.states ) 477 T = sympy.zeros( n, n ) 478 479 for tr in self.transitions : 480 T[ tr.origin_state_idx, tr.target_state_idx ] = tr.pq 481 482 return T 483 484 def get_fractional_stationary_distribution(self) : 485 486 T = self.get_T_sympy() 487 488 if self.pi_fractional is not None : 489 return self.pi_fractional 490 491 G = self.as_digraph() 492 493 if not nx.is_strongly_connected(G): 494 raise ValueError( "Single stationary distribution requires strongly connected HMM." ) 495 496 self.pi_fractional = solve_for_pi_fractional( T ) 497 498 return self.pi_fractional 499 500 def get_stationary_distribution(self): 501 502 if self.pi is not None : 503 return self.pi 504 505 G = self.as_digraph() 506 507 if not nx.is_strongly_connected(G): 508 raise ValueError( "Single stationary distribution requires strongly connected HMM." ) 509 510 T = self.get_transition_matrix() 511 return solve_for_pi( T ) 512 513 #-------------------------------------------------------# 514 # Complexity Measures # 515 #-------------------------------------------------------# 516 517 def C_mu( self ) : 518 519 """The *statistical complexity* (aka *forecasting complexity*) : 520 521 .. math:: 522 523 C_{\\mu} = - \\sum_{\\sigma \\in \\mathcal{S}} \\Pr(\\sigma) \\log_2 \\Pr(\\sigma), 524 525 where :math:`\\mathcal{S}` are the machine's states [^crutchfield_exact_2016], p.2. 526 527 .. note:: 528 529 **Interpretations** 530 531 * The amount of historical information a process stores. 532 * The amount of structure in a process. 533 534 Returns: 535 536 float: :math:`C_{\\mu}`. 537 538 [^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral 539 Decomposition of Intrinsic Computation*, 2016. 540 <https://arxiv.org/abs/1309.3792> 541 """ 542 543 m = self.get_complexity_measure_if_exists( "C_mu" ) 544 545 if m is not None : 546 return m 547 548 pi = self.get_stationary_distribution() 549 550 h = 0 551 for i, pr in enumerate( pi ) : 552 553 if pr < self.EPS : 554 continue 555 556 h += -pr * np.log2( pr ) 557 558 self.set_complexity_measure( "C_mu", h ) 559 560 return h 561 562 #-------------------------------------------------------# 563 564 def h_mu( self ) : 565 566 """The *entropy rate* : 567 568 .. math:: 569 570 h_{\\mu}(\\boldsymbol{\\mathcal{S}}) = - \\sum_{\\sigma \\in \\mathcal{S}} \\Pr(\\sigma) \\sum_{x \\in \\mathcal{A}} \\Pr(x|\\sigma) \\log_2 \\Pr(x|\\sigma), 571 572 where :math:`\\mathcal{A}` is the alphabet and :math:`\\mathcal{S}` are the machine's states [^crutchfield_exact_2016], p.2. 573 574 .. note:: 575 576 **Interpretations** 577 578 * The lower bound on achievable loss in bits. 579 * The irreducable randomness in the process. 580 * The intrinsic Randomness in the process. 581 582 Returns: 583 584 float: :math:`h_{\\mu}`. 585 586 [^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral 587 Decomposition of Intrinsic Computation*, 2016. 588 <https://arxiv.org/abs/1309.3792> 589 """ 590 591 m = self.get_complexity_measure_if_exists( "h_mu" ) 592 593 if m is not None : 594 return m 595 596 T = self.get_transition_matrix() 597 pi = self.get_stationary_distribution() 598 599 n_states = pi.size 600 601 h = 0 602 for i, pr in enumerate( pi ) : 603 604 if pr < self.EPS : 605 continue 606 607 row_entropy = 0 608 for j in range( len( pi ) ) : 609 610 if T[ i, j ] < self.EPS : 611 continue 612 613 row_entropy -= T[ i, j ] * np.log2( T[ i, j ] ) 614 615 h += pr * row_entropy 616 617 self.set_complexity_measure( "h_mu", h ) 618 619 return h 620 621 #-------------------------------------------------------# 622 623 def H_1(self) -> float : 624 625 """The *single symbol uncertainty*: 626 627 .. math:: 628 629 H(1)=-\\sum_{x\\in\\mathcal{A}} \\Pr(x) \\log_2{\\Pr(x)}, 630 631 where :math:`\\mathcal{A}` is the alphabet [^James_2018], p.2. 632 633 .. note:: 634 635 **Interpretations** 636 637 * How uncertain you are on average about a single measurement with no context. 638 639 Returns: 640 641 float: :math:`H(1)`. 642 643 [^James_2018]: James et al., Anatomy of a Bit: Information in a Time Series Observation, 2018. 644 <https://arxiv.org/abs/1105.2988> 645 """ 646 647 m = self.get_complexity_measure_if_exists("H_1") 648 if m is not None: 649 return m 650 651 pi = self.get_stationary_distribution() 652 T_X = self.get_T_X() # dict: symbol -> matrix 653 654 h = 0.0 655 for T_x in T_X: 656 # Pr(x) = sum_i pi[i] * sum_j T^(x)[i,j] 657 p_sym = 0.0 658 for i, pr in enumerate(pi): 659 if pr < self.EPS: 660 continue 661 p_sym += pr * T_x[i, :].sum() 662 663 if p_sym < self.EPS: 664 continue 665 h -= p_sym * np.log2(p_sym) 666 667 self.set_complexity_measure("H_1", h) 668 return h 669 670 #-------------------------------------------------------# 671 672 def rho_mu(self) -> float : 673 674 """The *anticipated information* [^James_2018], p.3.: 675 676 .. math:: 677 678 \\rho_{\\mu}= H(1) - h_{\\mu} 679 680 Returns: 681 682 float: :math:`\\rho_{\\mu}` 683 684 [^James_2018]: James et al., Anatomy of a Bit: Information in a Time Series Observation, 2018. 685 <https://arxiv.org/abs/1105.2988> 686 """ 687 688 m = self.get_complexity_measure_if_exists("rho_mu") 689 690 if m is not None: 691 return m 692 693 rho = self.H_1() - self.h_mu() 694 695 self.set_complexity_measure("rho_mu", rho) 696 697 return rho 698 699 #-------------------------------------------------------# 700 701 def block_convergence( self ) : 702 703 """ 704 "Run block entropy convergence and return a dict with $\\mathbf{E}, \\mathbf{S}, $\\mathbf{T},\\mathbf{T}(L), \\mathcal{H}(L),$ and $h_{\\mu}(L)$. 705 """ 706 707 trs = [ [] for _ in range( len( self.states ) ) ] 708 for tr in self.transitions : 709 trs[ tr.origin_state_idx ].append( ( 710 tr.symbol_idx, 711 float( tr.prob ), 712 tr.target_state_idx ) ) 713 714 pi = self.get_stationary_distribution() 715 716 state_dist = [ float( pi[ i ] ) for i in range( len( self.states ) ) ] 717 branches = [(1.0, list(state_dist))] 718 719 print( "\nComputing Block Entropy\n" ) 720 721 C = am_fast.block_entropy_convergence( 722 h_mu = self.h_mu(), 723 n_states = len( self.states ), 724 n_symbols = len( self.alphabet ), 725 convergence_tol = 1e-6, 726 precision = 10, 727 eps = 1e-25, 728 branches = branches, 729 trans = trs, 730 max_branches = 30_000_000 731 ) 732 733 print( "Done\n" ) 734 735 self.set_complexity_measure( f"E", C.E ) 736 self.set_complexity_measure( f"S", C.S ) 737 self.set_complexity_measure( f"T_inf", C.T ) 738 self.set_complexity_measure( f"T_L", C.T_L.tolist() ) 739 self.set_complexity_measure( f"H_L", C.H_L.tolist() ) 740 self.set_complexity_measure( f"h_mu_L", C.h_mu_L.tolist() ) 741 self.set_complexity_measure( f"H_sync", C.H_sync.tolist() ) 742 743 return C 744 745 #-------------------------------------------------------# 746 747 def E( self ) -> float : 748 749 """The *excess entropy* [^crutchfield_exact_2016], p.4: 750 751 .. math:: 752 753 \\mathbf{E} \\equiv \\sum_{L=1}^{\\infty} I[X_{-\\infty:0}; X_{0:\\infty}] 754 755 Computed via :meth:`get_msp` and :meth:`amachine.am_msp.MSP.get_E_S_T`, or :meth:`amachine.am_fast.block_entropy_convergence` 756 757 .. note:: 758 759 **Interpretations** 760 761 * The information from the past that reduces uncertainty in the future [^crutchfield_exact_2016]. 762 * How much information an observer must extract to synchronize to the process. 763 * Measures how long the process appears more complex than it asymptotically is. 764 * Vanishes for immediately synchronizable processes. 765 766 Returns: 767 768 float: :math:`\\mathbf{E}` 769 770 [^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral 771 Decomposition of Intrinsic Computation*, 2016. 772 <https://arxiv.org/abs/1309.3792> 773 """ 774 775 m = self.get_complexity_measure_if_exists( "E" ) 776 777 if m is not None : 778 return m 779 780 try : 781 msp = self.get_msp() 782 E, S, T = msp.get_E_S_T() 783 self.set_complexity_measure( "E", E ) 784 self.set_complexity_measure( "S", S ) 785 self.set_complexity_measure( "T_inf", T ) 786 787 except Exception as e : 788 789 print( f"MSP failed {e}" ) 790 791 C = self.block_convergence() 792 E = C.E 793 self.set_complexity_measure( "E", E ) 794 795 return E 796 797 #-------------------------------------------------------# 798 799 def S( self ) -> float : 800 801 """The *synchronization* information: 802 803 .. math:: 804 805 \\mathbf{S} \\equiv \\sum_{L=1}^{\\infty} \\mathcal{H}(L), 806 807 where :math:`\\mathcal{H}(L)` is the average state uncertainty having seen all length-L words [^crutchfield_exact_2016], p.4. 808 809 .. note:: 810 811 **Interpretations** 812 813 * The total amount of state information that an observer must extract to become synchronized [^crutchfield_exact_2016]. 814 815 Computed via :meth:`get_msp` and :meth:`amachine.am_msp.MSP.get_E_S_T`, or :meth:`amachine.am_fast.block_entropy_convergence` 816 817 Returns: 818 819 float: :math:`\\mathbf{S}` 820 821 [^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral 822 Decomposition of Intrinsic Computation*, 2016. 823 <https://arxiv.org/abs/1309.3792> 824 """ 825 826 m = self.get_complexity_measure_if_exists( "S" ) 827 828 if m is not None : 829 return m 830 831 try : 832 msp = self.get_msp() 833 E, S, T = msp.get_E_S_T() 834 self.set_complexity_measure( "E", E ) 835 self.set_complexity_measure( "S", S ) 836 self.set_complexity_measure( "T_inf", T ) 837 838 except Exception as e : 839 print( f"{e} \nFalling back to iterative estimation.") 840 exit() 841 842 C = self.block_convergence() 843 S = C.S 844 self.set_complexity_measure( "S", S ) 845 846 return S 847 848 #-------------------------------------------------------# 849 850 def T_inf( self ) -> float : 851 852 """The *transient information*[^crutchfield_exact_2016], p.4: 853 854 .. math:: 855 856 \\mathbf{T} \\equiv \\sum_{L=1}^{\\infty} L \\left[ h_{\\mu}(L) - h_{\\mu} \\right] 857 858 Computed via :meth:`get_msp` and :meth:`amachine.am_msp.MSP.get_E_S_T`, or :meth:`amachine.am_fast.block_entropy_convergence` 859 860 .. note:: 861 862 **Interpretations** 863 864 * The amount of information one must extract from observations so that the block entropy converges to its linear asymptote[^crutchfield_exact_2016]. 865 866 Returns: 867 868 float: :math:`\\mathbf{T}` 869 870 [^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral 871 Decomposition of Intrinsic Computation*, 2016. 872 <https://arxiv.org/abs/1309.3792> 873 """ 874 875 m = self.get_complexity_measure_if_exists( "T_inf" ) 876 877 if m is not None : 878 return m 879 880 try : 881 msp = self.get_msp() 882 E, S, T = msp.get_E_S_T() 883 self.set_complexity_measure( "E", E ) 884 self.set_complexity_measure( "S", S ) 885 self.set_complexity_measure( "T_inf", T ) 886 887 except Exception as e : 888 print( f"{e} \nFalling back to iterative estimation.") 889 C = self.block_convergence() 890 T_inf = C.T 891 self.set_complexity_measure( "T_inf", T_inf ) 892 893 return T_inf 894 895 #-------------------------------------------------------# 896 897 def chi( self ) -> float : 898 899 """Computes the foward crypticity[^crutchfield_crypticity_2009][^Mahoney_crypticity_2021], p.2: 900 901 .. math:: 902 903 \\chi = C_{\\mu} - \\mathbf{E} 904 905 :math:`C_{\\mu}` is trivially computed from the stationary distribution in :meth:`C_mu` and :math:`\\mathbf{E}` in :meth:`E`. 906 907 .. note:: 908 909 **Interpretations** 910 911 * Difference between internal stored information and apparent information to an observer. 912 * How muching information is hiding in the system. 913 914 Returns: 915 916 float: :math:`\\chi` 917 918 [^crutchfield_crypticity_2009]: Crutchfield et al., Time’s barbed arrow: Irreversibility, crypticity, and stored information, 2009. 919 <https://arxiv.org/abs/0902.1209> 920 921 [^Mahoney_crypticity_2021]: Mahoney et al., Information Accessibility and Cryptic Processes, 2021. 922 <https://arxiv.org/abs/0905.4787> 923 """ 924 925 m = self.get_complexity_measure_if_exists( "chi" ) 926 927 if m is not None : 928 return m 929 930 chi = self.C_mu() - self.E() 931 932 if chi < 0 : 933 934 # if chi is 0, accumulated floating point error can result in small negative values 935 if chi < -1e-5: 936 warnings.warn(f"Crypticity is negative ({chi:.6e}).") 937 938 chi = np.clamp( chi, 0 ) 939 940 self.set_complexity_measure( "chi", chi ) 941 942 return chi 943 944 #-------------------------------------------------------# 945 # Properties # 946 #-------------------------------------------------------# 947 948 def is_row_stochastic(self) : 949 950 """ 951 Check that all states have outgoing transition probabilities that sum to 1. 952 """ 953 954 sums = np.zeros( len( self.states ) ) 955 for tr in self.transitions : 956 sums[ tr.origin_state_idx ] += tr.prob 957 return np.allclose( sums, 1.0 ) 958 959 #-------------------------------------------------------# 960 961 def is_unifilar(self) : 962 963 """ 964 Check that no state emits the same symbol on transitions to different states. 965 """ 966 967 symbol_trs = np.full( ( len( self.states ), len( self.alphabet) ), -1 ) 968 969 for tr in self.transitions : 970 971 if symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] == -1 : 972 symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] = tr.target_state_idx 973 974 elif symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] != tr.target_state_idx : 975 return False 976 977 return True 978 979 #-------------------------------------------------------# 980 981 def is_strongly_connected(self) : 982 983 """ 984 Check if every state is reachable from every other state. Relies on [nx.is_strongly_connected](https://networkx.org/documentation/latest/reference/algorithms/generated/networkx.algorithms.components.is_strongly_connected.html). 985 """ 986 987 return nx.is_strongly_connected( self.as_digraph() ) 988 989 #-------------------------------------------------------# 990 991 def is_aperiodic(self) : 992 993 """ 994 Checks if machine is periodic. Relies on [nx.is_aperiodic](https://networkx.org/documentation/latest/reference/algorithms/generated/networkx.algorithms.dag.is_aperiodic.html), "A strongly connected directed graph is aperiodic if there is no integer k > 1 that divides the length of every cycle in the graph." 995 """ 996 997 return nx.is_aperiodic( self.as_digraph() ) 998 999 #-------------------------------------------------------# 1000 1001 def _is_minimal_as_dfa( self, topological_only : bool, verbose=True ) : 1002 1003 with_probs = not topological_only 1004 1005 # Construct the DFA 1006 dfa = self.as_dfa( with_probs=with_probs ) 1007 1008 # Minimize the DFA 1009 #dfa = dfa.minify(retain_names=True) 1010 dfa = am_fast.minify_cpp( dfa, retain_names=True ) 1011 1012 # check we have minimal number of states 1013 if len( dfa.states ) != len( self.states ) : 1014 if verbose : 1015 print( f"Not minimal reduces from {len( self.states )} to {len( dfa.states )} states" ) 1016 return False 1017 1018 return True 1019 1020 def is_topological_epsilon_machine( self, verbose=True ) : 1021 1022 """ 1023 Checks if the HMM is a topological $\\epsilon$-machine [^1]. 1024 1025 [^1]: Johnson et al, Enumerating Finitary Processes, 2024. 1026 <https://arxiv.org/abs/1011.0036> 1027 """ 1028 1029 if not ( self.is_unifilar() and self.is_strongly_connected() ) : 1030 if verbose : 1031 print( f"Either non unifilar or not strongly connected" ) 1032 return False 1033 else : 1034 return self._is_minimal_as_dfa( topological_only=True, verbose=verbose ) 1035 1036 def is_epsilon_machine( self, verbose=True ) : 1037 1038 if not ( self.is_unifilar() and self.is_strongly_connected() ) : 1039 if verbose : 1040 print( f"Either non unifilar or not strongly connected" ) 1041 return False 1042 else : 1043 return self._is_minimal_as_dfa( topological_only=False, verbose=verbose ) 1044 1045 #-------------------------------------------------------# 1046 # Properties # 1047 #-------------------------------------------------------# 1048 1049 def minimize(self, retain_names: bool = True, verbose=False): 1050 1051 """ 1052 Minimizes the HMM, resulting in an :math:`\\epsilon-`machine if the HMM 1053 is unifilar and strongly connected. Converts the HMM to a DFA with symbols 1054 labeled jointly with symbols and probabilities, and uses Myhill-Nerode 1055 equivalence for minimization. Relies on `automata_lib` and uses 1056 `automata.fa.dfa.DFA.minify` with `allow_partial=True`, and all states 1057 final. 1058 1059 Args: 1060 retain_names (bool): If `True`, the merged states will be named by their union, e.g. `{s_0, s_1}`, and other states will retain their origion names. Otherwise, they will be relabled `{ '0', '1', ..., 'n-1' }`. 1061 1062 Returns: 1063 1064 automata.fa.dfa.DFA : the resulting DFA. 1065 """ 1066 1067 if self.is_minimal : 1068 return 1069 1070 start = time.perf_counter() 1071 1072 if not self.is_unifilar(): 1073 raise ValueError( 1074 "DFA minimization is not valid for non-unifilar HMMs" 1075 ) 1076 1077 was_strongly_connected = self.is_strongly_connected() 1078 1079 was_row_stochastic = self.is_row_stochastic() 1080 n_states_before = len(self.states) 1081 1082 dfa = self.as_dfa(with_probs=True) 1083 1084 #min_dfa = self.as_dfa(with_probs=True).minify(retain_names=True) 1085 min_dfa = am_fast.minify_cpp( dfa, retain_names=True ) 1086 1087 # Build lookup from original state index -> CausalState object 1088 orig_state = {i: s for i, s in enumerate(self.states)} 1089 eq_list = list(min_dfa.states) 1090 1091 start_eq = min_dfa.initial_state 1092 1093 # Separate the start state, then sort the rest by the 1094 # smallest original state index inside each equivalence class. 1095 other_eqs = [eq for eq in eq_list if eq != start_eq] 1096 other_eqs.sort(key=lambda eq: min(eq)) 1097 1098 # Recombine so start eq comes first, followed by the sorted remaining classes 1099 eq_list = [start_eq] + other_eqs 1100 # ---------------------------------------------------------- 1101 1102 # Recompute eq_to_idx with the new ordering 1103 eq_to_idx = {eq: i for i, eq in enumerate(eq_list)} 1104 1105 # new_start is now guaranteed to be 0 1106 new_start = 0 1107 1108 # Map each original state index -> its equivalence class 1109 # Guard: minify() silently drops unreachable states 1110 orig_to_eq = {s: eq for eq in min_dfa.states for s in eq} 1111 1112 # Build lookup from original state index -> its transitions 1113 orig_trs = defaultdict(list) 1114 for t in self.transitions: 1115 orig_trs[t.origin_state_idx].append(t) 1116 1117 new_trs = [] 1118 for eq in min_dfa.states: 1119 rep = next(iter(eq)) 1120 origin_idx = eq_to_idx[eq] 1121 for t in orig_trs[rep]: 1122 target_eq = orig_to_eq[t.target_state_idx] 1123 target_idx = eq_to_idx[target_eq] 1124 new_trs.append(Transition( 1125 origin_state_idx = origin_idx, 1126 target_state_idx = target_idx, 1127 prob = t.prob, 1128 symbol_idx = t.symbol_idx, 1129 )) 1130 1131 members_list = [[orig_state[i] for i in sorted(eq)] for eq in eq_list] # sorted for determinism 1132 1133 # Compute new names 1134 if retain_names: 1135 new_names = [ 1136 "{" + ",".join(str(m.name) for m in members) + "}" if len(members) > 1 1137 else members[0].name 1138 for members in members_list 1139 ] 1140 else: 1141 new_names = [str(j) for j in range(len(eq_list))] 1142 1143 old_name_to_new_name = { 1144 m.name: new_names[j] 1145 for j, members in enumerate(members_list) 1146 for m in members 1147 } 1148 1149 # Build the new states, preserving classes and isomorphs regardless of naming 1150 new_states = [] 1151 for j, (eq, members, name) in enumerate(zip(eq_list, members_list, new_names)): 1152 classes = set().union(*(m.classes for m in members)) 1153 isomorphs = { 1154 old_name_to_new_name.get(iso, iso) 1155 for m in members 1156 for iso in m.isomorphs 1157 if old_name_to_new_name.get(iso, iso) != name 1158 } 1159 new_states.append(CausalState( 1160 name = name, 1161 classes = classes, 1162 isomorphs = isomorphs, 1163 )) 1164 1165 self.set_states(new_states) 1166 self.set_transitions(new_trs) 1167 self.start_state = new_start 1168 1169 if n_states_before == len(new_states) and verbose : 1170 print( f"{n_states_before} state HMM was already minimal.\n" ) 1171 elif verbose : 1172 print( f"Minimized from {n_states_before} to {len(new_states)}\n" ) 1173 1174 if not ( was_strongly_connected == self.is_strongly_connected() ) : 1175 raise RuntimeError( 1176 f"Minimization broke strongly connected" 1177 ) 1178 1179 if not ( was_row_stochastic == self.is_row_stochastic() ) : 1180 raise RuntimeError( 1181 f"Minimization broke row stochasticity" 1182 ) 1183 1184 self.is_minimal = True 1185 1186 #-------------------------------------------------------# 1187 # Modifiers # 1188 #-------------------------------------------------------# 1189 1190 1191 def collapse_to_largest_strongly_connected_subgraph( self, rename_states=True ) : 1192 1193 # get equivalent networkx graph 1194 G = self.as_digraph() 1195 1196 # if already strongly connected, nothing to do 1197 if not nx.is_strongly_connected( G ) : 1198 1199 start = time.perf_counter() 1200 subgraph_nodes = list( nx.strongly_connected_components( G ) ) 1201 1202 # decompose into strongly connected components and sort by length 1203 # subgraph_nodes = list(nx.strongly_connected_components( G )) 1204 subgraph_nodes.sort(key=len) 1205 component_state_set = subgraph_nodes[-1] 1206 1207 # Take the largest strongly connected component (as list of state names) 1208 component_states = sorted( list( component_state_set ) ) 1209 1210 # make temporary copies of the old transitions and states 1211 old_transitions = [ 1212 Transition( 1213 origin_state_idx=tr.origin_state_idx, 1214 target_state_idx=tr.target_state_idx, 1215 prob=tr.prob, 1216 symbol_idx=tr.symbol_idx 1217 ) 1218 1219 for tr in self.transitions 1220 ] 1221 1222 old_states = [ 1223 CausalState( 1224 name=s.name, 1225 classes=s.classes, 1226 isomorphs=s.isomorphs 1227 ) 1228 for s in self.states 1229 ] 1230 1231 self.set_states( 1232 states=[ 1233 state 1234 for i, state in enumerate( old_states ) if i in component_state_set 1235 ] 1236 ) 1237 1238 # we will build new transition list based on those belonging to the component 1239 self.set_transitions( transitions= [] ) 1240 1241 # for tracking which new transitions leave each state 1242 transitions_from_state = { state : set() for state in component_states } 1243 new_transitions = [] 1244 1245 for tr in old_transitions : 1246 1247 origin_state_name = old_states[ tr.origin_state_idx ].name 1248 target_state_name = old_states[ tr.target_state_idx ].name 1249 1250 # skip transitions that connect separate strongly connected components 1251 if not ( tr.origin_state_idx in component_state_set and tr.target_state_idx in component_state_set ) : 1252 continue 1253 1254 # track transitions (by index in new transitions list) that leave this state 1255 transitions_from_state[ tr.origin_state_idx ].add( len( new_transitions ) ) 1256 1257 my_origin_state_idx = self.state_idx_map[ origin_state_name ] 1258 my_target_state_idx = self.state_idx_map[ target_state_name ] 1259 1260 new_transitions.append( 1261 Transition( 1262 origin_state_idx=my_origin_state_idx, 1263 target_state_idx=my_target_state_idx, 1264 prob=tr.prob, 1265 symbol_idx=tr.symbol_idx 1266 ) ) 1267 1268 self.set_transitions( transitions=new_transitions ) 1269 1270 # if we removed an outgoing transition from a state, we need to distribute its probability 1271 # among the remaining outgoing transitions from the state 1272 for state in component_states : 1273 1274 # get the set of transitions leaving this state 1275 state_trs = transitions_from_state[ state ] 1276 1277 # sum the probabilities of the outgoing transitions from the state 1278 p_sum = np.sum( [ self.transitions[ i ].prob for i in state_trs ] ) 1279 1280 # how much probability is missing 1281 diff = 1.0 - p_sum 1282 1283 # if significant difference 1284 if abs( diff ) > self.EPS : 1285 1286 # calculate how much of the difference each transition gets 1287 adjustment = diff / len( state_trs ) 1288 1289 # update the transitions 1290 for i in state_trs : 1291 1292 # transition with origional probability 1293 tr = self.transitions[ i ] 1294 1295 # adjusted probability 1296 self.transitions[ i ] = Transition( 1297 origin_state_idx=tr.origin_state_idx, 1298 target_state_idx=tr.target_state_idx, 1299 prob=tr.prob + adjustment, 1300 symbol_idx=tr.symbol_idx ) 1301 1302 if rename_states : 1303 self.set_states( [ 1304 CausalState( name=f"{i}" ) 1305 for i, s in enumerate( self.states ) 1306 ] ) 1307 1308 1309 def to_q_weighted( self, denominator_limit=1000 ) : 1310 1311 """ 1312 Approximates the existing transition probabilities with exact fractions, stores 1313 the fractional probabilities as Fraction in Transition.pq, and sets the floating 1314 point probabilty to `float(pq)`. If `denominator_limit` is too small for a sane 1315 conversion, the function recurses with `denominator_limit=denominator_limit*10`. 1316 1317 Args: 1318 denominator_limit (int): The initial input to :meth:`Fraction.limit_denominator` in 1319 the conversion. 1320 """ 1321 1322 if self.is_q_weighted : 1323 return 1324 1325 if not self.is_row_stochastic() : 1326 raise ValueError( "Cannot convert to q-weighted because not row stochastic" ) 1327 1328 t_from = [[] for _ in range(len(self.states))] 1329 1330 for i, tr in enumerate( self.transitions ) : 1331 t_from[ tr.origin_state_idx ].append( i ) 1332 1333 new_transitions = [] 1334 for t_list in t_from : 1335 1336 if not t_list : 1337 continue 1338 1339 p_q_sum = Fraction(0,1) 1340 p_qs = [] 1341 1342 for t_idx in t_list : 1343 1344 p_q = Fraction( self.transitions[ t_idx ].prob ).limit_denominator( denominator_limit ) 1345 p_q_sum += p_q 1346 p_qs.append( p_q ) 1347 1348 if p_q_sum != Fraction(1,1) : 1349 1350 max_pq_i = np.argmax( p_qs ) 1351 max_oq = p_qs[ max_pq_i ] 1352 1353 diff = p_q_sum - Fraction(1,1) 1354 1355 # If recurse with higher resolution 1356 if diff > max_oq : 1357 return self.to_q_weighted( denominator_limit*10 ) 1358 else : 1359 p_qs[ max_pq_i ] -= diff 1360 1361 for i, t_idx in enumerate( t_list ) : 1362 new_transitions.append( 1363 Transition( 1364 origin_state_idx=self.transitions[ t_idx ].origin_state_idx, 1365 target_state_idx=self.transitions[ t_idx ].target_state_idx, 1366 prob=float(p_qs[ i ]), 1367 symbol_idx=self.transitions[ t_idx ].symbol_idx, 1368 pq=p_qs[ i ] 1369 ) 1370 ) 1371 1372 self.set_transitions( new_transitions ) 1373 self.is_q_weighted = True 1374 1375 #-------------------------------------------------------# 1376 # Data Generation # 1377 #-------------------------------------------------------# 1378 1379 def isomorphic_shift( 1380 self, 1381 input_symbol_indices: np.ndarray, 1382 input_state_indices: np.ndarray, 1383 shift : int = 1 1384 ) -> dict[str, np.ndarray]: 1385 1386 """ 1387 Generates a new sequence of of symbols that are permuted with the symbols emitted by 1388 isomorphic states, if they exists. 1389 1390 :math:`\\sigma_o = \\mathcal{S}\\left[\\texttt{input\\_state\\_indices}[i]\\right]`<br> 1391 :math:`\\sigma_t = \\mathcal{S}\\left[\\texttt{input\\_state\\_indices}[i+1]\\right]` 1392 1393 :math:`\\mathcal{I}(\\sigma_o) = \\\\{\\sigma^0_o,\\, \\sigma^1_o,\\, \\dots,\\, \\sigma^{n-1}_o \\\\}`<br> 1394 :math:`\\mathcal{I}(\\sigma_t) = \\\\{\\sigma^0_t,\\, \\sigma^1_t,\\, \\dots,\\, \\sigma^{n-1}_t \\\\}` 1395 1396 :math:`k = \\bigl(\\mathcal{I}(\\sigma_o).\\texttt{index}(\\sigma_o) + \\texttt{shift}\\bigr) \\bmod n` 1397 1398 :math:`\\texttt{output\\_symbol\\_indices}[i] := T(\\sigma_o^k,\\, \\sigma_t^k).\\text{symbol\\_index}`<br> 1399 :math:`\\texttt{output\\_state\\_indices}[i] := \\mathcal{S}.\\texttt{index}(\\sigma_o^k)`<br> 1400 :math:`\\texttt{output\\_state\\_indices}[i+1] := \\mathcal{S}.\\texttt{index}(\\sigma_t^k)` 1401 1402 Where :math:`\\mathcal{I}(\\sigma)`` is the ordered set of states isomorphic to :math:`\\sigma` including :math:`\\sigma` itself. 1403 1404 Args: 1405 input_symbol_indices (np.ndarray): The sequence of generated symbols. 1406 input_state_indices (np.ndarray): The sequence of states that generated symbols with the final state at the end. 1407 shift : int: How much to shift the symbols across the isomorphic states. 1408 """ 1409 1410 if not any( state.isomorphs for state in self.states ): 1411 raise ValueError("HMM has no states with isomorphs") 1412 1413 inputs = np.asarray(input_symbol_indices) 1414 states = np.asarray(input_state_indices) 1415 1416 n_states = len(self.states) 1417 1418 tr_sym_table = np.full((n_states, n_states), -1, dtype=np.int32) 1419 1420 for tr in self.transitions: 1421 tr_sym_table[ tr.origin_state_idx, tr.target_state_idx ] = tr.symbol_idx 1422 1423 # Build isomorph remapping: identity by default, overridden where isomorphs exist 1424 iso_table = np.arange(n_states, dtype=np.int32) 1425 1426 for i, state in enumerate(self.states): 1427 if state.isomorphs: 1428 isomorph = sorted(state.isomorphs)[0] 1429 iso_table[ i ] = self.state_idx_map[isomorph] 1430 1431 for i, state in enumerate(self.states): 1432 if state.isomorphs: 1433 1434 # extend the isomorph list to include the identity 1435 isormorphs_with_identity = sorted( [ i ] + [ self.state_idx_map[iso] for iso in state.isomorphs ] ) 1436 1437 # find the identity index 1438 pos = isormorphs_with_identity.index( i ) 1439 1440 # cyclical shift 1441 iso_table[ i ] = isormorphs_with_identity[ ( pos + shift ) % len( isormorphs_with_identity ) ] 1442 1443 origins = states[:-1] 1444 targets = states[1:] 1445 1446 out_origins = iso_table[origins] 1447 out_targets = iso_table[targets] 1448 1449 inv_sym = tr_sym_table[ out_origins, out_targets ] 1450 inv_sts = np.empty(states.size, dtype=states.dtype) 1451 1452 inv_sts[:-1] = out_origins 1453 inv_sts[-1] = out_targets[-1] 1454 1455 return { 1456 "symbol_index": inv_sym.astype(inputs.dtype), 1457 "state_index": inv_sts, 1458 } 1459 1460 def generate_data( 1461 self, 1462 file_prefix: str, 1463 n_gen: int, 1464 include_states: bool, 1465 isomorphic_shifts : set[int]=None, 1466 random_seed : int=42 ) -> dict[any] : 1467 1468 trs = [ [] for _ in range( len( self.states ) ) ] 1469 for tr in self.transitions : 1470 trs[ tr.origin_state_idx ].append( ( 1471 tr.symbol_idx, 1472 float( tr.prob ), 1473 tr.target_state_idx ) ) 1474 1475 data = am_fast.generate_data( 1476 n_gen=n_gen, 1477 start_state=self.start_state, 1478 transitions=trs, 1479 alphabet=sorted(list(self.alphabet)), 1480 include_states=include_states, 1481 random_seed=random_seed 1482 ) 1483 1484 if isomorphic_shifts is not None : 1485 1486 if not include_states : 1487 raise ValueError( "Isomorphic inversion requires include_states=True" ) 1488 1489 data[ "isomorphic_shifts" ] = {} 1490 1491 for shift in isomorphic_shifts : 1492 1493 try : 1494 1495 shifted = self.isomorphic_shift( 1496 input_symbol_indices=data[ "symbol_index" ], 1497 input_state_indices=data[ "state_index" ], 1498 shift=shift 1499 ) 1500 1501 data[ "isomorphic_shifts" ][ shift ] = { 1502 "symbol_index" : shifted[ "symbol_index" ], 1503 "state_index" : shifted[ "state_index" ] 1504 } 1505 1506 except Exception as e : 1507 print( f"Exception {e}" ) 1508 1509 am_fast.save_data( 1510 data=data, 1511 file_prefix=file_prefix, 1512 alphabet=sorted(list(self.alphabet)), 1513 n_states=len( self.states ), 1514 start_state=self.start_state, 1515 random_seed=random_seed, 1516 machine_metadata=self.get_metadata() ) 1517 1518 return data 1519 1520 #-------------------------------------------------------# 1521 # Basic Visualization # 1522 #-------------------------------------------------------# 1523 1524 def draw_graph( 1525 self, 1526 engine : str = 'dot', 1527 output_dir : Path = Path('.'), 1528 show : bool = True 1529 ) -> None : 1530 1531 """ 1532 Draws the machine using [pygraphiviz](https://pygraphviz.github.io/documentation/stable/) and saves it. 1533 1534 Returns: 1535 1536 networkx.DiGraph : the resulting graph. 1537 """ 1538 1539 G = self.as_digraph() 1540 1541 subgraphs = None if nx.is_strongly_connected( G ) else list( nx.strongly_connected_components( G ) ) 1542 1543 am_vis.draw_graph( 1544 self, 1545 output_dir=output_dir, 1546 title="am_graph", 1547 view=show, 1548 subgraphs=subgraphs, 1549 engine=engine ) 1550 1551 #-------------------------------------------------------# 1552 # Alternate Forms # 1553 #-------------------------------------------------------# 1554 1555 1556 def as_digraph( self ) -> nx.DiGraph : 1557 1558 """ 1559 Builds a [networkx.DiGraph](https://networkx.org/documentation/stable/reference/classes/digraph.html) constructed from the machine's transitions with no edge symbols or weights. 1560 1561 Returns: 1562 1563 networkx.DiGraph : the resulting graph. 1564 """ 1565 1566 G = nx.DiGraph() 1567 G.add_nodes_from( [ i for i, s in enumerate( self.states ) ] ) 1568 1569 for tr in self.transitions : 1570 G.add_edge( tr.origin_state_idx, tr.target_state_idx ) 1571 1572 return G 1573 1574 def as_dfa( self, with_probs : bool ) : 1575 1576 """ 1577 Builds an [automata.fa.dfa.DFA](https://caleb531.github.io/automata/api/fa/class-dfa/) constructed from the machine's transitions. 1578 1579 Args: 1580 with_probs (bool): If true the DFA transitions are labeled based on 1581 the symbol of the machines transition concatenated with its 1582 probability, othwise, the only the symbols. 1583 1584 Returns: 1585 1586 automata.fa.dfa.DFA : the resulting DFA. 1587 """ 1588 1589 precision=8 1590 1591 def edge_label( symb, prob ) : 1592 return f"({symb},{round(prob, precision)})" 1593 1594 # Build states, symbols, and transitions 1595 dfa_states = { i for i, _ in enumerate( self.states ) } 1596 1597 if not with_probs : 1598 dfa_symbols = set( { t.symbol_idx for t in self.transitions } ) 1599 else : 1600 dfa_symbols = set( { edge_label( t.symbol_idx, t.prob ) for t in self.transitions } ) 1601 1602 dfa_transitions = defaultdict(dict) 1603 1604 if not with_probs : 1605 for t in self.transitions : 1606 dfa_transitions[ t.origin_state_idx ][ t.symbol_idx ] = t.target_state_idx 1607 else : 1608 for t in self.transitions : 1609 dfa_transitions[ t.origin_state_idx ][ edge_label( t.symbol_idx, t.prob ) ] = t.target_state_idx 1610 1611 # Construct the DFA 1612 return DFA( 1613 states=dfa_states, 1614 input_symbols=dfa_symbols, 1615 transitions=dfa_transitions, 1616 initial_state=self.start_state, 1617 allow_partial=True, 1618 final_states={ 1619 s for s in dfa_states 1620 } 1621 )
Hidden Markov model implementing epsilon machines, mixed state presentations, complexity measures, and data generation.
Arguments:
- states (list[CausalState] | None): A list of causal states.
- transitions (list[Transition] | None): A list of transitions between states.
- start_state (int): Index of the start state.
- alphabet (list[str]): List of symbols making up the alphabet.
- name (str): Name of the model.
- description (str): Description of the model.
52 def __init__( 53 self, 54 states : list[CausalState] | None = None, 55 transitions : list[Transition] | None = None, 56 start_state : int = 0, 57 alphabet : list[str] | None = None, 58 name : str = "", 59 description : str = "" ) : 60 61 self.alphabet : list[str] | None = alphabet or [] 62 self.states : list[CausalState] | None= states or [] 63 self.transitions : list[Transition] | None= transitions or [] 64 65 self.set_alphabet( self.alphabet ) 66 self.set_states( self.states ) 67 self.set_transitions( self.transitions ) 68 69 self.start_state : int = start_state 70 71 self.name : str = name 72 self.description : str = description 73 74 # To be depreciated 75 self.isoclass : str = None 76 77 # --- derived -------- 78 79 self.complexity : dict[str, any] = {} 80 81 self.symbol_idx_map : dict[str,int] = {} 82 self.state_idx_map : dict[str,int] = {} 83 84 self.pi_fractional = None 85 self.pi : np.ndarray | None = None 86 self.T : np.ndarray | None = None 87 self.msp : MSP | None = None 88 self.reverse_am : HMM | None = None 89 self.is_q_weighted : bool = False 90 self.is_minimal : bool = False 91 92 # --- const ---------- 93 94 self.EPS : float = 1e-12
117 def to_dict(self) -> dict[str,any]: 118 119 """Create a dict representing the HMM configuration. 120 121 Returns: 122 123 dict[str,any]: Dictionary containing, name, description, states, transitions, alphabet, and isoclass. 124 """ 125 126 return { 127 "name" : self.name, 128 "description" : self.description, 129 "states" : [ asdict(state) for state in self.states ], 130 "transitions" : [ asdict(transition) for transition in self.transitions ], 131 "alphabet" : self.alphabet, 132 "isoclass" : self.isoclass 133 }
Create a dict representing the HMM configuration.
Returns:
dict[str,any]: Dictionary containing, name, description, states, transitions, alphabet, and isoclass.
135 def from_dict( self, config : dict[str,any] ) : 136 137 """Configure the HMM configuration from a dictionary. 138 139 Args: 140 config (dict[str,any]): The HMM configuration. 141 """ 142 143 self.name = config.name 144 self.description = config.description 145 146 self.set_states( states=[ 147 CausalState( 148 name=state[ "name" ], 149 classes=set( state[ "classes" ] ) 150 ) 151 for state in config.states 152 ] ) 153 154 self.set_transitions( transitions=[ 155 Transition( 156 origin_state_idx=tr[ "origin_state_idx" ], 157 target_state_idx=tr[ "target_state_idx" ], 158 prob=tr[ "prob" ], 159 symbol_idx=tr[ "symbol_idx" ] 160 ) 161 for tr in config.transitions 162 ] ) 163 164 self.set_alphabet( alphabet=config.alphabet )
Configure the HMM configuration from a dictionary.
Arguments:
- config (dict[str,any]): The HMM configuration.
166 def save_config( 167 self, 168 path : Path, 169 with_complexity : bool = False, 170 with_block_convergence : bool = False, 171 with_structural_properties : bool = False, 172 with_causal_properties : bool = False ) : 173 174 config = self.to_dict() 175 176 if with_complexity : 177 178 complexity = self.get_complexities( 179 with_block_convergence=with_block_convergence 180 ) 181 182 config[ "complexity" ] = complexity 183 184 config[ "structural_properties" ] = { 185 "unifilar" : self.is_unifilar(), 186 "row_stochastic" : self.is_row_stochastic(), 187 "strongly_connected" : self.is_strongly_connected(), 188 "aperiodic" : self.is_aperiodic(), 189 "minimal" : self._is_minimal_as_dfa( topological_only=False ), 190 "topologically_minimal" : self._is_minimal_as_dfa( topological_only=True ), 191 "is_epsilon_machine" : self.is_epsilon_machine() 192 } 193 194 with open( path / "am_config.json", "w", encoding="utf-8" ) as f : 195 json.dump( config, f, ensure_ascii=False, indent=2, default=list )
229 def set_alphabet( self, alphabet : list[str] ) : 230 231 self._invalidate() 232 233 old_alphabet = self.alphabet.copy() 234 235 aSet = set() 236 aSet.update( alphabet ) 237 238 self.alphabet = sorted(list(aSet)) 239 240 self.symbol_idx_map = {} 241 for idx, symbol in enumerate( self.alphabet ) : 242 self.symbol_idx_map[ symbol ] = idx 243 244 for i, tr in enumerate( self.transitions ) : 245 symbol = old_alphabet[ tr.symbol_idx ] 246 self.transitions[ i ] = Transition( 247 origin_state_idx=tr.origin_state_idx, 248 target_state_idx=tr.target_state_idx, 249 prob=tr.prob, 250 symbol_idx=self.symbol_idx_map[ symbol ] 251 )
277 def get_complexities( 278 self, 279 with_block_convergence=False ) : 280 281 directly_calculable = [ 282 self.C_mu, 283 self.h_mu, 284 self.H_1, 285 self.rho_mu 286 ] 287 288 requires_block_convergence = [ 289 self.E, 290 self.T_inf, 291 self.S, 292 self.chi 293 ] 294 295 complexities = { m.__name__ : m() for m in directly_calculable } 296 297 if with_block_convergence : 298 299 complexities |= { m.__name__ : m() for m in requires_block_convergence } 300 301 for key in [ 'H_L', 'T_L', 'h_mu_L', 'H_sync' ] : 302 if key in self.complexity : 303 complexities[ key ] = self.complexity[ key ] 304 305 return complexities
316 def get_transition_matrix(self) : 317 318 if self.T is not None : 319 return self.T 320 321 n_states = len( self.states ) 322 T = np.zeros((n_states, n_states)) 323 324 for tr in self.transitions : 325 T[ tr.origin_state_idx, tr.target_state_idx ] = tr.prob 326 327 self.T = T 328 329 return self.T
333 def get_T_X(self) : 334 335 if self.T_x is not None : 336 return self.T_x 337 338 n_states = len( self.states ) 339 n_symbols = len( self.alphabet ) 340 341 T_x = [ np.zeros((n_states, n_states)) for _ in range( n_symbols ) ] 342 343 for tr in self.transitions : 344 T_x[ tr.symbol_idx ][tr.origin_state_idx, tr.target_state_idx] = tr.prob 345 346 self.T_x = T_x 347 return self.T_x
351 def get_msp_qw( 352 self, 353 exact_state_cap: int = 1000, 354 verbose: bool = True, 355 ): 356 if self.msp is not None: 357 return self.msp 358 359 try : 360 361 print( "\nTrying to Compute Mixed State Presentation using Exact Fractions\n" ) 362 363 self.msp = compute_msp_exact( 364 T_x=self.get_Tx_fractional(), 365 pi=self.get_fractional_stationary_distribution(), 366 n_states=len(self.states), 367 alphabet=self.alphabet, 368 exact_state_cap=1000, 369 bool = True 370 ) 371 372 return self.msp 373 374 except RuntimeError as e : 375 warnings.warn( f"Exact msp failed: {e} Falling back to msp approximation." ) 376 377 return self.get_msp()
379 def get_msp( 380 self, 381 exact_state_cap: int = 175_000, 382 jsd_eps: float = 1e-7, 383 k_ann: int = 50, 384 verbose = True, 385 ) : 386 387 if self.msp is not None: 388 return self.msp 389 390 T_x = self.get_T_X() 391 pi = self.get_stationary_distribution() 392 393 T_stacked = np.stack(T_x) 394 n_symbols = len(self.alphabet) 395 n_input_states = T_stacked.shape[1] 396 397 print( "\nComputing Mixed State Presentation..." ) 398 399 self.msp = compute_msp( 400 T_x=T_x, 401 pi=pi, 402 n_states=len(self.states), 403 alphabet=self.alphabet, 404 exact_state_cap=exact_state_cap, 405 verbose=verbose 406 ) 407 408 return self.msp
410 def get_reverse_am(self) : 411 412 if self.reverse_am is not None: 413 return self.reverse_am 414 415 pi = self.get_stationary_distribution() 416 self.reverse_am = copy.deepcopy(self) 417 418 new_transitions = [] 419 for tr in self.transitions: 420 i = tr.target_state_idx 421 j = tr.origin_state_idx 422 423 p_reversed = (pi[j] * tr.prob) / pi[i] 424 425 new_transitions.append( 426 Transition( 427 origin_state_idx=i, 428 target_state_idx=j, 429 prob=p_reversed, 430 symbol_idx=tr.symbol_idx 431 ) 432 ) 433 434 self.reverse_am.set_transitions(new_transitions) 435 436 if self.reverse_am.is_epsilon_machine(): 437 return self.reverse_am 438 439 rmsp = self.reverse_am.get_msp_qw( exact_state_cap=len(self.states)*4 ) 440 441 self.reverse_am.set_states( rmsp.states ) 442 self.reverse_am.set_transitions( rmsp.transitions ) 443 self.reverse_am.msp = rmsp 444 self.reverse_am.start_state = 0 445 446 self.reverse_am.collapse_to_largest_strongly_connected_subgraph() 447 self.reverse_am.minimize() 448 449 return self.reverse_am
453 def get_Tx_fractional(self) -> list[ list[ list[ Fraction ] ] ] : 454 455 self.to_q_weighted() 456 457 n_states = len( self.states ) 458 n_symbols = len( self.alphabet ) 459 460 T_x = [] 461 462 for x in range( n_symbols ) : 463 T_x.append( [] ) 464 for i in range( n_states ) : 465 T_x[ x ].append( [ 0 for _ in range( n_states ) ] ) 466 467 for tr in self.transitions : 468 T_x[ tr.symbol_idx ][ tr.origin_state_idx ][ tr.target_state_idx ] = tr.pq 469 470 return T_x
484 def get_fractional_stationary_distribution(self) : 485 486 T = self.get_T_sympy() 487 488 if self.pi_fractional is not None : 489 return self.pi_fractional 490 491 G = self.as_digraph() 492 493 if not nx.is_strongly_connected(G): 494 raise ValueError( "Single stationary distribution requires strongly connected HMM." ) 495 496 self.pi_fractional = solve_for_pi_fractional( T ) 497 498 return self.pi_fractional
500 def get_stationary_distribution(self): 501 502 if self.pi is not None : 503 return self.pi 504 505 G = self.as_digraph() 506 507 if not nx.is_strongly_connected(G): 508 raise ValueError( "Single stationary distribution requires strongly connected HMM." ) 509 510 T = self.get_transition_matrix() 511 return solve_for_pi( T )
517 def C_mu( self ) : 518 519 """The *statistical complexity* (aka *forecasting complexity*) : 520 521 .. math:: 522 523 C_{\\mu} = - \\sum_{\\sigma \\in \\mathcal{S}} \\Pr(\\sigma) \\log_2 \\Pr(\\sigma), 524 525 where :math:`\\mathcal{S}` are the machine's states [^crutchfield_exact_2016], p.2. 526 527 .. note:: 528 529 **Interpretations** 530 531 * The amount of historical information a process stores. 532 * The amount of structure in a process. 533 534 Returns: 535 536 float: :math:`C_{\\mu}`. 537 538 [^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral 539 Decomposition of Intrinsic Computation*, 2016. 540 <https://arxiv.org/abs/1309.3792> 541 """ 542 543 m = self.get_complexity_measure_if_exists( "C_mu" ) 544 545 if m is not None : 546 return m 547 548 pi = self.get_stationary_distribution() 549 550 h = 0 551 for i, pr in enumerate( pi ) : 552 553 if pr < self.EPS : 554 continue 555 556 h += -pr * np.log2( pr ) 557 558 self.set_complexity_measure( "C_mu", h ) 559 560 return h
The statistical complexity (aka forecasting complexity) :
$$C_{\mu} = - \sum_{\sigma \in \mathcal{S}} \Pr(\sigma) \log_2 \Pr(\sigma),$$
where \( \mathcal{S} \) are the machine's states 1, p.2.
Interpretations
- The amount of historical information a process stores.
- The amount of structure in a process.
Returns:
float: \( C_{\mu} \).
-
Crutchfield et al., Exact Complexity: The Spectral Decomposition of Intrinsic Computation, 2016. https://arxiv.org/abs/1309.3792 ↩
564 def h_mu( self ) : 565 566 """The *entropy rate* : 567 568 .. math:: 569 570 h_{\\mu}(\\boldsymbol{\\mathcal{S}}) = - \\sum_{\\sigma \\in \\mathcal{S}} \\Pr(\\sigma) \\sum_{x \\in \\mathcal{A}} \\Pr(x|\\sigma) \\log_2 \\Pr(x|\\sigma), 571 572 where :math:`\\mathcal{A}` is the alphabet and :math:`\\mathcal{S}` are the machine's states [^crutchfield_exact_2016], p.2. 573 574 .. note:: 575 576 **Interpretations** 577 578 * The lower bound on achievable loss in bits. 579 * The irreducable randomness in the process. 580 * The intrinsic Randomness in the process. 581 582 Returns: 583 584 float: :math:`h_{\\mu}`. 585 586 [^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral 587 Decomposition of Intrinsic Computation*, 2016. 588 <https://arxiv.org/abs/1309.3792> 589 """ 590 591 m = self.get_complexity_measure_if_exists( "h_mu" ) 592 593 if m is not None : 594 return m 595 596 T = self.get_transition_matrix() 597 pi = self.get_stationary_distribution() 598 599 n_states = pi.size 600 601 h = 0 602 for i, pr in enumerate( pi ) : 603 604 if pr < self.EPS : 605 continue 606 607 row_entropy = 0 608 for j in range( len( pi ) ) : 609 610 if T[ i, j ] < self.EPS : 611 continue 612 613 row_entropy -= T[ i, j ] * np.log2( T[ i, j ] ) 614 615 h += pr * row_entropy 616 617 self.set_complexity_measure( "h_mu", h ) 618 619 return h
The entropy rate :
$$h_{\mu}(\boldsymbol{\mathcal{S}}) = - \sum_{\sigma \in \mathcal{S}} \Pr(\sigma) \sum_{x \in \mathcal{A}} \Pr(x|\sigma) \log_2 \Pr(x|\sigma),$$
where \( \mathcal{A} \) is the alphabet and \( \mathcal{S} \) are the machine's states 1, p.2.
Interpretations
- The lower bound on achievable loss in bits.
- The irreducable randomness in the process.
- The intrinsic Randomness in the process.
Returns:
float: \( h_{\mu} \).
-
Crutchfield et al., Exact Complexity: The Spectral Decomposition of Intrinsic Computation, 2016. https://arxiv.org/abs/1309.3792 ↩
623 def H_1(self) -> float : 624 625 """The *single symbol uncertainty*: 626 627 .. math:: 628 629 H(1)=-\\sum_{x\\in\\mathcal{A}} \\Pr(x) \\log_2{\\Pr(x)}, 630 631 where :math:`\\mathcal{A}` is the alphabet [^James_2018], p.2. 632 633 .. note:: 634 635 **Interpretations** 636 637 * How uncertain you are on average about a single measurement with no context. 638 639 Returns: 640 641 float: :math:`H(1)`. 642 643 [^James_2018]: James et al., Anatomy of a Bit: Information in a Time Series Observation, 2018. 644 <https://arxiv.org/abs/1105.2988> 645 """ 646 647 m = self.get_complexity_measure_if_exists("H_1") 648 if m is not None: 649 return m 650 651 pi = self.get_stationary_distribution() 652 T_X = self.get_T_X() # dict: symbol -> matrix 653 654 h = 0.0 655 for T_x in T_X: 656 # Pr(x) = sum_i pi[i] * sum_j T^(x)[i,j] 657 p_sym = 0.0 658 for i, pr in enumerate(pi): 659 if pr < self.EPS: 660 continue 661 p_sym += pr * T_x[i, :].sum() 662 663 if p_sym < self.EPS: 664 continue 665 h -= p_sym * np.log2(p_sym) 666 667 self.set_complexity_measure("H_1", h) 668 return h
The single symbol uncertainty:
$$H(1)=-\sum_{x\in\mathcal{A}} \Pr(x) \log_2{\Pr(x)},$$
where \( \mathcal{A} \) is the alphabet 1, p.2.
Interpretations
- How uncertain you are on average about a single measurement with no context.
Returns:
float: \( H(1) \).
-
James et al., Anatomy of a Bit: Information in a Time Series Observation, 2018. https://arxiv.org/abs/1105.2988 ↩
672 def rho_mu(self) -> float : 673 674 """The *anticipated information* [^James_2018], p.3.: 675 676 .. math:: 677 678 \\rho_{\\mu}= H(1) - h_{\\mu} 679 680 Returns: 681 682 float: :math:`\\rho_{\\mu}` 683 684 [^James_2018]: James et al., Anatomy of a Bit: Information in a Time Series Observation, 2018. 685 <https://arxiv.org/abs/1105.2988> 686 """ 687 688 m = self.get_complexity_measure_if_exists("rho_mu") 689 690 if m is not None: 691 return m 692 693 rho = self.H_1() - self.h_mu() 694 695 self.set_complexity_measure("rho_mu", rho) 696 697 return rho
The anticipated information 1, p.3.:
$$\rho_{\mu}= H(1) - h_{\mu}$$
Returns:
float: \( \rho_{\mu} \)
-
James et al., Anatomy of a Bit: Information in a Time Series Observation, 2018. https://arxiv.org/abs/1105.2988 ↩
701 def block_convergence( self ) : 702 703 """ 704 "Run block entropy convergence and return a dict with $\\mathbf{E}, \\mathbf{S}, $\\mathbf{T},\\mathbf{T}(L), \\mathcal{H}(L),$ and $h_{\\mu}(L)$. 705 """ 706 707 trs = [ [] for _ in range( len( self.states ) ) ] 708 for tr in self.transitions : 709 trs[ tr.origin_state_idx ].append( ( 710 tr.symbol_idx, 711 float( tr.prob ), 712 tr.target_state_idx ) ) 713 714 pi = self.get_stationary_distribution() 715 716 state_dist = [ float( pi[ i ] ) for i in range( len( self.states ) ) ] 717 branches = [(1.0, list(state_dist))] 718 719 print( "\nComputing Block Entropy\n" ) 720 721 C = am_fast.block_entropy_convergence( 722 h_mu = self.h_mu(), 723 n_states = len( self.states ), 724 n_symbols = len( self.alphabet ), 725 convergence_tol = 1e-6, 726 precision = 10, 727 eps = 1e-25, 728 branches = branches, 729 trans = trs, 730 max_branches = 30_000_000 731 ) 732 733 print( "Done\n" ) 734 735 self.set_complexity_measure( f"E", C.E ) 736 self.set_complexity_measure( f"S", C.S ) 737 self.set_complexity_measure( f"T_inf", C.T ) 738 self.set_complexity_measure( f"T_L", C.T_L.tolist() ) 739 self.set_complexity_measure( f"H_L", C.H_L.tolist() ) 740 self.set_complexity_measure( f"h_mu_L", C.h_mu_L.tolist() ) 741 self.set_complexity_measure( f"H_sync", C.H_sync.tolist() ) 742 743 return C
"Run block entropy convergence and return a dict with $\mathbf{E}, \mathbf{S}, $\mathbf{T},\mathbf{T}(L), \mathcal{H}(L),$ and $h_{\mu}(L)$.
747 def E( self ) -> float : 748 749 """The *excess entropy* [^crutchfield_exact_2016], p.4: 750 751 .. math:: 752 753 \\mathbf{E} \\equiv \\sum_{L=1}^{\\infty} I[X_{-\\infty:0}; X_{0:\\infty}] 754 755 Computed via :meth:`get_msp` and :meth:`amachine.am_msp.MSP.get_E_S_T`, or :meth:`amachine.am_fast.block_entropy_convergence` 756 757 .. note:: 758 759 **Interpretations** 760 761 * The information from the past that reduces uncertainty in the future [^crutchfield_exact_2016]. 762 * How much information an observer must extract to synchronize to the process. 763 * Measures how long the process appears more complex than it asymptotically is. 764 * Vanishes for immediately synchronizable processes. 765 766 Returns: 767 768 float: :math:`\\mathbf{E}` 769 770 [^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral 771 Decomposition of Intrinsic Computation*, 2016. 772 <https://arxiv.org/abs/1309.3792> 773 """ 774 775 m = self.get_complexity_measure_if_exists( "E" ) 776 777 if m is not None : 778 return m 779 780 try : 781 msp = self.get_msp() 782 E, S, T = msp.get_E_S_T() 783 self.set_complexity_measure( "E", E ) 784 self.set_complexity_measure( "S", S ) 785 self.set_complexity_measure( "T_inf", T ) 786 787 except Exception as e : 788 789 print( f"MSP failed {e}" ) 790 791 C = self.block_convergence() 792 E = C.E 793 self.set_complexity_measure( "E", E ) 794 795 return E
The excess entropy 1, p.4:
$$\mathbf{E} \equiv \sum_{L=1}^{\infty} I[X_{-\infty:0}; X_{0:\infty}]$$
Computed via get_msp() and amachine.am_msp.MSP.get_E_S_T(), or amachine.am_fast.block_entropy_convergence()
Interpretations
- The information from the past that reduces uncertainty in the future 1.
- How much information an observer must extract to synchronize to the process.
- Measures how long the process appears more complex than it asymptotically is.
- Vanishes for immediately synchronizable processes.
Returns:
float: \( \mathbf{E} \)
-
Crutchfield et al., Exact Complexity: The Spectral Decomposition of Intrinsic Computation, 2016. https://arxiv.org/abs/1309.3792 ↩
799 def S( self ) -> float : 800 801 """The *synchronization* information: 802 803 .. math:: 804 805 \\mathbf{S} \\equiv \\sum_{L=1}^{\\infty} \\mathcal{H}(L), 806 807 where :math:`\\mathcal{H}(L)` is the average state uncertainty having seen all length-L words [^crutchfield_exact_2016], p.4. 808 809 .. note:: 810 811 **Interpretations** 812 813 * The total amount of state information that an observer must extract to become synchronized [^crutchfield_exact_2016]. 814 815 Computed via :meth:`get_msp` and :meth:`amachine.am_msp.MSP.get_E_S_T`, or :meth:`amachine.am_fast.block_entropy_convergence` 816 817 Returns: 818 819 float: :math:`\\mathbf{S}` 820 821 [^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral 822 Decomposition of Intrinsic Computation*, 2016. 823 <https://arxiv.org/abs/1309.3792> 824 """ 825 826 m = self.get_complexity_measure_if_exists( "S" ) 827 828 if m is not None : 829 return m 830 831 try : 832 msp = self.get_msp() 833 E, S, T = msp.get_E_S_T() 834 self.set_complexity_measure( "E", E ) 835 self.set_complexity_measure( "S", S ) 836 self.set_complexity_measure( "T_inf", T ) 837 838 except Exception as e : 839 print( f"{e} \nFalling back to iterative estimation.") 840 exit() 841 842 C = self.block_convergence() 843 S = C.S 844 self.set_complexity_measure( "S", S ) 845 846 return S
The synchronization information:
$$\mathbf{S} \equiv \sum_{L=1}^{\infty} \mathcal{H}(L),$$
where \( \mathcal{H}(L) \) is the average state uncertainty having seen all length-L words 1, p.4.
Interpretations
- The total amount of state information that an observer must extract to become synchronized 1.
Computed via get_msp() and amachine.am_msp.MSP.get_E_S_T(), or amachine.am_fast.block_entropy_convergence()
Returns:
float: \( \mathbf{S} \)
-
Crutchfield et al., Exact Complexity: The Spectral Decomposition of Intrinsic Computation, 2016. https://arxiv.org/abs/1309.3792 ↩
850 def T_inf( self ) -> float : 851 852 """The *transient information*[^crutchfield_exact_2016], p.4: 853 854 .. math:: 855 856 \\mathbf{T} \\equiv \\sum_{L=1}^{\\infty} L \\left[ h_{\\mu}(L) - h_{\\mu} \\right] 857 858 Computed via :meth:`get_msp` and :meth:`amachine.am_msp.MSP.get_E_S_T`, or :meth:`amachine.am_fast.block_entropy_convergence` 859 860 .. note:: 861 862 **Interpretations** 863 864 * The amount of information one must extract from observations so that the block entropy converges to its linear asymptote[^crutchfield_exact_2016]. 865 866 Returns: 867 868 float: :math:`\\mathbf{T}` 869 870 [^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral 871 Decomposition of Intrinsic Computation*, 2016. 872 <https://arxiv.org/abs/1309.3792> 873 """ 874 875 m = self.get_complexity_measure_if_exists( "T_inf" ) 876 877 if m is not None : 878 return m 879 880 try : 881 msp = self.get_msp() 882 E, S, T = msp.get_E_S_T() 883 self.set_complexity_measure( "E", E ) 884 self.set_complexity_measure( "S", S ) 885 self.set_complexity_measure( "T_inf", T ) 886 887 except Exception as e : 888 print( f"{e} \nFalling back to iterative estimation.") 889 C = self.block_convergence() 890 T_inf = C.T 891 self.set_complexity_measure( "T_inf", T_inf ) 892 893 return T_inf
The transient information1, p.4:
$$\mathbf{T} \equiv \sum_{L=1}^{\infty} L \left[ h_{\mu}(L) - h_{\mu} \right]$$
Computed via get_msp() and amachine.am_msp.MSP.get_E_S_T(), or amachine.am_fast.block_entropy_convergence()
Interpretations
- The amount of information one must extract from observations so that the block entropy converges to its linear asymptote1.
Returns:
float: \( \mathbf{T} \)
-
Crutchfield et al., Exact Complexity: The Spectral Decomposition of Intrinsic Computation, 2016. https://arxiv.org/abs/1309.3792 ↩
897 def chi( self ) -> float : 898 899 """Computes the foward crypticity[^crutchfield_crypticity_2009][^Mahoney_crypticity_2021], p.2: 900 901 .. math:: 902 903 \\chi = C_{\\mu} - \\mathbf{E} 904 905 :math:`C_{\\mu}` is trivially computed from the stationary distribution in :meth:`C_mu` and :math:`\\mathbf{E}` in :meth:`E`. 906 907 .. note:: 908 909 **Interpretations** 910 911 * Difference between internal stored information and apparent information to an observer. 912 * How muching information is hiding in the system. 913 914 Returns: 915 916 float: :math:`\\chi` 917 918 [^crutchfield_crypticity_2009]: Crutchfield et al., Time’s barbed arrow: Irreversibility, crypticity, and stored information, 2009. 919 <https://arxiv.org/abs/0902.1209> 920 921 [^Mahoney_crypticity_2021]: Mahoney et al., Information Accessibility and Cryptic Processes, 2021. 922 <https://arxiv.org/abs/0905.4787> 923 """ 924 925 m = self.get_complexity_measure_if_exists( "chi" ) 926 927 if m is not None : 928 return m 929 930 chi = self.C_mu() - self.E() 931 932 if chi < 0 : 933 934 # if chi is 0, accumulated floating point error can result in small negative values 935 if chi < -1e-5: 936 warnings.warn(f"Crypticity is negative ({chi:.6e}).") 937 938 chi = np.clamp( chi, 0 ) 939 940 self.set_complexity_measure( "chi", chi ) 941 942 return chi
Computes the foward crypticity12, p.2:
$$\chi = C_{\mu} - \mathbf{E}$$
\( C_{\mu} \) is trivially computed from the stationary distribution in C_mu() and \( \mathbf{E} \) in E().
Interpretations
- Difference between internal stored information and apparent information to an observer.
- How muching information is hiding in the system.
Returns:
float: \( \chi \)
-
Crutchfield et al., Time’s barbed arrow: Irreversibility, crypticity, and stored information, 2009. https://arxiv.org/abs/0902.1209 ↩
-
Mahoney et al., Information Accessibility and Cryptic Processes, 2021. https://arxiv.org/abs/0905.4787 ↩
948 def is_row_stochastic(self) : 949 950 """ 951 Check that all states have outgoing transition probabilities that sum to 1. 952 """ 953 954 sums = np.zeros( len( self.states ) ) 955 for tr in self.transitions : 956 sums[ tr.origin_state_idx ] += tr.prob 957 return np.allclose( sums, 1.0 )
Check that all states have outgoing transition probabilities that sum to 1.
961 def is_unifilar(self) : 962 963 """ 964 Check that no state emits the same symbol on transitions to different states. 965 """ 966 967 symbol_trs = np.full( ( len( self.states ), len( self.alphabet) ), -1 ) 968 969 for tr in self.transitions : 970 971 if symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] == -1 : 972 symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] = tr.target_state_idx 973 974 elif symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] != tr.target_state_idx : 975 return False 976 977 return True
Check that no state emits the same symbol on transitions to different states.
981 def is_strongly_connected(self) : 982 983 """ 984 Check if every state is reachable from every other state. Relies on [nx.is_strongly_connected](https://networkx.org/documentation/latest/reference/algorithms/generated/networkx.algorithms.components.is_strongly_connected.html). 985 """ 986 987 return nx.is_strongly_connected( self.as_digraph() )
Check if every state is reachable from every other state. Relies on nx.is_strongly_connected.
991 def is_aperiodic(self) : 992 993 """ 994 Checks if machine is periodic. Relies on [nx.is_aperiodic](https://networkx.org/documentation/latest/reference/algorithms/generated/networkx.algorithms.dag.is_aperiodic.html), "A strongly connected directed graph is aperiodic if there is no integer k > 1 that divides the length of every cycle in the graph." 995 """ 996 997 return nx.is_aperiodic( self.as_digraph() )
Checks if machine is periodic. Relies on nx.is_aperiodic, "A strongly connected directed graph is aperiodic if there is no integer k > 1 that divides the length of every cycle in the graph."
1020 def is_topological_epsilon_machine( self, verbose=True ) : 1021 1022 """ 1023 Checks if the HMM is a topological $\\epsilon$-machine [^1]. 1024 1025 [^1]: Johnson et al, Enumerating Finitary Processes, 2024. 1026 <https://arxiv.org/abs/1011.0036> 1027 """ 1028 1029 if not ( self.is_unifilar() and self.is_strongly_connected() ) : 1030 if verbose : 1031 print( f"Either non unifilar or not strongly connected" ) 1032 return False 1033 else : 1034 return self._is_minimal_as_dfa( topological_only=True, verbose=verbose )
Checks if the HMM is a topological $\epsilon$-machine 1.
-
Johnson et al, Enumerating Finitary Processes, 2024. https://arxiv.org/abs/1011.0036 ↩
1036 def is_epsilon_machine( self, verbose=True ) : 1037 1038 if not ( self.is_unifilar() and self.is_strongly_connected() ) : 1039 if verbose : 1040 print( f"Either non unifilar or not strongly connected" ) 1041 return False 1042 else : 1043 return self._is_minimal_as_dfa( topological_only=False, verbose=verbose )
1049 def minimize(self, retain_names: bool = True, verbose=False): 1050 1051 """ 1052 Minimizes the HMM, resulting in an :math:`\\epsilon-`machine if the HMM 1053 is unifilar and strongly connected. Converts the HMM to a DFA with symbols 1054 labeled jointly with symbols and probabilities, and uses Myhill-Nerode 1055 equivalence for minimization. Relies on `automata_lib` and uses 1056 `automata.fa.dfa.DFA.minify` with `allow_partial=True`, and all states 1057 final. 1058 1059 Args: 1060 retain_names (bool): If `True`, the merged states will be named by their union, e.g. `{s_0, s_1}`, and other states will retain their origion names. Otherwise, they will be relabled `{ '0', '1', ..., 'n-1' }`. 1061 1062 Returns: 1063 1064 automata.fa.dfa.DFA : the resulting DFA. 1065 """ 1066 1067 if self.is_minimal : 1068 return 1069 1070 start = time.perf_counter() 1071 1072 if not self.is_unifilar(): 1073 raise ValueError( 1074 "DFA minimization is not valid for non-unifilar HMMs" 1075 ) 1076 1077 was_strongly_connected = self.is_strongly_connected() 1078 1079 was_row_stochastic = self.is_row_stochastic() 1080 n_states_before = len(self.states) 1081 1082 dfa = self.as_dfa(with_probs=True) 1083 1084 #min_dfa = self.as_dfa(with_probs=True).minify(retain_names=True) 1085 min_dfa = am_fast.minify_cpp( dfa, retain_names=True ) 1086 1087 # Build lookup from original state index -> CausalState object 1088 orig_state = {i: s for i, s in enumerate(self.states)} 1089 eq_list = list(min_dfa.states) 1090 1091 start_eq = min_dfa.initial_state 1092 1093 # Separate the start state, then sort the rest by the 1094 # smallest original state index inside each equivalence class. 1095 other_eqs = [eq for eq in eq_list if eq != start_eq] 1096 other_eqs.sort(key=lambda eq: min(eq)) 1097 1098 # Recombine so start eq comes first, followed by the sorted remaining classes 1099 eq_list = [start_eq] + other_eqs 1100 # ---------------------------------------------------------- 1101 1102 # Recompute eq_to_idx with the new ordering 1103 eq_to_idx = {eq: i for i, eq in enumerate(eq_list)} 1104 1105 # new_start is now guaranteed to be 0 1106 new_start = 0 1107 1108 # Map each original state index -> its equivalence class 1109 # Guard: minify() silently drops unreachable states 1110 orig_to_eq = {s: eq for eq in min_dfa.states for s in eq} 1111 1112 # Build lookup from original state index -> its transitions 1113 orig_trs = defaultdict(list) 1114 for t in self.transitions: 1115 orig_trs[t.origin_state_idx].append(t) 1116 1117 new_trs = [] 1118 for eq in min_dfa.states: 1119 rep = next(iter(eq)) 1120 origin_idx = eq_to_idx[eq] 1121 for t in orig_trs[rep]: 1122 target_eq = orig_to_eq[t.target_state_idx] 1123 target_idx = eq_to_idx[target_eq] 1124 new_trs.append(Transition( 1125 origin_state_idx = origin_idx, 1126 target_state_idx = target_idx, 1127 prob = t.prob, 1128 symbol_idx = t.symbol_idx, 1129 )) 1130 1131 members_list = [[orig_state[i] for i in sorted(eq)] for eq in eq_list] # sorted for determinism 1132 1133 # Compute new names 1134 if retain_names: 1135 new_names = [ 1136 "{" + ",".join(str(m.name) for m in members) + "}" if len(members) > 1 1137 else members[0].name 1138 for members in members_list 1139 ] 1140 else: 1141 new_names = [str(j) for j in range(len(eq_list))] 1142 1143 old_name_to_new_name = { 1144 m.name: new_names[j] 1145 for j, members in enumerate(members_list) 1146 for m in members 1147 } 1148 1149 # Build the new states, preserving classes and isomorphs regardless of naming 1150 new_states = [] 1151 for j, (eq, members, name) in enumerate(zip(eq_list, members_list, new_names)): 1152 classes = set().union(*(m.classes for m in members)) 1153 isomorphs = { 1154 old_name_to_new_name.get(iso, iso) 1155 for m in members 1156 for iso in m.isomorphs 1157 if old_name_to_new_name.get(iso, iso) != name 1158 } 1159 new_states.append(CausalState( 1160 name = name, 1161 classes = classes, 1162 isomorphs = isomorphs, 1163 )) 1164 1165 self.set_states(new_states) 1166 self.set_transitions(new_trs) 1167 self.start_state = new_start 1168 1169 if n_states_before == len(new_states) and verbose : 1170 print( f"{n_states_before} state HMM was already minimal.\n" ) 1171 elif verbose : 1172 print( f"Minimized from {n_states_before} to {len(new_states)}\n" ) 1173 1174 if not ( was_strongly_connected == self.is_strongly_connected() ) : 1175 raise RuntimeError( 1176 f"Minimization broke strongly connected" 1177 ) 1178 1179 if not ( was_row_stochastic == self.is_row_stochastic() ) : 1180 raise RuntimeError( 1181 f"Minimization broke row stochasticity" 1182 ) 1183 1184 self.is_minimal = True
Minimizes the HMM, resulting in an \( \epsilon- \)machine if the HMM
is unifilar and strongly connected. Converts the HMM to a DFA with symbols
labeled jointly with symbols and probabilities, and uses Myhill-Nerode
equivalence for minimization. Relies on automata_lib and uses
automata.fa.dfa.DFA.minify with allow_partial=True, and all states
final.
Arguments:
- retain_names (bool): If
True, the merged states will be named by their union, e.g.{s_0, s_1}, and other states will retain their origion names. Otherwise, they will be relabled{ '0', '1', ..., 'n-1' }.
Returns:
automata.fa.dfa.DFA : the resulting DFA.
1191 def collapse_to_largest_strongly_connected_subgraph( self, rename_states=True ) : 1192 1193 # get equivalent networkx graph 1194 G = self.as_digraph() 1195 1196 # if already strongly connected, nothing to do 1197 if not nx.is_strongly_connected( G ) : 1198 1199 start = time.perf_counter() 1200 subgraph_nodes = list( nx.strongly_connected_components( G ) ) 1201 1202 # decompose into strongly connected components and sort by length 1203 # subgraph_nodes = list(nx.strongly_connected_components( G )) 1204 subgraph_nodes.sort(key=len) 1205 component_state_set = subgraph_nodes[-1] 1206 1207 # Take the largest strongly connected component (as list of state names) 1208 component_states = sorted( list( component_state_set ) ) 1209 1210 # make temporary copies of the old transitions and states 1211 old_transitions = [ 1212 Transition( 1213 origin_state_idx=tr.origin_state_idx, 1214 target_state_idx=tr.target_state_idx, 1215 prob=tr.prob, 1216 symbol_idx=tr.symbol_idx 1217 ) 1218 1219 for tr in self.transitions 1220 ] 1221 1222 old_states = [ 1223 CausalState( 1224 name=s.name, 1225 classes=s.classes, 1226 isomorphs=s.isomorphs 1227 ) 1228 for s in self.states 1229 ] 1230 1231 self.set_states( 1232 states=[ 1233 state 1234 for i, state in enumerate( old_states ) if i in component_state_set 1235 ] 1236 ) 1237 1238 # we will build new transition list based on those belonging to the component 1239 self.set_transitions( transitions= [] ) 1240 1241 # for tracking which new transitions leave each state 1242 transitions_from_state = { state : set() for state in component_states } 1243 new_transitions = [] 1244 1245 for tr in old_transitions : 1246 1247 origin_state_name = old_states[ tr.origin_state_idx ].name 1248 target_state_name = old_states[ tr.target_state_idx ].name 1249 1250 # skip transitions that connect separate strongly connected components 1251 if not ( tr.origin_state_idx in component_state_set and tr.target_state_idx in component_state_set ) : 1252 continue 1253 1254 # track transitions (by index in new transitions list) that leave this state 1255 transitions_from_state[ tr.origin_state_idx ].add( len( new_transitions ) ) 1256 1257 my_origin_state_idx = self.state_idx_map[ origin_state_name ] 1258 my_target_state_idx = self.state_idx_map[ target_state_name ] 1259 1260 new_transitions.append( 1261 Transition( 1262 origin_state_idx=my_origin_state_idx, 1263 target_state_idx=my_target_state_idx, 1264 prob=tr.prob, 1265 symbol_idx=tr.symbol_idx 1266 ) ) 1267 1268 self.set_transitions( transitions=new_transitions ) 1269 1270 # if we removed an outgoing transition from a state, we need to distribute its probability 1271 # among the remaining outgoing transitions from the state 1272 for state in component_states : 1273 1274 # get the set of transitions leaving this state 1275 state_trs = transitions_from_state[ state ] 1276 1277 # sum the probabilities of the outgoing transitions from the state 1278 p_sum = np.sum( [ self.transitions[ i ].prob for i in state_trs ] ) 1279 1280 # how much probability is missing 1281 diff = 1.0 - p_sum 1282 1283 # if significant difference 1284 if abs( diff ) > self.EPS : 1285 1286 # calculate how much of the difference each transition gets 1287 adjustment = diff / len( state_trs ) 1288 1289 # update the transitions 1290 for i in state_trs : 1291 1292 # transition with origional probability 1293 tr = self.transitions[ i ] 1294 1295 # adjusted probability 1296 self.transitions[ i ] = Transition( 1297 origin_state_idx=tr.origin_state_idx, 1298 target_state_idx=tr.target_state_idx, 1299 prob=tr.prob + adjustment, 1300 symbol_idx=tr.symbol_idx ) 1301 1302 if rename_states : 1303 self.set_states( [ 1304 CausalState( name=f"{i}" ) 1305 for i, s in enumerate( self.states ) 1306 ] )
1309 def to_q_weighted( self, denominator_limit=1000 ) : 1310 1311 """ 1312 Approximates the existing transition probabilities with exact fractions, stores 1313 the fractional probabilities as Fraction in Transition.pq, and sets the floating 1314 point probabilty to `float(pq)`. If `denominator_limit` is too small for a sane 1315 conversion, the function recurses with `denominator_limit=denominator_limit*10`. 1316 1317 Args: 1318 denominator_limit (int): The initial input to :meth:`Fraction.limit_denominator` in 1319 the conversion. 1320 """ 1321 1322 if self.is_q_weighted : 1323 return 1324 1325 if not self.is_row_stochastic() : 1326 raise ValueError( "Cannot convert to q-weighted because not row stochastic" ) 1327 1328 t_from = [[] for _ in range(len(self.states))] 1329 1330 for i, tr in enumerate( self.transitions ) : 1331 t_from[ tr.origin_state_idx ].append( i ) 1332 1333 new_transitions = [] 1334 for t_list in t_from : 1335 1336 if not t_list : 1337 continue 1338 1339 p_q_sum = Fraction(0,1) 1340 p_qs = [] 1341 1342 for t_idx in t_list : 1343 1344 p_q = Fraction( self.transitions[ t_idx ].prob ).limit_denominator( denominator_limit ) 1345 p_q_sum += p_q 1346 p_qs.append( p_q ) 1347 1348 if p_q_sum != Fraction(1,1) : 1349 1350 max_pq_i = np.argmax( p_qs ) 1351 max_oq = p_qs[ max_pq_i ] 1352 1353 diff = p_q_sum - Fraction(1,1) 1354 1355 # If recurse with higher resolution 1356 if diff > max_oq : 1357 return self.to_q_weighted( denominator_limit*10 ) 1358 else : 1359 p_qs[ max_pq_i ] -= diff 1360 1361 for i, t_idx in enumerate( t_list ) : 1362 new_transitions.append( 1363 Transition( 1364 origin_state_idx=self.transitions[ t_idx ].origin_state_idx, 1365 target_state_idx=self.transitions[ t_idx ].target_state_idx, 1366 prob=float(p_qs[ i ]), 1367 symbol_idx=self.transitions[ t_idx ].symbol_idx, 1368 pq=p_qs[ i ] 1369 ) 1370 ) 1371 1372 self.set_transitions( new_transitions ) 1373 self.is_q_weighted = True
Approximates the existing transition probabilities with exact fractions, stores
the fractional probabilities as Fraction in Transition.pq, and sets the floating
point probabilty to float(pq). If denominator_limit is too small for a sane
conversion, the function recurses with denominator_limit=denominator_limit*10.
Arguments:
- denominator_limit (int): The initial input to
Fraction.limit_denominator()in - the conversion.
1379 def isomorphic_shift( 1380 self, 1381 input_symbol_indices: np.ndarray, 1382 input_state_indices: np.ndarray, 1383 shift : int = 1 1384 ) -> dict[str, np.ndarray]: 1385 1386 """ 1387 Generates a new sequence of of symbols that are permuted with the symbols emitted by 1388 isomorphic states, if they exists. 1389 1390 :math:`\\sigma_o = \\mathcal{S}\\left[\\texttt{input\\_state\\_indices}[i]\\right]`<br> 1391 :math:`\\sigma_t = \\mathcal{S}\\left[\\texttt{input\\_state\\_indices}[i+1]\\right]` 1392 1393 :math:`\\mathcal{I}(\\sigma_o) = \\\\{\\sigma^0_o,\\, \\sigma^1_o,\\, \\dots,\\, \\sigma^{n-1}_o \\\\}`<br> 1394 :math:`\\mathcal{I}(\\sigma_t) = \\\\{\\sigma^0_t,\\, \\sigma^1_t,\\, \\dots,\\, \\sigma^{n-1}_t \\\\}` 1395 1396 :math:`k = \\bigl(\\mathcal{I}(\\sigma_o).\\texttt{index}(\\sigma_o) + \\texttt{shift}\\bigr) \\bmod n` 1397 1398 :math:`\\texttt{output\\_symbol\\_indices}[i] := T(\\sigma_o^k,\\, \\sigma_t^k).\\text{symbol\\_index}`<br> 1399 :math:`\\texttt{output\\_state\\_indices}[i] := \\mathcal{S}.\\texttt{index}(\\sigma_o^k)`<br> 1400 :math:`\\texttt{output\\_state\\_indices}[i+1] := \\mathcal{S}.\\texttt{index}(\\sigma_t^k)` 1401 1402 Where :math:`\\mathcal{I}(\\sigma)`` is the ordered set of states isomorphic to :math:`\\sigma` including :math:`\\sigma` itself. 1403 1404 Args: 1405 input_symbol_indices (np.ndarray): The sequence of generated symbols. 1406 input_state_indices (np.ndarray): The sequence of states that generated symbols with the final state at the end. 1407 shift : int: How much to shift the symbols across the isomorphic states. 1408 """ 1409 1410 if not any( state.isomorphs for state in self.states ): 1411 raise ValueError("HMM has no states with isomorphs") 1412 1413 inputs = np.asarray(input_symbol_indices) 1414 states = np.asarray(input_state_indices) 1415 1416 n_states = len(self.states) 1417 1418 tr_sym_table = np.full((n_states, n_states), -1, dtype=np.int32) 1419 1420 for tr in self.transitions: 1421 tr_sym_table[ tr.origin_state_idx, tr.target_state_idx ] = tr.symbol_idx 1422 1423 # Build isomorph remapping: identity by default, overridden where isomorphs exist 1424 iso_table = np.arange(n_states, dtype=np.int32) 1425 1426 for i, state in enumerate(self.states): 1427 if state.isomorphs: 1428 isomorph = sorted(state.isomorphs)[0] 1429 iso_table[ i ] = self.state_idx_map[isomorph] 1430 1431 for i, state in enumerate(self.states): 1432 if state.isomorphs: 1433 1434 # extend the isomorph list to include the identity 1435 isormorphs_with_identity = sorted( [ i ] + [ self.state_idx_map[iso] for iso in state.isomorphs ] ) 1436 1437 # find the identity index 1438 pos = isormorphs_with_identity.index( i ) 1439 1440 # cyclical shift 1441 iso_table[ i ] = isormorphs_with_identity[ ( pos + shift ) % len( isormorphs_with_identity ) ] 1442 1443 origins = states[:-1] 1444 targets = states[1:] 1445 1446 out_origins = iso_table[origins] 1447 out_targets = iso_table[targets] 1448 1449 inv_sym = tr_sym_table[ out_origins, out_targets ] 1450 inv_sts = np.empty(states.size, dtype=states.dtype) 1451 1452 inv_sts[:-1] = out_origins 1453 inv_sts[-1] = out_targets[-1] 1454 1455 return { 1456 "symbol_index": inv_sym.astype(inputs.dtype), 1457 "state_index": inv_sts, 1458 }
Generates a new sequence of of symbols that are permuted with the symbols emitted by isomorphic states, if they exists.
\( \sigma_o = \mathcal{S}\left[\texttt{input_state_indices}[i]\right] \)
\( \sigma_t = \mathcal{S}\left[\texttt{input_state_indices}[i+1]\right] \)
\( \mathcal{I}(\sigma_o) = \{\sigma^0_o,\, \sigma^1_o,\, \dots,\, \sigma^{n-1}_o \} \)
\( \mathcal{I}(\sigma_t) = \{\sigma^0_t,\, \sigma^1_t,\, \dots,\, \sigma^{n-1}_t \} \)
\( k = \bigl(\mathcal{I}(\sigma_o).\texttt{index}(\sigma_o) + \texttt{shift}\bigr) \bmod n \)
\( \texttt{output_symbol_indices}[i] := T(\sigma_o^k,\, \sigma_t^k).\text{symbol_index} \)
\( \texttt{output_state_indices}[i] := \mathcal{S}.\texttt{index}(\sigma_o^k) \)
\( \texttt{output_state_indices}[i+1] := \mathcal{S}.\texttt{index}(\sigma_t^k) \)
Where \( \mathcal{I}(\sigma) \)` is the ordered set of states isomorphic to \( \sigma \) including \( \sigma \) itself.
Arguments:
- input_symbol_indices (np.ndarray): The sequence of generated symbols.
- input_state_indices (np.ndarray): The sequence of states that generated symbols with the final state at the end.
- shift : int: How much to shift the symbols across the isomorphic states.
1460 def generate_data( 1461 self, 1462 file_prefix: str, 1463 n_gen: int, 1464 include_states: bool, 1465 isomorphic_shifts : set[int]=None, 1466 random_seed : int=42 ) -> dict[any] : 1467 1468 trs = [ [] for _ in range( len( self.states ) ) ] 1469 for tr in self.transitions : 1470 trs[ tr.origin_state_idx ].append( ( 1471 tr.symbol_idx, 1472 float( tr.prob ), 1473 tr.target_state_idx ) ) 1474 1475 data = am_fast.generate_data( 1476 n_gen=n_gen, 1477 start_state=self.start_state, 1478 transitions=trs, 1479 alphabet=sorted(list(self.alphabet)), 1480 include_states=include_states, 1481 random_seed=random_seed 1482 ) 1483 1484 if isomorphic_shifts is not None : 1485 1486 if not include_states : 1487 raise ValueError( "Isomorphic inversion requires include_states=True" ) 1488 1489 data[ "isomorphic_shifts" ] = {} 1490 1491 for shift in isomorphic_shifts : 1492 1493 try : 1494 1495 shifted = self.isomorphic_shift( 1496 input_symbol_indices=data[ "symbol_index" ], 1497 input_state_indices=data[ "state_index" ], 1498 shift=shift 1499 ) 1500 1501 data[ "isomorphic_shifts" ][ shift ] = { 1502 "symbol_index" : shifted[ "symbol_index" ], 1503 "state_index" : shifted[ "state_index" ] 1504 } 1505 1506 except Exception as e : 1507 print( f"Exception {e}" ) 1508 1509 am_fast.save_data( 1510 data=data, 1511 file_prefix=file_prefix, 1512 alphabet=sorted(list(self.alphabet)), 1513 n_states=len( self.states ), 1514 start_state=self.start_state, 1515 random_seed=random_seed, 1516 machine_metadata=self.get_metadata() ) 1517 1518 return data
1524 def draw_graph( 1525 self, 1526 engine : str = 'dot', 1527 output_dir : Path = Path('.'), 1528 show : bool = True 1529 ) -> None : 1530 1531 """ 1532 Draws the machine using [pygraphiviz](https://pygraphviz.github.io/documentation/stable/) and saves it. 1533 1534 Returns: 1535 1536 networkx.DiGraph : the resulting graph. 1537 """ 1538 1539 G = self.as_digraph() 1540 1541 subgraphs = None if nx.is_strongly_connected( G ) else list( nx.strongly_connected_components( G ) ) 1542 1543 am_vis.draw_graph( 1544 self, 1545 output_dir=output_dir, 1546 title="am_graph", 1547 view=show, 1548 subgraphs=subgraphs, 1549 engine=engine )
1556 def as_digraph( self ) -> nx.DiGraph : 1557 1558 """ 1559 Builds a [networkx.DiGraph](https://networkx.org/documentation/stable/reference/classes/digraph.html) constructed from the machine's transitions with no edge symbols or weights. 1560 1561 Returns: 1562 1563 networkx.DiGraph : the resulting graph. 1564 """ 1565 1566 G = nx.DiGraph() 1567 G.add_nodes_from( [ i for i, s in enumerate( self.states ) ] ) 1568 1569 for tr in self.transitions : 1570 G.add_edge( tr.origin_state_idx, tr.target_state_idx ) 1571 1572 return G
Builds a networkx.DiGraph constructed from the machine's transitions with no edge symbols or weights.
Returns:
networkx.DiGraph : the resulting graph.
1574 def as_dfa( self, with_probs : bool ) : 1575 1576 """ 1577 Builds an [automata.fa.dfa.DFA](https://caleb531.github.io/automata/api/fa/class-dfa/) constructed from the machine's transitions. 1578 1579 Args: 1580 with_probs (bool): If true the DFA transitions are labeled based on 1581 the symbol of the machines transition concatenated with its 1582 probability, othwise, the only the symbols. 1583 1584 Returns: 1585 1586 automata.fa.dfa.DFA : the resulting DFA. 1587 """ 1588 1589 precision=8 1590 1591 def edge_label( symb, prob ) : 1592 return f"({symb},{round(prob, precision)})" 1593 1594 # Build states, symbols, and transitions 1595 dfa_states = { i for i, _ in enumerate( self.states ) } 1596 1597 if not with_probs : 1598 dfa_symbols = set( { t.symbol_idx for t in self.transitions } ) 1599 else : 1600 dfa_symbols = set( { edge_label( t.symbol_idx, t.prob ) for t in self.transitions } ) 1601 1602 dfa_transitions = defaultdict(dict) 1603 1604 if not with_probs : 1605 for t in self.transitions : 1606 dfa_transitions[ t.origin_state_idx ][ t.symbol_idx ] = t.target_state_idx 1607 else : 1608 for t in self.transitions : 1609 dfa_transitions[ t.origin_state_idx ][ edge_label( t.symbol_idx, t.prob ) ] = t.target_state_idx 1610 1611 # Construct the DFA 1612 return DFA( 1613 states=dfa_states, 1614 input_symbols=dfa_symbols, 1615 transitions=dfa_transitions, 1616 initial_state=self.start_state, 1617 allow_partial=True, 1618 final_states={ 1619 s for s in dfa_states 1620 } 1621 )
Builds an automata.fa.dfa.DFA constructed from the machine's transitions.
Arguments:
- with_probs (bool): If true the DFA transitions are labeled based on the symbol of the machines transition concatenated with its probability, othwise, the only the symbols.
Returns:
automata.fa.dfa.DFA : the resulting DFA.