Coverage for src/swiift/model/frac_handlers.py: 37%

118 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-09-11 16:23 +0200

1import abc 

2from collections.abc import Iterator, Sequence 

3import functools 

4 

5import attrs 

6import numpy as np 

7import scipy.optimize as optimize 

8import scipy.signal as signal 

9 

10from . import model 

11from ..lib import phase_shift as ps, physics as ph 

12from ..lib.constants import PI_2 

13 

14 

15def _make_search_array(wuf: model.WavesUnderFloe, coef: int): 

16 nd = np.ceil(4 * wuf.length * wuf.wui.wavenumbers.max() / PI_2).astype(int) + 2 

17 return np.linspace(0, wuf.length, nd * coef)[1:-1] 

18 

19 

20def _make_diagnose_array(wuf: model.WavesUnderFloe, res: float): 

21 return np.linspace(0, wuf.length, np.ceil(wuf.length / res).astype(int) + 1) 

22 

23 

24@attrs.define(frozen=True) 

25class _StrainDiag: 

26 x: np.ndarray 

27 strain: np.ndarray 

28 peaks: np.ndarray 

29 strain_extrema: np.ndarray 

30 

31 

32@attrs.define(frozen=True) 

33class _FractureDiag: 

34 x: np.ndarray 

35 energy: np.ndarray 

36 initial_energy: float 

37 frac_energy_rate: float 

38 

39 

40@attrs.define 

41class _FractureHandler(abc.ABC): 

42 coef_nd: int = 4 

43 scattering_handler: ps._ScatteringHandler = attrs.field( 

44 factory=ps.ContinuousScatteringHandler 

45 ) 

46 

47 def split( 

48 self, 

49 wuf: model.WavesUnderFloe, 

50 xf: float | Sequence[float], 

51 is_searching: bool = False, 

52 ) -> list[model.WavesUnderFloe]: 

53 new_relative_edges = np.hstack((0, xf)) 

54 new_absolute_edges = wuf.left_edge + new_relative_edges 

55 new_lengths = np.ediff1d(np.hstack((new_absolute_edges, wuf.right_edge))) 

56 

57 if is_searching: 

58 post_breakup_amplitudes = ( 

59 ps.ContinuousScatteringHandler().compute_edge_amplitudes( 

60 wuf.edge_amplitudes, wuf.wui._c_wavenumbers, new_relative_edges 

61 ) 

62 ) 

63 else: 

64 post_breakup_amplitudes = np.full( 

65 (new_relative_edges.size, wuf.edge_amplitudes.size), 

66 np.nan, 

67 dtype=complex, 

68 ) 

69 post_breakup_amplitudes[0] = wuf.edge_amplitudes.copy() 

70 post_breakup_amplitudes[1:] = ( 

71 self.scattering_handler.compute_edge_amplitudes( 

72 wuf.edge_amplitudes, wuf.wui._c_wavenumbers, new_relative_edges[1:] 

73 ) 

74 ) 

75 # edge_amplitudes = wuf.edge_amplitudes * np.exp( 

76 # 1j * wuf.wui._c_wavenumbers * xf[:, None] 

77 # ) 

78 gens = wuf.generation * np.ones(new_relative_edges.size, dtype=int) 

79 gens[:-1] += 1 

80 return [ 

81 model.WavesUnderFloe( 

82 left_edge=edge, 

83 length=lgth, 

84 wui=wuf.wui, 

85 edge_amplitudes=amplitudes, 

86 generation=gen, 

87 ) 

88 for edge, lgth, amplitudes, gen in zip( 

89 new_absolute_edges, new_lengths, post_breakup_amplitudes, gens 

90 ) 

91 ] 

92 

93 @abc.abstractmethod 

94 def search(self, wuf: model.WavesUnderFloe, growth_params, an_sol, num_params): 

95 raise NotImplementedError 

96 

97 

98@attrs.define(frozen=True) 

99class BinaryFracture(_FractureHandler): 

100 def compute_energies( 

101 self, 

102 wuf_collection: Sequence[model.WavesUnderFloe], 

103 growth_params, 

104 an_sol: bool, 

105 num_params, 

106 linear_curvature: bool | None = None, 

107 ) -> np.ndarray: 

108 energies = np.full(len(wuf_collection), np.nan) 

109 for i, wuf in enumerate(wuf_collection): 

110 handler = ph.EnergyHandler.from_wuf(wuf, growth_params) 

111 energies[i] = handler.compute(an_sol, num_params, linear_curvature) 

112 return energies 

113 

114 def _ener_min( 

115 self, 

116 length, 

117 wuf, 

118 growth_params, 

119 an_sol, 

120 num_params, 

121 linear_curvature: bool | None = None, 

122 ) -> float: 

123 """Objective function to minimise for energy-based fracture""" 

124 sub_left, sub_right = self.split(wuf, length) 

125 energy_left, energy_right = self.compute_energies( 

126 self.split(wuf, length, True), 

127 growth_params, 

128 an_sol, 

129 num_params, 

130 linear_curvature, 

131 ) 

132 return np.log(energy_left + energy_right) 

133 

134 def diagnose( 

135 self, 

136 wuf: model.WavesUnderFloe, 

137 res: float = 0.5, 

138 growth_params=None, 

139 an_sol=None, 

140 num_params=None, 

141 linear_curvature: bool | None = None, 

142 ): 

143 lengths = _make_diagnose_array(wuf, res)[1:-1] 

144 energies = np.full((lengths.size, 2), np.nan) 

145 initial_energy = ph.EnergyHandler.from_wuf(wuf, growth_params).compute( 

146 an_sol, num_params, linear_curvature 

147 ) 

148 for i, length in enumerate(lengths): 

149 energies[i, :] = self.compute_energies( 

150 self.split(wuf, length), 

151 growth_params, 

152 an_sol, 

153 num_params, 

154 linear_curvature, 

155 ) 

156 return _FractureDiag( 

157 lengths, 

158 energies, 

159 initial_energy, 

160 wuf.wui.ice.frac_energy_rate, 

161 ) 

162 

163 def discrete_sweep( 

164 self, 

165 wuf, 

166 an_sol, 

167 growth_params, 

168 num_params, 

169 linear_curvature: bool | None = None, 

170 ) -> Iterator[tuple[float]]: 

171 lengths = _make_search_array(wuf, self.coef_nd) 

172 ener = np.full(lengths.shape, np.nan) 

173 for i, length in enumerate(lengths): 

174 ener[i] = self._ener_min( 

175 length, wuf, growth_params, an_sol, num_params, linear_curvature 

176 ) 

177 

178 peak_idxs = np.hstack( 

179 (0, signal.find_peaks(ener, distance=2)[0], ener.size - 1) 

180 ) 

181 return zip(lengths[peak_idxs[:-1]], lengths[peak_idxs[1:]]) 

182 

183 def search( 

184 self, 

185 wuf: model.WavesUnderFloe, 

186 growth_params, 

187 an_sol, 

188 num_params, 

189 linear_curvature: bool | None = None, 

190 ) -> float | None: 

191 base_handler = ph.EnergyHandler.from_wuf(wuf, growth_params) 

192 base_energy = base_handler.compute(an_sol, num_params, linear_curvature) 

193 

194 # No fracture if the elastic energy is below the threshold 

195 if base_energy < wuf.wui.ice.frac_energy_rate: 

196 return None 

197 else: 

198 bounds_iterator = self.discrete_sweep( 

199 wuf, an_sol, growth_params, num_params 

200 ) 

201 local_ener_cost = functools.partial( 

202 self._ener_min, 

203 wuf=wuf, 

204 growth_params=growth_params, 

205 an_sol=an_sol, 

206 num_params=num_params, 

207 ) 

208 opts = [ 

209 optimize.minimize_scalar(local_ener_cost, bounds=bounds) 

210 for bounds in bounds_iterator 

211 ] 

212 opt = min(filter(lambda opt: opt.success, opts), key=lambda opt: opt.fun) 

213 # Minimisation is done on the log of energy 

214 if np.exp(opt.fun) + wuf.wui.ice.frac_energy_rate < base_energy: 

215 return opt.x 

216 return None 

217 

218 

219@attrs.define 

220class _StrainFracture(_FractureHandler): 

221 def discrete_sweep( 

222 self, strain_handler, wuf, growth_params, an_sol, num_params 

223 ) -> Iterator[tuple[float]]: 

224 # NOTE: caveat: some small peaks close to the edges can be missed. The 

225 # multi-fracture handler is more here as a demo anyway, so this is very 

226 # low priority. 

227 x = _make_search_array(wuf, self.coef_nd) 

228 strain = strain_handler.compute(x, an_sol, num_params) 

229 peak_idxs = np.hstack((0, signal.find_peaks(-(strain**2))[0], x.size - 1)) 

230 return zip(x[peak_idxs[:-1]], x[peak_idxs[1:]]) 

231 

232 def search_peaks( 

233 self, 

234 wuf: model.WavesUnderFloe, 

235 growth_params: tuple | None, 

236 an_sol: bool, 

237 num_params: dict | None, 

238 ) -> Iterator[optimize.OptimizeResult]: 

239 strain_handler = ph.StrainHandler.from_wuf(wuf, growth_params) 

240 bounds_iterator = self.discrete_sweep( 

241 strain_handler, wuf, growth_params, an_sol, num_params 

242 ) 

243 opts = ( 

244 optimize.minimize_scalar( 

245 lambda x: -strain_handler.compute(x, an_sol, num_params) ** 2, 

246 bounds=bounds, 

247 ) 

248 for bounds in bounds_iterator 

249 ) 

250 return filter(lambda opt: opt.success, opts) 

251 

252 def diagnose( 

253 self, 

254 wuf: model.WavesUnderFloe, 

255 res: float = 0.5, 

256 growth_params: tuple | None = None, 

257 an_sol: bool = True, 

258 num_params: dict | None = None, 

259 ) -> _StrainDiag: 

260 x = _make_diagnose_array(wuf, res) 

261 strain_handler = ph.StrainHandler.from_wuf(wuf, growth_params) 

262 opts = self.search_peaks(wuf, growth_params, an_sol, num_params) 

263 peaks = np.array([opt.x for opt in opts]) 

264 return _StrainDiag( 

265 x, 

266 strain_handler.compute(x, an_sol, num_params), 

267 peaks, 

268 strain_handler.compute(peaks, an_sol, num_params), 

269 ) 

270 

271 

272@attrs.define(frozen=True) 

273class BinaryStrainFracture(_StrainFracture): 

274 def search( 

275 self, 

276 wuf: model.WavesUnderFloe, 

277 growth_params, 

278 an_sol, 

279 num_params, 

280 ) -> float | None: 

281 opts = self.search_peaks(wuf, growth_params, an_sol, num_params) 

282 opt = min(opts, key=lambda opt: opt.fun) 

283 if (-opt.fun) ** 0.5 >= wuf.wui.ice.strain_threshold: 

284 return opt.x 

285 return None 

286 

287 

288@attrs.define(frozen=True) 

289class MultipleStrainFracture(_StrainFracture): 

290 def search( 

291 self, 

292 wuf: model.WavesUnderFloe, 

293 growth_params, 

294 an_sol, 

295 num_params, 

296 ) -> list[float] | None: 

297 opts = self.search_peaks(wuf, growth_params, an_sol, num_params) 

298 xfs = [ 

299 opt.x 

300 for opt in filter( 

301 lambda opt: (-opt.fun) ** 0.5 >= wuf.wui.ice.strain_threshold, opts 

302 ) 

303 ] 

304 if len(xfs) > 0: 

305 return xfs 

306 return None