Coverage for tests/test_experiment.py: 0%

360 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-09-11 16:23 +0200

1import io 

2import logging 

3import pathlib 

4import pickle 

5 

6from hypothesis import HealthCheck, given, settings, strategies as st 

7import numpy as np 

8import pytest 

9from pytest_mock import MockerFixture 

10 

11import swiift.api.api as api 

12from swiift.api.api import Experiment 

13import swiift.lib.att as att 

14import swiift.lib.phase_shift as ps 

15import swiift.model.frac_handlers as fh 

16from swiift.model.model import DiscreteSpectrum, Domain, Floe, Ice, Ocean 

17from tests.model_strategies import coupled_ocean_ice, ocean_and_mono_spectrum, spec_mono 

18from tests.utils import float_kw, fracture_handler_types 

19 

20epxeriment_targets_path = "tests/target/experiments" 

21fname_pattern = "exper_test_no_sortedlist*" 

22 

23attenuation_parameterisations = att.AttenuationParameterisation 

24growth_params = (None, (-13, None), (-28, 75), (np.array([-45]), None)) 

25 

26 

27loading_options = ("str", "path", "cwd") 

28 

29 

30class DummyPbar: 

31 def __init__(self): 

32 self.updates = 0 

33 self.closed = False 

34 

35 def update(self, n): 

36 self.updates += n 

37 

38 def close(self): 

39 self.closed = True 

40 

41 @classmethod 

42 def write(cls, msg): 

43 pass 

44 

45 

46def mock_breakup(*args): 

47 return 

48 

49 

50@st.composite 

51def run_time_chunks_composite(draw: st.DrawFn) -> tuple[int, float, int]: 

52 n_step = draw(st.integers(min_value=1, max_value=15)) 

53 delta_time = draw(st.floats(min_value=0.01, max_value=5.0, **float_kw)) 

54 chunk_size = draw(st.integers(min_value=1, max_value=n_step)) 

55 

56 return n_step, delta_time, chunk_size 

57 

58 

59def setup_experiment() -> api.Experiment: 

60 amplitude = 2 

61 period = 7 

62 spectrum = DiscreteSpectrum(amplitude, 1 / period) 

63 depth = np.inf 

64 ocean = Ocean(depth=depth) 

65 gravity = 9.8 

66 return Experiment.from_discrete(gravity, spectrum, ocean) 

67 

68 

69def setup_experiment_with_floe() -> tuple[api.Experiment, Floe]: 

70 experiment = setup_experiment() 

71 thickness = 0.5 

72 ice = Ice(thickness=thickness) 

73 floe = Floe(left_edge=0, length=200, ice=ice) 

74 experiment.add_floes(floe) 

75 return experiment, floe 

76 

77 

78def step_experiment(experiment: api.Experiment, delta_t: float) -> api.Experiment: 

79 experiment.step(delta_t) 

80 return experiment 

81 

82 

83@pytest.fixture(scope="function") 

84def experiment_with_history() -> api.Experiment: 

85 return api.load_pickles(fname_pattern, epxeriment_targets_path) 

86 

87 

88@pytest.mark.parametrize("dir_to_create", ("tmp_dir", pathlib.Path("tmp_dir2"))) 

89def test_create_directory(tmp_path: pathlib.Path, dir_to_create: str | pathlib.Path): 

90 target_path = tmp_path.joinpath(dir_to_create) 

91 path = api._create_path(target_path) 

92 assert path.exists() 

93 path2 = api._create_path(target_path) 

94 assert path == path2 

95 

96 

97@pytest.mark.parametrize("step", (False, True)) 

98def test_simple_read(mocker: MockerFixture, step): 

99 experiment = setup_experiment() 

100 step_size = 10 # simply to test we do recover different instance properties 

101 if step: 

102 experiment.time = 10 

103 file_content = io.BytesIO(pickle.dumps(experiment)) 

104 mocker.patch("builtins.open", return_value=file_content) 

105 loaded_result = api._load_pickle("dummy.pickle") 

106 assert experiment == loaded_result 

107 if step: 

108 assert loaded_result.time == step_size 

109 

110 

111def test_read_wrong_type(mocker: MockerFixture): 

112 experiment = 1.12 

113 file_content = io.BytesIO(pickle.dumps(experiment)) 

114 mocker.patch("builtins.open", return_value=file_content) 

115 with pytest.raises(TypeError): 

116 _ = api._load_pickle("dummy.pickle") 

117 

118 

119@pytest.mark.parametrize("use_glob", (True, False)) 

120def test_file_error(use_glob: bool): 

121 fname = "exper_test.pickle" 

122 with pytest.raises(FileNotFoundError): 

123 if not use_glob: 

124 api.load_pickle(fname) 

125 else: 

126 api.load_pickles(fname) 

127 

128 

129@pytest.mark.parametrize("loading_option", loading_options) 

130def test_load_pickles(loading_option: str, monkeypatch): 

131 path_as_str = epxeriment_targets_path 

132 path = pathlib.Path(path_as_str) 

133 experiments = [api._load_pickle(_p) for _p in sorted(path.glob(fname_pattern))] 

134 if loading_option == "str": 

135 experiment = api.load_pickles(fname_pattern, path_as_str) 

136 elif loading_option == "path": 

137 experiment = api.load_pickles(fname_pattern, path) 

138 else: 

139 # Reading from cwd. To be able to read, we chdir to the path we want 

140 # first. 

141 monkeypatch.chdir(epxeriment_targets_path) 

142 experiment = api.load_pickles(fname_pattern) 

143 

144 # Check the expected length. The read length should match the sum of the 

145 # individually loaded length, minus (total of experiment minus 1), as the 

146 # last key of a saved file should match the first key of the next one. 

147 assert len(experiment.history) == ( 

148 sum(len(_exper.history) for _exper in experiments) - (len(experiments) - 1) 

149 ) 

150 # Check the first history entry matches the first entry of the first history saved 

151 assert next(iter(experiment.history)) == next(iter(experiments[0].history)) 

152 # Check the last history entry matches the last entry of the last history saved 

153 assert experiment.time == experiments[-1].time 

154 

155 

156@pytest.mark.parametrize("do_recursive", (True, False)) 

157def test_recursive_load(do_recursive: bool): 

158 if do_recursive: 

159 path = pathlib.Path("/".join(epxeriment_targets_path.split("/")[:-1])) 

160 else: 

161 path = pathlib.Path(epxeriment_targets_path) 

162 _ = api.load_pickles(fname_pattern, path, do_recursive) 

163 

164 

165@given(**ocean_and_mono_spectrum) 

166def test_initialisation(gravity, spectrum, ocean): 

167 experiment = Experiment.from_discrete(gravity, spectrum, ocean) 

168 

169 assert experiment.time == 0 

170 assert isinstance(experiment.domain, Domain) 

171 assert experiment.domain.growth_params is None 

172 assert ( 

173 isinstance(experiment.domain.subdomains, list) 

174 and len(experiment.domain.subdomains) == 0 

175 ) 

176 assert ( 

177 isinstance(experiment.domain.attenuation, att.AttenuationParameterisation) 

178 and experiment.domain.attenuation == att.AttenuationParameterisation.PARAM_01 

179 ) 

180 assert isinstance( 

181 experiment.fracture_handler.scattering_handler, ps.ContinuousScatteringHandler 

182 ) 

183 assert isinstance(experiment.history, dict) and len(experiment.history) == 0 

184 assert isinstance(experiment.fracture_handler, fh.BinaryFracture) 

185 

186 

187@given(**ocean_and_mono_spectrum) 

188@pytest.mark.parametrize("growth_params", growth_params) 

189@pytest.mark.parametrize("fracture_handler_type", fracture_handler_types) 

190@pytest.mark.parametrize("att_spec", att.AttenuationParameterisation) 

191def test_initialisation_with_opt_params( 

192 gravity, 

193 spectrum, 

194 ocean, 

195 growth_params, 

196 fracture_handler_type, 

197 att_spec, 

198): 

199 fracture_handler = fracture_handler_type() 

200 experiment = Experiment.from_discrete( 

201 gravity, 

202 spectrum, 

203 ocean, 

204 growth_params=growth_params, 

205 fracture_handler=fracture_handler, 

206 attenuation_spec=att_spec, 

207 ) 

208 

209 if growth_params is None: 

210 assert experiment.domain.growth_params is None 

211 else: 

212 assert len(experiment.domain.growth_params) == 2 

213 assert experiment.domain.growth_params[0] == growth_params[0] 

214 assert experiment.domain.growth_params[1] is not None 

215 assert isinstance(experiment.fracture_handler, fracture_handler_type) 

216 assert isinstance(experiment.domain.attenuation, att.AttenuationParameterisation) 

217 assert experiment.domain.attenuation == att_spec 

218 

219 

220@given(spectrum=spec_mono(), **coupled_ocean_ice) 

221def test_add_floes_single(gravity, spectrum, ocean, ice): 

222 floe = Floe(left_edge=0, length=100, ice=ice) 

223 experiment = Experiment.from_discrete(gravity, spectrum, ocean) 

224 assert len(experiment.history) == 0 

225 assert len(experiment.domain.subdomains) == 0 

226 assert len(experiment.domain.cached_wuis) == 0 

227 experiment.add_floes(floe) 

228 assert len(experiment.history) == 1 

229 assert len(experiment.domain.subdomains) == 1 

230 assert experiment.domain.subdomains[0].left_edge == floe.left_edge 

231 assert experiment.domain.subdomains[0].length == floe.length 

232 assert ice in experiment.domain.cached_wuis 

233 assert experiment.history[0].subdomains[0] == experiment.domain.subdomains[0] 

234 

235 

236@given(spectrum=spec_mono(), **coupled_ocean_ice) 

237def test_add_floes_collection(gravity, spectrum, ocean, ice): 

238 floe1 = Floe(left_edge=0, length=100, ice=ice) 

239 floe2 = Floe(left_edge=100, length=100, ice=ice) 

240 experiment = Experiment.from_discrete(gravity, spectrum, ocean) 

241 experiment.add_floes((floe1, floe2)) 

242 assert len(experiment.history) == 1 

243 assert len(experiment.history[0].subdomains) == 2 

244 assert len(experiment.domain.subdomains) == 2 

245 

246 

247@given(spectrum=spec_mono(), **coupled_ocean_ice) 

248def test_add_floes_overlap(gravity, spectrum, ocean, ice): 

249 floe1 = Floe(left_edge=0, length=100, ice=ice) 

250 floe2 = Floe(left_edge=80, length=100, ice=ice) 

251 experiment = Experiment.from_discrete(gravity, spectrum, ocean) 

252 with pytest.raises(ValueError): 

253 experiment.add_floes((floe1, floe2)) 

254 

255 

256def total_length_comparison(subdomains, floe: Floe): 

257 total_length = sum(wuf.length for wuf in subdomains) 

258 return np.allclose(total_length - floe.length, 0) 

259 

260 

261def test_step(): 

262 experiment, floe = setup_experiment_with_floe() 

263 

264 assert len(experiment.history) == 1 

265 assert len(experiment.domain.subdomains) == 1 

266 

267 # NOTE: use an integer here to avoid floating point precision issues down the line 

268 delta_t = 1 

269 experiment = step_experiment(experiment, delta_t) 

270 assert np.allclose(experiment.time - delta_t, 0) 

271 assert len(experiment.history) == 2 

272 assert ( 

273 len(experiment.domain.subdomains) == 2 

274 ) # this floe should definitely have fractured in these conditions 

275 assert total_length_comparison(experiment.domain.subdomains, floe) 

276 assert delta_t in experiment.domain.cached_phases 

277 

278 number_of_additional_steps = 5 

279 for _ in range(number_of_additional_steps): 

280 experiment.step(delta_t) 

281 

282 assert np.allclose(experiment.time - (number_of_additional_steps + 1) * delta_t, 0) 

283 assert len(experiment.history) == number_of_additional_steps + 2 

284 assert total_length_comparison(experiment.domain.subdomains, floe) 

285 last_step = experiment.get_final_state() 

286 assert experiment.history[(number_of_additional_steps + 1) * delta_t] == last_step 

287 

288 

289@pytest.mark.parametrize("delta_t", (0.1, 0.5, 1, 1.5)) 

290def test_get_timesteps(delta_t): 

291 experiment, _ = setup_experiment_with_floe() 

292 n_steps = 4 

293 target_times = np.linspace(0, n_steps, n_steps + 1) * delta_t 

294 for i in range(n_steps): 

295 experiment = step_experiment(experiment, delta_t) 

296 times = experiment.timesteps 

297 assert np.allclose(target_times, times) 

298 

299 

300def test_pre_post_factures(experiment_with_history): 

301 timesteps = experiment_with_history.timesteps 

302 pre_times = experiment_with_history.get_pre_fracture_times() 

303 post_times = experiment_with_history.get_post_fracture_times() 

304 

305 # Diff between pre- and post-times should be the timestep. 

306 assert np.allclose(post_times - pre_times, timesteps[1]) 

307 

308 # Diff between number of post- and pre-fracture number of floes should be exactly 1. 

309 assert np.all( 

310 np.subtract( 

311 *[ 

312 np.array( 

313 [ 

314 len(experiment_with_history.history[_t].subdomains) 

315 for _t in _times 

316 ] 

317 ) 

318 for _times in (post_times, pre_times) 

319 ] 

320 ) 

321 == 1 

322 ) 

323 

324 

325@given(data=st.data()) 

326@settings(suppress_health_check=(HealthCheck.function_scoped_fixture,)) 

327def test_get_states_strict(data, experiment_with_history: api.Experiment): 

328 # Cast to list for hypothesis type correctness 

329 timesteps = experiment_with_history.timesteps.tolist() 

330 

331 # Draw a single time from timesteps 

332 single_time = data.draw(st.sampled_from(timesteps), label="single_time") 

333 result_single = experiment_with_history.get_states(single_time) 

334 assert isinstance(result_single, dict) 

335 assert single_time in result_single 

336 result_single_strict = experiment_with_history.get_states_strict(single_time) 

337 assert isinstance(result_single_strict, dict) 

338 assert result_single == result_single_strict 

339 

340 # Draw a random subset of timesteps (could be empty, single, or multiple) 

341 subset = data.draw( 

342 st.lists(st.sampled_from(timesteps), min_size=1, max_size=len(timesteps)), 

343 label="subset", 

344 ) 

345 result_list = experiment_with_history.get_states(subset) 

346 assert isinstance(result_list, dict) 

347 assert np.all([t in result_list for t in subset]) 

348 result_list_strict = experiment_with_history.get_states(subset) 

349 assert isinstance(result_list_strict, dict) 

350 assert result_list == result_list_strict 

351 

352 # Test with a numpy array of floats 

353 subset_as_array = np.array(subset) 

354 result_array = experiment_with_history.get_states(subset_as_array) 

355 assert isinstance(result_array, dict) 

356 assert np.all([t in result_array for t in subset]) 

357 result_array_strict = experiment_with_history.get_states_strict(subset) 

358 assert isinstance(result_array_strict, dict) 

359 assert result_array == result_array_strict 

360 

361 

362@given(data=st.data()) 

363@settings(suppress_health_check=(HealthCheck.function_scoped_fixture,)) 

364def test_get_states_perturbated(data, experiment_with_history: api.Experiment): 

365 perturbation = 1e-3 # delta_time := 5/6 ~ 0.833 

366 # Cast to list for hypothesis type correctness 

367 timesteps = experiment_with_history.timesteps.tolist() 

368 

369 # Draw a single time from timesteps 

370 single_time = data.draw(st.sampled_from(timesteps), label="single_time") 

371 perturbated_time = single_time + perturbation 

372 result_single = experiment_with_history.get_states(perturbated_time) 

373 assert isinstance(result_single, dict) 

374 assert single_time in result_single 

375 result_single = experiment_with_history.get_states_strict(perturbated_time) 

376 assert isinstance(result_single, dict) 

377 assert len(result_single) == 0 

378 

379 # Draw a random subset of timesteps (could be empty, single, or multiple) 

380 subset = data.draw( 

381 st.lists(st.sampled_from(timesteps), min_size=1, max_size=len(timesteps)), 

382 label="subset", 

383 ) 

384 perturbated_subset = [_v + perturbation for _v in subset] 

385 result_list = experiment_with_history.get_states(perturbated_subset) # type: ignore 

386 assert isinstance(result_list, dict) 

387 assert np.all([t in result_list for t in subset]) 

388 result_list = experiment_with_history.get_states_strict(perturbated_subset) # type: ignore 

389 assert isinstance(result_list, dict) 

390 assert len(result_list) == 0 

391 

392 # Test with a numpy array of floats 

393 perturbated_array = np.array(perturbated_subset) 

394 result_array = experiment_with_history.get_states(perturbated_array) 

395 assert isinstance(result_array, dict) 

396 assert np.all([t in result_array for t in subset]) 

397 result_array = experiment_with_history.get_states_strict(perturbated_array) 

398 assert isinstance(result_array, dict) 

399 assert len(result_array) == 0 

400 

401 

402@pytest.mark.parametrize("with_prefix", (True, False)) 

403def test_history_dump( 

404 tmp_path: pathlib.Path, 

405 experiment_with_history: api.Experiment, 

406 with_prefix: bool, 

407): 

408 prefix = "test_prefix" if with_prefix else None 

409 last_timestep = experiment_with_history.timesteps[-1] 

410 assert len(experiment_with_history.history) > 1 

411 experiment_with_history.dump_history(prefix, dir_path=tmp_path) 

412 assert len(experiment_with_history.history) == 1 

413 assert last_timestep in experiment_with_history.history 

414 if with_prefix: 

415 assert len(list(tmp_path.glob(f"{prefix}*.pickle"))) == 1 

416 

417 

418@given( 

419 n_steps=st.integers(1, 5), 

420 delta_time=st.floats(min_value=0.01, max_value=5.0, **float_kw), # type: ignore 

421) 

422def test_run_basic(n_steps, delta_time): 

423 time = n_steps * delta_time 

424 expected_n_steps = np.ceil(time / delta_time).astype(int) 

425 # Rounding errors can lead to the actual number of steps exceeding the 

426 # expected number of steps. 

427 assert expected_n_steps in (n_steps, n_steps + 1) 

428 

429 def step_spy(*args, **kwargs): 

430 # Function attribute! Magic! 

431 step_spy.calls += 1 

432 

433 step_spy.calls = 0 

434 

435 with pytest.MonkeyPatch().context() as mp: 

436 # Patching the class, not the instance, because methods are read-only. 

437 mp.setattr(api.Experiment, "step", step_spy) 

438 experiment, _ = setup_experiment_with_floe() 

439 experiment.run(time=time, delta_time=delta_time, dump_final=False) 

440 assert step_spy.calls == expected_n_steps 

441 

442 

443def test_run_with_pbar(monkeypatch): 

444 experiment, _ = setup_experiment_with_floe() 

445 

446 pbar = DummyPbar() 

447 experiment.run(time=2.0, delta_time=1.0, pbar=pbar, dump_final=False) 

448 assert pbar.updates == 2 

449 assert pbar.closed 

450 

451 

452@given(args=run_time_chunks_composite()) 

453@settings(suppress_health_check=(HealthCheck.function_scoped_fixture,)) 

454@pytest.mark.parametrize("dump_final", (True, False)) 

455def test_run_with_chunk_size( 

456 args: tuple[int, float, int], tmp_path: pathlib.Path, dump_final: bool 

457): 

458 n_steps, delta_time, chunk_size = args 

459 time = n_steps * delta_time 

460 # extra division to account for float errors 

461 actual_n_steps = np.ceil(time / delta_time).astype(int) 

462 if chunk_size == 1: 

463 expected_chunks = actual_n_steps 

464 else: 

465 expected_chunks = actual_n_steps // chunk_size 

466 # 1 removed from n_steps, because arithemtic done on iterator index, 

467 # starting at 0 and ending at n_steps - 1 

468 if dump_final and (((actual_n_steps - 1) % chunk_size) != (chunk_size - 1)): 

469 expected_chunks += 1 

470 

471 # Give unique names depending on given + parametrize, as tmp_path has 

472 # function scope and is not reinitialised for different @given cases. 

473 prefix = f"test_{hash(args + (dump_final,)):x}" 

474 

475 with pytest.MonkeyPatch().context() as mp: 

476 # Patching the class, not the instance, because methods are read-only. 

477 mp.setattr(Domain, "breakup", mock_breakup) 

478 experiment, _ = setup_experiment_with_floe() 

479 experiment.run( 

480 time=time, 

481 delta_time=delta_time, 

482 chunk_size=chunk_size, 

483 path=tmp_path, 

484 dump_final=dump_final, 

485 dump_prefix=prefix, 

486 ) 

487 saved_chunks = len(list(tmp_path.glob(f"{prefix}*pickle"))) 

488 assert saved_chunks == expected_chunks 

489 

490 

491@pytest.mark.parametrize("verbose", (None, 1, 2)) 

492def test_verbose_run( 

493 verbose: int | None, 

494 tmp_path: pathlib.Path, 

495 caplog: pytest.LogCaptureFixture, 

496): 

497 caplog.set_level(logging.INFO) 

498 experiment, _ = setup_experiment_with_floe() 

499 

500 n_steps = 1 

501 delta_time = 1 

502 experiment.run( 

503 time=n_steps * delta_time, 

504 delta_time=delta_time, 

505 chunk_size=1, 

506 verbose=verbose, 

507 path=tmp_path, 

508 dump_final=True, 

509 ) 

510 post_fracture_n_floes = len(experiment.get_final_state().subdomains) 

511 assert post_fracture_n_floes == 2 

512 

513 if verbose is None: 

514 assert len(caplog.text) == 0 

515 else: 

516 if verbose == 1: 

517 assert len(caplog.messages) == 1 

518 assert "history dumped" in caplog.text 

519 

520 if verbose == 2: 

521 assert len(caplog.messages) == 2 

522 assert f"N_f = {post_fracture_n_floes}" in caplog.text 

523 

524 

525@pytest.mark.parametrize("verbose", (None, 1, 2)) 

526@pytest.mark.parametrize("chunk_size", (None, 2)) 

527@pytest.mark.parametrize("dump_final", (False, True)) 

528def test_verbose_run_with_pbar( 

529 verbose: int | None, 

530 chunk_size: int | None, 

531 dump_final: bool, 

532 tmp_path: pathlib.Path, 

533 mocker: MockerFixture, 

534): 

535 time = 3 

536 delta_time = 1 

537 pbar = DummyPbar() 

538 spy = mocker.spy(pbar, "write") 

539 with pytest.MonkeyPatch().context() as mp: 

540 # Patching the class, not the instance, because methods are read-only. 

541 mp.setattr(Domain, "breakup", mock_breakup) 

542 experiment, _ = setup_experiment_with_floe() 

543 experiment.run( 

544 time=time, 

545 delta_time=delta_time, 

546 chunk_size=chunk_size, 

547 verbose=verbose, 

548 pbar=pbar, 

549 path=tmp_path, 

550 dump_final=dump_final, 

551 ) 

552 if verbose is None: 

553 spy.assert_not_called() 

554 else: 

555 if chunk_size is None: 

556 spy.assert_not_called() 

557 else: 

558 spy.assert_called_once() 

559 

560 

561@given(data=st.data()) 

562def test_run_early_termination(data): 

563 n_steps = 5 

564 with pytest.MonkeyPatch().context() as mp: 

565 # Patching the class, not the instance, because methods are read-only. 

566 orig_should_terminate = Experiment._should_terminate 

567 

568 def mock_should_terminate(self, *args): 

569 return orig_should_terminate(self, 0, *args[1:]) 

570 

571 mp.setattr(Experiment, "_should_terminate", mock_should_terminate) 

572 mp.setattr(Domain, "breakup", mock_breakup) 

573 

574 experiment, _ = setup_experiment_with_floe() 

575 time = 1 / experiment.domain.spectrum.frequencies[0] 

576 delta_time = time / n_steps 

577 break_time = data.draw(st.floats(delta_time, max_value=2 * time, **float_kw)) 

578 

579 expected_time = np.ceil(time / delta_time).astype(int) * delta_time 

580 expected_time_with_break = ( 

581 np.ceil(np.nextafter(break_time / delta_time, np.inf)).astype(int) 

582 * delta_time 

583 ) 

584 

585 experiment.run( 

586 time=time, 

587 delta_time=delta_time, 

588 break_time=break_time, 

589 dump_final=False, 

590 ) 

591 if break_time < time: 

592 assert experiment.time == expected_time_with_break 

593 else: 

594 assert experiment.time == expected_time