Coverage for src/flexfrac1d/lib/numerical.py: 90%

65 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-08-30 14:00 +0200

1from numbers import Real 

2import warnings 

3 

4import numpy as np 

5import scipy.integrate as integrate 

6import scipy.interpolate as interpolate 

7 

8from ._ph_utils import _unit_wavefield 

9 

10 

11def _growth_kernel(x: np.ndarray, mean: np.ndarray, std): 

12 kern = np.ones((mean.size, x.size)) 

13 mask = np.nonzero(x > mean) 

14 kern[mask] = np.exp(-((x - mean) ** 2) / (2 * std**2))[mask] 

15 return kern 

16 

17 

18def free_surface( 

19 x, 

20 wave_params: tuple[np.ndarray], 

21 growth_params: tuple[np.ndarray, Real] | None, 

22) -> np.ndarray: 

23 c_amplitudes, c_wavenumbers = wave_params 

24 wave_shape = _unit_wavefield(x, c_wavenumbers) 

25 if growth_params is not None: 

26 kern = _growth_kernel(np.asarray(x), *growth_params) 

27 wave_shape *= kern 

28 eta = np.imag(c_amplitudes @ wave_shape) 

29 return eta 

30 

31 

32def _ode_system( 

33 x, 

34 w, 

35 *, 

36 floe_params: tuple[float], 

37 wave_params: tuple[np.ndarray], 

38 growth_params: tuple[np.ndarray, Real] | None, 

39) -> np.ndarray: 

40 red_num, _ = floe_params 

41 eta = free_surface(x, wave_params, growth_params) 

42 # Factor 4 comes from sqrt(2)**4 

43 wprime = np.vstack((w[1], w[2], w[3], 4 * red_num**4 * (eta - w[0]))) 

44 return wprime 

45 

46 

47def _boundary_conditions(wa, wb): 

48 return np.array((wa[2], wb[2], wa[3], wb[3])) 

49 

50 

51def _solve_bvp( 

52 floe_params, wave_params, growth_params, **kwargs 

53) -> integrate._bvp.BVPResult: 

54 red_num, length = floe_params 

55 wavenumber = np.real(wave_params[1]) 

56 n_mesh = max(5, int(length * max(red_num, wavenumber.max()))) 

57 x0 = np.linspace(0, length, n_mesh) 

58 w0 = np.zeros((4, x0.size)) 

59 

60 opt = integrate.solve_bvp( 

61 lambda x, w: _ode_system( 

62 x, 

63 w, 

64 floe_params=floe_params, 

65 wave_params=wave_params, 

66 growth_params=growth_params, 

67 ), 

68 _boundary_conditions, 

69 x0, 

70 w0, 

71 **kwargs, 

72 ) 

73 return opt 

74 

75 

76def _get_result( 

77 floe_params, wave_params, growth_params, num_params 

78) -> integrate._bvp.BVPResult: 

79 if num_params is None: 79 ↛ 81line 79 didn't jump to line 81, because the condition on line 79 was never false

80 num_params = dict() 

81 opt = _solve_bvp(floe_params, wave_params, growth_params, **num_params) 

82 if not opt.success: 82 ↛ 83line 82 didn't jump to line 83, because the condition on line 82 was never true

83 warnings.warn("Numerical solution did not converge", stacklevel=2) 

84 return opt 

85 

86 

87def _use_an_sol( 

88 an_sol: bool | None, length: float, growth_params: tuple | None 

89) -> None: 

90 if an_sol is None: 

91 if growth_params is None: 

92 an_sol = True 

93 else: 

94 # If the wave growth kernel mean is to the right of the floe 

95 # for every wave component, the wave is fully developed 

96 # and the analytical solution can be used 

97 an_sol = np.all(growth_params[0] > length) 

98 return an_sol 

99 

100 

101def _extract_from_poly(sol: interpolate.PPoly, n: int) -> interpolate.PPoly: 

102 return interpolate.PPoly(sol.c[:, :, n], sol.x, extrapolate=False) 

103 

104 

105def _extract_dis_poly(sol: interpolate.PPoly) -> interpolate.PPoly: 

106 return _extract_from_poly(sol, 0) 

107 

108 

109def _extract_cur_poly(sol: interpolate.PPoly) -> interpolate.PPoly: 

110 return _extract_from_poly(sol, 2) 

111 

112 

113def displacement(x, floe_params, wave_params, growth_params, num_params): 

114 opt = _get_result(floe_params, wave_params, growth_params, num_params) 

115 return _extract_dis_poly(opt.sol)(x) 

116 

117 

118def curvature(x, floe_params, wave_params, growth_params, num_params): 

119 opt = _get_result(floe_params, wave_params, growth_params, num_params) 

120 return _extract_cur_poly(opt.sol)(x) 

121 

122 

123def energy(floe_params, wave_params, growth_params, num_params) -> tuple[float]: 

124 """Numerically evaluate the energy 

125 

126 The energy is up to a prefactor""" 

127 opt = _get_result(floe_params, wave_params, growth_params, num_params) 

128 curvature_poly = _extract_cur_poly(opt.sol) 

129 

130 def unit_energy(x: float) -> float: 

131 return curvature_poly(x) ** 2 

132 

133 return integrate.quad(unit_energy, *opt.x[[0, -1]])