Coverage for tests/test_experiment.py: 0%
360 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-09-11 16:23 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-09-11 16:23 +0200
1import io
2import logging
3import pathlib
4import pickle
6from hypothesis import HealthCheck, given, settings, strategies as st
7import numpy as np
8import pytest
9from pytest_mock import MockerFixture
11import swiift.api.api as api
12from swiift.api.api import Experiment
13import swiift.lib.att as att
14import swiift.lib.phase_shift as ps
15import swiift.model.frac_handlers as fh
16from swiift.model.model import DiscreteSpectrum, Domain, Floe, Ice, Ocean
17from tests.model_strategies import coupled_ocean_ice, ocean_and_mono_spectrum, spec_mono
18from tests.utils import float_kw, fracture_handler_types
20epxeriment_targets_path = "tests/target/experiments"
21fname_pattern = "exper_test_no_sortedlist*"
23attenuation_parameterisations = att.AttenuationParameterisation
24growth_params = (None, (-13, None), (-28, 75), (np.array([-45]), None))
27loading_options = ("str", "path", "cwd")
30class DummyPbar:
31 def __init__(self):
32 self.updates = 0
33 self.closed = False
35 def update(self, n):
36 self.updates += n
38 def close(self):
39 self.closed = True
41 @classmethod
42 def write(cls, msg):
43 pass
46def mock_breakup(*args):
47 return
50@st.composite
51def run_time_chunks_composite(draw: st.DrawFn) -> tuple[int, float, int]:
52 n_step = draw(st.integers(min_value=1, max_value=15))
53 delta_time = draw(st.floats(min_value=0.01, max_value=5.0, **float_kw))
54 chunk_size = draw(st.integers(min_value=1, max_value=n_step))
56 return n_step, delta_time, chunk_size
59def setup_experiment() -> api.Experiment:
60 amplitude = 2
61 period = 7
62 spectrum = DiscreteSpectrum(amplitude, 1 / period)
63 depth = np.inf
64 ocean = Ocean(depth=depth)
65 gravity = 9.8
66 return Experiment.from_discrete(gravity, spectrum, ocean)
69def setup_experiment_with_floe() -> tuple[api.Experiment, Floe]:
70 experiment = setup_experiment()
71 thickness = 0.5
72 ice = Ice(thickness=thickness)
73 floe = Floe(left_edge=0, length=200, ice=ice)
74 experiment.add_floes(floe)
75 return experiment, floe
78def step_experiment(experiment: api.Experiment, delta_t: float) -> api.Experiment:
79 experiment.step(delta_t)
80 return experiment
83@pytest.fixture(scope="function")
84def experiment_with_history() -> api.Experiment:
85 return api.load_pickles(fname_pattern, epxeriment_targets_path)
88@pytest.mark.parametrize("dir_to_create", ("tmp_dir", pathlib.Path("tmp_dir2")))
89def test_create_directory(tmp_path: pathlib.Path, dir_to_create: str | pathlib.Path):
90 target_path = tmp_path.joinpath(dir_to_create)
91 path = api._create_path(target_path)
92 assert path.exists()
93 path2 = api._create_path(target_path)
94 assert path == path2
97@pytest.mark.parametrize("step", (False, True))
98def test_simple_read(mocker: MockerFixture, step):
99 experiment = setup_experiment()
100 step_size = 10 # simply to test we do recover different instance properties
101 if step:
102 experiment.time = 10
103 file_content = io.BytesIO(pickle.dumps(experiment))
104 mocker.patch("builtins.open", return_value=file_content)
105 loaded_result = api._load_pickle("dummy.pickle")
106 assert experiment == loaded_result
107 if step:
108 assert loaded_result.time == step_size
111def test_read_wrong_type(mocker: MockerFixture):
112 experiment = 1.12
113 file_content = io.BytesIO(pickle.dumps(experiment))
114 mocker.patch("builtins.open", return_value=file_content)
115 with pytest.raises(TypeError):
116 _ = api._load_pickle("dummy.pickle")
119@pytest.mark.parametrize("use_glob", (True, False))
120def test_file_error(use_glob: bool):
121 fname = "exper_test.pickle"
122 with pytest.raises(FileNotFoundError):
123 if not use_glob:
124 api.load_pickle(fname)
125 else:
126 api.load_pickles(fname)
129@pytest.mark.parametrize("loading_option", loading_options)
130def test_load_pickles(loading_option: str, monkeypatch):
131 path_as_str = epxeriment_targets_path
132 path = pathlib.Path(path_as_str)
133 experiments = [api._load_pickle(_p) for _p in sorted(path.glob(fname_pattern))]
134 if loading_option == "str":
135 experiment = api.load_pickles(fname_pattern, path_as_str)
136 elif loading_option == "path":
137 experiment = api.load_pickles(fname_pattern, path)
138 else:
139 # Reading from cwd. To be able to read, we chdir to the path we want
140 # first.
141 monkeypatch.chdir(epxeriment_targets_path)
142 experiment = api.load_pickles(fname_pattern)
144 # Check the expected length. The read length should match the sum of the
145 # individually loaded length, minus (total of experiment minus 1), as the
146 # last key of a saved file should match the first key of the next one.
147 assert len(experiment.history) == (
148 sum(len(_exper.history) for _exper in experiments) - (len(experiments) - 1)
149 )
150 # Check the first history entry matches the first entry of the first history saved
151 assert next(iter(experiment.history)) == next(iter(experiments[0].history))
152 # Check the last history entry matches the last entry of the last history saved
153 assert experiment.time == experiments[-1].time
156@pytest.mark.parametrize("do_recursive", (True, False))
157def test_recursive_load(do_recursive: bool):
158 if do_recursive:
159 path = pathlib.Path("/".join(epxeriment_targets_path.split("/")[:-1]))
160 else:
161 path = pathlib.Path(epxeriment_targets_path)
162 _ = api.load_pickles(fname_pattern, path, do_recursive)
165@given(**ocean_and_mono_spectrum)
166def test_initialisation(gravity, spectrum, ocean):
167 experiment = Experiment.from_discrete(gravity, spectrum, ocean)
169 assert experiment.time == 0
170 assert isinstance(experiment.domain, Domain)
171 assert experiment.domain.growth_params is None
172 assert (
173 isinstance(experiment.domain.subdomains, list)
174 and len(experiment.domain.subdomains) == 0
175 )
176 assert (
177 isinstance(experiment.domain.attenuation, att.AttenuationParameterisation)
178 and experiment.domain.attenuation == att.AttenuationParameterisation.PARAM_01
179 )
180 assert isinstance(
181 experiment.fracture_handler.scattering_handler, ps.ContinuousScatteringHandler
182 )
183 assert isinstance(experiment.history, dict) and len(experiment.history) == 0
184 assert isinstance(experiment.fracture_handler, fh.BinaryFracture)
187@given(**ocean_and_mono_spectrum)
188@pytest.mark.parametrize("growth_params", growth_params)
189@pytest.mark.parametrize("fracture_handler_type", fracture_handler_types)
190@pytest.mark.parametrize("att_spec", att.AttenuationParameterisation)
191def test_initialisation_with_opt_params(
192 gravity,
193 spectrum,
194 ocean,
195 growth_params,
196 fracture_handler_type,
197 att_spec,
198):
199 fracture_handler = fracture_handler_type()
200 experiment = Experiment.from_discrete(
201 gravity,
202 spectrum,
203 ocean,
204 growth_params=growth_params,
205 fracture_handler=fracture_handler,
206 attenuation_spec=att_spec,
207 )
209 if growth_params is None:
210 assert experiment.domain.growth_params is None
211 else:
212 assert len(experiment.domain.growth_params) == 2
213 assert experiment.domain.growth_params[0] == growth_params[0]
214 assert experiment.domain.growth_params[1] is not None
215 assert isinstance(experiment.fracture_handler, fracture_handler_type)
216 assert isinstance(experiment.domain.attenuation, att.AttenuationParameterisation)
217 assert experiment.domain.attenuation == att_spec
220@given(spectrum=spec_mono(), **coupled_ocean_ice)
221def test_add_floes_single(gravity, spectrum, ocean, ice):
222 floe = Floe(left_edge=0, length=100, ice=ice)
223 experiment = Experiment.from_discrete(gravity, spectrum, ocean)
224 assert len(experiment.history) == 0
225 assert len(experiment.domain.subdomains) == 0
226 assert len(experiment.domain.cached_wuis) == 0
227 experiment.add_floes(floe)
228 assert len(experiment.history) == 1
229 assert len(experiment.domain.subdomains) == 1
230 assert experiment.domain.subdomains[0].left_edge == floe.left_edge
231 assert experiment.domain.subdomains[0].length == floe.length
232 assert ice in experiment.domain.cached_wuis
233 assert experiment.history[0].subdomains[0] == experiment.domain.subdomains[0]
236@given(spectrum=spec_mono(), **coupled_ocean_ice)
237def test_add_floes_collection(gravity, spectrum, ocean, ice):
238 floe1 = Floe(left_edge=0, length=100, ice=ice)
239 floe2 = Floe(left_edge=100, length=100, ice=ice)
240 experiment = Experiment.from_discrete(gravity, spectrum, ocean)
241 experiment.add_floes((floe1, floe2))
242 assert len(experiment.history) == 1
243 assert len(experiment.history[0].subdomains) == 2
244 assert len(experiment.domain.subdomains) == 2
247@given(spectrum=spec_mono(), **coupled_ocean_ice)
248def test_add_floes_overlap(gravity, spectrum, ocean, ice):
249 floe1 = Floe(left_edge=0, length=100, ice=ice)
250 floe2 = Floe(left_edge=80, length=100, ice=ice)
251 experiment = Experiment.from_discrete(gravity, spectrum, ocean)
252 with pytest.raises(ValueError):
253 experiment.add_floes((floe1, floe2))
256def total_length_comparison(subdomains, floe: Floe):
257 total_length = sum(wuf.length for wuf in subdomains)
258 return np.allclose(total_length - floe.length, 0)
261def test_step():
262 experiment, floe = setup_experiment_with_floe()
264 assert len(experiment.history) == 1
265 assert len(experiment.domain.subdomains) == 1
267 # NOTE: use an integer here to avoid floating point precision issues down the line
268 delta_t = 1
269 experiment = step_experiment(experiment, delta_t)
270 assert np.allclose(experiment.time - delta_t, 0)
271 assert len(experiment.history) == 2
272 assert (
273 len(experiment.domain.subdomains) == 2
274 ) # this floe should definitely have fractured in these conditions
275 assert total_length_comparison(experiment.domain.subdomains, floe)
276 assert delta_t in experiment.domain.cached_phases
278 number_of_additional_steps = 5
279 for _ in range(number_of_additional_steps):
280 experiment.step(delta_t)
282 assert np.allclose(experiment.time - (number_of_additional_steps + 1) * delta_t, 0)
283 assert len(experiment.history) == number_of_additional_steps + 2
284 assert total_length_comparison(experiment.domain.subdomains, floe)
285 last_step = experiment.get_final_state()
286 assert experiment.history[(number_of_additional_steps + 1) * delta_t] == last_step
289@pytest.mark.parametrize("delta_t", (0.1, 0.5, 1, 1.5))
290def test_get_timesteps(delta_t):
291 experiment, _ = setup_experiment_with_floe()
292 n_steps = 4
293 target_times = np.linspace(0, n_steps, n_steps + 1) * delta_t
294 for i in range(n_steps):
295 experiment = step_experiment(experiment, delta_t)
296 times = experiment.timesteps
297 assert np.allclose(target_times, times)
300def test_pre_post_factures(experiment_with_history):
301 timesteps = experiment_with_history.timesteps
302 pre_times = experiment_with_history.get_pre_fracture_times()
303 post_times = experiment_with_history.get_post_fracture_times()
305 # Diff between pre- and post-times should be the timestep.
306 assert np.allclose(post_times - pre_times, timesteps[1])
308 # Diff between number of post- and pre-fracture number of floes should be exactly 1.
309 assert np.all(
310 np.subtract(
311 *[
312 np.array(
313 [
314 len(experiment_with_history.history[_t].subdomains)
315 for _t in _times
316 ]
317 )
318 for _times in (post_times, pre_times)
319 ]
320 )
321 == 1
322 )
325@given(data=st.data())
326@settings(suppress_health_check=(HealthCheck.function_scoped_fixture,))
327def test_get_states_strict(data, experiment_with_history: api.Experiment):
328 # Cast to list for hypothesis type correctness
329 timesteps = experiment_with_history.timesteps.tolist()
331 # Draw a single time from timesteps
332 single_time = data.draw(st.sampled_from(timesteps), label="single_time")
333 result_single = experiment_with_history.get_states(single_time)
334 assert isinstance(result_single, dict)
335 assert single_time in result_single
336 result_single_strict = experiment_with_history.get_states_strict(single_time)
337 assert isinstance(result_single_strict, dict)
338 assert result_single == result_single_strict
340 # Draw a random subset of timesteps (could be empty, single, or multiple)
341 subset = data.draw(
342 st.lists(st.sampled_from(timesteps), min_size=1, max_size=len(timesteps)),
343 label="subset",
344 )
345 result_list = experiment_with_history.get_states(subset)
346 assert isinstance(result_list, dict)
347 assert np.all([t in result_list for t in subset])
348 result_list_strict = experiment_with_history.get_states(subset)
349 assert isinstance(result_list_strict, dict)
350 assert result_list == result_list_strict
352 # Test with a numpy array of floats
353 subset_as_array = np.array(subset)
354 result_array = experiment_with_history.get_states(subset_as_array)
355 assert isinstance(result_array, dict)
356 assert np.all([t in result_array for t in subset])
357 result_array_strict = experiment_with_history.get_states_strict(subset)
358 assert isinstance(result_array_strict, dict)
359 assert result_array == result_array_strict
362@given(data=st.data())
363@settings(suppress_health_check=(HealthCheck.function_scoped_fixture,))
364def test_get_states_perturbated(data, experiment_with_history: api.Experiment):
365 perturbation = 1e-3 # delta_time := 5/6 ~ 0.833
366 # Cast to list for hypothesis type correctness
367 timesteps = experiment_with_history.timesteps.tolist()
369 # Draw a single time from timesteps
370 single_time = data.draw(st.sampled_from(timesteps), label="single_time")
371 perturbated_time = single_time + perturbation
372 result_single = experiment_with_history.get_states(perturbated_time)
373 assert isinstance(result_single, dict)
374 assert single_time in result_single
375 result_single = experiment_with_history.get_states_strict(perturbated_time)
376 assert isinstance(result_single, dict)
377 assert len(result_single) == 0
379 # Draw a random subset of timesteps (could be empty, single, or multiple)
380 subset = data.draw(
381 st.lists(st.sampled_from(timesteps), min_size=1, max_size=len(timesteps)),
382 label="subset",
383 )
384 perturbated_subset = [_v + perturbation for _v in subset]
385 result_list = experiment_with_history.get_states(perturbated_subset) # type: ignore
386 assert isinstance(result_list, dict)
387 assert np.all([t in result_list for t in subset])
388 result_list = experiment_with_history.get_states_strict(perturbated_subset) # type: ignore
389 assert isinstance(result_list, dict)
390 assert len(result_list) == 0
392 # Test with a numpy array of floats
393 perturbated_array = np.array(perturbated_subset)
394 result_array = experiment_with_history.get_states(perturbated_array)
395 assert isinstance(result_array, dict)
396 assert np.all([t in result_array for t in subset])
397 result_array = experiment_with_history.get_states_strict(perturbated_array)
398 assert isinstance(result_array, dict)
399 assert len(result_array) == 0
402@pytest.mark.parametrize("with_prefix", (True, False))
403def test_history_dump(
404 tmp_path: pathlib.Path,
405 experiment_with_history: api.Experiment,
406 with_prefix: bool,
407):
408 prefix = "test_prefix" if with_prefix else None
409 last_timestep = experiment_with_history.timesteps[-1]
410 assert len(experiment_with_history.history) > 1
411 experiment_with_history.dump_history(prefix, dir_path=tmp_path)
412 assert len(experiment_with_history.history) == 1
413 assert last_timestep in experiment_with_history.history
414 if with_prefix:
415 assert len(list(tmp_path.glob(f"{prefix}*.pickle"))) == 1
418@given(
419 n_steps=st.integers(1, 5),
420 delta_time=st.floats(min_value=0.01, max_value=5.0, **float_kw), # type: ignore
421)
422def test_run_basic(n_steps, delta_time):
423 time = n_steps * delta_time
424 expected_n_steps = np.ceil(time / delta_time).astype(int)
425 # Rounding errors can lead to the actual number of steps exceeding the
426 # expected number of steps.
427 assert expected_n_steps in (n_steps, n_steps + 1)
429 def step_spy(*args, **kwargs):
430 # Function attribute! Magic!
431 step_spy.calls += 1
433 step_spy.calls = 0
435 with pytest.MonkeyPatch().context() as mp:
436 # Patching the class, not the instance, because methods are read-only.
437 mp.setattr(api.Experiment, "step", step_spy)
438 experiment, _ = setup_experiment_with_floe()
439 experiment.run(time=time, delta_time=delta_time, dump_final=False)
440 assert step_spy.calls == expected_n_steps
443def test_run_with_pbar(monkeypatch):
444 experiment, _ = setup_experiment_with_floe()
446 pbar = DummyPbar()
447 experiment.run(time=2.0, delta_time=1.0, pbar=pbar, dump_final=False)
448 assert pbar.updates == 2
449 assert pbar.closed
452@given(args=run_time_chunks_composite())
453@settings(suppress_health_check=(HealthCheck.function_scoped_fixture,))
454@pytest.mark.parametrize("dump_final", (True, False))
455def test_run_with_chunk_size(
456 args: tuple[int, float, int], tmp_path: pathlib.Path, dump_final: bool
457):
458 n_steps, delta_time, chunk_size = args
459 time = n_steps * delta_time
460 # extra division to account for float errors
461 actual_n_steps = np.ceil(time / delta_time).astype(int)
462 if chunk_size == 1:
463 expected_chunks = actual_n_steps
464 else:
465 expected_chunks = actual_n_steps // chunk_size
466 # 1 removed from n_steps, because arithemtic done on iterator index,
467 # starting at 0 and ending at n_steps - 1
468 if dump_final and (((actual_n_steps - 1) % chunk_size) != (chunk_size - 1)):
469 expected_chunks += 1
471 # Give unique names depending on given + parametrize, as tmp_path has
472 # function scope and is not reinitialised for different @given cases.
473 prefix = f"test_{hash(args + (dump_final,)):x}"
475 with pytest.MonkeyPatch().context() as mp:
476 # Patching the class, not the instance, because methods are read-only.
477 mp.setattr(Domain, "breakup", mock_breakup)
478 experiment, _ = setup_experiment_with_floe()
479 experiment.run(
480 time=time,
481 delta_time=delta_time,
482 chunk_size=chunk_size,
483 path=tmp_path,
484 dump_final=dump_final,
485 dump_prefix=prefix,
486 )
487 saved_chunks = len(list(tmp_path.glob(f"{prefix}*pickle")))
488 assert saved_chunks == expected_chunks
491@pytest.mark.parametrize("verbose", (None, 1, 2))
492def test_verbose_run(
493 verbose: int | None,
494 tmp_path: pathlib.Path,
495 caplog: pytest.LogCaptureFixture,
496):
497 caplog.set_level(logging.INFO)
498 experiment, _ = setup_experiment_with_floe()
500 n_steps = 1
501 delta_time = 1
502 experiment.run(
503 time=n_steps * delta_time,
504 delta_time=delta_time,
505 chunk_size=1,
506 verbose=verbose,
507 path=tmp_path,
508 dump_final=True,
509 )
510 post_fracture_n_floes = len(experiment.get_final_state().subdomains)
511 assert post_fracture_n_floes == 2
513 if verbose is None:
514 assert len(caplog.text) == 0
515 else:
516 if verbose == 1:
517 assert len(caplog.messages) == 1
518 assert "history dumped" in caplog.text
520 if verbose == 2:
521 assert len(caplog.messages) == 2
522 assert f"N_f = {post_fracture_n_floes}" in caplog.text
525@pytest.mark.parametrize("verbose", (None, 1, 2))
526@pytest.mark.parametrize("chunk_size", (None, 2))
527@pytest.mark.parametrize("dump_final", (False, True))
528def test_verbose_run_with_pbar(
529 verbose: int | None,
530 chunk_size: int | None,
531 dump_final: bool,
532 tmp_path: pathlib.Path,
533 mocker: MockerFixture,
534):
535 time = 3
536 delta_time = 1
537 pbar = DummyPbar()
538 spy = mocker.spy(pbar, "write")
539 with pytest.MonkeyPatch().context() as mp:
540 # Patching the class, not the instance, because methods are read-only.
541 mp.setattr(Domain, "breakup", mock_breakup)
542 experiment, _ = setup_experiment_with_floe()
543 experiment.run(
544 time=time,
545 delta_time=delta_time,
546 chunk_size=chunk_size,
547 verbose=verbose,
548 pbar=pbar,
549 path=tmp_path,
550 dump_final=dump_final,
551 )
552 if verbose is None:
553 spy.assert_not_called()
554 else:
555 if chunk_size is None:
556 spy.assert_not_called()
557 else:
558 spy.assert_called_once()
561@given(data=st.data())
562def test_run_early_termination(data):
563 n_steps = 5
564 with pytest.MonkeyPatch().context() as mp:
565 # Patching the class, not the instance, because methods are read-only.
566 orig_should_terminate = Experiment._should_terminate
568 def mock_should_terminate(self, *args):
569 return orig_should_terminate(self, 0, *args[1:])
571 mp.setattr(Experiment, "_should_terminate", mock_should_terminate)
572 mp.setattr(Domain, "breakup", mock_breakup)
574 experiment, _ = setup_experiment_with_floe()
575 time = 1 / experiment.domain.spectrum.frequencies[0]
576 delta_time = time / n_steps
577 break_time = data.draw(st.floats(delta_time, max_value=2 * time, **float_kw))
579 expected_time = np.ceil(time / delta_time).astype(int) * delta_time
580 expected_time_with_break = (
581 np.ceil(np.nextafter(break_time / delta_time, np.inf)).astype(int)
582 * delta_time
583 )
585 experiment.run(
586 time=time,
587 delta_time=delta_time,
588 break_time=break_time,
589 dump_final=False,
590 )
591 if break_time < time:
592 assert experiment.time == expected_time_with_break
593 else:
594 assert experiment.time == expected_time