Coverage for src/flexfrac1d/model/frac_handlers.py: 89%
119 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-08-30 14:00 +0200
« prev ^ index » next coverage.py v7.4.1, created at 2024-08-30 14:00 +0200
1import abc
2from collections.abc import Iterator, Sequence
3import functools
4from numbers import Real
6import attrs
7import numpy as np
8import scipy.optimize as optimize
9import scipy.signal as signal
11from . import model
12from ..lib import phase_shift as ps, physics as ph
13from ..lib.constants import PI_2
16def _make_search_array(wuf: model.WavesUnderFloe, coef: int):
17 nd = np.ceil(4 * wuf.length * wuf.wui.wavenumbers.max() / PI_2).astype(int) + 2
18 return np.linspace(0, wuf.length, nd * coef)[1:-1]
21def _make_diagnose_array(wuf: model.WavesUnderFloe, res: float):
22 return np.linspace(0, wuf.length, np.ceil(wuf.length / res).astype(int) + 1)
25@attrs.define(frozen=True)
26class _StrainDiag:
27 x: np.ndarray
28 strain: np.ndarray
29 peaks: np.ndarray
30 strain_extrema: np.ndarray
33@attrs.define(frozen=True)
34class _FractureDiag:
35 x: np.ndarray
36 energy: np.ndarray
37 initial_energy: float
38 frac_energy_rate: float
41@attrs.define
42class _FractureHandler(abc.ABC):
43 coef_nd: int = 4
44 scattering_handler: ps.ScatteringHandler = attrs.field(
45 factory=ps.ContinuousScatteringHandler
46 )
48 def split(
49 self,
50 wuf: model.WavesUnderFloe,
51 xf: Real | np.ndarray,
52 is_searching: bool = False,
53 ) -> list[model.WavesUnderFloe]:
54 xf = np.hstack((0, xf))
55 lengths = np.ediff1d(np.hstack((xf, wuf.length)))
56 edges = wuf.left_edge + xf
58 if is_searching:
59 post_breakup_amplitudes = (
60 ps.ContinuousScatteringHandler.compute_edge_amplitudes(
61 wuf.edge_amplitudes, wuf.wui._c_wavenumbers, xf
62 )
63 )
64 else:
65 post_breakup_amplitudes = np.full(
66 (xf.size, wuf.edge_amplitudes.size), np.nan, dtype=complex
67 )
68 post_breakup_amplitudes[0] = wuf.edge_amplitudes.copy()
69 post_breakup_amplitudes[1:] = (
70 self.scattering_handler.compute_edge_amplitudes(
71 wuf.edge_amplitudes, wuf.wui._c_wavenumbers, xf[1:]
72 )
73 )
74 # edge_amplitudes = wuf.edge_amplitudes * np.exp(
75 # 1j * wuf.wui._c_wavenumbers * xf[:, None]
76 # )
77 gens = wuf.generation * np.ones(xf.size, dtype=int)
78 gens[:-1] += 1
79 return [
80 model.WavesUnderFloe(
81 left_edge=edge,
82 length=lgth,
83 wui=wuf.wui,
84 edge_amplitudes=amplitudes,
85 generation=gen,
86 )
87 for edge, lgth, amplitudes, gen in zip(
88 edges, lengths, post_breakup_amplitudes, gens
89 )
90 ]
92 @abc.abstractmethod
93 def search(self, wuf: model.WavesUnderFloe, growth_params, an_sol, num_params):
94 raise NotImplementedError
97@attrs.define(frozen=True)
98class BinaryFracture(_FractureHandler):
99 def compute_energies(
100 self,
101 wuf_collection: Sequence[model.WavesUnderFloe],
102 growth_params,
103 an_sol: bool,
104 num_params,
105 ) -> tuple[float]:
106 energies = np.full(len(wuf_collection), np.nan)
107 for i, wuf in enumerate(wuf_collection):
108 handler = ph.EnergyHandler.from_wuf(wuf, growth_params)
109 energies[i] = handler.compute(an_sol, num_params)
110 return energies
112 def _ener_min(self, length, wuf, growth_params, an_sol, num_params) -> float:
113 """Objective function to minimise for energy-based fracture"""
114 sub_left, sub_right = self.split(wuf, length)
115 energy_left, energy_right = self.compute_energies(
116 self.split(wuf, length, True), growth_params, an_sol, num_params
117 )
118 return np.log(energy_left + energy_right)
120 def diagnose(
121 self,
122 wuf: model.WavesUnderFloe,
123 res: float = 0.5,
124 growth_params=None,
125 an_sol=False,
126 num_params=None,
127 ):
128 lengths = _make_diagnose_array(wuf, res)[1:-1]
129 energies = np.full((lengths.size, 2), np.nan)
130 initial_energy = (
131 ph.EnergyHandler.from_wuf(wuf, growth_params).compute(an_sol, num_params),
132 )
133 for i, length in enumerate(lengths):
134 energies[i, :] = self.compute_energies(
135 self.split(wuf, length), growth_params, an_sol, num_params
136 )
137 return _FractureDiag(
138 lengths,
139 energies,
140 initial_energy,
141 wuf.wui.ice.frac_energy_rate,
142 )
144 def discrete_sweep(
145 self, wuf, an_sol, growth_params, num_params
146 ) -> Iterator[tuple[float]]:
147 lengths = _make_search_array(wuf, self.coef_nd)
148 ener = np.full(lengths.shape, np.nan)
149 for i, length in enumerate(lengths):
150 ener[i] = self._ener_min(length, wuf, growth_params, an_sol, num_params)
152 peak_idxs = np.hstack(
153 (0, signal.find_peaks(ener, distance=2)[0], ener.size - 1)
154 )
155 return zip(lengths[peak_idxs[:-1]], lengths[peak_idxs[1:]])
157 def search(
158 self, wuf: model.WavesUnderFloe, growth_params, an_sol, num_params
159 ) -> float | None:
160 base_handler = ph.EnergyHandler.from_wuf(wuf, growth_params)
161 base_energy = base_handler.compute(an_sol, num_params)
163 # No fracture if the elastic energy is below the threshold
164 if base_energy < wuf.wui.ice.frac_energy_rate:
165 return None
166 else:
167 bounds_iterator = self.discrete_sweep(
168 wuf, an_sol, growth_params, num_params
169 )
170 local_ener_cost = functools.partial(
171 self._ener_min,
172 wuf=wuf,
173 growth_params=growth_params,
174 an_sol=an_sol,
175 num_params=num_params,
176 )
177 opts = [
178 optimize.minimize_scalar(local_ener_cost, bounds=bounds)
179 for bounds in bounds_iterator
180 ]
181 opt = min(filter(lambda opt: opt.success, opts), key=lambda opt: opt.fun)
182 # Minimisation is done on the log of energy
183 if np.exp(opt.fun) + wuf.wui.ice.frac_energy_rate < base_energy: 183 ↛ 185line 183 didn't jump to line 185, because the condition on line 183 was never false
184 return opt.x
185 return None
188@attrs.define
189class _StrainFracture(_FractureHandler):
190 def discrete_sweep(
191 self, strain_handler, wuf, growth_params, an_sol, num_params
192 ) -> Iterator[tuple[float]]:
193 x = _make_search_array(wuf, self.coef_nd)
194 strain = strain_handler.compute(x, an_sol, num_params)
195 peak_idxs = np.hstack((0, signal.find_peaks(-(strain**2))[0], x.size - 1))
196 return zip(x[peak_idxs[:-1]], x[peak_idxs[1:]])
198 def search_peaks(
199 self,
200 wuf: model.WavesUnderFloe,
201 growth_params: tuple | None,
202 an_sol: bool,
203 num_params: dict | None,
204 ) -> Iterator[optimize.OptimizeResult]:
205 strain_handler = ph.StrainHandler.from_wuf(wuf, growth_params)
206 bounds_iterator = self.discrete_sweep(
207 strain_handler, wuf, growth_params, an_sol, num_params
208 )
209 opts = (
210 optimize.minimize_scalar(
211 lambda x: -strain_handler.compute(x, an_sol, num_params) ** 2,
212 bounds=bounds,
213 )
214 for bounds in bounds_iterator
215 )
216 return filter(lambda opt: opt.success, opts)
218 def diagnose(
219 self,
220 wuf: model.WavesUnderFloe,
221 res: float = 0.5,
222 growth_params: tuple | None = None,
223 an_sol: bool = True,
224 num_params: dict | None = None,
225 ) -> _StrainDiag:
226 x = _make_diagnose_array(wuf, res)
227 strain_handler = ph.StrainHandler.from_wuf(wuf, growth_params)
228 opts = self.search_peaks(wuf, growth_params, an_sol, num_params)
229 peaks = np.array([opt.x for opt in opts])
230 return _StrainDiag(
231 x,
232 strain_handler.compute(x, an_sol, num_params),
233 peaks,
234 strain_handler.compute(peaks, an_sol, num_params),
235 )
238@attrs.define(frozen=True)
239class BinaryStrainFracture(_StrainFracture):
240 def search(
241 self,
242 wuf: model.WavesUnderFloe,
243 growth_params,
244 an_sol,
245 num_params,
246 ) -> float | None:
247 opts = self.search_peaks(wuf, growth_params, an_sol, num_params)
248 opt = min(opts, key=lambda opt: opt.fun)
249 if (-opt.fun) ** 0.5 >= wuf.wui.ice.strain_threshold:
250 return opt.x
251 return None
254@attrs.define(frozen=True)
255class MultipleStrainFracture(_StrainFracture):
256 def search(
257 self,
258 wuf: model.WavesUnderFloe,
259 growth_params,
260 an_sol,
261 num_params,
262 ) -> list[float] | None:
263 opts = self.search_peaks(wuf, growth_params, an_sol, num_params)
264 xfs = [
265 opt.x
266 for opt in filter(
267 lambda opt: (-opt.fun) ** 0.5 >= wuf.wui.ice.strain_threshold, opts
268 )
269 ]
270 if len(xfs) > 0:
271 return xfs
272 return None