Coverage for src/swiift/lib/numerical.py: 18%
142 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-09-11 16:23 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-09-11 16:23 +0200
1from __future__ import annotations
3from collections.abc import Callable
4import typing
5import warnings
7import numpy as np
8from scipy._lib._util import _RichResult
9import scipy.integrate as integrate
10import scipy.interpolate as interpolate
12from swiift.lib.constants import PI_2
14from ._ph_utils import _unit_wavefield
16IV = typing.TypeVar(
17 "IV", float, np.ndarray[tuple[int], np.dtype[np.float64]]
18) # Integration variable.
20CUBIC_BINOMIAL_COEFS = np.array([0, 0, 0, 1, 1, 2]), np.array([1, 2, 3, 2, 3, 3])
23def _growth_kernel(x: np.ndarray, mean: np.ndarray, std):
24 kern = np.ones((mean.size, x.size))
25 mask = np.nonzero(x > mean)
26 kern[mask] = np.exp(-((x - mean) ** 2) / (2 * std**2))[mask]
27 return kern
30def free_surface(
31 x,
32 wave_params: tuple[np.ndarray, np.ndarray],
33 growth_params: tuple[np.ndarray, float] | None,
34) -> np.ndarray:
35 c_amplitudes, c_wavenumbers = wave_params
36 wave_shape = _unit_wavefield(x, c_wavenumbers)
37 if growth_params is not None:
38 kern = _growth_kernel(np.asarray(x), *growth_params)
39 wave_shape *= kern
40 eta = np.imag(c_amplitudes @ wave_shape)
41 return eta
44def _ode_system(
45 x,
46 w,
47 *,
48 floe_params: tuple[float, float],
49 wave_params: tuple[np.ndarray, np.ndarray],
50 growth_params: tuple[np.ndarray, float] | None,
51) -> np.ndarray:
52 red_num, _ = floe_params
53 eta = free_surface(x, wave_params, growth_params)
54 # Factor 4 comes from sqrt(2)**4
55 wprime = np.vstack((w[1], w[2], w[3], 4 * red_num**4 * (eta - w[0])))
56 return wprime
59def _boundary_conditions(wa, wb):
60 return np.array((wa[2], wb[2], wa[3], wb[3]))
63def _solve_bvp(
64 floe_params, wave_params, growth_params, **kwargs
65) -> integrate._bvp.BVPResult:
66 red_num, length = floe_params
67 wavenumber = np.real(wave_params[1])
68 n_mesh = max(5, int(length * max(red_num, wavenumber.max())))
69 x0 = np.linspace(0, length, n_mesh)
70 w0 = np.zeros((4, x0.size))
72 opt = integrate.solve_bvp(
73 lambda x, w: _ode_system(
74 x,
75 w,
76 floe_params=floe_params,
77 wave_params=wave_params,
78 growth_params=growth_params,
79 ),
80 _boundary_conditions,
81 x0,
82 w0,
83 **kwargs,
84 )
85 return opt
88def _get_result(
89 floe_params, wave_params, growth_params, num_params
90) -> integrate._bvp.BVPResult:
91 if num_params is None:
92 num_params = dict()
93 opt = _solve_bvp(floe_params, wave_params, growth_params, **num_params)
94 if not opt.success:
95 warnings.warn("Numerical solution did not converge", stacklevel=2)
96 return opt
99def _use_an_sol(
100 analytical_solution: bool | None,
101 length: float,
102 growth_params: tuple | None,
103 linear_curvature: bool | None,
104) -> bool:
105 """Determine whether to use an analytical solution.
107 The displacement, curvature, and elastic energy have analytical expressions
108 under certain conditions. These are used if `analytical_solution` is
109 explicitely set to `True`. Otherwise, the other parameters are examined to
110 determine if the analytical solutions can (and, therefore, should) be used.
111 If `growth_params` is not provided, or if all its location values are
112 greater than `length`, and if `linear_curvature` is not provided or set to
113 `True`, analytical solutions will be used. If `linear_curvature` is set to
114 `False`, numerical solutions will be used.
116 Parameters
117 ----------
118 analytical_solution : bool, optional
119 Set to `True` to force using analytical solutions.
120 length : float
121 Length of the floe.
122 growth_params : tuple, optional
123 Parameters of a wave growth kernel.
124 linear_curvature : bool, optional
125 Set to `False` to force using numerical approximations to the
126 non-linear curvature. It has no effect for other variables
127 (displacement, elastic energy) but *does* force using numerical
128 solutions instead of analytical solutions, if set to `False`.
130 Returns
131 -------
132 bool
134 """
135 if analytical_solution is not None:
136 return analytical_solution
137 if growth_params is None:
138 if linear_curvature is None:
139 return True
140 # No analytical solution for non-linear curvature
141 return linear_curvature
142 else:
143 if linear_curvature is None or linear_curvature:
144 # If the wave growth kernel mean is to the right of the floe
145 # for every wave component, the wave is fully developed
146 # and the analytical solution can be used.
147 # Alternatively, if there is a single component of the kernel whose
148 # mean is to the left of the right edge, the numerical solution
149 # must be used.
150 return not np.any(growth_params[0] < length)
151 return False
154def _extract_from_poly(sol: interpolate.PPoly, n: int) -> interpolate.PPoly:
155 return interpolate.PPoly(sol.c[:, :, n], sol.x, extrapolate=False)
158def _extract_dis_poly(sol: interpolate.PPoly) -> interpolate.PPoly:
159 return _extract_from_poly(sol, 0)
162def _non_lin_curv(sol: interpolate.PPoly) -> Callable[[IV], np.ndarray]:
163 def non_lin_curv(x: IV) -> np.ndarray:
164 return (
165 _extract_from_poly(sol, 2)(x)
166 / (1 + _extract_from_poly(sol, 1)(x) ** 2) ** 1.5
167 )
169 return non_lin_curv
172@typing.overload
173def _extract_cur_poly(
174 sol: interpolate.PPoly,
175 is_linear: typing.Literal[True] = ...,
176) -> interpolate.PPoly: ...
179@typing.overload
180def _extract_cur_poly(
181 sol: interpolate.PPoly,
182 is_linear: typing.Literal[False] = ...,
183) -> Callable[[IV], np.ndarray]: ...
186@typing.overload
187def _extract_cur_poly(
188 sol: interpolate.PPoly, is_linear: bool = ...
189) -> interpolate.PPoly | Callable[[IV], np.ndarray]: ...
192def _extract_cur_poly(sol: interpolate.PPoly, is_linear: bool = True):
193 if is_linear:
194 return _extract_from_poly(sol, 2)
195 else:
196 return _non_lin_curv(sol)
199def displacement(x, floe_params, wave_params, growth_params, num_params):
200 opt = _get_result(floe_params, wave_params, growth_params, num_params)
201 return _extract_dis_poly(opt.sol)(x)
204def curvature(
205 x, floe_params, wave_params, growth_params, num_params, is_linear: bool = True
206):
207 opt = _get_result(floe_params, wave_params, growth_params, num_params)
208 return _extract_cur_poly(opt.sol, is_linear)(x)
211@typing.overload
212def _prepare_integrand0(
213 floe_params: tuple[float, float],
214 wave_params: tuple[np.ndarray, np.ndarray],
215 growth_params,
216 num_params,
217 linear_curvature: typing.Literal[True],
218) -> tuple[interpolate.PPoly, tuple[float, float]]: ...
221@typing.overload
222def _prepare_integrand0(
223 floe_params: tuple[float, float],
224 wave_params: tuple[np.ndarray, np.ndarray],
225 growth_params,
226 num_params,
227 linear_curvature: typing.Literal[False],
228) -> tuple[typing.Callable[[IV], np.ndarray], tuple[float, float]]: ...
231@typing.overload
232def _prepare_integrand0(
233 floe_params: tuple[float, float],
234 wave_params: tuple[np.ndarray, np.ndarray],
235 growth_params,
236 num_params,
237 linear_curvature: bool,
238) -> tuple[
239 interpolate.PPoly | typing.Callable[[IV], np.ndarray], tuple[float, float]
240]: ...
243def _prepare_integrand0(
244 floe_params: tuple[float, float],
245 wave_params: tuple[np.ndarray, np.ndarray],
246 growth_params,
247 num_params,
248 linear_curvature: bool,
249):
250 opt = _get_result(floe_params, wave_params, growth_params, num_params)
251 curvature_poly = _extract_cur_poly(opt.sol, linear_curvature)
252 bounds = opt.x[0], opt.x[-1]
253 return curvature_poly, bounds
256def _square_cubic_poly(ppoly: interpolate.PPoly) -> interpolate.PPoly:
257 # PPoly object have coefficients ordered opposite wrt to powers. That is,
258 # for a cubic, c[0] is the coefficient of the cubic term, c[3] of the
259 # constant term.
260 new_cs = np.zeros((ppoly.c.shape[0] * 2 - 1, ppoly.c.shape[1]))
261 cs = ppoly.c
262 idx1, idx2 = CUBIC_BINOMIAL_COEFS
263 new_cs[::2, :] = cs**2
264 extra_terms = 2 * cs[idx1] * cs[idx2]
266 # Need to do them one by one to avoid silent errors, as idx1 + idx2 has
267 # values with multiplicity > 1.
268 for idx, term in zip(idx1 + idx2, extra_terms):
269 new_cs[idx] += term
271 return interpolate.PPoly(new_cs, ppoly.x)
274def _pseudo_analytical_integration(
275 floe_params: tuple[float, float],
276 wave_params: tuple[np.ndarray, np.ndarray],
277 growth_params: tuple | None,
278 num_params: dict,
279) -> float:
280 # TODO: for now, only works with linear curvature. Could be adapted to
281 # nonlinear curvature, it would just require the partial fraction
282 # decomposition of the ratio of two 6th order polynomials.
283 curvature_poly, bounds = _prepare_integrand0(
284 floe_params, wave_params, growth_params, num_params, True
285 )
286 # TODO: integral could be computed manually, without building the PPoly
287 # object first.
288 squared_curvature = _square_cubic_poly(curvature_poly)
289 bounds = 0, floe_params[1]
290 return squared_curvature.integrate(*bounds).item()
293def _prepare_integrand(
294 floe_params: tuple[float, float],
295 wave_params: tuple[np.ndarray, np.ndarray],
296 growth_params: tuple | None,
297 num_params: dict,
298 linear_curvature: bool,
299) -> tuple[Callable[[IV], IV], tuple[float, float]]:
300 curvature_poly, bounds = _prepare_integrand0(
301 floe_params,
302 wave_params,
303 growth_params,
304 num_params,
305 linear_curvature,
306 )
308 def unit_energy(x):
309 return curvature_poly(x) ** 2
311 return unit_energy, bounds
314def _estimate_quad_limit(
315 floe_length: float, wave_params: tuple[np.ndarray, np.ndarray]
316) -> int:
317 # wave_params := complex amplitudes, complex wavenumbers
318 # When using `quad`, there might be convergence issues if the integrand
319 # observes many oscillations wrt the range of integration, that is the
320 # length of the floe. A usual heuristic seems to be fixing `limit` to L /
321 # lambda * N with N between 10 to 20. Choosing a big number doesn't hurt
322 # computing time, as the integration stops when reaching the desired
323 # tolerance anyway.
324 factor = 20 / PI_2 # high N, scaled by 2pi to get a wavelength
325 # We arbitrarily choose highest wavenumber.
326 highest_wave_number = np.real(wave_params[1]).max()
327 # 50 is the default
328 return max(50, np.ceil(factor * floe_length * highest_wave_number).astype(int))
331@typing.overload
332def _quad_integration(
333 integrand: Callable[[float], float],
334 bounds: tuple[float, float],
335 limit: int,
336 debug: typing.Literal[True],
337 **kwargs,
338) -> tuple[float, float]: ...
341@typing.overload
342def _quad_integration(
343 integrand: Callable[[float], float],
344 bounds: tuple[float, float],
345 limit: int,
346 debug: typing.Literal[False],
347 **kwargs,
348) -> float: ...
351@typing.overload
352def _quad_integration(
353 integrand: Callable[[float], float],
354 bounds: tuple[float, float],
355 limit: int,
356 debug: bool,
357 **kwargs,
358) -> float | tuple[float, float]: ...
361def _quad_integration(
362 integrand: Callable[[float], float],
363 bounds: tuple[float, float],
364 limit: int,
365 debug: bool = False,
366 **kwargs,
367) -> float | tuple[float, float]:
368 result = integrate.quad(integrand, *bounds, limit=limit, **kwargs)
369 if debug:
370 return result
371 return result[0]
374@typing.overload
375def _tanhsinh_integration(
376 integrand: Callable[[np.ndarray], np.ndarray],
377 bounds: tuple[float, float],
378 debug: typing.Literal[True],
379 **kwargs,
380) -> _RichResult[float]: ...
383@typing.overload
384def _tanhsinh_integration(
385 integrand: Callable[[np.ndarray], np.ndarray],
386 bounds: tuple[float, float],
387 debug: typing.Literal[False],
388 **kwargs,
389) -> float: ...
392@typing.overload
393def _tanhsinh_integration(
394 integrand: Callable[[np.ndarray], np.ndarray],
395 bounds: tuple[float, float],
396 debug: bool,
397 **kwargs,
398) -> float | _RichResult[float]: ...
401def _tanhsinh_integration(
402 integrand: Callable[[np.ndarray], np.ndarray],
403 bounds: tuple[float, float],
404 debug: bool = False,
405 **kwargs,
406) -> float | _RichResult[float]:
407 default_quad_tol = 1.49e-8
408 for key in ("atol", "rtol"):
409 if key not in kwargs:
410 kwargs[key] = default_quad_tol
411 try:
412 result = integrate.tanhsinh(integrand, *bounds, **kwargs)
413 except AttributeError:
414 warnings.warn(
415 "tanhsinh integration was made public in scipy 1.15.0. "
416 "Proceeding anyway, but you might want to upgrade if possible, "
417 "or use another integration method."
418 )
419 result = integrate._tanhsinh.tanhsinh(integrand, *bounds, **kwargs)
420 if debug:
421 return result
423 return result.integral
426# TODO: improve docstring
427def unit_energy(
428 floe_params: tuple[float, float],
429 wave_params: tuple[np.ndarray, np.ndarray],
430 growth_params,
431 num_params,
432 integration_method: str | None = None,
433 linear_curvature: bool = True,
434 **kwargs,
435) -> float:
436 """Numerically evaluate the energy.
438 The energy is up to a prefactor.
440 """
441 if not linear_curvature and integration_method == "pseudo_an":
442 warnings.warn(
443 f"The method {integration_method} can only be used with linear curvature. "
444 "Using tanhsinh instead."
445 )
446 integration_method = "tanhsinh"
448 if integration_method is None:
449 if linear_curvature:
450 integration_method = "pseudo_an"
451 else:
452 if hasattr(integrate, "tanhsinh"):
453 integration_method = "tanhsinh"
454 else:
455 integration_method = "quad"
457 if integration_method == "pseudo_an":
458 return _pseudo_analytical_integration(
459 floe_params, wave_params, growth_params, num_params
460 )
462 integrand, bounds = _prepare_integrand(
463 floe_params,
464 wave_params,
465 growth_params,
466 num_params,
467 linear_curvature,
468 )
470 if integration_method == "quad":
471 limit = kwargs.pop("limit", None)
472 if limit is None:
473 limit = _estimate_quad_limit(floe_params[1], wave_params)
474 return _quad_integration(integrand, bounds, limit=limit, **kwargs)
475 elif integration_method == "tanhsinh":
476 return _tanhsinh_integration(integrand, bounds, **kwargs)
477 else:
478 raise ValueError(
479 "Integration method should be `pseudo_an`, `quad`, or `tanhsinh`."
480 )