Coverage for tests/test_domain.py: 0%

92 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-09-11 16:23 +0200

1from hypothesis import given 

2import numpy as np 

3import pytest 

4 

5import swiift.lib.att as att 

6import swiift.model.frac_handlers as fh 

7from swiift.model.model import ( 

8 DiscreteSpectrum, 

9 Domain, 

10 Floe, 

11 FreeSurfaceWaves, 

12 Ocean, 

13) 

14from tests.model_strategies import ocean_and_spectrum, simple_objects 

15from tests.utils import fracture_handler_types 

16 

17growth_params = (None, (-13, None), (-28, 75), (np.array([-45]), None)) 

18 

19 

20def instantiate_domain(att_spec: att.Attenuation, is_mono: bool) -> Domain: 

21 ocean = simple_objects["ocean"] 

22 if is_mono: 

23 spectrum = simple_objects["spec_mono"] 

24 else: 

25 spectrum = simple_objects["spec_poly"] 

26 gravity = simple_objects["gravity"] 

27 return Domain.from_discrete(gravity, spectrum, ocean, attenuation=att_spec) 

28 

29 

30@given(**ocean_and_spectrum) 

31def test_initialisation(gravity: float, spectrum: DiscreteSpectrum, ocean: Ocean): 

32 domain = Domain.from_discrete(gravity, spectrum, ocean) 

33 fsw = FreeSurfaceWaves.from_ocean(ocean, spectrum, gravity) 

34 

35 assert domain.gravity == gravity 

36 

37 assert domain.fsw.ocean == ocean 

38 assert np.all(domain.fsw.wavenumbers == fsw.wavenumbers) 

39 

40 # TODO: to reenable when DiscreteSpectrum has been attrs'd 

41 # assert np.all(domain.spectrum.amplitudes == spectrum.amplitudes) 

42 # assert np.all(domain.spectrum.frequencies == spectrum.frequencies) 

43 # assert np.all(domain.spectrum.phases == spectrum.phases) 

44 

45 assert domain.growth_params is None 

46 assert domain.attenuation == att.AttenuationParameterisation.PARAM_01 

47 

48 assert len(domain.cached_wuis) == 0 

49 assert len(domain.cached_phases) == 0 

50 

51 

52@given(**ocean_and_spectrum) 

53def test_failing(gravity: float, spectrum: DiscreteSpectrum, ocean: Ocean): 

54 with pytest.raises(TypeError): 

55 Domain.from_discrete(gravity, spectrum, ocean, growth_params=1) 

56 

57 with pytest.raises(ValueError): 

58 Domain.from_discrete(gravity, spectrum, ocean, growth_params=(1, 1, 1)) 

59 

60 nf = spectrum.nf 

61 if nf in (1, 2): 

62 nf = 3 

63 else: 

64 nf -= 1 

65 means = np.zeros(nf) 

66 with pytest.raises(ValueError): 

67 Domain.with_growth_means(gravity, spectrum, ocean, growth_means=means) 

68 

69 

70def instantiate_floe() -> Floe: 

71 return Floe( 

72 left_edge=simple_objects["left_edge"], 

73 length=simple_objects["length"], 

74 ice=simple_objects["ice"], 

75 ) 

76 

77 

78@pytest.mark.parametrize("is_mono", (True, False)) 

79@pytest.mark.parametrize("att_spec", att.AttenuationParameterisation) 

80def test_att_parameterisations(att_spec, is_mono): 

81 floe = instantiate_floe() 

82 domain = instantiate_domain(att_spec, is_mono) 

83 domain.add_floes(floe) 

84 assert len(domain.cached_wuis) == 1 

85 assert floe.ice in domain.cached_wuis 

86 

87 

88def test_promote(): 

89 floe = instantiate_floe() 

90 res = Domain._promote_floe(floe) 

91 assert isinstance(res, tuple) 

92 assert len(res) == 1 

93 assert res[0] == floe 

94 

95 floes = [floe] 

96 res = Domain._promote_floe(floes) 

97 assert res == floes 

98 

99 with pytest.raises(ValueError): 

100 Domain._promote_floe(1) 

101 

102 

103@pytest.mark.parametrize("is_mono", (True, False)) 

104@pytest.mark.parametrize("att_spec", att.AttenuationParameterisation) 

105@pytest.mark.parametrize("fracture_handler_type", fracture_handler_types) 

106def test_breakup( 

107 att_spec: att.AttenuationParameterisation, 

108 is_mono: bool, 

109 fracture_handler_type: type[fh._FractureHandler], 

110): 

111 fracture_handler = fracture_handler_type() 

112 domain = instantiate_domain(att_spec, is_mono) 

113 floe = instantiate_floe() 

114 domain.add_floes(floe) 

115 wuf0 = domain.subdomains[0] 

116 assert len(domain.subdomains) == 1 

117 

118 domain.breakup(fracture_handler, an_sol=True) 

119 

120 # Check we did have some breakup 

121 match fracture_handler: 

122 case fh.BinaryFracture() | fh.BinaryStrainFracture(): 

123 assert len(domain.subdomains) == 2 

124 case fh.MultipleStrainFracture(): 

125 assert len(domain.subdomains) >= 2 

126 case _: # pragma: no cover 

127 raise ValueError("Unknown fracture handler") 

128 

129 # Check the edge has not moved 

130 assert domain.subdomains[0].left_edge == wuf0.left_edge 

131 

132 # Check we did not duplicate WUIs 

133 for _wuf in domain.subdomains[1:]: 

134 assert _wuf.wui is wuf0.wui 

135 

136 # Check all new floes except the last have had their generation counter 

137 # incremented. The last one should have the same generation counter. 

138 for _wuf in domain.subdomains[:-1]: 

139 assert _wuf.generation == wuf0.generation + 1 

140 assert domain.subdomains[-1].generation == wuf0.generation 

141 

142 # Check individual fragment lengths are less than that of the original floe. 

143 lengths = np.array([_wuf.length for _wuf in domain.subdomains]) 

144 assert np.all(lengths < wuf0.length) 

145 

146 # Check floes are in order of their left edges. 

147 left_edges = np.array([_wuf.left_edge for _wuf in domain.subdomains]) 

148 assert np.all(np.ediff1d(left_edges) > 0) 

149 

150 # Check the two definitions are equivalent 

151 relative_new_edges = left_edges[1:] - left_edges[0] 

152 assert np.allclose(relative_new_edges, lengths[:-1].cumsum()) 

153 

154 # Checked the complex amplitude at the edge is identical, as it should be 

155 # modified at a later step 

156 assert np.all(domain.subdomains[0].edge_amplitudes == wuf0.edge_amplitudes) 

157 # Check new fragments have the expected complex amplitudes at their left 

158 # edges. As there is no random scattering here, these amplitudes are 

159 # obtained by "propagating" spatially the original edge amplitudes over the 

160 # fragment lengths. 

161 phase_diffs = relative_new_edges[:, None] * ( 

162 wuf0.wui.wavenumbers + 1j * wuf0.wui.attenuations 

163 ) 

164 assert np.all( 

165 np.isclose( 

166 np.vstack([_wuf.edge_amplitudes for _wuf in domain.subdomains[1:]]), 

167 wuf0.edge_amplitudes * np.exp(1j * phase_diffs), 

168 ) 

169 )