Coverage for src/flexfrac1d/lib/numerical.py: 90%
65 statements
« prev ^ index » next coverage.py v7.4.1, created at 2024-08-30 14:00 +0200
« prev ^ index » next coverage.py v7.4.1, created at 2024-08-30 14:00 +0200
1from numbers import Real
2import warnings
4import numpy as np
5import scipy.integrate as integrate
6import scipy.interpolate as interpolate
8from ._ph_utils import _unit_wavefield
11def _growth_kernel(x: np.ndarray, mean: np.ndarray, std):
12 kern = np.ones((mean.size, x.size))
13 mask = np.nonzero(x > mean)
14 kern[mask] = np.exp(-((x - mean) ** 2) / (2 * std**2))[mask]
15 return kern
18def free_surface(
19 x,
20 wave_params: tuple[np.ndarray],
21 growth_params: tuple[np.ndarray, Real] | None,
22) -> np.ndarray:
23 c_amplitudes, c_wavenumbers = wave_params
24 wave_shape = _unit_wavefield(x, c_wavenumbers)
25 if growth_params is not None:
26 kern = _growth_kernel(np.asarray(x), *growth_params)
27 wave_shape *= kern
28 eta = np.imag(c_amplitudes @ wave_shape)
29 return eta
32def _ode_system(
33 x,
34 w,
35 *,
36 floe_params: tuple[float],
37 wave_params: tuple[np.ndarray],
38 growth_params: tuple[np.ndarray, Real] | None,
39) -> np.ndarray:
40 red_num, _ = floe_params
41 eta = free_surface(x, wave_params, growth_params)
42 # Factor 4 comes from sqrt(2)**4
43 wprime = np.vstack((w[1], w[2], w[3], 4 * red_num**4 * (eta - w[0])))
44 return wprime
47def _boundary_conditions(wa, wb):
48 return np.array((wa[2], wb[2], wa[3], wb[3]))
51def _solve_bvp(
52 floe_params, wave_params, growth_params, **kwargs
53) -> integrate._bvp.BVPResult:
54 red_num, length = floe_params
55 wavenumber = np.real(wave_params[1])
56 n_mesh = max(5, int(length * max(red_num, wavenumber.max())))
57 x0 = np.linspace(0, length, n_mesh)
58 w0 = np.zeros((4, x0.size))
60 opt = integrate.solve_bvp(
61 lambda x, w: _ode_system(
62 x,
63 w,
64 floe_params=floe_params,
65 wave_params=wave_params,
66 growth_params=growth_params,
67 ),
68 _boundary_conditions,
69 x0,
70 w0,
71 **kwargs,
72 )
73 return opt
76def _get_result(
77 floe_params, wave_params, growth_params, num_params
78) -> integrate._bvp.BVPResult:
79 if num_params is None: 79 ↛ 81line 79 didn't jump to line 81, because the condition on line 79 was never false
80 num_params = dict()
81 opt = _solve_bvp(floe_params, wave_params, growth_params, **num_params)
82 if not opt.success: 82 ↛ 83line 82 didn't jump to line 83, because the condition on line 82 was never true
83 warnings.warn("Numerical solution did not converge", stacklevel=2)
84 return opt
87def _use_an_sol(
88 an_sol: bool | None, length: float, growth_params: tuple | None
89) -> None:
90 if an_sol is None:
91 if growth_params is None:
92 an_sol = True
93 else:
94 # If the wave growth kernel mean is to the right of the floe
95 # for every wave component, the wave is fully developed
96 # and the analytical solution can be used
97 an_sol = np.all(growth_params[0] > length)
98 return an_sol
101def _extract_from_poly(sol: interpolate.PPoly, n: int) -> interpolate.PPoly:
102 return interpolate.PPoly(sol.c[:, :, n], sol.x, extrapolate=False)
105def _extract_dis_poly(sol: interpolate.PPoly) -> interpolate.PPoly:
106 return _extract_from_poly(sol, 0)
109def _extract_cur_poly(sol: interpolate.PPoly) -> interpolate.PPoly:
110 return _extract_from_poly(sol, 2)
113def displacement(x, floe_params, wave_params, growth_params, num_params):
114 opt = _get_result(floe_params, wave_params, growth_params, num_params)
115 return _extract_dis_poly(opt.sol)(x)
118def curvature(x, floe_params, wave_params, growth_params, num_params):
119 opt = _get_result(floe_params, wave_params, growth_params, num_params)
120 return _extract_cur_poly(opt.sol)(x)
123def energy(floe_params, wave_params, growth_params, num_params) -> tuple[float]:
124 """Numerically evaluate the energy
126 The energy is up to a prefactor"""
127 opt = _get_result(floe_params, wave_params, growth_params, num_params)
128 curvature_poly = _extract_cur_poly(opt.sol)
130 def unit_energy(x: float) -> float:
131 return curvature_poly(x) ** 2
133 return integrate.quad(unit_energy, *opt.x[[0, -1]])