GitLab Repo

amachine.am_hmm

   1from __future__ import annotations
   2
   3from abc import abstractmethod
   4from typing import override
   5import gc
   6import copy
   7from collections import deque
   8from collections import defaultdict
   9import json
  10from fractions import Fraction
  11from pathlib import Path
  12import warnings
  13import time
  14from dataclasses import dataclass, asdict, field
  15
  16import networkx as nx
  17
  18from automata.fa.dfa import DFA
  19
  20import sympy
  21import numpy as np
  22
  23from .am_machine import Machine
  24from .am_symbol import Symbol
  25
  26from   .am_msp import MSP, compute_msp, compute_msp_exact
  27from   .am_solve import solve_for_pi, solve_for_pi_fractional
  28from . import am_vis
  29from . import am_fast
  30
  31from .am_causal_state import CausalState
  32from .am_transition   import Transition
  33
  34from .am_fast.distance import jensenshannondivergence_gpu as af_jensenshannondivergence
  35from .am_fast.distance import jensenshannondivergence_cpu as af_jensenshannondivergence_cpu
  36
  37class HMM(Machine) :
  38
  39	"""Hidden Markov model implementing epsilon machines, mixed state presentations,
  40	complexity measures, and data generation.
  41
  42	Args:
  43		states (list[CausalState] | None): A list of causal states.
  44		transitions (list[Transition] | None): A list of transitions between states.
  45		start_state (int): Index of the start state.
  46		alphabet (list[str]): List of symbols making up the alphabet.
  47		name (str): Name of the model.
  48		description (str): Description of the model.
  49	"""
  50
  51	def __init__( 
  52		self,
  53		states : list[CausalState] | None = None,
  54		transitions : list[Transition] | None = None,
  55		start_state : int = 0,
  56		alphabet : list[str] | None = None,
  57		name : str = "",
  58		description : str = "" ) : 
  59
  60		self.alphabet    : list[str] | None = alphabet or []
  61		self.states      : list[CausalState] | None= states or []
  62		self.transitions : list[Transition] | None= transitions or []
  63
  64		self.set_alphabet( self.alphabet )
  65		self.set_states( self.states )
  66		self.set_transitions( self.transitions )
  67
  68		self.start_state : int = start_state
  69
  70		self.name : str = name
  71		self.description : str = description
  72
  73		# To be depreciated
  74		self.isoclass : str = None
  75
  76		# --- derived --------
  77
  78		self.complexity : dict[str, any] = {}
  79
  80		self.symbol_idx_map : dict[str,int] = {}
  81		self.state_idx_map  : dict[str,int] = {}
  82
  83		self.pi_fractional = None
  84		self.pi : np.ndarray | None = None
  85		self.T  : np.ndarray | None = None
  86		self.msp : MSP | None = None
  87		self.reverse_am : HMM | None = None
  88		self.is_q_weighted : bool = False
  89		self.is_minimal : bool = False
  90
  91		# --- const ----------
  92
  93		self.EPS : float = 1e-12
  94
  95	#-------------------------------------------------------#
  96	#                Overrides / Getters                    #
  97	#-------------------------------------------------------#
  98
  99	@override
 100	def get_states(self) -> list[CausalState] : 
 101		return self.states
 102
 103	@override
 104	def get_transitions(self) -> list[Transition] : 
 105		return self.transitions
 106
 107	@override
 108	def get_alphabet(self) -> list[Symbol]:
 109		return [ Symbol(a) for a in self.alphabet ]
 110
 111	#-------------------------------------------------------#
 112	#                     Serialization                     #
 113	#-------------------------------------------------------#
 114
 115
 116	def to_dict(self) -> dict[str,any]:
 117
 118		"""Create a dict representing the HMM configuration.
 119
 120		Returns:
 121
 122			dict[str,any]: Dictionary containing, name, description, states, transitions, alphabet, and isoclass.
 123		"""
 124
 125		return {
 126			"name"            : self.name,
 127			"description"     : self.description,
 128			"states"          : [ asdict(state)      for state      in self.states      ],
 129			"transitions"     : [ asdict(transition) for transition in self.transitions ],
 130			"alphabet"        : self.alphabet,
 131			"isoclass"        : self.isoclass
 132		}
 133
 134	def from_dict( self, config : dict[str,any] ) :
 135
 136		"""Configure the HMM configuration from a dictionary.
 137
 138		Args:
 139			config (dict[str,any]): The HMM configuration.
 140		"""
 141
 142		self.name            = config.name 
 143		self.description     = config.description 
 144		
 145		self.set_states( states=[ 
 146			CausalState( 
 147				name=state[ "name" ],
 148				classes=set( state[ "classes" ] )
 149			)
 150			for state in config.states
 151		] )
 152
 153		self.set_transitions( transitions=[ 
 154			Transition( 
 155				origin_state_idx=tr[ "origin_state_idx" ],
 156				target_state_idx=tr[ "target_state_idx" ],
 157				prob=tr[ "prob" ],
 158				symbol_idx=tr[ "symbol_idx" ]
 159			)
 160			for tr in config.transitions
 161		] )
 162
 163		self.set_alphabet( alphabet=config.alphabet ) 
 164
 165	def save_config(
 166		self, 
 167		path : Path, 
 168		with_complexity : bool = False, 
 169		with_block_convergence :  bool = False,
 170		with_structural_properties : bool = False,
 171		with_causal_properties : bool = False ) :
 172
 173		config = self.to_dict()
 174
 175		if with_complexity :
 176			
 177			complexity = self.get_complexities( 
 178				with_block_convergence=with_block_convergence
 179			)
 180
 181			config[ "complexity" ] = complexity
 182
 183		config[ "structural_properties" ] = {
 184			"unifilar"              : self.is_unifilar(),
 185			"row_stochastic"        : self.is_row_stochastic(),
 186			"strongly_connected"    : self.is_strongly_connected(),
 187			"aperiodic"             : self.is_aperiodic(),
 188			"minimal"               : self._is_minimal_as_dfa( topological_only=False ),
 189			"topologically_minimal" : self._is_minimal_as_dfa( topological_only=True ),
 190			"is_epsilon_machine"    : self.is_epsilon_machine()
 191		}
 192
 193		with open( path / "am_config.json", "w", encoding="utf-8" ) as f :
 194			json.dump( config, f, ensure_ascii=False, indent=2, default=list )
 195
 196	def from_file( self, path : Path ) :
 197		with open( Path / "am_config.json", "r" ) as f:
 198			config = json.load(f)
 199		self.from_dict()
 200
 201	#-------------------------------------------------------#
 202	#             Setters and State Management              #
 203	#-------------------------------------------------------#
 204	
 205	def _invalidate(self) :
 206
 207		"""
 208		Reset all derived properties so they will be recomputed when requested later.
 209		"""
 210
 211		self.complexity = {}
 212		self.T = None
 213		self.T_x = None
 214		self.pi = None
 215		self.pi_fractional = None
 216		self.msp = None 
 217		self.reverse_am = None
 218		self.is_q_weighted = False
 219		self.is_minimal = False
 220
 221	def set_states( self, states : list[CausalState] ) :
 222		self._invalidate()
 223		self.states = states.copy()
 224		self.state_idx_map = {}
 225		for idx, state in enumerate( self.states ) :
 226			self.state_idx_map[ state.name ] = idx
 227
 228	def set_alphabet( self, alphabet : list[str] ) :
 229
 230		self._invalidate()
 231
 232		old_alphabet = self.alphabet.copy()
 233
 234		aSet = set()
 235		aSet.update( alphabet )
 236
 237		self.alphabet = sorted(list(aSet))
 238
 239		self.symbol_idx_map = {}
 240		for idx, symbol in enumerate( self.alphabet ) :
 241			self.symbol_idx_map[ symbol ] = idx
 242
 243		for i, tr in enumerate( self.transitions ) :
 244			symbol = old_alphabet[ tr.symbol_idx ]
 245			self.transitions[ i ] = Transition(
 246				origin_state_idx=tr.origin_state_idx,
 247				target_state_idx=tr.target_state_idx,
 248				prob=tr.prob,
 249				symbol_idx=self.symbol_idx_map[ symbol ]
 250			)
 251
 252	def set_transitions( self, transitions : list[Transition] ) :
 253		self._invalidate()
 254		self.transitions = transitions.copy()
 255
 256	def extend_states( self, states : list[CausalState] ) :
 257		self.set_states( self.states + states )
 258
 259	def extend_alphabet( self, alphabet : list[str] ) :
 260		self.set_alphabet( self.alphabet + alphabet  )
 261
 262	def extend_transitions( self, transitions : list[Transition] ) :
 263		self.set_transitions( self.transitions + transitions  )
 264
 265	def get_complexity_measure_if_exists(self, measure ) :
 266		m = self.complexity.get( measure, None )
 267		return m
 268
 269	def set_complexity_measure(self, measure, value ) :
 270		self.complexity[ measure ] = value
 271
 272	#-------------------------------------------------------#
 273	#                   Get and Compute                     #
 274	#-------------------------------------------------------#
 275
 276	def get_complexities( 
 277		self, 
 278		with_block_convergence=False ) :
 279
 280		directly_calculable = [
 281			self.C_mu,
 282			self.h_mu,
 283			self.H_1,
 284			self.rho_mu
 285		]
 286
 287		requires_block_convergence = [
 288			self.E, 
 289			self.T_inf,
 290			self.S,
 291			self.chi
 292		]
 293			
 294		complexities = { m.__name__ : m() for m in directly_calculable }
 295
 296		if with_block_convergence :
 297
 298			complexities |= { m.__name__ : m() for m in requires_block_convergence }
 299
 300			for key in [ 'H_L', 'T_L', 'h_mu_L', 'H_sync' ] :
 301				if key in self.complexity :
 302					complexities[ key ] = self.complexity[ key ]
 303
 304		return complexities
 305
 306	#-------------------------------------------------------#
 307
 308	def get_metadata(self) :
 309		return {
 310			"name" : self.name,
 311			'complexity' : self.complexity,
 312			"description" : self.description
 313		}
 314
 315	def get_transition_matrix(self) :
 316
 317		if self.T  is not None :
 318			return self.T
 319
 320		n_states = len( self.states )
 321		T = np.zeros((n_states, n_states))
 322
 323		for tr in self.transitions :    
 324			T[ tr.origin_state_idx, tr.target_state_idx  ] = tr.prob
 325
 326		self.T = T
 327
 328		return self.T
 329
 330	#-------------------------------------------------------#
 331
 332	def get_T_X(self) :
 333
 334		if self.T_x  is not None :
 335			return self.T_x
 336
 337		n_states  = len( self.states )
 338		n_symbols = len( self.alphabet )
 339
 340		T_x = [ np.zeros((n_states, n_states)) for _ in range( n_symbols ) ]
 341
 342		for tr in self.transitions :
 343			T_x[ tr.symbol_idx ][tr.origin_state_idx, tr.target_state_idx] = tr.prob
 344
 345		self.T_x = T_x
 346		return self.T_x
 347
 348	#-------------------------------------------------------#
 349
 350	def get_msp_qw(
 351		self,
 352		exact_state_cap: int = 1000,
 353		verbose: bool = True,
 354	):
 355		if self.msp is not None:
 356			return self.msp
 357
 358		try : 
 359
 360			print( "\nTrying to Compute Mixed State Presentation using Exact Fractions\n" )
 361
 362			self.msp = compute_msp_exact(
 363				T_x=self.get_Tx_fractional(),
 364				pi=self.get_fractional_stationary_distribution(),
 365				n_states=len(self.states),
 366				alphabet=self.alphabet,
 367				exact_state_cap=1000,
 368				bool = True
 369			)
 370
 371			return self.msp 
 372
 373		except RuntimeError as e :
 374			warnings.warn( f"Exact msp failed: {e} Falling back to msp approximation." )
 375
 376		return self.get_msp()
 377
 378	def get_msp(
 379		self,
 380		exact_state_cap: int = 175_000,
 381		jsd_eps:         float = 1e-7,
 382		k_ann:           int   = 50,
 383		verbose                = True,
 384	) :
 385
 386		if self.msp is not None:
 387			return self.msp
 388	 
 389		T_x = self.get_T_X()
 390		pi  = self.get_stationary_distribution()
 391
 392		T_stacked      = np.stack(T_x)
 393		n_symbols      = len(self.alphabet)
 394		n_input_states = T_stacked.shape[1]
 395	
 396		print( "\nComputing Mixed State Presentation..." )
 397
 398		self.msp = compute_msp( 
 399			T_x=T_x,
 400			pi=pi,
 401			n_states=len(self.states),
 402			alphabet=self.alphabet,
 403			exact_state_cap=exact_state_cap,
 404			verbose=verbose
 405		)
 406
 407		return self.msp
 408
 409	def get_reverse_am(self) :
 410
 411		if self.reverse_am is not None:
 412			return self.reverse_am
 413
 414		pi = self.get_stationary_distribution()
 415		self.reverse_am = copy.deepcopy(self)
 416		
 417		new_transitions = []
 418		for tr in self.transitions:
 419			i = tr.target_state_idx
 420			j = tr.origin_state_idx
 421			
 422			p_reversed = (pi[j] * tr.prob) / pi[i]
 423			
 424			new_transitions.append(
 425				Transition(
 426					origin_state_idx=i,
 427					target_state_idx=j,
 428					prob=p_reversed,
 429					symbol_idx=tr.symbol_idx
 430				)
 431			)
 432
 433		self.reverse_am.set_transitions(new_transitions)
 434
 435		if self.reverse_am.is_epsilon_machine():
 436			return self.reverse_am
 437
 438		rmsp = self.reverse_am.get_msp_qw( exact_state_cap=len(self.states)*4 )
 439
 440		self.reverse_am.set_states( rmsp.states )
 441		self.reverse_am.set_transitions( rmsp.transitions )
 442		self.reverse_am.msp = rmsp
 443		self.reverse_am.start_state = 0
 444
 445		self.reverse_am.collapse_to_largest_strongly_connected_subgraph()
 446		self.reverse_am.minimize()
 447
 448		return self.reverse_am
 449
 450	#-------------------------------------------------------#
 451
 452	def get_Tx_fractional(self) -> list[ list[ list[ Fraction ] ] ] :
 453
 454		self.to_q_weighted()
 455
 456		n_states  = len( self.states )
 457		n_symbols = len( self.alphabet )
 458
 459		T_x = []
 460
 461		for x in range( n_symbols ) :
 462			T_x.append( [] )
 463			for i in range( n_states ) :
 464				T_x[ x ].append( [ 0 for _ in range( n_states ) ] )
 465
 466		for tr in self.transitions :
 467			T_x[ tr.symbol_idx ][ tr.origin_state_idx ][ tr.target_state_idx ] = tr.pq
 468
 469		return T_x
 470
 471	def get_T_sympy( self ) :
 472
 473		self.to_q_weighted()
 474
 475		n = len( self.states )
 476		T = sympy.zeros( n, n )
 477
 478		for tr in self.transitions :
 479			T[ tr.origin_state_idx, tr.target_state_idx ] = tr.pq
 480
 481		return T
 482
 483	def get_fractional_stationary_distribution(self) :
 484
 485		T = self.get_T_sympy()
 486
 487		if self.pi_fractional is not None :
 488			return self.pi_fractional
 489
 490		G = self.as_digraph()
 491
 492		if not nx.is_strongly_connected(G):
 493			raise ValueError( "Single stationary distribution requires strongly connected HMM." )
 494
 495		self.pi_fractional = solve_for_pi_fractional( T )
 496
 497		return self.pi_fractional
 498
 499	def get_stationary_distribution(self):
 500
 501		if self.pi is not None :
 502			return self.pi
 503
 504		G = self.as_digraph()
 505		
 506		if not nx.is_strongly_connected(G):
 507			raise ValueError( "Single stationary distribution requires strongly connected HMM." )
 508
 509		T = self.get_transition_matrix()
 510		return solve_for_pi( T )		
 511
 512	#-------------------------------------------------------#
 513	#                 Complexity Measures                   #
 514	#-------------------------------------------------------#
 515	
 516	def C_mu( self ) :
 517
 518		"""The *statistical complexity* (aka *forecasting complexity*) :
 519
 520		.. math::
 521
 522			C_{\\mu} = - \\sum_{\\sigma \\in \\mathcal{S}} \\Pr(\\sigma) \\log_2 \\Pr(\\sigma),
 523
 524		where :math:`\\mathcal{S}` are the machine's states [^crutchfield_exact_2016], p.2.
 525
 526		.. note::
 527
 528			**Interpretations**
 529
 530			* The amount of historical information a process stores.
 531			* The amount of structure in a process.
 532
 533		Returns:
 534
 535			float: :math:`C_{\\mu}`.
 536
 537		[^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral
 538			Decomposition of Intrinsic Computation*, 2016.
 539			<https://arxiv.org/abs/1309.3792>
 540		"""
 541
 542		m = self.get_complexity_measure_if_exists( "C_mu" )
 543
 544		if m is not None :
 545			return m
 546
 547		pi = self.get_stationary_distribution()
 548
 549		h = 0
 550		for i, pr in enumerate( pi ) :
 551			
 552			if pr < self.EPS :
 553				continue
 554
 555			h += -pr * np.log2( pr )
 556
 557		self.set_complexity_measure( "C_mu", h )
 558
 559		return h
 560
 561	#-------------------------------------------------------#
 562
 563	def h_mu( self ) :
 564
 565		"""The *entropy rate* :
 566
 567		.. math::
 568
 569			h_{\\mu}(\\boldsymbol{\\mathcal{S}}) = - \\sum_{\\sigma \\in \\mathcal{S}} \\Pr(\\sigma) \\sum_{x \\in \\mathcal{A}} \\Pr(x|\\sigma) \\log_2 \\Pr(x|\\sigma),
 570
 571		where :math:`\\mathcal{A}` is the alphabet and :math:`\\mathcal{S}` are the machine's states [^crutchfield_exact_2016], p.2.
 572
 573		.. note::
 574
 575			**Interpretations**
 576
 577			* The lower bound on achievable loss in bits. 
 578			* The irreducable randomness in the process.
 579			* The intrinsic Randomness in the process.
 580
 581		Returns:
 582			
 583			float: :math:`h_{\\mu}`.
 584
 585		[^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral
 586			Decomposition of Intrinsic Computation*, 2016.
 587			<https://arxiv.org/abs/1309.3792>
 588		"""
 589
 590		m = self.get_complexity_measure_if_exists( "h_mu" )
 591
 592		if m is not None :
 593			return m
 594
 595		T  = self.get_transition_matrix()
 596		pi = self.get_stationary_distribution()
 597
 598		n_states = pi.size
 599
 600		h = 0
 601		for i, pr in enumerate( pi ) :
 602
 603			if pr < self.EPS :
 604				continue
 605
 606			row_entropy = 0
 607			for j in range( len( pi ) ) :
 608
 609				if T[ i, j ]  < self.EPS :
 610					continue
 611
 612				row_entropy -= T[ i, j ] * np.log2( T[ i, j ] )
 613
 614			h += pr * row_entropy
 615
 616		self.set_complexity_measure( "h_mu", h )
 617
 618		return h
 619
 620	#-------------------------------------------------------#
 621
 622	def H_1(self) -> float :
 623
 624		"""The *single symbol uncertainty*:
 625
 626		.. math::
 627
 628			H(1)=-\\sum_{x\\in\\mathcal{A}} \\Pr(x) \\log_2{\\Pr(x)},
 629
 630		where :math:`\\mathcal{A}` is the alphabet [^James_2018], p.2.
 631
 632		.. note::
 633
 634			**Interpretations**
 635
 636			* How uncertain you are on average about a single measurement with no context.
 637
 638		Returns:
 639
 640			float: :math:`H(1)`.
 641
 642		[^James_2018]: James et al., Anatomy of a Bit: Information in a Time Series Observation, 2018.
 643			<https://arxiv.org/abs/1105.2988>
 644		"""
 645
 646		m = self.get_complexity_measure_if_exists("H_1")
 647		if m is not None:
 648			return m
 649
 650		pi  = self.get_stationary_distribution()
 651		T_X = self.get_T_X()  # dict: symbol -> matrix
 652
 653		h = 0.0
 654		for T_x in T_X:
 655			# Pr(x) = sum_i pi[i] * sum_j T^(x)[i,j]
 656			p_sym = 0.0
 657			for i, pr in enumerate(pi):
 658				if pr < self.EPS:
 659					continue
 660				p_sym += pr * T_x[i, :].sum()
 661
 662			if p_sym < self.EPS:
 663				continue
 664			h -= p_sym * np.log2(p_sym)
 665
 666		self.set_complexity_measure("H_1", h)
 667		return h
 668
 669	#-------------------------------------------------------#
 670
 671	def rho_mu(self) -> float :
 672		
 673		"""The *anticipated information* [^James_2018], p.3.:
 674
 675		.. math::
 676
 677			\\rho_{\\mu}= H(1) - h_{\\mu}
 678
 679		Returns:
 680			
 681			float: :math:`\\rho_{\\mu}`
 682
 683		[^James_2018]: James et al., Anatomy of a Bit: Information in a Time Series Observation, 2018.
 684			<https://arxiv.org/abs/1105.2988>
 685		"""
 686
 687		m = self.get_complexity_measure_if_exists("rho_mu")
 688		
 689		if m is not None:
 690			return m
 691
 692		rho = self.H_1() - self.h_mu()
 693		
 694		self.set_complexity_measure("rho_mu", rho)
 695		
 696		return rho
 697
 698	#-------------------------------------------------------#
 699
 700	def block_convergence( self )  :
 701
 702		"""
 703		"Run block entropy convergence and return a dict with $\\mathbf{E}, \\mathbf{S}, $\\mathbf{T},\\mathbf{T}(L), \\mathcal{H}(L),$ and $h_{\\mu}(L)$.
 704		"""
 705
 706		trs = [ [] for _ in range( len( self.states ) ) ]
 707		for tr in self.transitions :
 708			trs[ tr.origin_state_idx ].append( ( 
 709				tr.symbol_idx, 
 710				float( tr.prob ),
 711				tr.target_state_idx ) )
 712
 713		pi = self.get_stationary_distribution()
 714
 715		state_dist = [ float( pi[ i ] ) for i in range( len( self.states ) ) ]
 716		branches = [(1.0, list(state_dist))]
 717
 718		print( "\nComputing Block Entropy\n" )
 719
 720		C = am_fast.block_entropy_convergence(
 721			h_mu            = self.h_mu(),
 722			n_states        = len( self.states ),
 723			n_symbols       = len( self.alphabet ),
 724			convergence_tol = 1e-6,
 725			precision       = 10,
 726			eps             = 1e-25,
 727			branches        = branches,
 728			trans           = trs,
 729			max_branches    = 30_000_000
 730		)
 731
 732		print( "Done\n" )
 733
 734		self.set_complexity_measure( f"E",          C.E )
 735		self.set_complexity_measure( f"S",          C.S )
 736		self.set_complexity_measure( f"T_inf",      C.T )
 737		self.set_complexity_measure( f"T_L",        C.T_L.tolist() )
 738		self.set_complexity_measure( f"H_L",        C.H_L.tolist() )
 739		self.set_complexity_measure( f"h_mu_L",  C.h_mu_L.tolist() )
 740		self.set_complexity_measure( f"H_sync",  C.H_sync.tolist() )
 741
 742		return C
 743
 744	#-------------------------------------------------------#
 745
 746	def E( self ) -> float :
 747
 748		"""The *excess entropy* [^crutchfield_exact_2016], p.4:
 749
 750		.. math::
 751
 752			\\mathbf{E} \\equiv \\sum_{L=1}^{\\infty} I[X_{-\\infty:0}; X_{0:\\infty}]
 753		
 754		Computed via :meth:`get_msp` and :meth:`amachine.am_msp.MSP.get_E_S_T`, or :meth:`amachine.am_fast.block_entropy_convergence`
 755
 756		.. note::
 757
 758			**Interpretations**
 759
 760			* The information from the past that reduces uncertainty in the future [^crutchfield_exact_2016].
 761			* How much information an observer must extract to synchronize to the process.
 762			* Measures how long the process appears more complex than it asymptotically is.
 763			* Vanishes for immediately synchronizable processes.
 764
 765		Returns:
 766		
 767			float: :math:`\\mathbf{E}`
 768
 769		[^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral
 770			Decomposition of Intrinsic Computation*, 2016.
 771			<https://arxiv.org/abs/1309.3792>
 772		"""
 773
 774		m = self.get_complexity_measure_if_exists( "E" )
 775
 776		if m is not None :
 777			return m
 778
 779		try : 
 780			msp = self.get_msp()
 781			E, S, T = msp.get_E_S_T()
 782			self.set_complexity_measure( "E", E )
 783			self.set_complexity_measure( "S", S )
 784			self.set_complexity_measure( "T_inf", T )
 785			
 786		except Exception as e :
 787
 788			print( f"MSP failed {e}" )
 789
 790			C = self.block_convergence()	
 791			E = C.E
 792			self.set_complexity_measure( "E", E )
 793
 794		return E
 795
 796	#-------------------------------------------------------#
 797
 798	def S( self ) -> float :
 799
 800		"""The *synchronization* information:
 801
 802		.. math::
 803
 804			\\mathbf{S} \\equiv \\sum_{L=1}^{\\infty} \\mathcal{H}(L),
 805
 806		where :math:`\\mathcal{H}(L)` is the average state uncertainty having seen all length-L words [^crutchfield_exact_2016], p.4.
 807
 808		.. note::
 809
 810			**Interpretations**
 811
 812			* The total amount of state information that an observer must extract to become synchronized [^crutchfield_exact_2016].
 813
 814		Computed via :meth:`get_msp` and :meth:`amachine.am_msp.MSP.get_E_S_T`, or :meth:`amachine.am_fast.block_entropy_convergence`
 815
 816		Returns:
 817		
 818			float: :math:`\\mathbf{S}`
 819
 820		[^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral
 821			Decomposition of Intrinsic Computation*, 2016.
 822			<https://arxiv.org/abs/1309.3792>
 823		"""
 824
 825		m = self.get_complexity_measure_if_exists( "S" )
 826
 827		if m is not None :
 828			return m
 829
 830		try : 
 831			msp = self.get_msp()
 832			E, S, T = msp.get_E_S_T()
 833			self.set_complexity_measure( "E", E )
 834			self.set_complexity_measure( "S", S )
 835			self.set_complexity_measure( "T_inf", T )
 836
 837		except Exception as e :
 838			print( f"{e} \nFalling back to iterative estimation.")
 839			exit()
 840
 841			C = self.block_convergence()	
 842			S = C.S
 843			self.set_complexity_measure( "S", S )
 844
 845		return S
 846
 847	#-------------------------------------------------------#
 848
 849	def T_inf( self ) -> float :
 850
 851		"""The *transient information*[^crutchfield_exact_2016], p.4:
 852
 853		.. math::
 854
 855			\\mathbf{T} \\equiv \\sum_{L=1}^{\\infty} L \\left[ h_{\\mu}(L) - h_{\\mu} \\right]
 856
 857		Computed via :meth:`get_msp` and :meth:`amachine.am_msp.MSP.get_E_S_T`, or :meth:`amachine.am_fast.block_entropy_convergence`
 858
 859		.. note::
 860
 861			**Interpretations**
 862
 863			* The amount of information one must extract from observations so that the block entropy converges to its linear asymptote[^crutchfield_exact_2016].
 864
 865		Returns:
 866		
 867			float: :math:`\\mathbf{T}`
 868
 869		[^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral
 870			Decomposition of Intrinsic Computation*, 2016.
 871			<https://arxiv.org/abs/1309.3792>
 872		"""
 873
 874		m = self.get_complexity_measure_if_exists( "T_inf" )
 875
 876		if m is not None :
 877			return m
 878
 879		try : 
 880			msp = self.get_msp()
 881			E, S, T = msp.get_E_S_T()
 882			self.set_complexity_measure( "E", E )
 883			self.set_complexity_measure( "S", S )
 884			self.set_complexity_measure( "T_inf", T )
 885
 886		except Exception as e :
 887			print( f"{e} \nFalling back to iterative estimation.")
 888			C = self.block_convergence()	
 889			T_inf = C.T
 890			self.set_complexity_measure( "T_inf", T_inf )
 891
 892		return T_inf
 893
 894	#-------------------------------------------------------#
 895
 896	def chi( self ) -> float :
 897
 898		"""Computes the foward crypticity[^crutchfield_crypticity_2009][^Mahoney_crypticity_2021], p.2:
 899
 900		.. math::
 901
 902			\\chi = C_{\\mu} - \\mathbf{E}
 903
 904		:math:`C_{\\mu}` is trivially computed from the stationary distribution in :meth:`C_mu` and :math:`\\mathbf{E}` in :meth:`E`.
 905
 906		.. note::
 907
 908			**Interpretations**
 909
 910			* Difference between internal stored information and apparent information to an observer.
 911			* How muching information is hiding in the system.
 912
 913		Returns:
 914		
 915			float: :math:`\\chi`
 916
 917		[^crutchfield_crypticity_2009]: Crutchfield et al., Time’s barbed arrow: Irreversibility, crypticity, and stored information, 2009.
 918			<https://arxiv.org/abs/0902.1209>
 919
 920		[^Mahoney_crypticity_2021]: Mahoney et al., Information Accessibility and Cryptic Processes, 2021.
 921			<https://arxiv.org/abs/0905.4787>
 922		"""
 923
 924		m = self.get_complexity_measure_if_exists( "chi" )
 925
 926		if m is not None :
 927			return m
 928
 929		chi = self.C_mu() - self.E()
 930
 931		if chi < 0 :
 932			
 933			# if chi is 0, accumulated floating point error can result in small negative values
 934			if chi < -1e-5:
 935				warnings.warn(f"Crypticity is negative ({chi:.6e}).")
 936			
 937			chi = np.clamp( chi, 0 )
 938
 939		self.set_complexity_measure( "chi", chi )
 940
 941		return chi
 942
 943	#-------------------------------------------------------#
 944	#                      Properties                       #
 945	#-------------------------------------------------------#
 946
 947	def is_row_stochastic(self) :
 948
 949		"""
 950		Check that all states have outgoing transition probabilities that sum to 1.
 951		"""
 952
 953		sums = np.zeros( len( self.states ) )
 954		for tr in self.transitions :
 955			sums[ tr.origin_state_idx ] += tr.prob
 956		return np.allclose( sums, 1.0 )
 957
 958	#-------------------------------------------------------#
 959
 960	def is_unifilar(self) :
 961
 962		"""
 963		Check that no state emits the same symbol on transitions to different states. 
 964		"""
 965
 966		symbol_trs = np.full( ( len( self.states ), len( self.alphabet) ), -1 )
 967
 968		for tr in self.transitions : 
 969		
 970			if symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] == -1 :
 971				symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] = tr.target_state_idx
 972		
 973			elif symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] != tr.target_state_idx :
 974				return False
 975		
 976		return True
 977
 978	#-------------------------------------------------------#
 979
 980	def is_strongly_connected(self) :
 981
 982		"""
 983		Check if every state is reachable from every other state. Relies on [nx.is_strongly_connected](https://networkx.org/documentation/latest/reference/algorithms/generated/networkx.algorithms.components.is_strongly_connected.html).
 984		"""
 985
 986		return nx.is_strongly_connected( self.as_digraph() )
 987
 988	#-------------------------------------------------------#
 989
 990	def is_aperiodic(self) :
 991
 992		"""
 993		Checks if machine is periodic. Relies on [nx.is_aperiodic](https://networkx.org/documentation/latest/reference/algorithms/generated/networkx.algorithms.dag.is_aperiodic.html), "A strongly connected directed graph is aperiodic if there is no integer k > 1 that divides the length of every cycle in the graph."
 994		"""
 995
 996		return nx.is_aperiodic( self.as_digraph() )
 997
 998	#-------------------------------------------------------#
 999
1000	def _is_minimal_as_dfa( self, topological_only : bool, verbose=True ) :
1001
1002		with_probs = not topological_only
1003
1004		# Construct the DFA
1005		dfa = self.as_dfa( with_probs=with_probs )
1006
1007		# Minimize the DFA
1008		#dfa = dfa.minify(retain_names=True)
1009		dfa = am_fast.minify_cpp( dfa, retain_names=True )
1010
1011		# check we have minimal number of states
1012		if len( dfa.states ) != len( self.states ) :
1013			if verbose : 
1014				print( f"Not minimal reduces from {len( self.states )} to {len( dfa.states )} states" )
1015			return False
1016
1017		return True
1018
1019	def is_topological_epsilon_machine( self, verbose=True ) :
1020
1021		"""
1022		Checks if the HMM is a topological $\\epsilon$-machine [^1].
1023
1024		[^1]: Johnson et al, Enumerating Finitary Processes, 2024.
1025			<https://arxiv.org/abs/1011.0036>
1026		"""
1027
1028		if not ( self.is_unifilar() and self.is_strongly_connected() ) :
1029			if verbose : 
1030				print( f"Either non unifilar or not strongly connected" )
1031			return False
1032		else :
1033			return self._is_minimal_as_dfa( topological_only=True, verbose=verbose )
1034
1035	def is_epsilon_machine( self, verbose=True ) :
1036
1037		if not ( self.is_unifilar() and self.is_strongly_connected() ) :
1038			if verbose : 
1039				print( f"Either non unifilar or not strongly connected" )
1040			return False
1041		else :
1042			return self._is_minimal_as_dfa( topological_only=False, verbose=verbose )
1043
1044	#-------------------------------------------------------#
1045	#                      Properties                       #
1046	#-------------------------------------------------------#
1047
1048	def minimize(self, retain_names: bool = True, verbose=False):
1049
1050		"""
1051		Minimizes the HMM, resulting in an :math:`\\epsilon-`machine if the HMM
1052		is unifilar and strongly connected. Converts the HMM to a DFA with symbols
1053		labeled jointly with symbols and probabilities, and uses Myhill-Nerode 
1054		equivalence for minimization. Relies on `automata_lib` and uses
1055		 `automata.fa.dfa.DFA.minify` with `allow_partial=True`, and all states
1056		 final.
1057
1058		Args:
1059			retain_names (bool): If `True`, the merged states will be named by their union, e.g. `{s_0, s_1}`, and other states will retain their origion names. Otherwise, they will be relabled `{ '0', '1', ..., 'n-1' }`.
1060
1061		Returns:
1062		
1063			automata.fa.dfa.DFA : the resulting DFA.
1064		"""
1065
1066		if self.is_minimal :
1067			return
1068
1069		start = time.perf_counter()
1070
1071		if not self.is_unifilar():
1072			raise ValueError(
1073				"DFA minimization is not valid for non-unifilar HMMs"
1074			)
1075
1076		was_strongly_connected = self.is_strongly_connected()
1077
1078		was_row_stochastic = self.is_row_stochastic()
1079		n_states_before = len(self.states)
1080
1081		dfa = self.as_dfa(with_probs=True)
1082
1083		#min_dfa = self.as_dfa(with_probs=True).minify(retain_names=True)
1084		min_dfa = am_fast.minify_cpp( dfa, retain_names=True )
1085
1086		# Build lookup from original state index -> CausalState object
1087		orig_state   = {i: s for i, s in enumerate(self.states)}
1088		eq_list      = list(min_dfa.states)
1089
1090		start_eq = min_dfa.initial_state
1091		
1092		# Separate the start state, then sort the rest by the 
1093		# smallest original state index inside each equivalence class.
1094		other_eqs = [eq for eq in eq_list if eq != start_eq]
1095		other_eqs.sort(key=lambda eq: min(eq))
1096
1097		# Recombine so start eq comes first, followed by the sorted remaining classes
1098		eq_list = [start_eq] + other_eqs
1099		# ----------------------------------------------------------
1100
1101		# Recompute eq_to_idx with the new ordering
1102		eq_to_idx = {eq: i for i, eq in enumerate(eq_list)}
1103
1104		# new_start is now guaranteed to be 0
1105		new_start = 0
1106
1107		# Map each original state index -> its equivalence class
1108		# Guard: minify() silently drops unreachable states
1109		orig_to_eq = {s: eq for eq in min_dfa.states for s in eq}
1110
1111		# Build lookup from original state index -> its transitions
1112		orig_trs = defaultdict(list)
1113		for t in self.transitions:
1114			orig_trs[t.origin_state_idx].append(t)
1115
1116		new_trs = []
1117		for eq in min_dfa.states:
1118			rep        = next(iter(eq))
1119			origin_idx = eq_to_idx[eq]
1120			for t in orig_trs[rep]:
1121				target_eq  = orig_to_eq[t.target_state_idx]
1122				target_idx = eq_to_idx[target_eq]
1123				new_trs.append(Transition(
1124					origin_state_idx = origin_idx,
1125					target_state_idx = target_idx,
1126					prob             = t.prob,
1127					symbol_idx       = t.symbol_idx,
1128				))
1129
1130		members_list = [[orig_state[i] for i in sorted(eq)] for eq in eq_list]  # sorted for determinism
1131
1132		# Compute new names
1133		if retain_names:
1134			new_names = [
1135				"{" + ",".join(str(m.name) for m in members) + "}" if len(members) > 1
1136				else members[0].name
1137				for members in members_list
1138			]
1139		else:
1140			new_names = [str(j) for j in range(len(eq_list))]
1141
1142		old_name_to_new_name = {
1143			m.name: new_names[j]
1144			for j, members in enumerate(members_list)
1145			for m in members
1146		}
1147
1148		# Build the new states, preserving classes and isomorphs regardless of naming
1149		new_states = []
1150		for j, (eq, members, name) in enumerate(zip(eq_list, members_list, new_names)):
1151			classes   = set().union(*(m.classes for m in members))
1152			isomorphs = {
1153				old_name_to_new_name.get(iso, iso)
1154				for m in members
1155				for iso in m.isomorphs
1156				if old_name_to_new_name.get(iso, iso) != name
1157			}
1158			new_states.append(CausalState(
1159				name      = name,
1160				classes   = classes,
1161				isomorphs = isomorphs,
1162			))
1163
1164		self.set_states(new_states)
1165		self.set_transitions(new_trs)
1166		self.start_state = new_start
1167
1168		if n_states_before == len(new_states) and verbose :
1169			print( f"{n_states_before} state HMM was already minimal.\n" )
1170		elif verbose :
1171			print( f"Minimized from {n_states_before} to {len(new_states)}\n" )
1172
1173		if not ( was_strongly_connected ==  self.is_strongly_connected() ) :
1174			raise RuntimeError(
1175				f"Minimization broke strongly connected"
1176			)
1177
1178		if not ( was_row_stochastic ==  self.is_row_stochastic() ) :
1179			raise RuntimeError(
1180				f"Minimization broke row stochasticity"
1181			)
1182
1183		self.is_minimal = True
1184
1185	#-------------------------------------------------------#
1186	#                      Modifiers                        #
1187	#-------------------------------------------------------#
1188
1189
1190	def collapse_to_largest_strongly_connected_subgraph( self, rename_states=True ) :
1191
1192		# get equivalent networkx graph
1193		G = self.as_digraph()
1194
1195		# if already strongly connected, nothing to do
1196		if not nx.is_strongly_connected( G ) :
1197
1198			start = time.perf_counter()
1199			subgraph_nodes = list( nx.strongly_connected_components( G ) )
1200
1201			# decompose into strongly connected components and sort by length
1202			# subgraph_nodes = list(nx.strongly_connected_components( G ))
1203			subgraph_nodes.sort(key=len)
1204			component_state_set = subgraph_nodes[-1]
1205
1206			# Take the largest strongly connected component (as list of state names)
1207			component_states = sorted( list( component_state_set ) )
1208
1209			# make temporary copies of the old transitions and states
1210			old_transitions = [
1211				Transition(
1212					origin_state_idx=tr.origin_state_idx,
1213					target_state_idx=tr.target_state_idx,
1214					prob=tr.prob,
1215					symbol_idx=tr.symbol_idx
1216				)
1217
1218				for tr in self.transitions
1219			]
1220
1221			old_states = [
1222				CausalState(
1223					name=s.name,
1224					classes=s.classes,
1225					isomorphs=s.isomorphs
1226				)
1227				for s in self.states
1228			]
1229
1230			self.set_states(
1231				states=[ 
1232					state
1233					for i, state in enumerate( old_states ) if i in component_state_set
1234				]
1235			)
1236
1237			# we will build new transition list based on those belonging to the component
1238			self.set_transitions( transitions= [] )
1239
1240			# for tracking which new transitions leave each state
1241			transitions_from_state = { state : set() for state in component_states }
1242			new_transitions = []
1243
1244			for tr in old_transitions :
1245
1246				origin_state_name = old_states[ tr.origin_state_idx ].name
1247				target_state_name = old_states[ tr.target_state_idx ].name
1248
1249				# skip transitions that connect separate strongly connected components
1250				if not ( tr.origin_state_idx in component_state_set and tr.target_state_idx in component_state_set ) :
1251					continue
1252
1253				# track transitions (by index in new transitions list) that leave this state
1254				transitions_from_state[ tr.origin_state_idx ].add( len( new_transitions ) )
1255
1256				my_origin_state_idx = self.state_idx_map[ origin_state_name ]
1257				my_target_state_idx = self.state_idx_map[ target_state_name ]
1258
1259				new_transitions.append( 
1260					Transition(
1261						origin_state_idx=my_origin_state_idx,
1262						target_state_idx=my_target_state_idx,
1263						prob=tr.prob,
1264						symbol_idx=tr.symbol_idx
1265					)  )
1266
1267			self.set_transitions( transitions=new_transitions )
1268
1269			# if we removed an outgoing transition from a state, we need to distribute its probability 
1270			# among the remaining outgoing transitions from the state
1271			for state in component_states :
1272				
1273				# get the set of transitions leaving this state
1274				state_trs = transitions_from_state[ state ]
1275
1276				# sum the probabilities of the outgoing transitions from the state
1277				p_sum = np.sum( [ self.transitions[ i ].prob for i in state_trs ] )
1278
1279				# how much probability is missing
1280				diff = 1.0 - p_sum
1281
1282				# if significant difference
1283				if abs( diff ) > self.EPS : 
1284
1285					# calculate how much of the difference each transition gets
1286					adjustment = diff / len( state_trs )
1287					
1288					# update the transitions
1289					for i in state_trs :
1290
1291						# transition with origional probability
1292						tr = self.transitions[ i ]
1293
1294						# adjusted probability
1295						self.transitions[ i ] = Transition(
1296							origin_state_idx=tr.origin_state_idx, 
1297							target_state_idx=tr.target_state_idx, 
1298							prob=tr.prob + adjustment, 
1299							symbol_idx=tr.symbol_idx )
1300
1301			if rename_states :
1302				self.set_states( [
1303					CausalState( name=f"{i}" )
1304					for i, s in enumerate( self.states )
1305				] )
1306
1307
1308	def to_q_weighted( self, denominator_limit=1000 ) :
1309
1310		"""
1311		Approximates the existing transition probabilities with exact fractions, stores 
1312		the fractional probabilities as Fraction in Transition.pq, and sets the floating
1313		point probabilty to `float(pq)`. If `denominator_limit` is too small for a sane 
1314		conversion, the function recurses with `denominator_limit=denominator_limit*10`.
1315
1316		Args:
1317			denominator_limit (int): The initial input to :meth:`Fraction.limit_denominator` in 
1318			the conversion.
1319		"""
1320
1321		if self.is_q_weighted :
1322			return
1323
1324		if not self.is_row_stochastic() :
1325			raise ValueError( "Cannot convert to q-weighted because not row stochastic" )
1326
1327		t_from = [[] for _ in range(len(self.states))]
1328
1329		for i, tr in enumerate( self.transitions ) :
1330			t_from[ tr.origin_state_idx ].append( i )
1331
1332		new_transitions = []
1333		for t_list in t_from :
1334			
1335			if not t_list : 
1336				continue
1337
1338			p_q_sum = Fraction(0,1)
1339			p_qs = []
1340
1341			for t_idx in t_list :
1342			
1343				p_q = Fraction( self.transitions[ t_idx ].prob ).limit_denominator( denominator_limit )
1344				p_q_sum += p_q
1345				p_qs.append( p_q )
1346
1347			if p_q_sum != Fraction(1,1) :
1348				
1349				max_pq_i = np.argmax( p_qs )
1350				max_oq = p_qs[ max_pq_i ]
1351
1352				diff = p_q_sum - Fraction(1,1)
1353
1354				# If recurse with higher resolution
1355				if diff > max_oq :
1356					return self.to_q_weighted( denominator_limit*10 )
1357				else :
1358					p_qs[ max_pq_i ] -= diff
1359
1360			for i, t_idx in enumerate( t_list ) :	
1361				new_transitions.append( 
1362					Transition( 
1363						origin_state_idx=self.transitions[  t_idx ].origin_state_idx,
1364						target_state_idx=self.transitions[  t_idx ].target_state_idx,
1365						prob=float(p_qs[ i ]),
1366						symbol_idx=self.transitions[  t_idx ].symbol_idx,
1367						pq=p_qs[ i ]
1368					)
1369				)
1370
1371		self.set_transitions( new_transitions )
1372		self.is_q_weighted = True
1373
1374	#-------------------------------------------------------#
1375	#                    Data Generation                    #
1376	#-------------------------------------------------------#
1377
1378	def isomorphic_shift(
1379		self,
1380		input_symbol_indices: np.ndarray,
1381		input_state_indices:  np.ndarray,
1382		shift : int = 1
1383	) -> dict[str, np.ndarray]:
1384
1385		"""
1386		Generates a new sequence of of symbols that are permuted with the symbols emitted by
1387		isomorphic states, if they exists.
1388
1389		:math:`\\sigma_o = \\mathcal{S}\\left[\\texttt{input\\_state\\_indices}[i]\\right]`<br>
1390		:math:`\\sigma_t = \\mathcal{S}\\left[\\texttt{input\\_state\\_indices}[i+1]\\right]`
1391 
1392		:math:`\\mathcal{I}(\\sigma_o) = \\\\{\\sigma^0_o,\\, \\sigma^1_o,\\, \\dots,\\, \\sigma^{n-1}_o \\\\}`<br>
1393		:math:`\\mathcal{I}(\\sigma_t) = \\\\{\\sigma^0_t,\\, \\sigma^1_t,\\, \\dots,\\, \\sigma^{n-1}_t \\\\}`
1394 
1395		:math:`k = \\bigl(\\mathcal{I}(\\sigma_o).\\texttt{index}(\\sigma_o) + \\texttt{shift}\\bigr) \\bmod n`
1396 
1397		:math:`\\texttt{output\\_symbol\\_indices}[i]   := T(\\sigma_o^k,\\, \\sigma_t^k).\\text{symbol\\_index}`<br>
1398		:math:`\\texttt{output\\_state\\_indices}[i]    := \\mathcal{S}.\\texttt{index}(\\sigma_o^k)`<br>
1399		:math:`\\texttt{output\\_state\\_indices}[i+1]  := \\mathcal{S}.\\texttt{index}(\\sigma_t^k)`
1400
1401		Where :math:`\\mathcal{I}(\\sigma)`` is the ordered set of states isomorphic to :math:`\\sigma` including :math:`\\sigma` itself.
1402
1403		Args:
1404			input_symbol_indices (np.ndarray): The sequence of generated symbols.
1405			input_state_indices (np.ndarray): The sequence of states that generated symbols with the final state at the end.
1406			shift : int: How much to shift the symbols across the isomorphic states.
1407		"""
1408
1409		if not any( state.isomorphs for state in self.states ):
1410			raise ValueError("HMM has no states with isomorphs")
1411
1412		inputs = np.asarray(input_symbol_indices)
1413		states = np.asarray(input_state_indices)
1414
1415		n_states = len(self.states)
1416
1417		tr_sym_table = np.full((n_states, n_states), -1, dtype=np.int32)
1418
1419		for tr in self.transitions:
1420			tr_sym_table[ tr.origin_state_idx, tr.target_state_idx ] = tr.symbol_idx
1421
1422		# Build isomorph remapping: identity by default, overridden where isomorphs exist
1423		iso_table = np.arange(n_states, dtype=np.int32)
1424
1425		for i, state in enumerate(self.states):
1426			if state.isomorphs:
1427				isomorph       = sorted(state.isomorphs)[0]
1428				iso_table[ i ] = self.state_idx_map[isomorph]
1429
1430		for i, state in enumerate(self.states):
1431			if state.isomorphs:
1432
1433				# extend the isomorph list to include the identity
1434				isormorphs_with_identity = sorted( [ i ] + [ self.state_idx_map[iso] for iso in state.isomorphs ] )
1435
1436				# find the identity index
1437				pos = isormorphs_with_identity.index( i )
1438
1439				# cyclical shift 
1440				iso_table[ i ] = isormorphs_with_identity[ ( pos + shift ) % len( isormorphs_with_identity ) ]
1441
1442		origins = states[:-1]
1443		targets = states[1:]
1444
1445		out_origins = iso_table[origins]
1446		out_targets = iso_table[targets]
1447
1448		inv_sym = tr_sym_table[ out_origins, out_targets ]
1449		inv_sts = np.empty(states.size, dtype=states.dtype)
1450
1451		inv_sts[:-1] = out_origins
1452		inv_sts[-1]  = out_targets[-1]
1453
1454		return {
1455			"symbol_index": inv_sym.astype(inputs.dtype),
1456			"state_index":  inv_sts,
1457		}
1458
1459	def generate_data(
1460		self,
1461		file_prefix: str,
1462		n_gen: int,
1463		include_states: bool,
1464		isomorphic_shifts : set[int]=None,
1465		random_seed : int=42 ) -> dict[any] : 
1466
1467		trs = [ [] for _ in range( len( self.states ) ) ]
1468		for tr in self.transitions :
1469			trs[ tr.origin_state_idx ].append( ( 
1470				tr.symbol_idx, 
1471				float( tr.prob ),
1472				tr.target_state_idx ) )
1473
1474		data = am_fast.generate_data(
1475			n_gen=n_gen,
1476			start_state=self.start_state,
1477			transitions=trs,
1478			alphabet=sorted(list(self.alphabet)),
1479			include_states=include_states,
1480			random_seed=random_seed
1481		)
1482
1483		if isomorphic_shifts is not None :
1484
1485			if not include_states :
1486				raise ValueError( "Isomorphic inversion requires include_states=True" )
1487
1488			data[ "isomorphic_shifts" ] = {}
1489
1490			for shift in isomorphic_shifts :
1491
1492				try : 
1493
1494					shifted = self.isomorphic_shift(
1495						input_symbol_indices=data[ "symbol_index" ], 
1496						input_state_indices=data[ "state_index" ], 
1497						shift=shift
1498					)
1499
1500					data[ "isomorphic_shifts" ][ shift ] = {
1501						"symbol_index" : shifted[ "symbol_index" ],
1502						"state_index"  : shifted[ "state_index" ]
1503					}
1504
1505				except Exception as e :
1506					print( f"Exception {e}" )
1507
1508		am_fast.save_data(
1509			data=data,
1510			file_prefix=file_prefix,
1511			alphabet=sorted(list(self.alphabet)),
1512			n_states=len( self.states ),
1513			start_state=self.start_state,
1514			random_seed=random_seed,
1515			machine_metadata=self.get_metadata() )
1516
1517		return data
1518
1519	#-------------------------------------------------------#
1520	#                 Basic Visualization                   #
1521	#-------------------------------------------------------#
1522
1523	def draw_graph(
1524		self,
1525		engine     : str  = 'dot',
1526		output_dir : Path = Path('.'),
1527		show       : bool = True
1528	) -> None :
1529
1530		"""
1531		Draws the machine using [pygraphiviz](https://pygraphviz.github.io/documentation/stable/) and saves it.
1532
1533		Returns:
1534		
1535			networkx.DiGraph : the resulting graph.
1536		"""
1537
1538		G = self.as_digraph()
1539
1540		subgraphs = None if nx.is_strongly_connected( G ) else list( nx.strongly_connected_components( G ) )
1541		
1542		am_vis.draw_graph( 
1543			self, 
1544			output_dir=output_dir, 
1545			title="am_graph", 
1546			view=show, 
1547			subgraphs=subgraphs, 
1548			engine=engine )
1549
1550	#-------------------------------------------------------#
1551	#                   Alternate Forms                     #
1552	#-------------------------------------------------------#
1553
1554
1555	def as_digraph( self ) -> nx.DiGraph :
1556
1557		"""
1558		Builds a [networkx.DiGraph](https://networkx.org/documentation/stable/reference/classes/digraph.html) constructed from the machine's transitions with no edge symbols or weights.
1559
1560		Returns:
1561		
1562			networkx.DiGraph : the resulting graph.
1563		"""
1564
1565		G = nx.DiGraph()
1566		G.add_nodes_from( [ i for i, s in enumerate( self.states ) ] )
1567
1568		for tr in self.transitions :
1569			G.add_edge( tr.origin_state_idx, tr.target_state_idx )
1570
1571		return G
1572
1573	def as_dfa( self, with_probs : bool ) :
1574
1575		"""
1576		Builds an [automata.fa.dfa.DFA](https://caleb531.github.io/automata/api/fa/class-dfa/) constructed from the machine's transitions.
1577
1578		Args:
1579			with_probs (bool): If true the DFA transitions are labeled based on
1580				the symbol of the machines transition concatenated with its
1581				probability, othwise, the only the symbols.
1582
1583		Returns:
1584		
1585			automata.fa.dfa.DFA : the resulting DFA.
1586		"""
1587
1588		precision=8
1589
1590		def edge_label( symb, prob ) :
1591			return f"({symb},{round(prob, precision)})"
1592
1593		# Build states, symbols, and transitions 
1594		dfa_states  = { i for i, _ in enumerate( self.states ) }
1595			
1596		if not with_probs :
1597			dfa_symbols = set( { t.symbol_idx for t in self.transitions } )
1598		else : 
1599			dfa_symbols = set( { edge_label( t.symbol_idx, t.prob ) for t in self.transitions } )
1600
1601		dfa_transitions = defaultdict(dict)
1602
1603		if not with_probs :
1604			for t in self.transitions :
1605				dfa_transitions[ t.origin_state_idx ][ t.symbol_idx ] = t.target_state_idx
1606		else :
1607			for t in self.transitions :
1608				dfa_transitions[ t.origin_state_idx ][ edge_label( t.symbol_idx, t.prob ) ] = t.target_state_idx
1609
1610		# Construct the DFA
1611		return DFA(
1612			states=dfa_states,
1613			input_symbols=dfa_symbols,
1614			transitions=dfa_transitions,
1615			initial_state=self.start_state,
1616			allow_partial=True,
1617			final_states={ 
1618				s for s in dfa_states
1619			}
1620		)
class HMM(amachine.am_machine.Machine):
  38class HMM(Machine) :
  39
  40	"""Hidden Markov model implementing epsilon machines, mixed state presentations,
  41	complexity measures, and data generation.
  42
  43	Args:
  44		states (list[CausalState] | None): A list of causal states.
  45		transitions (list[Transition] | None): A list of transitions between states.
  46		start_state (int): Index of the start state.
  47		alphabet (list[str]): List of symbols making up the alphabet.
  48		name (str): Name of the model.
  49		description (str): Description of the model.
  50	"""
  51
  52	def __init__( 
  53		self,
  54		states : list[CausalState] | None = None,
  55		transitions : list[Transition] | None = None,
  56		start_state : int = 0,
  57		alphabet : list[str] | None = None,
  58		name : str = "",
  59		description : str = "" ) : 
  60
  61		self.alphabet    : list[str] | None = alphabet or []
  62		self.states      : list[CausalState] | None= states or []
  63		self.transitions : list[Transition] | None= transitions or []
  64
  65		self.set_alphabet( self.alphabet )
  66		self.set_states( self.states )
  67		self.set_transitions( self.transitions )
  68
  69		self.start_state : int = start_state
  70
  71		self.name : str = name
  72		self.description : str = description
  73
  74		# To be depreciated
  75		self.isoclass : str = None
  76
  77		# --- derived --------
  78
  79		self.complexity : dict[str, any] = {}
  80
  81		self.symbol_idx_map : dict[str,int] = {}
  82		self.state_idx_map  : dict[str,int] = {}
  83
  84		self.pi_fractional = None
  85		self.pi : np.ndarray | None = None
  86		self.T  : np.ndarray | None = None
  87		self.msp : MSP | None = None
  88		self.reverse_am : HMM | None = None
  89		self.is_q_weighted : bool = False
  90		self.is_minimal : bool = False
  91
  92		# --- const ----------
  93
  94		self.EPS : float = 1e-12
  95
  96	#-------------------------------------------------------#
  97	#                Overrides / Getters                    #
  98	#-------------------------------------------------------#
  99
 100	@override
 101	def get_states(self) -> list[CausalState] : 
 102		return self.states
 103
 104	@override
 105	def get_transitions(self) -> list[Transition] : 
 106		return self.transitions
 107
 108	@override
 109	def get_alphabet(self) -> list[Symbol]:
 110		return [ Symbol(a) for a in self.alphabet ]
 111
 112	#-------------------------------------------------------#
 113	#                     Serialization                     #
 114	#-------------------------------------------------------#
 115
 116
 117	def to_dict(self) -> dict[str,any]:
 118
 119		"""Create a dict representing the HMM configuration.
 120
 121		Returns:
 122
 123			dict[str,any]: Dictionary containing, name, description, states, transitions, alphabet, and isoclass.
 124		"""
 125
 126		return {
 127			"name"            : self.name,
 128			"description"     : self.description,
 129			"states"          : [ asdict(state)      for state      in self.states      ],
 130			"transitions"     : [ asdict(transition) for transition in self.transitions ],
 131			"alphabet"        : self.alphabet,
 132			"isoclass"        : self.isoclass
 133		}
 134
 135	def from_dict( self, config : dict[str,any] ) :
 136
 137		"""Configure the HMM configuration from a dictionary.
 138
 139		Args:
 140			config (dict[str,any]): The HMM configuration.
 141		"""
 142
 143		self.name            = config.name 
 144		self.description     = config.description 
 145		
 146		self.set_states( states=[ 
 147			CausalState( 
 148				name=state[ "name" ],
 149				classes=set( state[ "classes" ] )
 150			)
 151			for state in config.states
 152		] )
 153
 154		self.set_transitions( transitions=[ 
 155			Transition( 
 156				origin_state_idx=tr[ "origin_state_idx" ],
 157				target_state_idx=tr[ "target_state_idx" ],
 158				prob=tr[ "prob" ],
 159				symbol_idx=tr[ "symbol_idx" ]
 160			)
 161			for tr in config.transitions
 162		] )
 163
 164		self.set_alphabet( alphabet=config.alphabet ) 
 165
 166	def save_config(
 167		self, 
 168		path : Path, 
 169		with_complexity : bool = False, 
 170		with_block_convergence :  bool = False,
 171		with_structural_properties : bool = False,
 172		with_causal_properties : bool = False ) :
 173
 174		config = self.to_dict()
 175
 176		if with_complexity :
 177			
 178			complexity = self.get_complexities( 
 179				with_block_convergence=with_block_convergence
 180			)
 181
 182			config[ "complexity" ] = complexity
 183
 184		config[ "structural_properties" ] = {
 185			"unifilar"              : self.is_unifilar(),
 186			"row_stochastic"        : self.is_row_stochastic(),
 187			"strongly_connected"    : self.is_strongly_connected(),
 188			"aperiodic"             : self.is_aperiodic(),
 189			"minimal"               : self._is_minimal_as_dfa( topological_only=False ),
 190			"topologically_minimal" : self._is_minimal_as_dfa( topological_only=True ),
 191			"is_epsilon_machine"    : self.is_epsilon_machine()
 192		}
 193
 194		with open( path / "am_config.json", "w", encoding="utf-8" ) as f :
 195			json.dump( config, f, ensure_ascii=False, indent=2, default=list )
 196
 197	def from_file( self, path : Path ) :
 198		with open( Path / "am_config.json", "r" ) as f:
 199			config = json.load(f)
 200		self.from_dict()
 201
 202	#-------------------------------------------------------#
 203	#             Setters and State Management              #
 204	#-------------------------------------------------------#
 205	
 206	def _invalidate(self) :
 207
 208		"""
 209		Reset all derived properties so they will be recomputed when requested later.
 210		"""
 211
 212		self.complexity = {}
 213		self.T = None
 214		self.T_x = None
 215		self.pi = None
 216		self.pi_fractional = None
 217		self.msp = None 
 218		self.reverse_am = None
 219		self.is_q_weighted = False
 220		self.is_minimal = False
 221
 222	def set_states( self, states : list[CausalState] ) :
 223		self._invalidate()
 224		self.states = states.copy()
 225		self.state_idx_map = {}
 226		for idx, state in enumerate( self.states ) :
 227			self.state_idx_map[ state.name ] = idx
 228
 229	def set_alphabet( self, alphabet : list[str] ) :
 230
 231		self._invalidate()
 232
 233		old_alphabet = self.alphabet.copy()
 234
 235		aSet = set()
 236		aSet.update( alphabet )
 237
 238		self.alphabet = sorted(list(aSet))
 239
 240		self.symbol_idx_map = {}
 241		for idx, symbol in enumerate( self.alphabet ) :
 242			self.symbol_idx_map[ symbol ] = idx
 243
 244		for i, tr in enumerate( self.transitions ) :
 245			symbol = old_alphabet[ tr.symbol_idx ]
 246			self.transitions[ i ] = Transition(
 247				origin_state_idx=tr.origin_state_idx,
 248				target_state_idx=tr.target_state_idx,
 249				prob=tr.prob,
 250				symbol_idx=self.symbol_idx_map[ symbol ]
 251			)
 252
 253	def set_transitions( self, transitions : list[Transition] ) :
 254		self._invalidate()
 255		self.transitions = transitions.copy()
 256
 257	def extend_states( self, states : list[CausalState] ) :
 258		self.set_states( self.states + states )
 259
 260	def extend_alphabet( self, alphabet : list[str] ) :
 261		self.set_alphabet( self.alphabet + alphabet  )
 262
 263	def extend_transitions( self, transitions : list[Transition] ) :
 264		self.set_transitions( self.transitions + transitions  )
 265
 266	def get_complexity_measure_if_exists(self, measure ) :
 267		m = self.complexity.get( measure, None )
 268		return m
 269
 270	def set_complexity_measure(self, measure, value ) :
 271		self.complexity[ measure ] = value
 272
 273	#-------------------------------------------------------#
 274	#                   Get and Compute                     #
 275	#-------------------------------------------------------#
 276
 277	def get_complexities( 
 278		self, 
 279		with_block_convergence=False ) :
 280
 281		directly_calculable = [
 282			self.C_mu,
 283			self.h_mu,
 284			self.H_1,
 285			self.rho_mu
 286		]
 287
 288		requires_block_convergence = [
 289			self.E, 
 290			self.T_inf,
 291			self.S,
 292			self.chi
 293		]
 294			
 295		complexities = { m.__name__ : m() for m in directly_calculable }
 296
 297		if with_block_convergence :
 298
 299			complexities |= { m.__name__ : m() for m in requires_block_convergence }
 300
 301			for key in [ 'H_L', 'T_L', 'h_mu_L', 'H_sync' ] :
 302				if key in self.complexity :
 303					complexities[ key ] = self.complexity[ key ]
 304
 305		return complexities
 306
 307	#-------------------------------------------------------#
 308
 309	def get_metadata(self) :
 310		return {
 311			"name" : self.name,
 312			'complexity' : self.complexity,
 313			"description" : self.description
 314		}
 315
 316	def get_transition_matrix(self) :
 317
 318		if self.T  is not None :
 319			return self.T
 320
 321		n_states = len( self.states )
 322		T = np.zeros((n_states, n_states))
 323
 324		for tr in self.transitions :    
 325			T[ tr.origin_state_idx, tr.target_state_idx  ] = tr.prob
 326
 327		self.T = T
 328
 329		return self.T
 330
 331	#-------------------------------------------------------#
 332
 333	def get_T_X(self) :
 334
 335		if self.T_x  is not None :
 336			return self.T_x
 337
 338		n_states  = len( self.states )
 339		n_symbols = len( self.alphabet )
 340
 341		T_x = [ np.zeros((n_states, n_states)) for _ in range( n_symbols ) ]
 342
 343		for tr in self.transitions :
 344			T_x[ tr.symbol_idx ][tr.origin_state_idx, tr.target_state_idx] = tr.prob
 345
 346		self.T_x = T_x
 347		return self.T_x
 348
 349	#-------------------------------------------------------#
 350
 351	def get_msp_qw(
 352		self,
 353		exact_state_cap: int = 1000,
 354		verbose: bool = True,
 355	):
 356		if self.msp is not None:
 357			return self.msp
 358
 359		try : 
 360
 361			print( "\nTrying to Compute Mixed State Presentation using Exact Fractions\n" )
 362
 363			self.msp = compute_msp_exact(
 364				T_x=self.get_Tx_fractional(),
 365				pi=self.get_fractional_stationary_distribution(),
 366				n_states=len(self.states),
 367				alphabet=self.alphabet,
 368				exact_state_cap=1000,
 369				bool = True
 370			)
 371
 372			return self.msp 
 373
 374		except RuntimeError as e :
 375			warnings.warn( f"Exact msp failed: {e} Falling back to msp approximation." )
 376
 377		return self.get_msp()
 378
 379	def get_msp(
 380		self,
 381		exact_state_cap: int = 175_000,
 382		jsd_eps:         float = 1e-7,
 383		k_ann:           int   = 50,
 384		verbose                = True,
 385	) :
 386
 387		if self.msp is not None:
 388			return self.msp
 389	 
 390		T_x = self.get_T_X()
 391		pi  = self.get_stationary_distribution()
 392
 393		T_stacked      = np.stack(T_x)
 394		n_symbols      = len(self.alphabet)
 395		n_input_states = T_stacked.shape[1]
 396	
 397		print( "\nComputing Mixed State Presentation..." )
 398
 399		self.msp = compute_msp( 
 400			T_x=T_x,
 401			pi=pi,
 402			n_states=len(self.states),
 403			alphabet=self.alphabet,
 404			exact_state_cap=exact_state_cap,
 405			verbose=verbose
 406		)
 407
 408		return self.msp
 409
 410	def get_reverse_am(self) :
 411
 412		if self.reverse_am is not None:
 413			return self.reverse_am
 414
 415		pi = self.get_stationary_distribution()
 416		self.reverse_am = copy.deepcopy(self)
 417		
 418		new_transitions = []
 419		for tr in self.transitions:
 420			i = tr.target_state_idx
 421			j = tr.origin_state_idx
 422			
 423			p_reversed = (pi[j] * tr.prob) / pi[i]
 424			
 425			new_transitions.append(
 426				Transition(
 427					origin_state_idx=i,
 428					target_state_idx=j,
 429					prob=p_reversed,
 430					symbol_idx=tr.symbol_idx
 431				)
 432			)
 433
 434		self.reverse_am.set_transitions(new_transitions)
 435
 436		if self.reverse_am.is_epsilon_machine():
 437			return self.reverse_am
 438
 439		rmsp = self.reverse_am.get_msp_qw( exact_state_cap=len(self.states)*4 )
 440
 441		self.reverse_am.set_states( rmsp.states )
 442		self.reverse_am.set_transitions( rmsp.transitions )
 443		self.reverse_am.msp = rmsp
 444		self.reverse_am.start_state = 0
 445
 446		self.reverse_am.collapse_to_largest_strongly_connected_subgraph()
 447		self.reverse_am.minimize()
 448
 449		return self.reverse_am
 450
 451	#-------------------------------------------------------#
 452
 453	def get_Tx_fractional(self) -> list[ list[ list[ Fraction ] ] ] :
 454
 455		self.to_q_weighted()
 456
 457		n_states  = len( self.states )
 458		n_symbols = len( self.alphabet )
 459
 460		T_x = []
 461
 462		for x in range( n_symbols ) :
 463			T_x.append( [] )
 464			for i in range( n_states ) :
 465				T_x[ x ].append( [ 0 for _ in range( n_states ) ] )
 466
 467		for tr in self.transitions :
 468			T_x[ tr.symbol_idx ][ tr.origin_state_idx ][ tr.target_state_idx ] = tr.pq
 469
 470		return T_x
 471
 472	def get_T_sympy( self ) :
 473
 474		self.to_q_weighted()
 475
 476		n = len( self.states )
 477		T = sympy.zeros( n, n )
 478
 479		for tr in self.transitions :
 480			T[ tr.origin_state_idx, tr.target_state_idx ] = tr.pq
 481
 482		return T
 483
 484	def get_fractional_stationary_distribution(self) :
 485
 486		T = self.get_T_sympy()
 487
 488		if self.pi_fractional is not None :
 489			return self.pi_fractional
 490
 491		G = self.as_digraph()
 492
 493		if not nx.is_strongly_connected(G):
 494			raise ValueError( "Single stationary distribution requires strongly connected HMM." )
 495
 496		self.pi_fractional = solve_for_pi_fractional( T )
 497
 498		return self.pi_fractional
 499
 500	def get_stationary_distribution(self):
 501
 502		if self.pi is not None :
 503			return self.pi
 504
 505		G = self.as_digraph()
 506		
 507		if not nx.is_strongly_connected(G):
 508			raise ValueError( "Single stationary distribution requires strongly connected HMM." )
 509
 510		T = self.get_transition_matrix()
 511		return solve_for_pi( T )		
 512
 513	#-------------------------------------------------------#
 514	#                 Complexity Measures                   #
 515	#-------------------------------------------------------#
 516	
 517	def C_mu( self ) :
 518
 519		"""The *statistical complexity* (aka *forecasting complexity*) :
 520
 521		.. math::
 522
 523			C_{\\mu} = - \\sum_{\\sigma \\in \\mathcal{S}} \\Pr(\\sigma) \\log_2 \\Pr(\\sigma),
 524
 525		where :math:`\\mathcal{S}` are the machine's states [^crutchfield_exact_2016], p.2.
 526
 527		.. note::
 528
 529			**Interpretations**
 530
 531			* The amount of historical information a process stores.
 532			* The amount of structure in a process.
 533
 534		Returns:
 535
 536			float: :math:`C_{\\mu}`.
 537
 538		[^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral
 539			Decomposition of Intrinsic Computation*, 2016.
 540			<https://arxiv.org/abs/1309.3792>
 541		"""
 542
 543		m = self.get_complexity_measure_if_exists( "C_mu" )
 544
 545		if m is not None :
 546			return m
 547
 548		pi = self.get_stationary_distribution()
 549
 550		h = 0
 551		for i, pr in enumerate( pi ) :
 552			
 553			if pr < self.EPS :
 554				continue
 555
 556			h += -pr * np.log2( pr )
 557
 558		self.set_complexity_measure( "C_mu", h )
 559
 560		return h
 561
 562	#-------------------------------------------------------#
 563
 564	def h_mu( self ) :
 565
 566		"""The *entropy rate* :
 567
 568		.. math::
 569
 570			h_{\\mu}(\\boldsymbol{\\mathcal{S}}) = - \\sum_{\\sigma \\in \\mathcal{S}} \\Pr(\\sigma) \\sum_{x \\in \\mathcal{A}} \\Pr(x|\\sigma) \\log_2 \\Pr(x|\\sigma),
 571
 572		where :math:`\\mathcal{A}` is the alphabet and :math:`\\mathcal{S}` are the machine's states [^crutchfield_exact_2016], p.2.
 573
 574		.. note::
 575
 576			**Interpretations**
 577
 578			* The lower bound on achievable loss in bits. 
 579			* The irreducable randomness in the process.
 580			* The intrinsic Randomness in the process.
 581
 582		Returns:
 583			
 584			float: :math:`h_{\\mu}`.
 585
 586		[^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral
 587			Decomposition of Intrinsic Computation*, 2016.
 588			<https://arxiv.org/abs/1309.3792>
 589		"""
 590
 591		m = self.get_complexity_measure_if_exists( "h_mu" )
 592
 593		if m is not None :
 594			return m
 595
 596		T  = self.get_transition_matrix()
 597		pi = self.get_stationary_distribution()
 598
 599		n_states = pi.size
 600
 601		h = 0
 602		for i, pr in enumerate( pi ) :
 603
 604			if pr < self.EPS :
 605				continue
 606
 607			row_entropy = 0
 608			for j in range( len( pi ) ) :
 609
 610				if T[ i, j ]  < self.EPS :
 611					continue
 612
 613				row_entropy -= T[ i, j ] * np.log2( T[ i, j ] )
 614
 615			h += pr * row_entropy
 616
 617		self.set_complexity_measure( "h_mu", h )
 618
 619		return h
 620
 621	#-------------------------------------------------------#
 622
 623	def H_1(self) -> float :
 624
 625		"""The *single symbol uncertainty*:
 626
 627		.. math::
 628
 629			H(1)=-\\sum_{x\\in\\mathcal{A}} \\Pr(x) \\log_2{\\Pr(x)},
 630
 631		where :math:`\\mathcal{A}` is the alphabet [^James_2018], p.2.
 632
 633		.. note::
 634
 635			**Interpretations**
 636
 637			* How uncertain you are on average about a single measurement with no context.
 638
 639		Returns:
 640
 641			float: :math:`H(1)`.
 642
 643		[^James_2018]: James et al., Anatomy of a Bit: Information in a Time Series Observation, 2018.
 644			<https://arxiv.org/abs/1105.2988>
 645		"""
 646
 647		m = self.get_complexity_measure_if_exists("H_1")
 648		if m is not None:
 649			return m
 650
 651		pi  = self.get_stationary_distribution()
 652		T_X = self.get_T_X()  # dict: symbol -> matrix
 653
 654		h = 0.0
 655		for T_x in T_X:
 656			# Pr(x) = sum_i pi[i] * sum_j T^(x)[i,j]
 657			p_sym = 0.0
 658			for i, pr in enumerate(pi):
 659				if pr < self.EPS:
 660					continue
 661				p_sym += pr * T_x[i, :].sum()
 662
 663			if p_sym < self.EPS:
 664				continue
 665			h -= p_sym * np.log2(p_sym)
 666
 667		self.set_complexity_measure("H_1", h)
 668		return h
 669
 670	#-------------------------------------------------------#
 671
 672	def rho_mu(self) -> float :
 673		
 674		"""The *anticipated information* [^James_2018], p.3.:
 675
 676		.. math::
 677
 678			\\rho_{\\mu}= H(1) - h_{\\mu}
 679
 680		Returns:
 681			
 682			float: :math:`\\rho_{\\mu}`
 683
 684		[^James_2018]: James et al., Anatomy of a Bit: Information in a Time Series Observation, 2018.
 685			<https://arxiv.org/abs/1105.2988>
 686		"""
 687
 688		m = self.get_complexity_measure_if_exists("rho_mu")
 689		
 690		if m is not None:
 691			return m
 692
 693		rho = self.H_1() - self.h_mu()
 694		
 695		self.set_complexity_measure("rho_mu", rho)
 696		
 697		return rho
 698
 699	#-------------------------------------------------------#
 700
 701	def block_convergence( self )  :
 702
 703		"""
 704		"Run block entropy convergence and return a dict with $\\mathbf{E}, \\mathbf{S}, $\\mathbf{T},\\mathbf{T}(L), \\mathcal{H}(L),$ and $h_{\\mu}(L)$.
 705		"""
 706
 707		trs = [ [] for _ in range( len( self.states ) ) ]
 708		for tr in self.transitions :
 709			trs[ tr.origin_state_idx ].append( ( 
 710				tr.symbol_idx, 
 711				float( tr.prob ),
 712				tr.target_state_idx ) )
 713
 714		pi = self.get_stationary_distribution()
 715
 716		state_dist = [ float( pi[ i ] ) for i in range( len( self.states ) ) ]
 717		branches = [(1.0, list(state_dist))]
 718
 719		print( "\nComputing Block Entropy\n" )
 720
 721		C = am_fast.block_entropy_convergence(
 722			h_mu            = self.h_mu(),
 723			n_states        = len( self.states ),
 724			n_symbols       = len( self.alphabet ),
 725			convergence_tol = 1e-6,
 726			precision       = 10,
 727			eps             = 1e-25,
 728			branches        = branches,
 729			trans           = trs,
 730			max_branches    = 30_000_000
 731		)
 732
 733		print( "Done\n" )
 734
 735		self.set_complexity_measure( f"E",          C.E )
 736		self.set_complexity_measure( f"S",          C.S )
 737		self.set_complexity_measure( f"T_inf",      C.T )
 738		self.set_complexity_measure( f"T_L",        C.T_L.tolist() )
 739		self.set_complexity_measure( f"H_L",        C.H_L.tolist() )
 740		self.set_complexity_measure( f"h_mu_L",  C.h_mu_L.tolist() )
 741		self.set_complexity_measure( f"H_sync",  C.H_sync.tolist() )
 742
 743		return C
 744
 745	#-------------------------------------------------------#
 746
 747	def E( self ) -> float :
 748
 749		"""The *excess entropy* [^crutchfield_exact_2016], p.4:
 750
 751		.. math::
 752
 753			\\mathbf{E} \\equiv \\sum_{L=1}^{\\infty} I[X_{-\\infty:0}; X_{0:\\infty}]
 754		
 755		Computed via :meth:`get_msp` and :meth:`amachine.am_msp.MSP.get_E_S_T`, or :meth:`amachine.am_fast.block_entropy_convergence`
 756
 757		.. note::
 758
 759			**Interpretations**
 760
 761			* The information from the past that reduces uncertainty in the future [^crutchfield_exact_2016].
 762			* How much information an observer must extract to synchronize to the process.
 763			* Measures how long the process appears more complex than it asymptotically is.
 764			* Vanishes for immediately synchronizable processes.
 765
 766		Returns:
 767		
 768			float: :math:`\\mathbf{E}`
 769
 770		[^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral
 771			Decomposition of Intrinsic Computation*, 2016.
 772			<https://arxiv.org/abs/1309.3792>
 773		"""
 774
 775		m = self.get_complexity_measure_if_exists( "E" )
 776
 777		if m is not None :
 778			return m
 779
 780		try : 
 781			msp = self.get_msp()
 782			E, S, T = msp.get_E_S_T()
 783			self.set_complexity_measure( "E", E )
 784			self.set_complexity_measure( "S", S )
 785			self.set_complexity_measure( "T_inf", T )
 786			
 787		except Exception as e :
 788
 789			print( f"MSP failed {e}" )
 790
 791			C = self.block_convergence()	
 792			E = C.E
 793			self.set_complexity_measure( "E", E )
 794
 795		return E
 796
 797	#-------------------------------------------------------#
 798
 799	def S( self ) -> float :
 800
 801		"""The *synchronization* information:
 802
 803		.. math::
 804
 805			\\mathbf{S} \\equiv \\sum_{L=1}^{\\infty} \\mathcal{H}(L),
 806
 807		where :math:`\\mathcal{H}(L)` is the average state uncertainty having seen all length-L words [^crutchfield_exact_2016], p.4.
 808
 809		.. note::
 810
 811			**Interpretations**
 812
 813			* The total amount of state information that an observer must extract to become synchronized [^crutchfield_exact_2016].
 814
 815		Computed via :meth:`get_msp` and :meth:`amachine.am_msp.MSP.get_E_S_T`, or :meth:`amachine.am_fast.block_entropy_convergence`
 816
 817		Returns:
 818		
 819			float: :math:`\\mathbf{S}`
 820
 821		[^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral
 822			Decomposition of Intrinsic Computation*, 2016.
 823			<https://arxiv.org/abs/1309.3792>
 824		"""
 825
 826		m = self.get_complexity_measure_if_exists( "S" )
 827
 828		if m is not None :
 829			return m
 830
 831		try : 
 832			msp = self.get_msp()
 833			E, S, T = msp.get_E_S_T()
 834			self.set_complexity_measure( "E", E )
 835			self.set_complexity_measure( "S", S )
 836			self.set_complexity_measure( "T_inf", T )
 837
 838		except Exception as e :
 839			print( f"{e} \nFalling back to iterative estimation.")
 840			exit()
 841
 842			C = self.block_convergence()	
 843			S = C.S
 844			self.set_complexity_measure( "S", S )
 845
 846		return S
 847
 848	#-------------------------------------------------------#
 849
 850	def T_inf( self ) -> float :
 851
 852		"""The *transient information*[^crutchfield_exact_2016], p.4:
 853
 854		.. math::
 855
 856			\\mathbf{T} \\equiv \\sum_{L=1}^{\\infty} L \\left[ h_{\\mu}(L) - h_{\\mu} \\right]
 857
 858		Computed via :meth:`get_msp` and :meth:`amachine.am_msp.MSP.get_E_S_T`, or :meth:`amachine.am_fast.block_entropy_convergence`
 859
 860		.. note::
 861
 862			**Interpretations**
 863
 864			* The amount of information one must extract from observations so that the block entropy converges to its linear asymptote[^crutchfield_exact_2016].
 865
 866		Returns:
 867		
 868			float: :math:`\\mathbf{T}`
 869
 870		[^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral
 871			Decomposition of Intrinsic Computation*, 2016.
 872			<https://arxiv.org/abs/1309.3792>
 873		"""
 874
 875		m = self.get_complexity_measure_if_exists( "T_inf" )
 876
 877		if m is not None :
 878			return m
 879
 880		try : 
 881			msp = self.get_msp()
 882			E, S, T = msp.get_E_S_T()
 883			self.set_complexity_measure( "E", E )
 884			self.set_complexity_measure( "S", S )
 885			self.set_complexity_measure( "T_inf", T )
 886
 887		except Exception as e :
 888			print( f"{e} \nFalling back to iterative estimation.")
 889			C = self.block_convergence()	
 890			T_inf = C.T
 891			self.set_complexity_measure( "T_inf", T_inf )
 892
 893		return T_inf
 894
 895	#-------------------------------------------------------#
 896
 897	def chi( self ) -> float :
 898
 899		"""Computes the foward crypticity[^crutchfield_crypticity_2009][^Mahoney_crypticity_2021], p.2:
 900
 901		.. math::
 902
 903			\\chi = C_{\\mu} - \\mathbf{E}
 904
 905		:math:`C_{\\mu}` is trivially computed from the stationary distribution in :meth:`C_mu` and :math:`\\mathbf{E}` in :meth:`E`.
 906
 907		.. note::
 908
 909			**Interpretations**
 910
 911			* Difference between internal stored information and apparent information to an observer.
 912			* How muching information is hiding in the system.
 913
 914		Returns:
 915		
 916			float: :math:`\\chi`
 917
 918		[^crutchfield_crypticity_2009]: Crutchfield et al., Time’s barbed arrow: Irreversibility, crypticity, and stored information, 2009.
 919			<https://arxiv.org/abs/0902.1209>
 920
 921		[^Mahoney_crypticity_2021]: Mahoney et al., Information Accessibility and Cryptic Processes, 2021.
 922			<https://arxiv.org/abs/0905.4787>
 923		"""
 924
 925		m = self.get_complexity_measure_if_exists( "chi" )
 926
 927		if m is not None :
 928			return m
 929
 930		chi = self.C_mu() - self.E()
 931
 932		if chi < 0 :
 933			
 934			# if chi is 0, accumulated floating point error can result in small negative values
 935			if chi < -1e-5:
 936				warnings.warn(f"Crypticity is negative ({chi:.6e}).")
 937			
 938			chi = np.clamp( chi, 0 )
 939
 940		self.set_complexity_measure( "chi", chi )
 941
 942		return chi
 943
 944	#-------------------------------------------------------#
 945	#                      Properties                       #
 946	#-------------------------------------------------------#
 947
 948	def is_row_stochastic(self) :
 949
 950		"""
 951		Check that all states have outgoing transition probabilities that sum to 1.
 952		"""
 953
 954		sums = np.zeros( len( self.states ) )
 955		for tr in self.transitions :
 956			sums[ tr.origin_state_idx ] += tr.prob
 957		return np.allclose( sums, 1.0 )
 958
 959	#-------------------------------------------------------#
 960
 961	def is_unifilar(self) :
 962
 963		"""
 964		Check that no state emits the same symbol on transitions to different states. 
 965		"""
 966
 967		symbol_trs = np.full( ( len( self.states ), len( self.alphabet) ), -1 )
 968
 969		for tr in self.transitions : 
 970		
 971			if symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] == -1 :
 972				symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] = tr.target_state_idx
 973		
 974			elif symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] != tr.target_state_idx :
 975				return False
 976		
 977		return True
 978
 979	#-------------------------------------------------------#
 980
 981	def is_strongly_connected(self) :
 982
 983		"""
 984		Check if every state is reachable from every other state. Relies on [nx.is_strongly_connected](https://networkx.org/documentation/latest/reference/algorithms/generated/networkx.algorithms.components.is_strongly_connected.html).
 985		"""
 986
 987		return nx.is_strongly_connected( self.as_digraph() )
 988
 989	#-------------------------------------------------------#
 990
 991	def is_aperiodic(self) :
 992
 993		"""
 994		Checks if machine is periodic. Relies on [nx.is_aperiodic](https://networkx.org/documentation/latest/reference/algorithms/generated/networkx.algorithms.dag.is_aperiodic.html), "A strongly connected directed graph is aperiodic if there is no integer k > 1 that divides the length of every cycle in the graph."
 995		"""
 996
 997		return nx.is_aperiodic( self.as_digraph() )
 998
 999	#-------------------------------------------------------#
1000
1001	def _is_minimal_as_dfa( self, topological_only : bool, verbose=True ) :
1002
1003		with_probs = not topological_only
1004
1005		# Construct the DFA
1006		dfa = self.as_dfa( with_probs=with_probs )
1007
1008		# Minimize the DFA
1009		#dfa = dfa.minify(retain_names=True)
1010		dfa = am_fast.minify_cpp( dfa, retain_names=True )
1011
1012		# check we have minimal number of states
1013		if len( dfa.states ) != len( self.states ) :
1014			if verbose : 
1015				print( f"Not minimal reduces from {len( self.states )} to {len( dfa.states )} states" )
1016			return False
1017
1018		return True
1019
1020	def is_topological_epsilon_machine( self, verbose=True ) :
1021
1022		"""
1023		Checks if the HMM is a topological $\\epsilon$-machine [^1].
1024
1025		[^1]: Johnson et al, Enumerating Finitary Processes, 2024.
1026			<https://arxiv.org/abs/1011.0036>
1027		"""
1028
1029		if not ( self.is_unifilar() and self.is_strongly_connected() ) :
1030			if verbose : 
1031				print( f"Either non unifilar or not strongly connected" )
1032			return False
1033		else :
1034			return self._is_minimal_as_dfa( topological_only=True, verbose=verbose )
1035
1036	def is_epsilon_machine( self, verbose=True ) :
1037
1038		if not ( self.is_unifilar() and self.is_strongly_connected() ) :
1039			if verbose : 
1040				print( f"Either non unifilar or not strongly connected" )
1041			return False
1042		else :
1043			return self._is_minimal_as_dfa( topological_only=False, verbose=verbose )
1044
1045	#-------------------------------------------------------#
1046	#                      Properties                       #
1047	#-------------------------------------------------------#
1048
1049	def minimize(self, retain_names: bool = True, verbose=False):
1050
1051		"""
1052		Minimizes the HMM, resulting in an :math:`\\epsilon-`machine if the HMM
1053		is unifilar and strongly connected. Converts the HMM to a DFA with symbols
1054		labeled jointly with symbols and probabilities, and uses Myhill-Nerode 
1055		equivalence for minimization. Relies on `automata_lib` and uses
1056		 `automata.fa.dfa.DFA.minify` with `allow_partial=True`, and all states
1057		 final.
1058
1059		Args:
1060			retain_names (bool): If `True`, the merged states will be named by their union, e.g. `{s_0, s_1}`, and other states will retain their origion names. Otherwise, they will be relabled `{ '0', '1', ..., 'n-1' }`.
1061
1062		Returns:
1063		
1064			automata.fa.dfa.DFA : the resulting DFA.
1065		"""
1066
1067		if self.is_minimal :
1068			return
1069
1070		start = time.perf_counter()
1071
1072		if not self.is_unifilar():
1073			raise ValueError(
1074				"DFA minimization is not valid for non-unifilar HMMs"
1075			)
1076
1077		was_strongly_connected = self.is_strongly_connected()
1078
1079		was_row_stochastic = self.is_row_stochastic()
1080		n_states_before = len(self.states)
1081
1082		dfa = self.as_dfa(with_probs=True)
1083
1084		#min_dfa = self.as_dfa(with_probs=True).minify(retain_names=True)
1085		min_dfa = am_fast.minify_cpp( dfa, retain_names=True )
1086
1087		# Build lookup from original state index -> CausalState object
1088		orig_state   = {i: s for i, s in enumerate(self.states)}
1089		eq_list      = list(min_dfa.states)
1090
1091		start_eq = min_dfa.initial_state
1092		
1093		# Separate the start state, then sort the rest by the 
1094		# smallest original state index inside each equivalence class.
1095		other_eqs = [eq for eq in eq_list if eq != start_eq]
1096		other_eqs.sort(key=lambda eq: min(eq))
1097
1098		# Recombine so start eq comes first, followed by the sorted remaining classes
1099		eq_list = [start_eq] + other_eqs
1100		# ----------------------------------------------------------
1101
1102		# Recompute eq_to_idx with the new ordering
1103		eq_to_idx = {eq: i for i, eq in enumerate(eq_list)}
1104
1105		# new_start is now guaranteed to be 0
1106		new_start = 0
1107
1108		# Map each original state index -> its equivalence class
1109		# Guard: minify() silently drops unreachable states
1110		orig_to_eq = {s: eq for eq in min_dfa.states for s in eq}
1111
1112		# Build lookup from original state index -> its transitions
1113		orig_trs = defaultdict(list)
1114		for t in self.transitions:
1115			orig_trs[t.origin_state_idx].append(t)
1116
1117		new_trs = []
1118		for eq in min_dfa.states:
1119			rep        = next(iter(eq))
1120			origin_idx = eq_to_idx[eq]
1121			for t in orig_trs[rep]:
1122				target_eq  = orig_to_eq[t.target_state_idx]
1123				target_idx = eq_to_idx[target_eq]
1124				new_trs.append(Transition(
1125					origin_state_idx = origin_idx,
1126					target_state_idx = target_idx,
1127					prob             = t.prob,
1128					symbol_idx       = t.symbol_idx,
1129				))
1130
1131		members_list = [[orig_state[i] for i in sorted(eq)] for eq in eq_list]  # sorted for determinism
1132
1133		# Compute new names
1134		if retain_names:
1135			new_names = [
1136				"{" + ",".join(str(m.name) for m in members) + "}" if len(members) > 1
1137				else members[0].name
1138				for members in members_list
1139			]
1140		else:
1141			new_names = [str(j) for j in range(len(eq_list))]
1142
1143		old_name_to_new_name = {
1144			m.name: new_names[j]
1145			for j, members in enumerate(members_list)
1146			for m in members
1147		}
1148
1149		# Build the new states, preserving classes and isomorphs regardless of naming
1150		new_states = []
1151		for j, (eq, members, name) in enumerate(zip(eq_list, members_list, new_names)):
1152			classes   = set().union(*(m.classes for m in members))
1153			isomorphs = {
1154				old_name_to_new_name.get(iso, iso)
1155				for m in members
1156				for iso in m.isomorphs
1157				if old_name_to_new_name.get(iso, iso) != name
1158			}
1159			new_states.append(CausalState(
1160				name      = name,
1161				classes   = classes,
1162				isomorphs = isomorphs,
1163			))
1164
1165		self.set_states(new_states)
1166		self.set_transitions(new_trs)
1167		self.start_state = new_start
1168
1169		if n_states_before == len(new_states) and verbose :
1170			print( f"{n_states_before} state HMM was already minimal.\n" )
1171		elif verbose :
1172			print( f"Minimized from {n_states_before} to {len(new_states)}\n" )
1173
1174		if not ( was_strongly_connected ==  self.is_strongly_connected() ) :
1175			raise RuntimeError(
1176				f"Minimization broke strongly connected"
1177			)
1178
1179		if not ( was_row_stochastic ==  self.is_row_stochastic() ) :
1180			raise RuntimeError(
1181				f"Minimization broke row stochasticity"
1182			)
1183
1184		self.is_minimal = True
1185
1186	#-------------------------------------------------------#
1187	#                      Modifiers                        #
1188	#-------------------------------------------------------#
1189
1190
1191	def collapse_to_largest_strongly_connected_subgraph( self, rename_states=True ) :
1192
1193		# get equivalent networkx graph
1194		G = self.as_digraph()
1195
1196		# if already strongly connected, nothing to do
1197		if not nx.is_strongly_connected( G ) :
1198
1199			start = time.perf_counter()
1200			subgraph_nodes = list( nx.strongly_connected_components( G ) )
1201
1202			# decompose into strongly connected components and sort by length
1203			# subgraph_nodes = list(nx.strongly_connected_components( G ))
1204			subgraph_nodes.sort(key=len)
1205			component_state_set = subgraph_nodes[-1]
1206
1207			# Take the largest strongly connected component (as list of state names)
1208			component_states = sorted( list( component_state_set ) )
1209
1210			# make temporary copies of the old transitions and states
1211			old_transitions = [
1212				Transition(
1213					origin_state_idx=tr.origin_state_idx,
1214					target_state_idx=tr.target_state_idx,
1215					prob=tr.prob,
1216					symbol_idx=tr.symbol_idx
1217				)
1218
1219				for tr in self.transitions
1220			]
1221
1222			old_states = [
1223				CausalState(
1224					name=s.name,
1225					classes=s.classes,
1226					isomorphs=s.isomorphs
1227				)
1228				for s in self.states
1229			]
1230
1231			self.set_states(
1232				states=[ 
1233					state
1234					for i, state in enumerate( old_states ) if i in component_state_set
1235				]
1236			)
1237
1238			# we will build new transition list based on those belonging to the component
1239			self.set_transitions( transitions= [] )
1240
1241			# for tracking which new transitions leave each state
1242			transitions_from_state = { state : set() for state in component_states }
1243			new_transitions = []
1244
1245			for tr in old_transitions :
1246
1247				origin_state_name = old_states[ tr.origin_state_idx ].name
1248				target_state_name = old_states[ tr.target_state_idx ].name
1249
1250				# skip transitions that connect separate strongly connected components
1251				if not ( tr.origin_state_idx in component_state_set and tr.target_state_idx in component_state_set ) :
1252					continue
1253
1254				# track transitions (by index in new transitions list) that leave this state
1255				transitions_from_state[ tr.origin_state_idx ].add( len( new_transitions ) )
1256
1257				my_origin_state_idx = self.state_idx_map[ origin_state_name ]
1258				my_target_state_idx = self.state_idx_map[ target_state_name ]
1259
1260				new_transitions.append( 
1261					Transition(
1262						origin_state_idx=my_origin_state_idx,
1263						target_state_idx=my_target_state_idx,
1264						prob=tr.prob,
1265						symbol_idx=tr.symbol_idx
1266					)  )
1267
1268			self.set_transitions( transitions=new_transitions )
1269
1270			# if we removed an outgoing transition from a state, we need to distribute its probability 
1271			# among the remaining outgoing transitions from the state
1272			for state in component_states :
1273				
1274				# get the set of transitions leaving this state
1275				state_trs = transitions_from_state[ state ]
1276
1277				# sum the probabilities of the outgoing transitions from the state
1278				p_sum = np.sum( [ self.transitions[ i ].prob for i in state_trs ] )
1279
1280				# how much probability is missing
1281				diff = 1.0 - p_sum
1282
1283				# if significant difference
1284				if abs( diff ) > self.EPS : 
1285
1286					# calculate how much of the difference each transition gets
1287					adjustment = diff / len( state_trs )
1288					
1289					# update the transitions
1290					for i in state_trs :
1291
1292						# transition with origional probability
1293						tr = self.transitions[ i ]
1294
1295						# adjusted probability
1296						self.transitions[ i ] = Transition(
1297							origin_state_idx=tr.origin_state_idx, 
1298							target_state_idx=tr.target_state_idx, 
1299							prob=tr.prob + adjustment, 
1300							symbol_idx=tr.symbol_idx )
1301
1302			if rename_states :
1303				self.set_states( [
1304					CausalState( name=f"{i}" )
1305					for i, s in enumerate( self.states )
1306				] )
1307
1308
1309	def to_q_weighted( self, denominator_limit=1000 ) :
1310
1311		"""
1312		Approximates the existing transition probabilities with exact fractions, stores 
1313		the fractional probabilities as Fraction in Transition.pq, and sets the floating
1314		point probabilty to `float(pq)`. If `denominator_limit` is too small for a sane 
1315		conversion, the function recurses with `denominator_limit=denominator_limit*10`.
1316
1317		Args:
1318			denominator_limit (int): The initial input to :meth:`Fraction.limit_denominator` in 
1319			the conversion.
1320		"""
1321
1322		if self.is_q_weighted :
1323			return
1324
1325		if not self.is_row_stochastic() :
1326			raise ValueError( "Cannot convert to q-weighted because not row stochastic" )
1327
1328		t_from = [[] for _ in range(len(self.states))]
1329
1330		for i, tr in enumerate( self.transitions ) :
1331			t_from[ tr.origin_state_idx ].append( i )
1332
1333		new_transitions = []
1334		for t_list in t_from :
1335			
1336			if not t_list : 
1337				continue
1338
1339			p_q_sum = Fraction(0,1)
1340			p_qs = []
1341
1342			for t_idx in t_list :
1343			
1344				p_q = Fraction( self.transitions[ t_idx ].prob ).limit_denominator( denominator_limit )
1345				p_q_sum += p_q
1346				p_qs.append( p_q )
1347
1348			if p_q_sum != Fraction(1,1) :
1349				
1350				max_pq_i = np.argmax( p_qs )
1351				max_oq = p_qs[ max_pq_i ]
1352
1353				diff = p_q_sum - Fraction(1,1)
1354
1355				# If recurse with higher resolution
1356				if diff > max_oq :
1357					return self.to_q_weighted( denominator_limit*10 )
1358				else :
1359					p_qs[ max_pq_i ] -= diff
1360
1361			for i, t_idx in enumerate( t_list ) :	
1362				new_transitions.append( 
1363					Transition( 
1364						origin_state_idx=self.transitions[  t_idx ].origin_state_idx,
1365						target_state_idx=self.transitions[  t_idx ].target_state_idx,
1366						prob=float(p_qs[ i ]),
1367						symbol_idx=self.transitions[  t_idx ].symbol_idx,
1368						pq=p_qs[ i ]
1369					)
1370				)
1371
1372		self.set_transitions( new_transitions )
1373		self.is_q_weighted = True
1374
1375	#-------------------------------------------------------#
1376	#                    Data Generation                    #
1377	#-------------------------------------------------------#
1378
1379	def isomorphic_shift(
1380		self,
1381		input_symbol_indices: np.ndarray,
1382		input_state_indices:  np.ndarray,
1383		shift : int = 1
1384	) -> dict[str, np.ndarray]:
1385
1386		"""
1387		Generates a new sequence of of symbols that are permuted with the symbols emitted by
1388		isomorphic states, if they exists.
1389
1390		:math:`\\sigma_o = \\mathcal{S}\\left[\\texttt{input\\_state\\_indices}[i]\\right]`<br>
1391		:math:`\\sigma_t = \\mathcal{S}\\left[\\texttt{input\\_state\\_indices}[i+1]\\right]`
1392 
1393		:math:`\\mathcal{I}(\\sigma_o) = \\\\{\\sigma^0_o,\\, \\sigma^1_o,\\, \\dots,\\, \\sigma^{n-1}_o \\\\}`<br>
1394		:math:`\\mathcal{I}(\\sigma_t) = \\\\{\\sigma^0_t,\\, \\sigma^1_t,\\, \\dots,\\, \\sigma^{n-1}_t \\\\}`
1395 
1396		:math:`k = \\bigl(\\mathcal{I}(\\sigma_o).\\texttt{index}(\\sigma_o) + \\texttt{shift}\\bigr) \\bmod n`
1397 
1398		:math:`\\texttt{output\\_symbol\\_indices}[i]   := T(\\sigma_o^k,\\, \\sigma_t^k).\\text{symbol\\_index}`<br>
1399		:math:`\\texttt{output\\_state\\_indices}[i]    := \\mathcal{S}.\\texttt{index}(\\sigma_o^k)`<br>
1400		:math:`\\texttt{output\\_state\\_indices}[i+1]  := \\mathcal{S}.\\texttt{index}(\\sigma_t^k)`
1401
1402		Where :math:`\\mathcal{I}(\\sigma)`` is the ordered set of states isomorphic to :math:`\\sigma` including :math:`\\sigma` itself.
1403
1404		Args:
1405			input_symbol_indices (np.ndarray): The sequence of generated symbols.
1406			input_state_indices (np.ndarray): The sequence of states that generated symbols with the final state at the end.
1407			shift : int: How much to shift the symbols across the isomorphic states.
1408		"""
1409
1410		if not any( state.isomorphs for state in self.states ):
1411			raise ValueError("HMM has no states with isomorphs")
1412
1413		inputs = np.asarray(input_symbol_indices)
1414		states = np.asarray(input_state_indices)
1415
1416		n_states = len(self.states)
1417
1418		tr_sym_table = np.full((n_states, n_states), -1, dtype=np.int32)
1419
1420		for tr in self.transitions:
1421			tr_sym_table[ tr.origin_state_idx, tr.target_state_idx ] = tr.symbol_idx
1422
1423		# Build isomorph remapping: identity by default, overridden where isomorphs exist
1424		iso_table = np.arange(n_states, dtype=np.int32)
1425
1426		for i, state in enumerate(self.states):
1427			if state.isomorphs:
1428				isomorph       = sorted(state.isomorphs)[0]
1429				iso_table[ i ] = self.state_idx_map[isomorph]
1430
1431		for i, state in enumerate(self.states):
1432			if state.isomorphs:
1433
1434				# extend the isomorph list to include the identity
1435				isormorphs_with_identity = sorted( [ i ] + [ self.state_idx_map[iso] for iso in state.isomorphs ] )
1436
1437				# find the identity index
1438				pos = isormorphs_with_identity.index( i )
1439
1440				# cyclical shift 
1441				iso_table[ i ] = isormorphs_with_identity[ ( pos + shift ) % len( isormorphs_with_identity ) ]
1442
1443		origins = states[:-1]
1444		targets = states[1:]
1445
1446		out_origins = iso_table[origins]
1447		out_targets = iso_table[targets]
1448
1449		inv_sym = tr_sym_table[ out_origins, out_targets ]
1450		inv_sts = np.empty(states.size, dtype=states.dtype)
1451
1452		inv_sts[:-1] = out_origins
1453		inv_sts[-1]  = out_targets[-1]
1454
1455		return {
1456			"symbol_index": inv_sym.astype(inputs.dtype),
1457			"state_index":  inv_sts,
1458		}
1459
1460	def generate_data(
1461		self,
1462		file_prefix: str,
1463		n_gen: int,
1464		include_states: bool,
1465		isomorphic_shifts : set[int]=None,
1466		random_seed : int=42 ) -> dict[any] : 
1467
1468		trs = [ [] for _ in range( len( self.states ) ) ]
1469		for tr in self.transitions :
1470			trs[ tr.origin_state_idx ].append( ( 
1471				tr.symbol_idx, 
1472				float( tr.prob ),
1473				tr.target_state_idx ) )
1474
1475		data = am_fast.generate_data(
1476			n_gen=n_gen,
1477			start_state=self.start_state,
1478			transitions=trs,
1479			alphabet=sorted(list(self.alphabet)),
1480			include_states=include_states,
1481			random_seed=random_seed
1482		)
1483
1484		if isomorphic_shifts is not None :
1485
1486			if not include_states :
1487				raise ValueError( "Isomorphic inversion requires include_states=True" )
1488
1489			data[ "isomorphic_shifts" ] = {}
1490
1491			for shift in isomorphic_shifts :
1492
1493				try : 
1494
1495					shifted = self.isomorphic_shift(
1496						input_symbol_indices=data[ "symbol_index" ], 
1497						input_state_indices=data[ "state_index" ], 
1498						shift=shift
1499					)
1500
1501					data[ "isomorphic_shifts" ][ shift ] = {
1502						"symbol_index" : shifted[ "symbol_index" ],
1503						"state_index"  : shifted[ "state_index" ]
1504					}
1505
1506				except Exception as e :
1507					print( f"Exception {e}" )
1508
1509		am_fast.save_data(
1510			data=data,
1511			file_prefix=file_prefix,
1512			alphabet=sorted(list(self.alphabet)),
1513			n_states=len( self.states ),
1514			start_state=self.start_state,
1515			random_seed=random_seed,
1516			machine_metadata=self.get_metadata() )
1517
1518		return data
1519
1520	#-------------------------------------------------------#
1521	#                 Basic Visualization                   #
1522	#-------------------------------------------------------#
1523
1524	def draw_graph(
1525		self,
1526		engine     : str  = 'dot',
1527		output_dir : Path = Path('.'),
1528		show       : bool = True
1529	) -> None :
1530
1531		"""
1532		Draws the machine using [pygraphiviz](https://pygraphviz.github.io/documentation/stable/) and saves it.
1533
1534		Returns:
1535		
1536			networkx.DiGraph : the resulting graph.
1537		"""
1538
1539		G = self.as_digraph()
1540
1541		subgraphs = None if nx.is_strongly_connected( G ) else list( nx.strongly_connected_components( G ) )
1542		
1543		am_vis.draw_graph( 
1544			self, 
1545			output_dir=output_dir, 
1546			title="am_graph", 
1547			view=show, 
1548			subgraphs=subgraphs, 
1549			engine=engine )
1550
1551	#-------------------------------------------------------#
1552	#                   Alternate Forms                     #
1553	#-------------------------------------------------------#
1554
1555
1556	def as_digraph( self ) -> nx.DiGraph :
1557
1558		"""
1559		Builds a [networkx.DiGraph](https://networkx.org/documentation/stable/reference/classes/digraph.html) constructed from the machine's transitions with no edge symbols or weights.
1560
1561		Returns:
1562		
1563			networkx.DiGraph : the resulting graph.
1564		"""
1565
1566		G = nx.DiGraph()
1567		G.add_nodes_from( [ i for i, s in enumerate( self.states ) ] )
1568
1569		for tr in self.transitions :
1570			G.add_edge( tr.origin_state_idx, tr.target_state_idx )
1571
1572		return G
1573
1574	def as_dfa( self, with_probs : bool ) :
1575
1576		"""
1577		Builds an [automata.fa.dfa.DFA](https://caleb531.github.io/automata/api/fa/class-dfa/) constructed from the machine's transitions.
1578
1579		Args:
1580			with_probs (bool): If true the DFA transitions are labeled based on
1581				the symbol of the machines transition concatenated with its
1582				probability, othwise, the only the symbols.
1583
1584		Returns:
1585		
1586			automata.fa.dfa.DFA : the resulting DFA.
1587		"""
1588
1589		precision=8
1590
1591		def edge_label( symb, prob ) :
1592			return f"({symb},{round(prob, precision)})"
1593
1594		# Build states, symbols, and transitions 
1595		dfa_states  = { i for i, _ in enumerate( self.states ) }
1596			
1597		if not with_probs :
1598			dfa_symbols = set( { t.symbol_idx for t in self.transitions } )
1599		else : 
1600			dfa_symbols = set( { edge_label( t.symbol_idx, t.prob ) for t in self.transitions } )
1601
1602		dfa_transitions = defaultdict(dict)
1603
1604		if not with_probs :
1605			for t in self.transitions :
1606				dfa_transitions[ t.origin_state_idx ][ t.symbol_idx ] = t.target_state_idx
1607		else :
1608			for t in self.transitions :
1609				dfa_transitions[ t.origin_state_idx ][ edge_label( t.symbol_idx, t.prob ) ] = t.target_state_idx
1610
1611		# Construct the DFA
1612		return DFA(
1613			states=dfa_states,
1614			input_symbols=dfa_symbols,
1615			transitions=dfa_transitions,
1616			initial_state=self.start_state,
1617			allow_partial=True,
1618			final_states={ 
1619				s for s in dfa_states
1620			}
1621		)

Hidden Markov model implementing epsilon machines, mixed state presentations, complexity measures, and data generation.

Arguments:
  • states (list[CausalState] | None): A list of causal states.
  • transitions (list[Transition] | None): A list of transitions between states.
  • start_state (int): Index of the start state.
  • alphabet (list[str]): List of symbols making up the alphabet.
  • name (str): Name of the model.
  • description (str): Description of the model.
HMM( states: list[amachine.am_causal_state.CausalState] | None = None, transitions: list[amachine.am_transition.Transition] | None = None, start_state: int = 0, alphabet: list[str] | None = None, name: str = '', description: str = '')
52	def __init__( 
53		self,
54		states : list[CausalState] | None = None,
55		transitions : list[Transition] | None = None,
56		start_state : int = 0,
57		alphabet : list[str] | None = None,
58		name : str = "",
59		description : str = "" ) : 
60
61		self.alphabet    : list[str] | None = alphabet or []
62		self.states      : list[CausalState] | None= states or []
63		self.transitions : list[Transition] | None= transitions or []
64
65		self.set_alphabet( self.alphabet )
66		self.set_states( self.states )
67		self.set_transitions( self.transitions )
68
69		self.start_state : int = start_state
70
71		self.name : str = name
72		self.description : str = description
73
74		# To be depreciated
75		self.isoclass : str = None
76
77		# --- derived --------
78
79		self.complexity : dict[str, any] = {}
80
81		self.symbol_idx_map : dict[str,int] = {}
82		self.state_idx_map  : dict[str,int] = {}
83
84		self.pi_fractional = None
85		self.pi : np.ndarray | None = None
86		self.T  : np.ndarray | None = None
87		self.msp : MSP | None = None
88		self.reverse_am : HMM | None = None
89		self.is_q_weighted : bool = False
90		self.is_minimal : bool = False
91
92		# --- const ----------
93
94		self.EPS : float = 1e-12
alphabet: list[str] | None
transitions: list[amachine.am_transition.Transition] | None
start_state: int
name: str
description: str
isoclass: str
complexity: dict[str, any]
symbol_idx_map: dict[str, int]
state_idx_map: dict[str, int]
pi_fractional
pi: numpy.ndarray | None
T: numpy.ndarray | None
msp: amachine.am_msp.MSP | None
reverse_am: HMM | None
is_q_weighted: bool
is_minimal: bool
EPS: float
@override
def get_states(self) -> list[amachine.am_causal_state.CausalState]:
100	@override
101	def get_states(self) -> list[CausalState] : 
102		return self.states
@override
def get_transitions(self) -> list[amachine.am_transition.Transition]:
104	@override
105	def get_transitions(self) -> list[Transition] : 
106		return self.transitions
@override
def get_alphabet(self) -> list[amachine.am_symbol.Symbol]:
108	@override
109	def get_alphabet(self) -> list[Symbol]:
110		return [ Symbol(a) for a in self.alphabet ]
def to_dict(self) -> dict[str, any]:
117	def to_dict(self) -> dict[str,any]:
118
119		"""Create a dict representing the HMM configuration.
120
121		Returns:
122
123			dict[str,any]: Dictionary containing, name, description, states, transitions, alphabet, and isoclass.
124		"""
125
126		return {
127			"name"            : self.name,
128			"description"     : self.description,
129			"states"          : [ asdict(state)      for state      in self.states      ],
130			"transitions"     : [ asdict(transition) for transition in self.transitions ],
131			"alphabet"        : self.alphabet,
132			"isoclass"        : self.isoclass
133		}

Create a dict representing the HMM configuration.

Returns:

dict[str,any]: Dictionary containing, name, description, states, transitions, alphabet, and isoclass.

def from_dict(self, config: dict[str, any]):
135	def from_dict( self, config : dict[str,any] ) :
136
137		"""Configure the HMM configuration from a dictionary.
138
139		Args:
140			config (dict[str,any]): The HMM configuration.
141		"""
142
143		self.name            = config.name 
144		self.description     = config.description 
145		
146		self.set_states( states=[ 
147			CausalState( 
148				name=state[ "name" ],
149				classes=set( state[ "classes" ] )
150			)
151			for state in config.states
152		] )
153
154		self.set_transitions( transitions=[ 
155			Transition( 
156				origin_state_idx=tr[ "origin_state_idx" ],
157				target_state_idx=tr[ "target_state_idx" ],
158				prob=tr[ "prob" ],
159				symbol_idx=tr[ "symbol_idx" ]
160			)
161			for tr in config.transitions
162		] )
163
164		self.set_alphabet( alphabet=config.alphabet ) 

Configure the HMM configuration from a dictionary.

Arguments:
  • config (dict[str,any]): The HMM configuration.
def save_config( self, path: pathlib.Path, with_complexity: bool = False, with_block_convergence: bool = False, with_structural_properties: bool = False, with_causal_properties: bool = False):
166	def save_config(
167		self, 
168		path : Path, 
169		with_complexity : bool = False, 
170		with_block_convergence :  bool = False,
171		with_structural_properties : bool = False,
172		with_causal_properties : bool = False ) :
173
174		config = self.to_dict()
175
176		if with_complexity :
177			
178			complexity = self.get_complexities( 
179				with_block_convergence=with_block_convergence
180			)
181
182			config[ "complexity" ] = complexity
183
184		config[ "structural_properties" ] = {
185			"unifilar"              : self.is_unifilar(),
186			"row_stochastic"        : self.is_row_stochastic(),
187			"strongly_connected"    : self.is_strongly_connected(),
188			"aperiodic"             : self.is_aperiodic(),
189			"minimal"               : self._is_minimal_as_dfa( topological_only=False ),
190			"topologically_minimal" : self._is_minimal_as_dfa( topological_only=True ),
191			"is_epsilon_machine"    : self.is_epsilon_machine()
192		}
193
194		with open( path / "am_config.json", "w", encoding="utf-8" ) as f :
195			json.dump( config, f, ensure_ascii=False, indent=2, default=list )
def from_file(self, path: pathlib.Path):
197	def from_file( self, path : Path ) :
198		with open( Path / "am_config.json", "r" ) as f:
199			config = json.load(f)
200		self.from_dict()
def set_states(self, states: list[amachine.am_causal_state.CausalState]):
222	def set_states( self, states : list[CausalState] ) :
223		self._invalidate()
224		self.states = states.copy()
225		self.state_idx_map = {}
226		for idx, state in enumerate( self.states ) :
227			self.state_idx_map[ state.name ] = idx
def set_alphabet(self, alphabet: list[str]):
229	def set_alphabet( self, alphabet : list[str] ) :
230
231		self._invalidate()
232
233		old_alphabet = self.alphabet.copy()
234
235		aSet = set()
236		aSet.update( alphabet )
237
238		self.alphabet = sorted(list(aSet))
239
240		self.symbol_idx_map = {}
241		for idx, symbol in enumerate( self.alphabet ) :
242			self.symbol_idx_map[ symbol ] = idx
243
244		for i, tr in enumerate( self.transitions ) :
245			symbol = old_alphabet[ tr.symbol_idx ]
246			self.transitions[ i ] = Transition(
247				origin_state_idx=tr.origin_state_idx,
248				target_state_idx=tr.target_state_idx,
249				prob=tr.prob,
250				symbol_idx=self.symbol_idx_map[ symbol ]
251			)
def set_transitions(self, transitions: list[amachine.am_transition.Transition]):
253	def set_transitions( self, transitions : list[Transition] ) :
254		self._invalidate()
255		self.transitions = transitions.copy()
def extend_states(self, states: list[amachine.am_causal_state.CausalState]):
257	def extend_states( self, states : list[CausalState] ) :
258		self.set_states( self.states + states )
def extend_alphabet(self, alphabet: list[str]):
260	def extend_alphabet( self, alphabet : list[str] ) :
261		self.set_alphabet( self.alphabet + alphabet  )
def extend_transitions(self, transitions: list[amachine.am_transition.Transition]):
263	def extend_transitions( self, transitions : list[Transition] ) :
264		self.set_transitions( self.transitions + transitions  )
def get_complexity_measure_if_exists(self, measure):
266	def get_complexity_measure_if_exists(self, measure ) :
267		m = self.complexity.get( measure, None )
268		return m
def set_complexity_measure(self, measure, value):
270	def set_complexity_measure(self, measure, value ) :
271		self.complexity[ measure ] = value
def get_complexities(self, with_block_convergence=False):
277	def get_complexities( 
278		self, 
279		with_block_convergence=False ) :
280
281		directly_calculable = [
282			self.C_mu,
283			self.h_mu,
284			self.H_1,
285			self.rho_mu
286		]
287
288		requires_block_convergence = [
289			self.E, 
290			self.T_inf,
291			self.S,
292			self.chi
293		]
294			
295		complexities = { m.__name__ : m() for m in directly_calculable }
296
297		if with_block_convergence :
298
299			complexities |= { m.__name__ : m() for m in requires_block_convergence }
300
301			for key in [ 'H_L', 'T_L', 'h_mu_L', 'H_sync' ] :
302				if key in self.complexity :
303					complexities[ key ] = self.complexity[ key ]
304
305		return complexities
def get_metadata(self):
309	def get_metadata(self) :
310		return {
311			"name" : self.name,
312			'complexity' : self.complexity,
313			"description" : self.description
314		}
def get_transition_matrix(self):
316	def get_transition_matrix(self) :
317
318		if self.T  is not None :
319			return self.T
320
321		n_states = len( self.states )
322		T = np.zeros((n_states, n_states))
323
324		for tr in self.transitions :    
325			T[ tr.origin_state_idx, tr.target_state_idx  ] = tr.prob
326
327		self.T = T
328
329		return self.T
def get_T_X(self):
333	def get_T_X(self) :
334
335		if self.T_x  is not None :
336			return self.T_x
337
338		n_states  = len( self.states )
339		n_symbols = len( self.alphabet )
340
341		T_x = [ np.zeros((n_states, n_states)) for _ in range( n_symbols ) ]
342
343		for tr in self.transitions :
344			T_x[ tr.symbol_idx ][tr.origin_state_idx, tr.target_state_idx] = tr.prob
345
346		self.T_x = T_x
347		return self.T_x
def get_msp_qw(self, exact_state_cap: int = 1000, verbose: bool = True):
351	def get_msp_qw(
352		self,
353		exact_state_cap: int = 1000,
354		verbose: bool = True,
355	):
356		if self.msp is not None:
357			return self.msp
358
359		try : 
360
361			print( "\nTrying to Compute Mixed State Presentation using Exact Fractions\n" )
362
363			self.msp = compute_msp_exact(
364				T_x=self.get_Tx_fractional(),
365				pi=self.get_fractional_stationary_distribution(),
366				n_states=len(self.states),
367				alphabet=self.alphabet,
368				exact_state_cap=1000,
369				bool = True
370			)
371
372			return self.msp 
373
374		except RuntimeError as e :
375			warnings.warn( f"Exact msp failed: {e} Falling back to msp approximation." )
376
377		return self.get_msp()
def get_msp( self, exact_state_cap: int = 175000, jsd_eps: float = 1e-07, k_ann: int = 50, verbose=True):
379	def get_msp(
380		self,
381		exact_state_cap: int = 175_000,
382		jsd_eps:         float = 1e-7,
383		k_ann:           int   = 50,
384		verbose                = True,
385	) :
386
387		if self.msp is not None:
388			return self.msp
389	 
390		T_x = self.get_T_X()
391		pi  = self.get_stationary_distribution()
392
393		T_stacked      = np.stack(T_x)
394		n_symbols      = len(self.alphabet)
395		n_input_states = T_stacked.shape[1]
396	
397		print( "\nComputing Mixed State Presentation..." )
398
399		self.msp = compute_msp( 
400			T_x=T_x,
401			pi=pi,
402			n_states=len(self.states),
403			alphabet=self.alphabet,
404			exact_state_cap=exact_state_cap,
405			verbose=verbose
406		)
407
408		return self.msp
def get_reverse_am(self):
410	def get_reverse_am(self) :
411
412		if self.reverse_am is not None:
413			return self.reverse_am
414
415		pi = self.get_stationary_distribution()
416		self.reverse_am = copy.deepcopy(self)
417		
418		new_transitions = []
419		for tr in self.transitions:
420			i = tr.target_state_idx
421			j = tr.origin_state_idx
422			
423			p_reversed = (pi[j] * tr.prob) / pi[i]
424			
425			new_transitions.append(
426				Transition(
427					origin_state_idx=i,
428					target_state_idx=j,
429					prob=p_reversed,
430					symbol_idx=tr.symbol_idx
431				)
432			)
433
434		self.reverse_am.set_transitions(new_transitions)
435
436		if self.reverse_am.is_epsilon_machine():
437			return self.reverse_am
438
439		rmsp = self.reverse_am.get_msp_qw( exact_state_cap=len(self.states)*4 )
440
441		self.reverse_am.set_states( rmsp.states )
442		self.reverse_am.set_transitions( rmsp.transitions )
443		self.reverse_am.msp = rmsp
444		self.reverse_am.start_state = 0
445
446		self.reverse_am.collapse_to_largest_strongly_connected_subgraph()
447		self.reverse_am.minimize()
448
449		return self.reverse_am
def get_Tx_fractional(self) -> list[list[list[fractions.Fraction]]]:
453	def get_Tx_fractional(self) -> list[ list[ list[ Fraction ] ] ] :
454
455		self.to_q_weighted()
456
457		n_states  = len( self.states )
458		n_symbols = len( self.alphabet )
459
460		T_x = []
461
462		for x in range( n_symbols ) :
463			T_x.append( [] )
464			for i in range( n_states ) :
465				T_x[ x ].append( [ 0 for _ in range( n_states ) ] )
466
467		for tr in self.transitions :
468			T_x[ tr.symbol_idx ][ tr.origin_state_idx ][ tr.target_state_idx ] = tr.pq
469
470		return T_x
def get_T_sympy(self):
472	def get_T_sympy( self ) :
473
474		self.to_q_weighted()
475
476		n = len( self.states )
477		T = sympy.zeros( n, n )
478
479		for tr in self.transitions :
480			T[ tr.origin_state_idx, tr.target_state_idx ] = tr.pq
481
482		return T
def get_fractional_stationary_distribution(self):
484	def get_fractional_stationary_distribution(self) :
485
486		T = self.get_T_sympy()
487
488		if self.pi_fractional is not None :
489			return self.pi_fractional
490
491		G = self.as_digraph()
492
493		if not nx.is_strongly_connected(G):
494			raise ValueError( "Single stationary distribution requires strongly connected HMM." )
495
496		self.pi_fractional = solve_for_pi_fractional( T )
497
498		return self.pi_fractional
def get_stationary_distribution(self):
500	def get_stationary_distribution(self):
501
502		if self.pi is not None :
503			return self.pi
504
505		G = self.as_digraph()
506		
507		if not nx.is_strongly_connected(G):
508			raise ValueError( "Single stationary distribution requires strongly connected HMM." )
509
510		T = self.get_transition_matrix()
511		return solve_for_pi( T )		
def C_mu(self):
517	def C_mu( self ) :
518
519		"""The *statistical complexity* (aka *forecasting complexity*) :
520
521		.. math::
522
523			C_{\\mu} = - \\sum_{\\sigma \\in \\mathcal{S}} \\Pr(\\sigma) \\log_2 \\Pr(\\sigma),
524
525		where :math:`\\mathcal{S}` are the machine's states [^crutchfield_exact_2016], p.2.
526
527		.. note::
528
529			**Interpretations**
530
531			* The amount of historical information a process stores.
532			* The amount of structure in a process.
533
534		Returns:
535
536			float: :math:`C_{\\mu}`.
537
538		[^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral
539			Decomposition of Intrinsic Computation*, 2016.
540			<https://arxiv.org/abs/1309.3792>
541		"""
542
543		m = self.get_complexity_measure_if_exists( "C_mu" )
544
545		if m is not None :
546			return m
547
548		pi = self.get_stationary_distribution()
549
550		h = 0
551		for i, pr in enumerate( pi ) :
552			
553			if pr < self.EPS :
554				continue
555
556			h += -pr * np.log2( pr )
557
558		self.set_complexity_measure( "C_mu", h )
559
560		return h

The statistical complexity (aka forecasting complexity) :

$$C_{\mu} = - \sum_{\sigma \in \mathcal{S}} \Pr(\sigma) \log_2 \Pr(\sigma),$$

where \( \mathcal{S} \) are the machine's states 1, p.2.

Interpretations

  • The amount of historical information a process stores.
  • The amount of structure in a process.
Returns:

float: \( C_{\mu} \).


  1. Crutchfield et al., Exact Complexity: The Spectral Decomposition of Intrinsic Computation, 2016. https://arxiv.org/abs/1309.3792 

def h_mu(self):
564	def h_mu( self ) :
565
566		"""The *entropy rate* :
567
568		.. math::
569
570			h_{\\mu}(\\boldsymbol{\\mathcal{S}}) = - \\sum_{\\sigma \\in \\mathcal{S}} \\Pr(\\sigma) \\sum_{x \\in \\mathcal{A}} \\Pr(x|\\sigma) \\log_2 \\Pr(x|\\sigma),
571
572		where :math:`\\mathcal{A}` is the alphabet and :math:`\\mathcal{S}` are the machine's states [^crutchfield_exact_2016], p.2.
573
574		.. note::
575
576			**Interpretations**
577
578			* The lower bound on achievable loss in bits. 
579			* The irreducable randomness in the process.
580			* The intrinsic Randomness in the process.
581
582		Returns:
583			
584			float: :math:`h_{\\mu}`.
585
586		[^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral
587			Decomposition of Intrinsic Computation*, 2016.
588			<https://arxiv.org/abs/1309.3792>
589		"""
590
591		m = self.get_complexity_measure_if_exists( "h_mu" )
592
593		if m is not None :
594			return m
595
596		T  = self.get_transition_matrix()
597		pi = self.get_stationary_distribution()
598
599		n_states = pi.size
600
601		h = 0
602		for i, pr in enumerate( pi ) :
603
604			if pr < self.EPS :
605				continue
606
607			row_entropy = 0
608			for j in range( len( pi ) ) :
609
610				if T[ i, j ]  < self.EPS :
611					continue
612
613				row_entropy -= T[ i, j ] * np.log2( T[ i, j ] )
614
615			h += pr * row_entropy
616
617		self.set_complexity_measure( "h_mu", h )
618
619		return h

The entropy rate :

$$h_{\mu}(\boldsymbol{\mathcal{S}}) = - \sum_{\sigma \in \mathcal{S}} \Pr(\sigma) \sum_{x \in \mathcal{A}} \Pr(x|\sigma) \log_2 \Pr(x|\sigma),$$

where \( \mathcal{A} \) is the alphabet and \( \mathcal{S} \) are the machine's states 1, p.2.

Interpretations

  • The lower bound on achievable loss in bits.
  • The irreducable randomness in the process.
  • The intrinsic Randomness in the process.
Returns:

float: \( h_{\mu} \).


  1. Crutchfield et al., Exact Complexity: The Spectral Decomposition of Intrinsic Computation, 2016. https://arxiv.org/abs/1309.3792 

def H_1(self) -> float:
623	def H_1(self) -> float :
624
625		"""The *single symbol uncertainty*:
626
627		.. math::
628
629			H(1)=-\\sum_{x\\in\\mathcal{A}} \\Pr(x) \\log_2{\\Pr(x)},
630
631		where :math:`\\mathcal{A}` is the alphabet [^James_2018], p.2.
632
633		.. note::
634
635			**Interpretations**
636
637			* How uncertain you are on average about a single measurement with no context.
638
639		Returns:
640
641			float: :math:`H(1)`.
642
643		[^James_2018]: James et al., Anatomy of a Bit: Information in a Time Series Observation, 2018.
644			<https://arxiv.org/abs/1105.2988>
645		"""
646
647		m = self.get_complexity_measure_if_exists("H_1")
648		if m is not None:
649			return m
650
651		pi  = self.get_stationary_distribution()
652		T_X = self.get_T_X()  # dict: symbol -> matrix
653
654		h = 0.0
655		for T_x in T_X:
656			# Pr(x) = sum_i pi[i] * sum_j T^(x)[i,j]
657			p_sym = 0.0
658			for i, pr in enumerate(pi):
659				if pr < self.EPS:
660					continue
661				p_sym += pr * T_x[i, :].sum()
662
663			if p_sym < self.EPS:
664				continue
665			h -= p_sym * np.log2(p_sym)
666
667		self.set_complexity_measure("H_1", h)
668		return h

The single symbol uncertainty:

$$H(1)=-\sum_{x\in\mathcal{A}} \Pr(x) \log_2{\Pr(x)},$$

where \( \mathcal{A} \) is the alphabet 1, p.2.

Interpretations

  • How uncertain you are on average about a single measurement with no context.
Returns:

float: \( H(1) \).


  1. James et al., Anatomy of a Bit: Information in a Time Series Observation, 2018. https://arxiv.org/abs/1105.2988 

def rho_mu(self) -> float:
672	def rho_mu(self) -> float :
673		
674		"""The *anticipated information* [^James_2018], p.3.:
675
676		.. math::
677
678			\\rho_{\\mu}= H(1) - h_{\\mu}
679
680		Returns:
681			
682			float: :math:`\\rho_{\\mu}`
683
684		[^James_2018]: James et al., Anatomy of a Bit: Information in a Time Series Observation, 2018.
685			<https://arxiv.org/abs/1105.2988>
686		"""
687
688		m = self.get_complexity_measure_if_exists("rho_mu")
689		
690		if m is not None:
691			return m
692
693		rho = self.H_1() - self.h_mu()
694		
695		self.set_complexity_measure("rho_mu", rho)
696		
697		return rho

The anticipated information 1, p.3.:

$$\rho_{\mu}= H(1) - h_{\mu}$$

Returns:

float: \( \rho_{\mu} \)


  1. James et al., Anatomy of a Bit: Information in a Time Series Observation, 2018. https://arxiv.org/abs/1105.2988 

def block_convergence(self):
701	def block_convergence( self )  :
702
703		"""
704		"Run block entropy convergence and return a dict with $\\mathbf{E}, \\mathbf{S}, $\\mathbf{T},\\mathbf{T}(L), \\mathcal{H}(L),$ and $h_{\\mu}(L)$.
705		"""
706
707		trs = [ [] for _ in range( len( self.states ) ) ]
708		for tr in self.transitions :
709			trs[ tr.origin_state_idx ].append( ( 
710				tr.symbol_idx, 
711				float( tr.prob ),
712				tr.target_state_idx ) )
713
714		pi = self.get_stationary_distribution()
715
716		state_dist = [ float( pi[ i ] ) for i in range( len( self.states ) ) ]
717		branches = [(1.0, list(state_dist))]
718
719		print( "\nComputing Block Entropy\n" )
720
721		C = am_fast.block_entropy_convergence(
722			h_mu            = self.h_mu(),
723			n_states        = len( self.states ),
724			n_symbols       = len( self.alphabet ),
725			convergence_tol = 1e-6,
726			precision       = 10,
727			eps             = 1e-25,
728			branches        = branches,
729			trans           = trs,
730			max_branches    = 30_000_000
731		)
732
733		print( "Done\n" )
734
735		self.set_complexity_measure( f"E",          C.E )
736		self.set_complexity_measure( f"S",          C.S )
737		self.set_complexity_measure( f"T_inf",      C.T )
738		self.set_complexity_measure( f"T_L",        C.T_L.tolist() )
739		self.set_complexity_measure( f"H_L",        C.H_L.tolist() )
740		self.set_complexity_measure( f"h_mu_L",  C.h_mu_L.tolist() )
741		self.set_complexity_measure( f"H_sync",  C.H_sync.tolist() )
742
743		return C

"Run block entropy convergence and return a dict with $\mathbf{E}, \mathbf{S}, $\mathbf{T},\mathbf{T}(L), \mathcal{H}(L),$ and $h_{\mu}(L)$.

def E(self) -> float:
747	def E( self ) -> float :
748
749		"""The *excess entropy* [^crutchfield_exact_2016], p.4:
750
751		.. math::
752
753			\\mathbf{E} \\equiv \\sum_{L=1}^{\\infty} I[X_{-\\infty:0}; X_{0:\\infty}]
754		
755		Computed via :meth:`get_msp` and :meth:`amachine.am_msp.MSP.get_E_S_T`, or :meth:`amachine.am_fast.block_entropy_convergence`
756
757		.. note::
758
759			**Interpretations**
760
761			* The information from the past that reduces uncertainty in the future [^crutchfield_exact_2016].
762			* How much information an observer must extract to synchronize to the process.
763			* Measures how long the process appears more complex than it asymptotically is.
764			* Vanishes for immediately synchronizable processes.
765
766		Returns:
767		
768			float: :math:`\\mathbf{E}`
769
770		[^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral
771			Decomposition of Intrinsic Computation*, 2016.
772			<https://arxiv.org/abs/1309.3792>
773		"""
774
775		m = self.get_complexity_measure_if_exists( "E" )
776
777		if m is not None :
778			return m
779
780		try : 
781			msp = self.get_msp()
782			E, S, T = msp.get_E_S_T()
783			self.set_complexity_measure( "E", E )
784			self.set_complexity_measure( "S", S )
785			self.set_complexity_measure( "T_inf", T )
786			
787		except Exception as e :
788
789			print( f"MSP failed {e}" )
790
791			C = self.block_convergence()	
792			E = C.E
793			self.set_complexity_measure( "E", E )
794
795		return E

The excess entropy 1, p.4:

$$\mathbf{E} \equiv \sum_{L=1}^{\infty} I[X_{-\infty:0}; X_{0:\infty}]$$

Computed via get_msp() and amachine.am_msp.MSP.get_E_S_T(), or amachine.am_fast.block_entropy_convergence()

Interpretations

  • The information from the past that reduces uncertainty in the future 1.
  • How much information an observer must extract to synchronize to the process.
  • Measures how long the process appears more complex than it asymptotically is.
  • Vanishes for immediately synchronizable processes.
Returns:

float: \( \mathbf{E} \)


  1. Crutchfield et al., Exact Complexity: The Spectral Decomposition of Intrinsic Computation, 2016. https://arxiv.org/abs/1309.3792 

def S(self) -> float:
799	def S( self ) -> float :
800
801		"""The *synchronization* information:
802
803		.. math::
804
805			\\mathbf{S} \\equiv \\sum_{L=1}^{\\infty} \\mathcal{H}(L),
806
807		where :math:`\\mathcal{H}(L)` is the average state uncertainty having seen all length-L words [^crutchfield_exact_2016], p.4.
808
809		.. note::
810
811			**Interpretations**
812
813			* The total amount of state information that an observer must extract to become synchronized [^crutchfield_exact_2016].
814
815		Computed via :meth:`get_msp` and :meth:`amachine.am_msp.MSP.get_E_S_T`, or :meth:`amachine.am_fast.block_entropy_convergence`
816
817		Returns:
818		
819			float: :math:`\\mathbf{S}`
820
821		[^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral
822			Decomposition of Intrinsic Computation*, 2016.
823			<https://arxiv.org/abs/1309.3792>
824		"""
825
826		m = self.get_complexity_measure_if_exists( "S" )
827
828		if m is not None :
829			return m
830
831		try : 
832			msp = self.get_msp()
833			E, S, T = msp.get_E_S_T()
834			self.set_complexity_measure( "E", E )
835			self.set_complexity_measure( "S", S )
836			self.set_complexity_measure( "T_inf", T )
837
838		except Exception as e :
839			print( f"{e} \nFalling back to iterative estimation.")
840			exit()
841
842			C = self.block_convergence()	
843			S = C.S
844			self.set_complexity_measure( "S", S )
845
846		return S

The synchronization information:

$$\mathbf{S} \equiv \sum_{L=1}^{\infty} \mathcal{H}(L),$$

where \( \mathcal{H}(L) \) is the average state uncertainty having seen all length-L words 1, p.4.

Interpretations

  • The total amount of state information that an observer must extract to become synchronized 1.

Computed via get_msp() and amachine.am_msp.MSP.get_E_S_T(), or amachine.am_fast.block_entropy_convergence()

Returns:

float: \( \mathbf{S} \)


  1. Crutchfield et al., Exact Complexity: The Spectral Decomposition of Intrinsic Computation, 2016. https://arxiv.org/abs/1309.3792 

def T_inf(self) -> float:
850	def T_inf( self ) -> float :
851
852		"""The *transient information*[^crutchfield_exact_2016], p.4:
853
854		.. math::
855
856			\\mathbf{T} \\equiv \\sum_{L=1}^{\\infty} L \\left[ h_{\\mu}(L) - h_{\\mu} \\right]
857
858		Computed via :meth:`get_msp` and :meth:`amachine.am_msp.MSP.get_E_S_T`, or :meth:`amachine.am_fast.block_entropy_convergence`
859
860		.. note::
861
862			**Interpretations**
863
864			* The amount of information one must extract from observations so that the block entropy converges to its linear asymptote[^crutchfield_exact_2016].
865
866		Returns:
867		
868			float: :math:`\\mathbf{T}`
869
870		[^crutchfield_exact_2016]: Crutchfield et al., *Exact Complexity: The Spectral
871			Decomposition of Intrinsic Computation*, 2016.
872			<https://arxiv.org/abs/1309.3792>
873		"""
874
875		m = self.get_complexity_measure_if_exists( "T_inf" )
876
877		if m is not None :
878			return m
879
880		try : 
881			msp = self.get_msp()
882			E, S, T = msp.get_E_S_T()
883			self.set_complexity_measure( "E", E )
884			self.set_complexity_measure( "S", S )
885			self.set_complexity_measure( "T_inf", T )
886
887		except Exception as e :
888			print( f"{e} \nFalling back to iterative estimation.")
889			C = self.block_convergence()	
890			T_inf = C.T
891			self.set_complexity_measure( "T_inf", T_inf )
892
893		return T_inf

The transient information1, p.4:

$$\mathbf{T} \equiv \sum_{L=1}^{\infty} L \left[ h_{\mu}(L) - h_{\mu} \right]$$

Computed via get_msp() and amachine.am_msp.MSP.get_E_S_T(), or amachine.am_fast.block_entropy_convergence()

Interpretations

  • The amount of information one must extract from observations so that the block entropy converges to its linear asymptote1.
Returns:

float: \( \mathbf{T} \)


  1. Crutchfield et al., Exact Complexity: The Spectral Decomposition of Intrinsic Computation, 2016. https://arxiv.org/abs/1309.3792 

def chi(self) -> float:
897	def chi( self ) -> float :
898
899		"""Computes the foward crypticity[^crutchfield_crypticity_2009][^Mahoney_crypticity_2021], p.2:
900
901		.. math::
902
903			\\chi = C_{\\mu} - \\mathbf{E}
904
905		:math:`C_{\\mu}` is trivially computed from the stationary distribution in :meth:`C_mu` and :math:`\\mathbf{E}` in :meth:`E`.
906
907		.. note::
908
909			**Interpretations**
910
911			* Difference between internal stored information and apparent information to an observer.
912			* How muching information is hiding in the system.
913
914		Returns:
915		
916			float: :math:`\\chi`
917
918		[^crutchfield_crypticity_2009]: Crutchfield et al., Time’s barbed arrow: Irreversibility, crypticity, and stored information, 2009.
919			<https://arxiv.org/abs/0902.1209>
920
921		[^Mahoney_crypticity_2021]: Mahoney et al., Information Accessibility and Cryptic Processes, 2021.
922			<https://arxiv.org/abs/0905.4787>
923		"""
924
925		m = self.get_complexity_measure_if_exists( "chi" )
926
927		if m is not None :
928			return m
929
930		chi = self.C_mu() - self.E()
931
932		if chi < 0 :
933			
934			# if chi is 0, accumulated floating point error can result in small negative values
935			if chi < -1e-5:
936				warnings.warn(f"Crypticity is negative ({chi:.6e}).")
937			
938			chi = np.clamp( chi, 0 )
939
940		self.set_complexity_measure( "chi", chi )
941
942		return chi

Computes the foward crypticity12, p.2:

$$\chi = C_{\mu} - \mathbf{E}$$

\( C_{\mu} \) is trivially computed from the stationary distribution in C_mu() and \( \mathbf{E} \) in E().

Interpretations

  • Difference between internal stored information and apparent information to an observer.
  • How muching information is hiding in the system.
Returns:

float: \( \chi \)


  1. Crutchfield et al., Time’s barbed arrow: Irreversibility, crypticity, and stored information, 2009. https://arxiv.org/abs/0902.1209 

  2. Mahoney et al., Information Accessibility and Cryptic Processes, 2021. https://arxiv.org/abs/0905.4787 

def is_row_stochastic(self):
948	def is_row_stochastic(self) :
949
950		"""
951		Check that all states have outgoing transition probabilities that sum to 1.
952		"""
953
954		sums = np.zeros( len( self.states ) )
955		for tr in self.transitions :
956			sums[ tr.origin_state_idx ] += tr.prob
957		return np.allclose( sums, 1.0 )

Check that all states have outgoing transition probabilities that sum to 1.

def is_unifilar(self):
961	def is_unifilar(self) :
962
963		"""
964		Check that no state emits the same symbol on transitions to different states. 
965		"""
966
967		symbol_trs = np.full( ( len( self.states ), len( self.alphabet) ), -1 )
968
969		for tr in self.transitions : 
970		
971			if symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] == -1 :
972				symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] = tr.target_state_idx
973		
974			elif symbol_trs[ tr.origin_state_idx, tr.symbol_idx ] != tr.target_state_idx :
975				return False
976		
977		return True

Check that no state emits the same symbol on transitions to different states.

def is_strongly_connected(self):
981	def is_strongly_connected(self) :
982
983		"""
984		Check if every state is reachable from every other state. Relies on [nx.is_strongly_connected](https://networkx.org/documentation/latest/reference/algorithms/generated/networkx.algorithms.components.is_strongly_connected.html).
985		"""
986
987		return nx.is_strongly_connected( self.as_digraph() )

Check if every state is reachable from every other state. Relies on nx.is_strongly_connected.

def is_aperiodic(self):
991	def is_aperiodic(self) :
992
993		"""
994		Checks if machine is periodic. Relies on [nx.is_aperiodic](https://networkx.org/documentation/latest/reference/algorithms/generated/networkx.algorithms.dag.is_aperiodic.html), "A strongly connected directed graph is aperiodic if there is no integer k > 1 that divides the length of every cycle in the graph."
995		"""
996
997		return nx.is_aperiodic( self.as_digraph() )

Checks if machine is periodic. Relies on nx.is_aperiodic, "A strongly connected directed graph is aperiodic if there is no integer k > 1 that divides the length of every cycle in the graph."

def is_topological_epsilon_machine(self, verbose=True):
1020	def is_topological_epsilon_machine( self, verbose=True ) :
1021
1022		"""
1023		Checks if the HMM is a topological $\\epsilon$-machine [^1].
1024
1025		[^1]: Johnson et al, Enumerating Finitary Processes, 2024.
1026			<https://arxiv.org/abs/1011.0036>
1027		"""
1028
1029		if not ( self.is_unifilar() and self.is_strongly_connected() ) :
1030			if verbose : 
1031				print( f"Either non unifilar or not strongly connected" )
1032			return False
1033		else :
1034			return self._is_minimal_as_dfa( topological_only=True, verbose=verbose )

Checks if the HMM is a topological $\epsilon$-machine 1.


  1. Johnson et al, Enumerating Finitary Processes, 2024. https://arxiv.org/abs/1011.0036 

def is_epsilon_machine(self, verbose=True):
1036	def is_epsilon_machine( self, verbose=True ) :
1037
1038		if not ( self.is_unifilar() and self.is_strongly_connected() ) :
1039			if verbose : 
1040				print( f"Either non unifilar or not strongly connected" )
1041			return False
1042		else :
1043			return self._is_minimal_as_dfa( topological_only=False, verbose=verbose )
def minimize(self, retain_names: bool = True, verbose=False):
1049	def minimize(self, retain_names: bool = True, verbose=False):
1050
1051		"""
1052		Minimizes the HMM, resulting in an :math:`\\epsilon-`machine if the HMM
1053		is unifilar and strongly connected. Converts the HMM to a DFA with symbols
1054		labeled jointly with symbols and probabilities, and uses Myhill-Nerode 
1055		equivalence for minimization. Relies on `automata_lib` and uses
1056		 `automata.fa.dfa.DFA.minify` with `allow_partial=True`, and all states
1057		 final.
1058
1059		Args:
1060			retain_names (bool): If `True`, the merged states will be named by their union, e.g. `{s_0, s_1}`, and other states will retain their origion names. Otherwise, they will be relabled `{ '0', '1', ..., 'n-1' }`.
1061
1062		Returns:
1063		
1064			automata.fa.dfa.DFA : the resulting DFA.
1065		"""
1066
1067		if self.is_minimal :
1068			return
1069
1070		start = time.perf_counter()
1071
1072		if not self.is_unifilar():
1073			raise ValueError(
1074				"DFA minimization is not valid for non-unifilar HMMs"
1075			)
1076
1077		was_strongly_connected = self.is_strongly_connected()
1078
1079		was_row_stochastic = self.is_row_stochastic()
1080		n_states_before = len(self.states)
1081
1082		dfa = self.as_dfa(with_probs=True)
1083
1084		#min_dfa = self.as_dfa(with_probs=True).minify(retain_names=True)
1085		min_dfa = am_fast.minify_cpp( dfa, retain_names=True )
1086
1087		# Build lookup from original state index -> CausalState object
1088		orig_state   = {i: s for i, s in enumerate(self.states)}
1089		eq_list      = list(min_dfa.states)
1090
1091		start_eq = min_dfa.initial_state
1092		
1093		# Separate the start state, then sort the rest by the 
1094		# smallest original state index inside each equivalence class.
1095		other_eqs = [eq for eq in eq_list if eq != start_eq]
1096		other_eqs.sort(key=lambda eq: min(eq))
1097
1098		# Recombine so start eq comes first, followed by the sorted remaining classes
1099		eq_list = [start_eq] + other_eqs
1100		# ----------------------------------------------------------
1101
1102		# Recompute eq_to_idx with the new ordering
1103		eq_to_idx = {eq: i for i, eq in enumerate(eq_list)}
1104
1105		# new_start is now guaranteed to be 0
1106		new_start = 0
1107
1108		# Map each original state index -> its equivalence class
1109		# Guard: minify() silently drops unreachable states
1110		orig_to_eq = {s: eq for eq in min_dfa.states for s in eq}
1111
1112		# Build lookup from original state index -> its transitions
1113		orig_trs = defaultdict(list)
1114		for t in self.transitions:
1115			orig_trs[t.origin_state_idx].append(t)
1116
1117		new_trs = []
1118		for eq in min_dfa.states:
1119			rep        = next(iter(eq))
1120			origin_idx = eq_to_idx[eq]
1121			for t in orig_trs[rep]:
1122				target_eq  = orig_to_eq[t.target_state_idx]
1123				target_idx = eq_to_idx[target_eq]
1124				new_trs.append(Transition(
1125					origin_state_idx = origin_idx,
1126					target_state_idx = target_idx,
1127					prob             = t.prob,
1128					symbol_idx       = t.symbol_idx,
1129				))
1130
1131		members_list = [[orig_state[i] for i in sorted(eq)] for eq in eq_list]  # sorted for determinism
1132
1133		# Compute new names
1134		if retain_names:
1135			new_names = [
1136				"{" + ",".join(str(m.name) for m in members) + "}" if len(members) > 1
1137				else members[0].name
1138				for members in members_list
1139			]
1140		else:
1141			new_names = [str(j) for j in range(len(eq_list))]
1142
1143		old_name_to_new_name = {
1144			m.name: new_names[j]
1145			for j, members in enumerate(members_list)
1146			for m in members
1147		}
1148
1149		# Build the new states, preserving classes and isomorphs regardless of naming
1150		new_states = []
1151		for j, (eq, members, name) in enumerate(zip(eq_list, members_list, new_names)):
1152			classes   = set().union(*(m.classes for m in members))
1153			isomorphs = {
1154				old_name_to_new_name.get(iso, iso)
1155				for m in members
1156				for iso in m.isomorphs
1157				if old_name_to_new_name.get(iso, iso) != name
1158			}
1159			new_states.append(CausalState(
1160				name      = name,
1161				classes   = classes,
1162				isomorphs = isomorphs,
1163			))
1164
1165		self.set_states(new_states)
1166		self.set_transitions(new_trs)
1167		self.start_state = new_start
1168
1169		if n_states_before == len(new_states) and verbose :
1170			print( f"{n_states_before} state HMM was already minimal.\n" )
1171		elif verbose :
1172			print( f"Minimized from {n_states_before} to {len(new_states)}\n" )
1173
1174		if not ( was_strongly_connected ==  self.is_strongly_connected() ) :
1175			raise RuntimeError(
1176				f"Minimization broke strongly connected"
1177			)
1178
1179		if not ( was_row_stochastic ==  self.is_row_stochastic() ) :
1180			raise RuntimeError(
1181				f"Minimization broke row stochasticity"
1182			)
1183
1184		self.is_minimal = True

Minimizes the HMM, resulting in an \( \epsilon- \)machine if the HMM is unifilar and strongly connected. Converts the HMM to a DFA with symbols labeled jointly with symbols and probabilities, and uses Myhill-Nerode equivalence for minimization. Relies on automata_lib and uses automata.fa.dfa.DFA.minify with allow_partial=True, and all states final.

Arguments:
  • retain_names (bool): If True, the merged states will be named by their union, e.g. {s_0, s_1}, and other states will retain their origion names. Otherwise, they will be relabled { '0', '1', ..., 'n-1' }.
Returns:

automata.fa.dfa.DFA : the resulting DFA.

def collapse_to_largest_strongly_connected_subgraph(self, rename_states=True):
1191	def collapse_to_largest_strongly_connected_subgraph( self, rename_states=True ) :
1192
1193		# get equivalent networkx graph
1194		G = self.as_digraph()
1195
1196		# if already strongly connected, nothing to do
1197		if not nx.is_strongly_connected( G ) :
1198
1199			start = time.perf_counter()
1200			subgraph_nodes = list( nx.strongly_connected_components( G ) )
1201
1202			# decompose into strongly connected components and sort by length
1203			# subgraph_nodes = list(nx.strongly_connected_components( G ))
1204			subgraph_nodes.sort(key=len)
1205			component_state_set = subgraph_nodes[-1]
1206
1207			# Take the largest strongly connected component (as list of state names)
1208			component_states = sorted( list( component_state_set ) )
1209
1210			# make temporary copies of the old transitions and states
1211			old_transitions = [
1212				Transition(
1213					origin_state_idx=tr.origin_state_idx,
1214					target_state_idx=tr.target_state_idx,
1215					prob=tr.prob,
1216					symbol_idx=tr.symbol_idx
1217				)
1218
1219				for tr in self.transitions
1220			]
1221
1222			old_states = [
1223				CausalState(
1224					name=s.name,
1225					classes=s.classes,
1226					isomorphs=s.isomorphs
1227				)
1228				for s in self.states
1229			]
1230
1231			self.set_states(
1232				states=[ 
1233					state
1234					for i, state in enumerate( old_states ) if i in component_state_set
1235				]
1236			)
1237
1238			# we will build new transition list based on those belonging to the component
1239			self.set_transitions( transitions= [] )
1240
1241			# for tracking which new transitions leave each state
1242			transitions_from_state = { state : set() for state in component_states }
1243			new_transitions = []
1244
1245			for tr in old_transitions :
1246
1247				origin_state_name = old_states[ tr.origin_state_idx ].name
1248				target_state_name = old_states[ tr.target_state_idx ].name
1249
1250				# skip transitions that connect separate strongly connected components
1251				if not ( tr.origin_state_idx in component_state_set and tr.target_state_idx in component_state_set ) :
1252					continue
1253
1254				# track transitions (by index in new transitions list) that leave this state
1255				transitions_from_state[ tr.origin_state_idx ].add( len( new_transitions ) )
1256
1257				my_origin_state_idx = self.state_idx_map[ origin_state_name ]
1258				my_target_state_idx = self.state_idx_map[ target_state_name ]
1259
1260				new_transitions.append( 
1261					Transition(
1262						origin_state_idx=my_origin_state_idx,
1263						target_state_idx=my_target_state_idx,
1264						prob=tr.prob,
1265						symbol_idx=tr.symbol_idx
1266					)  )
1267
1268			self.set_transitions( transitions=new_transitions )
1269
1270			# if we removed an outgoing transition from a state, we need to distribute its probability 
1271			# among the remaining outgoing transitions from the state
1272			for state in component_states :
1273				
1274				# get the set of transitions leaving this state
1275				state_trs = transitions_from_state[ state ]
1276
1277				# sum the probabilities of the outgoing transitions from the state
1278				p_sum = np.sum( [ self.transitions[ i ].prob for i in state_trs ] )
1279
1280				# how much probability is missing
1281				diff = 1.0 - p_sum
1282
1283				# if significant difference
1284				if abs( diff ) > self.EPS : 
1285
1286					# calculate how much of the difference each transition gets
1287					adjustment = diff / len( state_trs )
1288					
1289					# update the transitions
1290					for i in state_trs :
1291
1292						# transition with origional probability
1293						tr = self.transitions[ i ]
1294
1295						# adjusted probability
1296						self.transitions[ i ] = Transition(
1297							origin_state_idx=tr.origin_state_idx, 
1298							target_state_idx=tr.target_state_idx, 
1299							prob=tr.prob + adjustment, 
1300							symbol_idx=tr.symbol_idx )
1301
1302			if rename_states :
1303				self.set_states( [
1304					CausalState( name=f"{i}" )
1305					for i, s in enumerate( self.states )
1306				] )
def to_q_weighted(self, denominator_limit=1000):
1309	def to_q_weighted( self, denominator_limit=1000 ) :
1310
1311		"""
1312		Approximates the existing transition probabilities with exact fractions, stores 
1313		the fractional probabilities as Fraction in Transition.pq, and sets the floating
1314		point probabilty to `float(pq)`. If `denominator_limit` is too small for a sane 
1315		conversion, the function recurses with `denominator_limit=denominator_limit*10`.
1316
1317		Args:
1318			denominator_limit (int): The initial input to :meth:`Fraction.limit_denominator` in 
1319			the conversion.
1320		"""
1321
1322		if self.is_q_weighted :
1323			return
1324
1325		if not self.is_row_stochastic() :
1326			raise ValueError( "Cannot convert to q-weighted because not row stochastic" )
1327
1328		t_from = [[] for _ in range(len(self.states))]
1329
1330		for i, tr in enumerate( self.transitions ) :
1331			t_from[ tr.origin_state_idx ].append( i )
1332
1333		new_transitions = []
1334		for t_list in t_from :
1335			
1336			if not t_list : 
1337				continue
1338
1339			p_q_sum = Fraction(0,1)
1340			p_qs = []
1341
1342			for t_idx in t_list :
1343			
1344				p_q = Fraction( self.transitions[ t_idx ].prob ).limit_denominator( denominator_limit )
1345				p_q_sum += p_q
1346				p_qs.append( p_q )
1347
1348			if p_q_sum != Fraction(1,1) :
1349				
1350				max_pq_i = np.argmax( p_qs )
1351				max_oq = p_qs[ max_pq_i ]
1352
1353				diff = p_q_sum - Fraction(1,1)
1354
1355				# If recurse with higher resolution
1356				if diff > max_oq :
1357					return self.to_q_weighted( denominator_limit*10 )
1358				else :
1359					p_qs[ max_pq_i ] -= diff
1360
1361			for i, t_idx in enumerate( t_list ) :	
1362				new_transitions.append( 
1363					Transition( 
1364						origin_state_idx=self.transitions[  t_idx ].origin_state_idx,
1365						target_state_idx=self.transitions[  t_idx ].target_state_idx,
1366						prob=float(p_qs[ i ]),
1367						symbol_idx=self.transitions[  t_idx ].symbol_idx,
1368						pq=p_qs[ i ]
1369					)
1370				)
1371
1372		self.set_transitions( new_transitions )
1373		self.is_q_weighted = True

Approximates the existing transition probabilities with exact fractions, stores the fractional probabilities as Fraction in Transition.pq, and sets the floating point probabilty to float(pq). If denominator_limit is too small for a sane conversion, the function recurses with denominator_limit=denominator_limit*10.

Arguments:
  • denominator_limit (int): The initial input to Fraction.limit_denominator() in
  • the conversion.
def isomorphic_shift( self, input_symbol_indices: numpy.ndarray, input_state_indices: numpy.ndarray, shift: int = 1) -> dict[str, numpy.ndarray]:
1379	def isomorphic_shift(
1380		self,
1381		input_symbol_indices: np.ndarray,
1382		input_state_indices:  np.ndarray,
1383		shift : int = 1
1384	) -> dict[str, np.ndarray]:
1385
1386		"""
1387		Generates a new sequence of of symbols that are permuted with the symbols emitted by
1388		isomorphic states, if they exists.
1389
1390		:math:`\\sigma_o = \\mathcal{S}\\left[\\texttt{input\\_state\\_indices}[i]\\right]`<br>
1391		:math:`\\sigma_t = \\mathcal{S}\\left[\\texttt{input\\_state\\_indices}[i+1]\\right]`
1392 
1393		:math:`\\mathcal{I}(\\sigma_o) = \\\\{\\sigma^0_o,\\, \\sigma^1_o,\\, \\dots,\\, \\sigma^{n-1}_o \\\\}`<br>
1394		:math:`\\mathcal{I}(\\sigma_t) = \\\\{\\sigma^0_t,\\, \\sigma^1_t,\\, \\dots,\\, \\sigma^{n-1}_t \\\\}`
1395 
1396		:math:`k = \\bigl(\\mathcal{I}(\\sigma_o).\\texttt{index}(\\sigma_o) + \\texttt{shift}\\bigr) \\bmod n`
1397 
1398		:math:`\\texttt{output\\_symbol\\_indices}[i]   := T(\\sigma_o^k,\\, \\sigma_t^k).\\text{symbol\\_index}`<br>
1399		:math:`\\texttt{output\\_state\\_indices}[i]    := \\mathcal{S}.\\texttt{index}(\\sigma_o^k)`<br>
1400		:math:`\\texttt{output\\_state\\_indices}[i+1]  := \\mathcal{S}.\\texttt{index}(\\sigma_t^k)`
1401
1402		Where :math:`\\mathcal{I}(\\sigma)`` is the ordered set of states isomorphic to :math:`\\sigma` including :math:`\\sigma` itself.
1403
1404		Args:
1405			input_symbol_indices (np.ndarray): The sequence of generated symbols.
1406			input_state_indices (np.ndarray): The sequence of states that generated symbols with the final state at the end.
1407			shift : int: How much to shift the symbols across the isomorphic states.
1408		"""
1409
1410		if not any( state.isomorphs for state in self.states ):
1411			raise ValueError("HMM has no states with isomorphs")
1412
1413		inputs = np.asarray(input_symbol_indices)
1414		states = np.asarray(input_state_indices)
1415
1416		n_states = len(self.states)
1417
1418		tr_sym_table = np.full((n_states, n_states), -1, dtype=np.int32)
1419
1420		for tr in self.transitions:
1421			tr_sym_table[ tr.origin_state_idx, tr.target_state_idx ] = tr.symbol_idx
1422
1423		# Build isomorph remapping: identity by default, overridden where isomorphs exist
1424		iso_table = np.arange(n_states, dtype=np.int32)
1425
1426		for i, state in enumerate(self.states):
1427			if state.isomorphs:
1428				isomorph       = sorted(state.isomorphs)[0]
1429				iso_table[ i ] = self.state_idx_map[isomorph]
1430
1431		for i, state in enumerate(self.states):
1432			if state.isomorphs:
1433
1434				# extend the isomorph list to include the identity
1435				isormorphs_with_identity = sorted( [ i ] + [ self.state_idx_map[iso] for iso in state.isomorphs ] )
1436
1437				# find the identity index
1438				pos = isormorphs_with_identity.index( i )
1439
1440				# cyclical shift 
1441				iso_table[ i ] = isormorphs_with_identity[ ( pos + shift ) % len( isormorphs_with_identity ) ]
1442
1443		origins = states[:-1]
1444		targets = states[1:]
1445
1446		out_origins = iso_table[origins]
1447		out_targets = iso_table[targets]
1448
1449		inv_sym = tr_sym_table[ out_origins, out_targets ]
1450		inv_sts = np.empty(states.size, dtype=states.dtype)
1451
1452		inv_sts[:-1] = out_origins
1453		inv_sts[-1]  = out_targets[-1]
1454
1455		return {
1456			"symbol_index": inv_sym.astype(inputs.dtype),
1457			"state_index":  inv_sts,
1458		}

Generates a new sequence of of symbols that are permuted with the symbols emitted by isomorphic states, if they exists.

\( \sigma_o = \mathcal{S}\left[\texttt{input_state_indices}[i]\right] \)
\( \sigma_t = \mathcal{S}\left[\texttt{input_state_indices}[i+1]\right] \)

\( \mathcal{I}(\sigma_o) = \{\sigma^0_o,\, \sigma^1_o,\, \dots,\, \sigma^{n-1}_o \} \)
\( \mathcal{I}(\sigma_t) = \{\sigma^0_t,\, \sigma^1_t,\, \dots,\, \sigma^{n-1}_t \} \)

\( k = \bigl(\mathcal{I}(\sigma_o).\texttt{index}(\sigma_o) + \texttt{shift}\bigr) \bmod n \)

\( \texttt{output_symbol_indices}[i] := T(\sigma_o^k,\, \sigma_t^k).\text{symbol_index} \)
\( \texttt{output_state_indices}[i] := \mathcal{S}.\texttt{index}(\sigma_o^k) \)
\( \texttt{output_state_indices}[i+1] := \mathcal{S}.\texttt{index}(\sigma_t^k) \)

Where \( \mathcal{I}(\sigma) \)` is the ordered set of states isomorphic to \( \sigma \) including \( \sigma \) itself.

Arguments:
  • input_symbol_indices (np.ndarray): The sequence of generated symbols.
  • input_state_indices (np.ndarray): The sequence of states that generated symbols with the final state at the end.
  • shift : int: How much to shift the symbols across the isomorphic states.
def generate_data( self, file_prefix: str, n_gen: int, include_states: bool, isomorphic_shifts: set[int] = None, random_seed: int = 42) -> dict[any]:
1460	def generate_data(
1461		self,
1462		file_prefix: str,
1463		n_gen: int,
1464		include_states: bool,
1465		isomorphic_shifts : set[int]=None,
1466		random_seed : int=42 ) -> dict[any] : 
1467
1468		trs = [ [] for _ in range( len( self.states ) ) ]
1469		for tr in self.transitions :
1470			trs[ tr.origin_state_idx ].append( ( 
1471				tr.symbol_idx, 
1472				float( tr.prob ),
1473				tr.target_state_idx ) )
1474
1475		data = am_fast.generate_data(
1476			n_gen=n_gen,
1477			start_state=self.start_state,
1478			transitions=trs,
1479			alphabet=sorted(list(self.alphabet)),
1480			include_states=include_states,
1481			random_seed=random_seed
1482		)
1483
1484		if isomorphic_shifts is not None :
1485
1486			if not include_states :
1487				raise ValueError( "Isomorphic inversion requires include_states=True" )
1488
1489			data[ "isomorphic_shifts" ] = {}
1490
1491			for shift in isomorphic_shifts :
1492
1493				try : 
1494
1495					shifted = self.isomorphic_shift(
1496						input_symbol_indices=data[ "symbol_index" ], 
1497						input_state_indices=data[ "state_index" ], 
1498						shift=shift
1499					)
1500
1501					data[ "isomorphic_shifts" ][ shift ] = {
1502						"symbol_index" : shifted[ "symbol_index" ],
1503						"state_index"  : shifted[ "state_index" ]
1504					}
1505
1506				except Exception as e :
1507					print( f"Exception {e}" )
1508
1509		am_fast.save_data(
1510			data=data,
1511			file_prefix=file_prefix,
1512			alphabet=sorted(list(self.alphabet)),
1513			n_states=len( self.states ),
1514			start_state=self.start_state,
1515			random_seed=random_seed,
1516			machine_metadata=self.get_metadata() )
1517
1518		return data
def draw_graph( self, engine: str = 'dot', output_dir: pathlib.Path = PosixPath('.'), show: bool = True) -> None:
1524	def draw_graph(
1525		self,
1526		engine     : str  = 'dot',
1527		output_dir : Path = Path('.'),
1528		show       : bool = True
1529	) -> None :
1530
1531		"""
1532		Draws the machine using [pygraphiviz](https://pygraphviz.github.io/documentation/stable/) and saves it.
1533
1534		Returns:
1535		
1536			networkx.DiGraph : the resulting graph.
1537		"""
1538
1539		G = self.as_digraph()
1540
1541		subgraphs = None if nx.is_strongly_connected( G ) else list( nx.strongly_connected_components( G ) )
1542		
1543		am_vis.draw_graph( 
1544			self, 
1545			output_dir=output_dir, 
1546			title="am_graph", 
1547			view=show, 
1548			subgraphs=subgraphs, 
1549			engine=engine )

Draws the machine using pygraphiviz and saves it.

Returns:

networkx.DiGraph : the resulting graph.

def as_digraph(self) -> networkx.classes.digraph.DiGraph:
1556	def as_digraph( self ) -> nx.DiGraph :
1557
1558		"""
1559		Builds a [networkx.DiGraph](https://networkx.org/documentation/stable/reference/classes/digraph.html) constructed from the machine's transitions with no edge symbols or weights.
1560
1561		Returns:
1562		
1563			networkx.DiGraph : the resulting graph.
1564		"""
1565
1566		G = nx.DiGraph()
1567		G.add_nodes_from( [ i for i, s in enumerate( self.states ) ] )
1568
1569		for tr in self.transitions :
1570			G.add_edge( tr.origin_state_idx, tr.target_state_idx )
1571
1572		return G

Builds a networkx.DiGraph constructed from the machine's transitions with no edge symbols or weights.

Returns:

networkx.DiGraph : the resulting graph.

def as_dfa(self, with_probs: bool):
1574	def as_dfa( self, with_probs : bool ) :
1575
1576		"""
1577		Builds an [automata.fa.dfa.DFA](https://caleb531.github.io/automata/api/fa/class-dfa/) constructed from the machine's transitions.
1578
1579		Args:
1580			with_probs (bool): If true the DFA transitions are labeled based on
1581				the symbol of the machines transition concatenated with its
1582				probability, othwise, the only the symbols.
1583
1584		Returns:
1585		
1586			automata.fa.dfa.DFA : the resulting DFA.
1587		"""
1588
1589		precision=8
1590
1591		def edge_label( symb, prob ) :
1592			return f"({symb},{round(prob, precision)})"
1593
1594		# Build states, symbols, and transitions 
1595		dfa_states  = { i for i, _ in enumerate( self.states ) }
1596			
1597		if not with_probs :
1598			dfa_symbols = set( { t.symbol_idx for t in self.transitions } )
1599		else : 
1600			dfa_symbols = set( { edge_label( t.symbol_idx, t.prob ) for t in self.transitions } )
1601
1602		dfa_transitions = defaultdict(dict)
1603
1604		if not with_probs :
1605			for t in self.transitions :
1606				dfa_transitions[ t.origin_state_idx ][ t.symbol_idx ] = t.target_state_idx
1607		else :
1608			for t in self.transitions :
1609				dfa_transitions[ t.origin_state_idx ][ edge_label( t.symbol_idx, t.prob ) ] = t.target_state_idx
1610
1611		# Construct the DFA
1612		return DFA(
1613			states=dfa_states,
1614			input_symbols=dfa_symbols,
1615			transitions=dfa_transitions,
1616			initial_state=self.start_state,
1617			allow_partial=True,
1618			final_states={ 
1619				s for s in dfa_states
1620			}
1621		)

Builds an automata.fa.dfa.DFA constructed from the machine's transitions.

Arguments:
  • with_probs (bool): If true the DFA transitions are labeled based on the symbol of the machines transition concatenated with its probability, othwise, the only the symbols.
Returns:

automata.fa.dfa.DFA : the resulting DFA.