Coverage for src/flexfrac1d/model/frac_handlers.py: 89%

119 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-08-30 14:00 +0200

1import abc 

2from collections.abc import Iterator, Sequence 

3import functools 

4from numbers import Real 

5 

6import attrs 

7import numpy as np 

8import scipy.optimize as optimize 

9import scipy.signal as signal 

10 

11from . import model 

12from ..lib import phase_shift as ps, physics as ph 

13from ..lib.constants import PI_2 

14 

15 

16def _make_search_array(wuf: model.WavesUnderFloe, coef: int): 

17 nd = np.ceil(4 * wuf.length * wuf.wui.wavenumbers.max() / PI_2).astype(int) + 2 

18 return np.linspace(0, wuf.length, nd * coef)[1:-1] 

19 

20 

21def _make_diagnose_array(wuf: model.WavesUnderFloe, res: float): 

22 return np.linspace(0, wuf.length, np.ceil(wuf.length / res).astype(int) + 1) 

23 

24 

25@attrs.define(frozen=True) 

26class _StrainDiag: 

27 x: np.ndarray 

28 strain: np.ndarray 

29 peaks: np.ndarray 

30 strain_extrema: np.ndarray 

31 

32 

33@attrs.define(frozen=True) 

34class _FractureDiag: 

35 x: np.ndarray 

36 energy: np.ndarray 

37 initial_energy: float 

38 frac_energy_rate: float 

39 

40 

41@attrs.define 

42class _FractureHandler(abc.ABC): 

43 coef_nd: int = 4 

44 scattering_handler: ps.ScatteringHandler = attrs.field( 

45 factory=ps.ContinuousScatteringHandler 

46 ) 

47 

48 def split( 

49 self, 

50 wuf: model.WavesUnderFloe, 

51 xf: Real | np.ndarray, 

52 is_searching: bool = False, 

53 ) -> list[model.WavesUnderFloe]: 

54 xf = np.hstack((0, xf)) 

55 lengths = np.ediff1d(np.hstack((xf, wuf.length))) 

56 edges = wuf.left_edge + xf 

57 

58 if is_searching: 

59 post_breakup_amplitudes = ( 

60 ps.ContinuousScatteringHandler.compute_edge_amplitudes( 

61 wuf.edge_amplitudes, wuf.wui._c_wavenumbers, xf 

62 ) 

63 ) 

64 else: 

65 post_breakup_amplitudes = np.full( 

66 (xf.size, wuf.edge_amplitudes.size), np.nan, dtype=complex 

67 ) 

68 post_breakup_amplitudes[0] = wuf.edge_amplitudes.copy() 

69 post_breakup_amplitudes[1:] = ( 

70 self.scattering_handler.compute_edge_amplitudes( 

71 wuf.edge_amplitudes, wuf.wui._c_wavenumbers, xf[1:] 

72 ) 

73 ) 

74 # edge_amplitudes = wuf.edge_amplitudes * np.exp( 

75 # 1j * wuf.wui._c_wavenumbers * xf[:, None] 

76 # ) 

77 gens = wuf.generation * np.ones(xf.size, dtype=int) 

78 gens[:-1] += 1 

79 return [ 

80 model.WavesUnderFloe( 

81 left_edge=edge, 

82 length=lgth, 

83 wui=wuf.wui, 

84 edge_amplitudes=amplitudes, 

85 generation=gen, 

86 ) 

87 for edge, lgth, amplitudes, gen in zip( 

88 edges, lengths, post_breakup_amplitudes, gens 

89 ) 

90 ] 

91 

92 @abc.abstractmethod 

93 def search(self, wuf: model.WavesUnderFloe, growth_params, an_sol, num_params): 

94 raise NotImplementedError 

95 

96 

97@attrs.define(frozen=True) 

98class BinaryFracture(_FractureHandler): 

99 def compute_energies( 

100 self, 

101 wuf_collection: Sequence[model.WavesUnderFloe], 

102 growth_params, 

103 an_sol: bool, 

104 num_params, 

105 ) -> tuple[float]: 

106 energies = np.full(len(wuf_collection), np.nan) 

107 for i, wuf in enumerate(wuf_collection): 

108 handler = ph.EnergyHandler.from_wuf(wuf, growth_params) 

109 energies[i] = handler.compute(an_sol, num_params) 

110 return energies 

111 

112 def _ener_min(self, length, wuf, growth_params, an_sol, num_params) -> float: 

113 """Objective function to minimise for energy-based fracture""" 

114 sub_left, sub_right = self.split(wuf, length) 

115 energy_left, energy_right = self.compute_energies( 

116 self.split(wuf, length, True), growth_params, an_sol, num_params 

117 ) 

118 return np.log(energy_left + energy_right) 

119 

120 def diagnose( 

121 self, 

122 wuf: model.WavesUnderFloe, 

123 res: float = 0.5, 

124 growth_params=None, 

125 an_sol=False, 

126 num_params=None, 

127 ): 

128 lengths = _make_diagnose_array(wuf, res)[1:-1] 

129 energies = np.full((lengths.size, 2), np.nan) 

130 initial_energy = ( 

131 ph.EnergyHandler.from_wuf(wuf, growth_params).compute(an_sol, num_params), 

132 ) 

133 for i, length in enumerate(lengths): 

134 energies[i, :] = self.compute_energies( 

135 self.split(wuf, length), growth_params, an_sol, num_params 

136 ) 

137 return _FractureDiag( 

138 lengths, 

139 energies, 

140 initial_energy, 

141 wuf.wui.ice.frac_energy_rate, 

142 ) 

143 

144 def discrete_sweep( 

145 self, wuf, an_sol, growth_params, num_params 

146 ) -> Iterator[tuple[float]]: 

147 lengths = _make_search_array(wuf, self.coef_nd) 

148 ener = np.full(lengths.shape, np.nan) 

149 for i, length in enumerate(lengths): 

150 ener[i] = self._ener_min(length, wuf, growth_params, an_sol, num_params) 

151 

152 peak_idxs = np.hstack( 

153 (0, signal.find_peaks(ener, distance=2)[0], ener.size - 1) 

154 ) 

155 return zip(lengths[peak_idxs[:-1]], lengths[peak_idxs[1:]]) 

156 

157 def search( 

158 self, wuf: model.WavesUnderFloe, growth_params, an_sol, num_params 

159 ) -> float | None: 

160 base_handler = ph.EnergyHandler.from_wuf(wuf, growth_params) 

161 base_energy = base_handler.compute(an_sol, num_params) 

162 

163 # No fracture if the elastic energy is below the threshold 

164 if base_energy < wuf.wui.ice.frac_energy_rate: 

165 return None 

166 else: 

167 bounds_iterator = self.discrete_sweep( 

168 wuf, an_sol, growth_params, num_params 

169 ) 

170 local_ener_cost = functools.partial( 

171 self._ener_min, 

172 wuf=wuf, 

173 growth_params=growth_params, 

174 an_sol=an_sol, 

175 num_params=num_params, 

176 ) 

177 opts = [ 

178 optimize.minimize_scalar(local_ener_cost, bounds=bounds) 

179 for bounds in bounds_iterator 

180 ] 

181 opt = min(filter(lambda opt: opt.success, opts), key=lambda opt: opt.fun) 

182 # Minimisation is done on the log of energy 

183 if np.exp(opt.fun) + wuf.wui.ice.frac_energy_rate < base_energy: 183 ↛ 185line 183 didn't jump to line 185, because the condition on line 183 was never false

184 return opt.x 

185 return None 

186 

187 

188@attrs.define 

189class _StrainFracture(_FractureHandler): 

190 def discrete_sweep( 

191 self, strain_handler, wuf, growth_params, an_sol, num_params 

192 ) -> Iterator[tuple[float]]: 

193 x = _make_search_array(wuf, self.coef_nd) 

194 strain = strain_handler.compute(x, an_sol, num_params) 

195 peak_idxs = np.hstack((0, signal.find_peaks(-(strain**2))[0], x.size - 1)) 

196 return zip(x[peak_idxs[:-1]], x[peak_idxs[1:]]) 

197 

198 def search_peaks( 

199 self, 

200 wuf: model.WavesUnderFloe, 

201 growth_params: tuple | None, 

202 an_sol: bool, 

203 num_params: dict | None, 

204 ) -> Iterator[optimize.OptimizeResult]: 

205 strain_handler = ph.StrainHandler.from_wuf(wuf, growth_params) 

206 bounds_iterator = self.discrete_sweep( 

207 strain_handler, wuf, growth_params, an_sol, num_params 

208 ) 

209 opts = ( 

210 optimize.minimize_scalar( 

211 lambda x: -strain_handler.compute(x, an_sol, num_params) ** 2, 

212 bounds=bounds, 

213 ) 

214 for bounds in bounds_iterator 

215 ) 

216 return filter(lambda opt: opt.success, opts) 

217 

218 def diagnose( 

219 self, 

220 wuf: model.WavesUnderFloe, 

221 res: float = 0.5, 

222 growth_params: tuple | None = None, 

223 an_sol: bool = True, 

224 num_params: dict | None = None, 

225 ) -> _StrainDiag: 

226 x = _make_diagnose_array(wuf, res) 

227 strain_handler = ph.StrainHandler.from_wuf(wuf, growth_params) 

228 opts = self.search_peaks(wuf, growth_params, an_sol, num_params) 

229 peaks = np.array([opt.x for opt in opts]) 

230 return _StrainDiag( 

231 x, 

232 strain_handler.compute(x, an_sol, num_params), 

233 peaks, 

234 strain_handler.compute(peaks, an_sol, num_params), 

235 ) 

236 

237 

238@attrs.define(frozen=True) 

239class BinaryStrainFracture(_StrainFracture): 

240 def search( 

241 self, 

242 wuf: model.WavesUnderFloe, 

243 growth_params, 

244 an_sol, 

245 num_params, 

246 ) -> float | None: 

247 opts = self.search_peaks(wuf, growth_params, an_sol, num_params) 

248 opt = min(opts, key=lambda opt: opt.fun) 

249 if (-opt.fun) ** 0.5 >= wuf.wui.ice.strain_threshold: 

250 return opt.x 

251 return None 

252 

253 

254@attrs.define(frozen=True) 

255class MultipleStrainFracture(_StrainFracture): 

256 def search( 

257 self, 

258 wuf: model.WavesUnderFloe, 

259 growth_params, 

260 an_sol, 

261 num_params, 

262 ) -> list[float] | None: 

263 opts = self.search_peaks(wuf, growth_params, an_sol, num_params) 

264 xfs = [ 

265 opt.x 

266 for opt in filter( 

267 lambda opt: (-opt.fun) ** 0.5 >= wuf.wui.ice.strain_threshold, opts 

268 ) 

269 ] 

270 if len(xfs) > 0: 

271 return xfs 

272 return None