Coverage for src/swiift/lib/numerical.py: 18%

142 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-09-11 16:23 +0200

1from __future__ import annotations 

2 

3from collections.abc import Callable 

4import typing 

5import warnings 

6 

7import numpy as np 

8from scipy._lib._util import _RichResult 

9import scipy.integrate as integrate 

10import scipy.interpolate as interpolate 

11 

12from swiift.lib.constants import PI_2 

13 

14from ._ph_utils import _unit_wavefield 

15 

16IV = typing.TypeVar( 

17 "IV", float, np.ndarray[tuple[int], np.dtype[np.float64]] 

18) # Integration variable. 

19 

20CUBIC_BINOMIAL_COEFS = np.array([0, 0, 0, 1, 1, 2]), np.array([1, 2, 3, 2, 3, 3]) 

21 

22 

23def _growth_kernel(x: np.ndarray, mean: np.ndarray, std): 

24 kern = np.ones((mean.size, x.size)) 

25 mask = np.nonzero(x > mean) 

26 kern[mask] = np.exp(-((x - mean) ** 2) / (2 * std**2))[mask] 

27 return kern 

28 

29 

30def free_surface( 

31 x, 

32 wave_params: tuple[np.ndarray, np.ndarray], 

33 growth_params: tuple[np.ndarray, float] | None, 

34) -> np.ndarray: 

35 c_amplitudes, c_wavenumbers = wave_params 

36 wave_shape = _unit_wavefield(x, c_wavenumbers) 

37 if growth_params is not None: 

38 kern = _growth_kernel(np.asarray(x), *growth_params) 

39 wave_shape *= kern 

40 eta = np.imag(c_amplitudes @ wave_shape) 

41 return eta 

42 

43 

44def _ode_system( 

45 x, 

46 w, 

47 *, 

48 floe_params: tuple[float, float], 

49 wave_params: tuple[np.ndarray, np.ndarray], 

50 growth_params: tuple[np.ndarray, float] | None, 

51) -> np.ndarray: 

52 red_num, _ = floe_params 

53 eta = free_surface(x, wave_params, growth_params) 

54 # Factor 4 comes from sqrt(2)**4 

55 wprime = np.vstack((w[1], w[2], w[3], 4 * red_num**4 * (eta - w[0]))) 

56 return wprime 

57 

58 

59def _boundary_conditions(wa, wb): 

60 return np.array((wa[2], wb[2], wa[3], wb[3])) 

61 

62 

63def _solve_bvp( 

64 floe_params, wave_params, growth_params, **kwargs 

65) -> integrate._bvp.BVPResult: 

66 red_num, length = floe_params 

67 wavenumber = np.real(wave_params[1]) 

68 n_mesh = max(5, int(length * max(red_num, wavenumber.max()))) 

69 x0 = np.linspace(0, length, n_mesh) 

70 w0 = np.zeros((4, x0.size)) 

71 

72 opt = integrate.solve_bvp( 

73 lambda x, w: _ode_system( 

74 x, 

75 w, 

76 floe_params=floe_params, 

77 wave_params=wave_params, 

78 growth_params=growth_params, 

79 ), 

80 _boundary_conditions, 

81 x0, 

82 w0, 

83 **kwargs, 

84 ) 

85 return opt 

86 

87 

88def _get_result( 

89 floe_params, wave_params, growth_params, num_params 

90) -> integrate._bvp.BVPResult: 

91 if num_params is None: 

92 num_params = dict() 

93 opt = _solve_bvp(floe_params, wave_params, growth_params, **num_params) 

94 if not opt.success: 

95 warnings.warn("Numerical solution did not converge", stacklevel=2) 

96 return opt 

97 

98 

99def _use_an_sol( 

100 analytical_solution: bool | None, 

101 length: float, 

102 growth_params: tuple | None, 

103 linear_curvature: bool | None, 

104) -> bool: 

105 """Determine whether to use an analytical solution. 

106 

107 The displacement, curvature, and elastic energy have analytical expressions 

108 under certain conditions. These are used if `analytical_solution` is 

109 explicitely set to `True`. Otherwise, the other parameters are examined to 

110 determine if the analytical solutions can (and, therefore, should) be used. 

111 If `growth_params` is not provided, or if all its location values are 

112 greater than `length`, and if `linear_curvature` is not provided or set to 

113 `True`, analytical solutions will be used. If `linear_curvature` is set to 

114 `False`, numerical solutions will be used. 

115 

116 Parameters 

117 ---------- 

118 analytical_solution : bool, optional 

119 Set to `True` to force using analytical solutions. 

120 length : float 

121 Length of the floe. 

122 growth_params : tuple, optional 

123 Parameters of a wave growth kernel. 

124 linear_curvature : bool, optional 

125 Set to `False` to force using numerical approximations to the 

126 non-linear curvature. It has no effect for other variables 

127 (displacement, elastic energy) but *does* force using numerical 

128 solutions instead of analytical solutions, if set to `False`. 

129 

130 Returns 

131 ------- 

132 bool 

133 

134 """ 

135 if analytical_solution is not None: 

136 return analytical_solution 

137 if growth_params is None: 

138 if linear_curvature is None: 

139 return True 

140 # No analytical solution for non-linear curvature 

141 return linear_curvature 

142 else: 

143 if linear_curvature is None or linear_curvature: 

144 # If the wave growth kernel mean is to the right of the floe 

145 # for every wave component, the wave is fully developed 

146 # and the analytical solution can be used. 

147 # Alternatively, if there is a single component of the kernel whose 

148 # mean is to the left of the right edge, the numerical solution 

149 # must be used. 

150 return not np.any(growth_params[0] < length) 

151 return False 

152 

153 

154def _extract_from_poly(sol: interpolate.PPoly, n: int) -> interpolate.PPoly: 

155 return interpolate.PPoly(sol.c[:, :, n], sol.x, extrapolate=False) 

156 

157 

158def _extract_dis_poly(sol: interpolate.PPoly) -> interpolate.PPoly: 

159 return _extract_from_poly(sol, 0) 

160 

161 

162def _non_lin_curv(sol: interpolate.PPoly) -> Callable[[IV], np.ndarray]: 

163 def non_lin_curv(x: IV) -> np.ndarray: 

164 return ( 

165 _extract_from_poly(sol, 2)(x) 

166 / (1 + _extract_from_poly(sol, 1)(x) ** 2) ** 1.5 

167 ) 

168 

169 return non_lin_curv 

170 

171 

172@typing.overload 

173def _extract_cur_poly( 

174 sol: interpolate.PPoly, 

175 is_linear: typing.Literal[True] = ..., 

176) -> interpolate.PPoly: ... 

177 

178 

179@typing.overload 

180def _extract_cur_poly( 

181 sol: interpolate.PPoly, 

182 is_linear: typing.Literal[False] = ..., 

183) -> Callable[[IV], np.ndarray]: ... 

184 

185 

186@typing.overload 

187def _extract_cur_poly( 

188 sol: interpolate.PPoly, is_linear: bool = ... 

189) -> interpolate.PPoly | Callable[[IV], np.ndarray]: ... 

190 

191 

192def _extract_cur_poly(sol: interpolate.PPoly, is_linear: bool = True): 

193 if is_linear: 

194 return _extract_from_poly(sol, 2) 

195 else: 

196 return _non_lin_curv(sol) 

197 

198 

199def displacement(x, floe_params, wave_params, growth_params, num_params): 

200 opt = _get_result(floe_params, wave_params, growth_params, num_params) 

201 return _extract_dis_poly(opt.sol)(x) 

202 

203 

204def curvature( 

205 x, floe_params, wave_params, growth_params, num_params, is_linear: bool = True 

206): 

207 opt = _get_result(floe_params, wave_params, growth_params, num_params) 

208 return _extract_cur_poly(opt.sol, is_linear)(x) 

209 

210 

211@typing.overload 

212def _prepare_integrand0( 

213 floe_params: tuple[float, float], 

214 wave_params: tuple[np.ndarray, np.ndarray], 

215 growth_params, 

216 num_params, 

217 linear_curvature: typing.Literal[True], 

218) -> tuple[interpolate.PPoly, tuple[float, float]]: ... 

219 

220 

221@typing.overload 

222def _prepare_integrand0( 

223 floe_params: tuple[float, float], 

224 wave_params: tuple[np.ndarray, np.ndarray], 

225 growth_params, 

226 num_params, 

227 linear_curvature: typing.Literal[False], 

228) -> tuple[typing.Callable[[IV], np.ndarray], tuple[float, float]]: ... 

229 

230 

231@typing.overload 

232def _prepare_integrand0( 

233 floe_params: tuple[float, float], 

234 wave_params: tuple[np.ndarray, np.ndarray], 

235 growth_params, 

236 num_params, 

237 linear_curvature: bool, 

238) -> tuple[ 

239 interpolate.PPoly | typing.Callable[[IV], np.ndarray], tuple[float, float] 

240]: ... 

241 

242 

243def _prepare_integrand0( 

244 floe_params: tuple[float, float], 

245 wave_params: tuple[np.ndarray, np.ndarray], 

246 growth_params, 

247 num_params, 

248 linear_curvature: bool, 

249): 

250 opt = _get_result(floe_params, wave_params, growth_params, num_params) 

251 curvature_poly = _extract_cur_poly(opt.sol, linear_curvature) 

252 bounds = opt.x[0], opt.x[-1] 

253 return curvature_poly, bounds 

254 

255 

256def _square_cubic_poly(ppoly: interpolate.PPoly) -> interpolate.PPoly: 

257 # PPoly object have coefficients ordered opposite wrt to powers. That is, 

258 # for a cubic, c[0] is the coefficient of the cubic term, c[3] of the 

259 # constant term. 

260 new_cs = np.zeros((ppoly.c.shape[0] * 2 - 1, ppoly.c.shape[1])) 

261 cs = ppoly.c 

262 idx1, idx2 = CUBIC_BINOMIAL_COEFS 

263 new_cs[::2, :] = cs**2 

264 extra_terms = 2 * cs[idx1] * cs[idx2] 

265 

266 # Need to do them one by one to avoid silent errors, as idx1 + idx2 has 

267 # values with multiplicity > 1. 

268 for idx, term in zip(idx1 + idx2, extra_terms): 

269 new_cs[idx] += term 

270 

271 return interpolate.PPoly(new_cs, ppoly.x) 

272 

273 

274def _pseudo_analytical_integration( 

275 floe_params: tuple[float, float], 

276 wave_params: tuple[np.ndarray, np.ndarray], 

277 growth_params: tuple | None, 

278 num_params: dict, 

279) -> float: 

280 # TODO: for now, only works with linear curvature. Could be adapted to 

281 # nonlinear curvature, it would just require the partial fraction 

282 # decomposition of the ratio of two 6th order polynomials. 

283 curvature_poly, bounds = _prepare_integrand0( 

284 floe_params, wave_params, growth_params, num_params, True 

285 ) 

286 # TODO: integral could be computed manually, without building the PPoly 

287 # object first. 

288 squared_curvature = _square_cubic_poly(curvature_poly) 

289 bounds = 0, floe_params[1] 

290 return squared_curvature.integrate(*bounds).item() 

291 

292 

293def _prepare_integrand( 

294 floe_params: tuple[float, float], 

295 wave_params: tuple[np.ndarray, np.ndarray], 

296 growth_params: tuple | None, 

297 num_params: dict, 

298 linear_curvature: bool, 

299) -> tuple[Callable[[IV], IV], tuple[float, float]]: 

300 curvature_poly, bounds = _prepare_integrand0( 

301 floe_params, 

302 wave_params, 

303 growth_params, 

304 num_params, 

305 linear_curvature, 

306 ) 

307 

308 def unit_energy(x): 

309 return curvature_poly(x) ** 2 

310 

311 return unit_energy, bounds 

312 

313 

314def _estimate_quad_limit( 

315 floe_length: float, wave_params: tuple[np.ndarray, np.ndarray] 

316) -> int: 

317 # wave_params := complex amplitudes, complex wavenumbers 

318 # When using `quad`, there might be convergence issues if the integrand 

319 # observes many oscillations wrt the range of integration, that is the 

320 # length of the floe. A usual heuristic seems to be fixing `limit` to L / 

321 # lambda * N with N between 10 to 20. Choosing a big number doesn't hurt 

322 # computing time, as the integration stops when reaching the desired 

323 # tolerance anyway. 

324 factor = 20 / PI_2 # high N, scaled by 2pi to get a wavelength 

325 # We arbitrarily choose highest wavenumber. 

326 highest_wave_number = np.real(wave_params[1]).max() 

327 # 50 is the default 

328 return max(50, np.ceil(factor * floe_length * highest_wave_number).astype(int)) 

329 

330 

331@typing.overload 

332def _quad_integration( 

333 integrand: Callable[[float], float], 

334 bounds: tuple[float, float], 

335 limit: int, 

336 debug: typing.Literal[True], 

337 **kwargs, 

338) -> tuple[float, float]: ... 

339 

340 

341@typing.overload 

342def _quad_integration( 

343 integrand: Callable[[float], float], 

344 bounds: tuple[float, float], 

345 limit: int, 

346 debug: typing.Literal[False], 

347 **kwargs, 

348) -> float: ... 

349 

350 

351@typing.overload 

352def _quad_integration( 

353 integrand: Callable[[float], float], 

354 bounds: tuple[float, float], 

355 limit: int, 

356 debug: bool, 

357 **kwargs, 

358) -> float | tuple[float, float]: ... 

359 

360 

361def _quad_integration( 

362 integrand: Callable[[float], float], 

363 bounds: tuple[float, float], 

364 limit: int, 

365 debug: bool = False, 

366 **kwargs, 

367) -> float | tuple[float, float]: 

368 result = integrate.quad(integrand, *bounds, limit=limit, **kwargs) 

369 if debug: 

370 return result 

371 return result[0] 

372 

373 

374@typing.overload 

375def _tanhsinh_integration( 

376 integrand: Callable[[np.ndarray], np.ndarray], 

377 bounds: tuple[float, float], 

378 debug: typing.Literal[True], 

379 **kwargs, 

380) -> _RichResult[float]: ... 

381 

382 

383@typing.overload 

384def _tanhsinh_integration( 

385 integrand: Callable[[np.ndarray], np.ndarray], 

386 bounds: tuple[float, float], 

387 debug: typing.Literal[False], 

388 **kwargs, 

389) -> float: ... 

390 

391 

392@typing.overload 

393def _tanhsinh_integration( 

394 integrand: Callable[[np.ndarray], np.ndarray], 

395 bounds: tuple[float, float], 

396 debug: bool, 

397 **kwargs, 

398) -> float | _RichResult[float]: ... 

399 

400 

401def _tanhsinh_integration( 

402 integrand: Callable[[np.ndarray], np.ndarray], 

403 bounds: tuple[float, float], 

404 debug: bool = False, 

405 **kwargs, 

406) -> float | _RichResult[float]: 

407 default_quad_tol = 1.49e-8 

408 for key in ("atol", "rtol"): 

409 if key not in kwargs: 

410 kwargs[key] = default_quad_tol 

411 try: 

412 result = integrate.tanhsinh(integrand, *bounds, **kwargs) 

413 except AttributeError: 

414 warnings.warn( 

415 "tanhsinh integration was made public in scipy 1.15.0. " 

416 "Proceeding anyway, but you might want to upgrade if possible, " 

417 "or use another integration method." 

418 ) 

419 result = integrate._tanhsinh.tanhsinh(integrand, *bounds, **kwargs) 

420 if debug: 

421 return result 

422 

423 return result.integral 

424 

425 

426# TODO: improve docstring 

427def unit_energy( 

428 floe_params: tuple[float, float], 

429 wave_params: tuple[np.ndarray, np.ndarray], 

430 growth_params, 

431 num_params, 

432 integration_method: str | None = None, 

433 linear_curvature: bool = True, 

434 **kwargs, 

435) -> float: 

436 """Numerically evaluate the energy. 

437 

438 The energy is up to a prefactor. 

439 

440 """ 

441 if not linear_curvature and integration_method == "pseudo_an": 

442 warnings.warn( 

443 f"The method {integration_method} can only be used with linear curvature. " 

444 "Using tanhsinh instead." 

445 ) 

446 integration_method = "tanhsinh" 

447 

448 if integration_method is None: 

449 if linear_curvature: 

450 integration_method = "pseudo_an" 

451 else: 

452 if hasattr(integrate, "tanhsinh"): 

453 integration_method = "tanhsinh" 

454 else: 

455 integration_method = "quad" 

456 

457 if integration_method == "pseudo_an": 

458 return _pseudo_analytical_integration( 

459 floe_params, wave_params, growth_params, num_params 

460 ) 

461 

462 integrand, bounds = _prepare_integrand( 

463 floe_params, 

464 wave_params, 

465 growth_params, 

466 num_params, 

467 linear_curvature, 

468 ) 

469 

470 if integration_method == "quad": 

471 limit = kwargs.pop("limit", None) 

472 if limit is None: 

473 limit = _estimate_quad_limit(floe_params[1], wave_params) 

474 return _quad_integration(integrand, bounds, limit=limit, **kwargs) 

475 elif integration_method == "tanhsinh": 

476 return _tanhsinh_integration(integrand, bounds, **kwargs) 

477 else: 

478 raise ValueError( 

479 "Integration method should be `pseudo_an`, `quad`, or `tanhsinh`." 

480 )