Coverage for src/swiift/model/frac_handlers.py: 37%
118 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-09-11 16:23 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-09-11 16:23 +0200
1import abc
2from collections.abc import Iterator, Sequence
3import functools
5import attrs
6import numpy as np
7import scipy.optimize as optimize
8import scipy.signal as signal
10from . import model
11from ..lib import phase_shift as ps, physics as ph
12from ..lib.constants import PI_2
15def _make_search_array(wuf: model.WavesUnderFloe, coef: int):
16 nd = np.ceil(4 * wuf.length * wuf.wui.wavenumbers.max() / PI_2).astype(int) + 2
17 return np.linspace(0, wuf.length, nd * coef)[1:-1]
20def _make_diagnose_array(wuf: model.WavesUnderFloe, res: float):
21 return np.linspace(0, wuf.length, np.ceil(wuf.length / res).astype(int) + 1)
24@attrs.define(frozen=True)
25class _StrainDiag:
26 x: np.ndarray
27 strain: np.ndarray
28 peaks: np.ndarray
29 strain_extrema: np.ndarray
32@attrs.define(frozen=True)
33class _FractureDiag:
34 x: np.ndarray
35 energy: np.ndarray
36 initial_energy: float
37 frac_energy_rate: float
40@attrs.define
41class _FractureHandler(abc.ABC):
42 coef_nd: int = 4
43 scattering_handler: ps._ScatteringHandler = attrs.field(
44 factory=ps.ContinuousScatteringHandler
45 )
47 def split(
48 self,
49 wuf: model.WavesUnderFloe,
50 xf: float | Sequence[float],
51 is_searching: bool = False,
52 ) -> list[model.WavesUnderFloe]:
53 new_relative_edges = np.hstack((0, xf))
54 new_absolute_edges = wuf.left_edge + new_relative_edges
55 new_lengths = np.ediff1d(np.hstack((new_absolute_edges, wuf.right_edge)))
57 if is_searching:
58 post_breakup_amplitudes = (
59 ps.ContinuousScatteringHandler().compute_edge_amplitudes(
60 wuf.edge_amplitudes, wuf.wui._c_wavenumbers, new_relative_edges
61 )
62 )
63 else:
64 post_breakup_amplitudes = np.full(
65 (new_relative_edges.size, wuf.edge_amplitudes.size),
66 np.nan,
67 dtype=complex,
68 )
69 post_breakup_amplitudes[0] = wuf.edge_amplitudes.copy()
70 post_breakup_amplitudes[1:] = (
71 self.scattering_handler.compute_edge_amplitudes(
72 wuf.edge_amplitudes, wuf.wui._c_wavenumbers, new_relative_edges[1:]
73 )
74 )
75 # edge_amplitudes = wuf.edge_amplitudes * np.exp(
76 # 1j * wuf.wui._c_wavenumbers * xf[:, None]
77 # )
78 gens = wuf.generation * np.ones(new_relative_edges.size, dtype=int)
79 gens[:-1] += 1
80 return [
81 model.WavesUnderFloe(
82 left_edge=edge,
83 length=lgth,
84 wui=wuf.wui,
85 edge_amplitudes=amplitudes,
86 generation=gen,
87 )
88 for edge, lgth, amplitudes, gen in zip(
89 new_absolute_edges, new_lengths, post_breakup_amplitudes, gens
90 )
91 ]
93 @abc.abstractmethod
94 def search(self, wuf: model.WavesUnderFloe, growth_params, an_sol, num_params):
95 raise NotImplementedError
98@attrs.define(frozen=True)
99class BinaryFracture(_FractureHandler):
100 def compute_energies(
101 self,
102 wuf_collection: Sequence[model.WavesUnderFloe],
103 growth_params,
104 an_sol: bool,
105 num_params,
106 linear_curvature: bool | None = None,
107 ) -> np.ndarray:
108 energies = np.full(len(wuf_collection), np.nan)
109 for i, wuf in enumerate(wuf_collection):
110 handler = ph.EnergyHandler.from_wuf(wuf, growth_params)
111 energies[i] = handler.compute(an_sol, num_params, linear_curvature)
112 return energies
114 def _ener_min(
115 self,
116 length,
117 wuf,
118 growth_params,
119 an_sol,
120 num_params,
121 linear_curvature: bool | None = None,
122 ) -> float:
123 """Objective function to minimise for energy-based fracture"""
124 sub_left, sub_right = self.split(wuf, length)
125 energy_left, energy_right = self.compute_energies(
126 self.split(wuf, length, True),
127 growth_params,
128 an_sol,
129 num_params,
130 linear_curvature,
131 )
132 return np.log(energy_left + energy_right)
134 def diagnose(
135 self,
136 wuf: model.WavesUnderFloe,
137 res: float = 0.5,
138 growth_params=None,
139 an_sol=None,
140 num_params=None,
141 linear_curvature: bool | None = None,
142 ):
143 lengths = _make_diagnose_array(wuf, res)[1:-1]
144 energies = np.full((lengths.size, 2), np.nan)
145 initial_energy = ph.EnergyHandler.from_wuf(wuf, growth_params).compute(
146 an_sol, num_params, linear_curvature
147 )
148 for i, length in enumerate(lengths):
149 energies[i, :] = self.compute_energies(
150 self.split(wuf, length),
151 growth_params,
152 an_sol,
153 num_params,
154 linear_curvature,
155 )
156 return _FractureDiag(
157 lengths,
158 energies,
159 initial_energy,
160 wuf.wui.ice.frac_energy_rate,
161 )
163 def discrete_sweep(
164 self,
165 wuf,
166 an_sol,
167 growth_params,
168 num_params,
169 linear_curvature: bool | None = None,
170 ) -> Iterator[tuple[float]]:
171 lengths = _make_search_array(wuf, self.coef_nd)
172 ener = np.full(lengths.shape, np.nan)
173 for i, length in enumerate(lengths):
174 ener[i] = self._ener_min(
175 length, wuf, growth_params, an_sol, num_params, linear_curvature
176 )
178 peak_idxs = np.hstack(
179 (0, signal.find_peaks(ener, distance=2)[0], ener.size - 1)
180 )
181 return zip(lengths[peak_idxs[:-1]], lengths[peak_idxs[1:]])
183 def search(
184 self,
185 wuf: model.WavesUnderFloe,
186 growth_params,
187 an_sol,
188 num_params,
189 linear_curvature: bool | None = None,
190 ) -> float | None:
191 base_handler = ph.EnergyHandler.from_wuf(wuf, growth_params)
192 base_energy = base_handler.compute(an_sol, num_params, linear_curvature)
194 # No fracture if the elastic energy is below the threshold
195 if base_energy < wuf.wui.ice.frac_energy_rate:
196 return None
197 else:
198 bounds_iterator = self.discrete_sweep(
199 wuf, an_sol, growth_params, num_params
200 )
201 local_ener_cost = functools.partial(
202 self._ener_min,
203 wuf=wuf,
204 growth_params=growth_params,
205 an_sol=an_sol,
206 num_params=num_params,
207 )
208 opts = [
209 optimize.minimize_scalar(local_ener_cost, bounds=bounds)
210 for bounds in bounds_iterator
211 ]
212 opt = min(filter(lambda opt: opt.success, opts), key=lambda opt: opt.fun)
213 # Minimisation is done on the log of energy
214 if np.exp(opt.fun) + wuf.wui.ice.frac_energy_rate < base_energy:
215 return opt.x
216 return None
219@attrs.define
220class _StrainFracture(_FractureHandler):
221 def discrete_sweep(
222 self, strain_handler, wuf, growth_params, an_sol, num_params
223 ) -> Iterator[tuple[float]]:
224 # NOTE: caveat: some small peaks close to the edges can be missed. The
225 # multi-fracture handler is more here as a demo anyway, so this is very
226 # low priority.
227 x = _make_search_array(wuf, self.coef_nd)
228 strain = strain_handler.compute(x, an_sol, num_params)
229 peak_idxs = np.hstack((0, signal.find_peaks(-(strain**2))[0], x.size - 1))
230 return zip(x[peak_idxs[:-1]], x[peak_idxs[1:]])
232 def search_peaks(
233 self,
234 wuf: model.WavesUnderFloe,
235 growth_params: tuple | None,
236 an_sol: bool,
237 num_params: dict | None,
238 ) -> Iterator[optimize.OptimizeResult]:
239 strain_handler = ph.StrainHandler.from_wuf(wuf, growth_params)
240 bounds_iterator = self.discrete_sweep(
241 strain_handler, wuf, growth_params, an_sol, num_params
242 )
243 opts = (
244 optimize.minimize_scalar(
245 lambda x: -strain_handler.compute(x, an_sol, num_params) ** 2,
246 bounds=bounds,
247 )
248 for bounds in bounds_iterator
249 )
250 return filter(lambda opt: opt.success, opts)
252 def diagnose(
253 self,
254 wuf: model.WavesUnderFloe,
255 res: float = 0.5,
256 growth_params: tuple | None = None,
257 an_sol: bool = True,
258 num_params: dict | None = None,
259 ) -> _StrainDiag:
260 x = _make_diagnose_array(wuf, res)
261 strain_handler = ph.StrainHandler.from_wuf(wuf, growth_params)
262 opts = self.search_peaks(wuf, growth_params, an_sol, num_params)
263 peaks = np.array([opt.x for opt in opts])
264 return _StrainDiag(
265 x,
266 strain_handler.compute(x, an_sol, num_params),
267 peaks,
268 strain_handler.compute(peaks, an_sol, num_params),
269 )
272@attrs.define(frozen=True)
273class BinaryStrainFracture(_StrainFracture):
274 def search(
275 self,
276 wuf: model.WavesUnderFloe,
277 growth_params,
278 an_sol,
279 num_params,
280 ) -> float | None:
281 opts = self.search_peaks(wuf, growth_params, an_sol, num_params)
282 opt = min(opts, key=lambda opt: opt.fun)
283 if (-opt.fun) ** 0.5 >= wuf.wui.ice.strain_threshold:
284 return opt.x
285 return None
288@attrs.define(frozen=True)
289class MultipleStrainFracture(_StrainFracture):
290 def search(
291 self,
292 wuf: model.WavesUnderFloe,
293 growth_params,
294 an_sol,
295 num_params,
296 ) -> list[float] | None:
297 opts = self.search_peaks(wuf, growth_params, an_sol, num_params)
298 xfs = [
299 opt.x
300 for opt in filter(
301 lambda opt: (-opt.fun) ** 0.5 >= wuf.wui.ice.strain_threshold, opts
302 )
303 ]
304 if len(xfs) > 0:
305 return xfs
306 return None