Coverage for src/swiift/lib/phase_shift.py: 74%

43 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-09-11 16:23 +0200

1"""Pseudo-scattering parameterisations.""" 

2 

3from __future__ import annotations 

4 

5import abc 

6import typing 

7 

8import attrs 

9import numpy as np 

10 

11from .constants import PI_2 

12 

13 

14def _seed_rng(seed: int): 

15 return np.random.default_rng(seed) 

16 

17 

18class _ScatteringHandler(abc.ABC): 

19 @abc.abstractmethod 

20 def compute_edge_amplitudes( 

21 self, 

22 edge_amplitudes: np.ndarray, 

23 c_wavenumbers: np.ndarray, 

24 xf: np.ndarray, 

25 ) -> np.ndarray: 

26 """Determine post-breakup wave amplitudes at the edge of new floes. 

27 

28 The wave propagates and is attenuated underneath the floe. For the 

29 current timestep, for each wave component, the complex wave amplitude 

30 is fully determined at all the coordinates where fracture is about to 

31 occur. After fractures have occured, the complex amplitudes at these 

32 coordinates become the complex amplitudes at the left edges of new 

33 fragments. Further pseudo-scattering rules can be used if it is not 

34 desirable to keep the wave surface in phase on both sides of a floe 

35 edge. 

36 

37 

38 Parameters 

39 ---------- 

40 edge_amplitudes : np.ndarray of complex 

41 The complex wave amplitudes at the edge of a breaking floe, in m 

42 c_wavenumbers : np.ndarray of complex 

43 The complex wavenumbers stressing the floe, in m^-1 

44 xf : np.ndarray of float 

45 The coordinates of fractures, in m 

46 

47 Returns 

48 ------- 

49 np.ndarray of complex 

50 

51 """ 

52 

53 

54@attrs.frozen 

55class ContinuousScatteringHandler(_ScatteringHandler): 

56 """No scattering. 

57 

58 The surface stays continuous across floes edges. 

59 

60 """ 

61 

62 def compute_edge_amplitudes( 

63 self, 

64 edge_amplitudes, 

65 c_wavenumbers: np.ndarray, 

66 xf: np.ndarray, 

67 ) -> np.ndarray: 

68 return edge_amplitudes * np.exp(1j * c_wavenumbers * xf[:, None]) 

69 

70 

71@attrs.define 

72class _RandomScatteringHandler(_ScatteringHandler): 

73 @classmethod 

74 @abc.abstractmethod 

75 def from_seed(cls, seed: int, *_: typing.Any, **__: typing.Any) -> typing.Self: 

76 """Instantiate self with an RNG seeded by an integer. 

77 

78 Parameters 

79 ---------- 

80 seed : int 

81 A seed passed to `numpy.random.default_rng` 

82 

83 """ 

84 

85 

86@attrs.frozen 

87class UniformScatteringHandler(_RandomScatteringHandler): 

88 r"""Scattering with uniformly sampled new phases. 

89 

90 The wave phase at the edge of a new floe is sampled from the uniform 

91 distribution on :math:`[0; 2\pi)`. 

92 

93 Parameters 

94 ---------- 

95 rng : numpy.random.Generator 

96 Random generator used to sample phases 

97 

98 

99 ------ 

100 )] 

101 

102 """ 

103 

104 rng: np.random.Generator 

105 

106 @classmethod 

107 def from_seed(cls, seed: int) -> typing.Self: 

108 """Instantiate self with an RNG seeded by an integer. 

109 

110 Parameters 

111 ---------- 

112 seed : int 

113 A seed passed to `numpy.random.default_rng` 

114 

115 Returns 

116 ------- 

117 UniformScatteringHandler 

118 

119 """ 

120 return cls(_seed_rng(seed)) 

121 

122 def compute_edge_amplitudes( 

123 self, 

124 edge_amplitudes: np.ndarray, 

125 c_wavenumbers: np.ndarray, 

126 xf: np.ndarray, 

127 ) -> np.ndarray: 

128 phases = self.rng.uniform(0, PI_2, size=edge_amplitudes.shape) 

129 return ( 

130 np.abs(edge_amplitudes) 

131 * np.exp(-np.imag(c_wavenumbers) * xf[:, None]) 

132 * np.exp(1j * phases) 

133 ) 

134 

135 

136@attrs.frozen 

137class PerturbationScatteringHandler(_RandomScatteringHandler): 

138 """Scattering with phases perturbated around the continuous solution. 

139 

140 The wave phase at the left edge of a new floe is computed to maintain 

141 continuity of the surface across the edge. Then, a random perturbation 

142 sampled from a normal distribution, is added to the phase. 

143 

144 Attributes 

145 ---------- 

146 rng : numpy.random.Generator 

147 Random generator used to sample perturbations 

148 loc : float 

149 The mean of the normal distribution used to sample perturbations, 

150 in rad 

151 scale : float 

152 The standard deviation of the normal distribution used to sample 

153 perturbations, in rad 

154 

155 Notes 

156 ----- 

157 Perturbations are always added to an existing phase. The expectation of the 

158 resulting phase is thus the sum of `loc` and the phase of the continuous 

159 solution. 

160 

161 """ 

162 

163 rng: np.random.Generator 

164 loc: float 

165 scale: float 

166 

167 @classmethod 

168 def from_seed(cls, seed: int, loc: float = 0, scale: float = 1) -> typing.Self: 

169 """Instantiate with an RNG seeded with an integer. 

170 

171 Parameters 

172 ---------- 

173 seed : int 

174 A seed passed to `numpy.random.default_rng` 

175 loc : float 

176 Mean of a normal distribution, in rad 

177 scale : float 

178 Standard deviation of a normal distribution, in rad 

179 

180 Returns 

181 ------- 

182 PerturbationScatteringHandler 

183 

184 """ 

185 rng = _seed_rng(seed) 

186 return cls(rng, loc, scale) 

187 

188 def compute_edge_amplitudes( 

189 self, 

190 edge_amplitudes: np.ndarray, 

191 c_wavenumbers: np.ndarray, 

192 xf: np.ndarray, 

193 ) -> np.ndarray: 

194 edge_amplitudes = ContinuousScatteringHandler().compute_edge_amplitudes( 

195 edge_amplitudes, c_wavenumbers, xf 

196 ) 

197 perturbations = self.rng.normal( 

198 self.loc, self.scale, size=edge_amplitudes.shape 

199 ) 

200 edge_amplitudes *= np.exp(1j * perturbations) 

201 return edge_amplitudes