Coverage for tests/test_physics.py: 0%
149 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-09-11 16:23 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-09-11 16:23 +0200
1from __future__ import annotations
3import abc
4import pathlib
5import typing
7import numpy as np
8import pytest
10import swiift.lib.physics as ph
12if typing.TYPE_CHECKING:
13 from pytest_benchmark.fixture import BenchmarkFixture # type: ignore
15# Test configurations visually examined against solution from scipy.solve_bvp.
16TARGET_DIR_MONO = pathlib.Path("tests/target/physics_monochromatic")
17TARGET_DIR_POLY = pathlib.Path("tests/target/physics_polychromatic")
19# Hard-coded for ease of use within parametrize decorators.
20# Tests ensure these numbers match the dimensions of the targets
21# (specifically, _TestPhysics::test_dimensions).
22N_CASES_MONO = 49
23N_N_FREQS = 8 # number of different spectral lengths (2 to 100)
24N_TRIES = 5 # number of tries per spectral length
25N_CASES_POLY = N_N_FREQS * N_TRIES
26INTEGRATION_METHODS = "pseudo_an", "tanhsinh", "quad"
29def _flatten_and_squeeze(array: np.ndarray, n_dims_to_keep: int, expected_size: int):
30 """Reshape an array by contracting middle dimensions.
32 This function is intented to be flexible enough to transform
33 different-shaped targets of mono- and polychromatic parameters to a
34 standard shape that can be interpreted (and looped over) by the test
35 functions. The first `n_dims_to_keep` dimensions of the array are
36 preserved.
38 Parameters
39 ----------
40 array : np.ndarray
41 ND-array to reshape.
42 n_dims_to_keep : int
43 Number of leading dimensions to preserve.
44 size : int
45 Size several dimensions will be reshaped to.
47 Returns
48 -------
49 np.ndarray
50 The reshaped array, squeezed to remove remaining axes of size 1.
52 Examples
53 --------
54 Situation corresponding to the polychromatic displacement and curvature
55 targets. The two middle dimensions are contracted into one, preserving the
56 first and last original dimensions.
58 >>> arr1 = np.empty((2, 8, 5, 20))
59 >>> arr1.shape
60 (2, 8, 5, 20)
61 >>> _flatten_and_squeeze(arr1, 1, 40).shape
62 (2, 40, 20)
64 Situation corresponding to the polychromatic energy target. The two last
65 dimensions are contracted into one. The initial reshaped array has shape
66 (4, 40, 1). That last axis is squeezed out.
68 >>> arr2 = np.empty((4, 8, 5))
69 >>> arr2.shape
70 (4, 8, 5)
71 >>> _flatten_and_squeeze(arr2, 1, 40).shape
72 (4, 40)
74 Situation corresponding to the monochromatic floe_params input. No
75 contraction is necessary, but the function handles that case for
76 genericity, returing exactly the initial array.
78 >>> arr3 = np.empty((49, 2))
79 >>> arr3.shape
80 (49, 2)
81 >>> _flatten_and_squeeze(arr3, 0, 49).shape
82 (49, 2)
84 Situation corresponding to the polychromatic floe_params input. This time,
85 contraction is necessary and behaves as expected. This spares us from
86 overriding the `floe_params_all` in the concrete classes, which can rely on
87 the behaviour of the abstract class, to produce the same kinds of ouputs.
89 >>> arr4 = np.empty((8, 5, 2))
90 >>> arr4.shape
91 (8, 5, 2)
92 >>> _flatten_and_squeeze(arr4, 0, 40).shape
93 (40, 2)
95 """
96 return np.squeeze(
97 np.reshape(array, (*array.shape[:n_dims_to_keep], expected_size, -1))
98 )
101class _TestPhysics(abc.ABC):
102 """Base class exposing fixtures and logic for physics-related test.
104 By physics, we mean the computing of vertical displacement and the
105 associated curvature along a floe, and the computing of the energy of the
106 deformed floe, as defined by the Handler classes under swiift.lib.physics.
107 This class tests the stability of the `compute` method of these handlers,
108 by comparing outputs parametrised with known inputs, to expected outputs
109 (targets) verified visually.
111 The inputs are two positive real numbers and two complex numbers
112 (monochromatic cases) or two positive real numbers and two arrays of
113 complex numbers of same size (polychromatic cases). There are thus six
114 independent real numbers for the monochromatic cases. These are generated
115 using Latin hypercube sampling. These samples are in turn sampled and
116 combined to generate inputs for the polychromatic cases.
118 Attributes
119 ----------
120 target_dir : pathlib.Path
121 Path of the directory containing the inputs and targets.
122 n_cases : int
123 The number of expected test cases (combination of input and target).
125 """
127 target_dir: pathlib.Path
128 n_cases: int
130 def pytest_generate_tests(self, metafunc):
131 # Pytest magic. Equivalent to adding pytest.mark.parametrize on j, but
132 # allows for using a class/instance attribute as a parameter, which
133 # would not be possible with a simple decorator.
134 if "j" in metafunc.fixturenames:
135 metafunc.parametrize("j", range(self.n_cases))
137 def _flatten_and_squeeze(
138 self, array: np.ndarray, n_dims_to_keep: int = 0
139 ) -> np.ndarray:
140 """Wrapper over the module-level function, setting the size parameters.
142 Parameters
143 ----------
144 array : np.ndarray
145 Array to be reshaped.
146 n_dims_to_keep : int
147 Number of leading dimensions to preserve.
149 Returns
150 -------
151 np.ndarray
153 """
154 return _flatten_and_squeeze(array, n_dims_to_keep, self.n_cases)
156 def _load(self, filename: str) -> np.ndarray:
157 """Helper function loading arrays with respect to class attribute.
159 The argument is expected to be an NPY file.
161 Parameters
162 ----------
163 filename : str
164 File under `cls.target_dir`.
166 Returns
167 -------
168 np.ndarray
170 """
171 return np.load(self.target_dir.joinpath(filename))
173 @pytest.fixture(scope="class")
174 def x_axes(self) -> np.ndarray:
175 """X-axes over which to compute displacement and curvature.
177 Returns
178 -------
179 np.ndarray
181 """
182 return self._flatten_and_squeeze(self._load("x.npy"))
184 @pytest.fixture(scope="class")
185 def floe_params_all(self) -> np.ndarray:
186 """Floe parameters to instantiate physical handlers.
188 Returns
189 -------
190 np.ndarray
192 """
193 return self._flatten_and_squeeze(self._load("floe_params.npy"))
195 @abc.abstractmethod
196 def wave_params_all(self) -> list[tuple[np.ndarray, np.ndarray]]:
197 """Wave parameters to instantiate physical handlers.
199 Returns
200 -------
201 list[tuple[np.ndarray, np.ndarray]]
203 """
204 ...
206 @abc.abstractmethod
207 def growth_params_all(self) -> list[tuple[np.ndarray, float | np.ndarray]]: ...
209 @pytest.fixture(scope="class")
210 def displacements(self) -> np.ndarray:
211 """Vertical displacement targets.
213 Returns
214 -------
215 np.ndarray
217 Notes
218 -----
219 Shape: (2, n_cases, len(x_axes)). The first dimension is for analytical
220 solution (0) or numerical solution (1).
222 """
223 return self._flatten_and_squeeze(self._load("displacements.npy"), 1)
225 @pytest.fixture(scope="class")
226 def displacements_growth(self) -> np.ndarray:
227 """Vertical displacement targets, with wave growth.
229 Shape: (n_cases, len(x_axes)).
231 Returns
232 -------
233 np.ndarray
235 Notes
236 -----
237 Shape: (n_cases, len(x_axes)).
239 """
240 return self._flatten_and_squeeze(self._load("displacements_growth.npy"))
242 @pytest.fixture(scope="class")
243 def curvatures(self) -> np.ndarray:
244 """Curvature targets.
246 Returns
247 -------
248 np.ndarray
250 Notes
251 -----
252 Shape: (2, n_cases, len(x_axes)). The first dimension is for analytical
253 solution (0) or numerical solution (1).
255 """
256 return self._flatten_and_squeeze(self._load("curvatures.npy"), 1)
258 @pytest.fixture(scope="class")
259 def curvatures_growth(self) -> np.ndarray:
260 """Curvature targets, with wave growth.
262 Returns
263 -------
264 np.ndarray
266 Notes
267 -----
268 Shape: (n_cases, len(x_axes)).
270 """
271 return self._flatten_and_squeeze(self._load("curvatures_growth.npy"))
273 @pytest.fixture(scope="class")
274 def energies(self) -> np.ndarray:
275 """Energy targets.
277 Returns
278 -------
279 np.ndarray
281 Notes
282 -----
283 Shape: (4, n_cases). The first dimension is for analytical solution (0)
284 or numerical solution: pseudo_an (1), tanhsinh (2), quad (3).
286 """
287 return self._flatten_and_squeeze(self._load("energies.npy"), 1)
289 @pytest.fixture(scope="class")
290 def energies_growth(self) -> np.ndarray:
291 """Energy targets.
293 Returns
294 -------
295 np.ndarray
297 Notes
298 -----
299 Shape: (3, n_cases). The first dimension corresponds to an entry in
300 INTEGRATION_METHODS.
302 """
303 return self._flatten_and_squeeze(self._load("energies_growth.npy"), 1)
305 def test_dimensions(
306 self,
307 x_axes: np.ndarray,
308 floe_params_all: np.ndarray,
309 wave_params_all: list[tuple[np.ndarray, np.ndarray]],
310 displacements: np.ndarray,
311 displacements_growth: np.ndarray,
312 curvatures: np.ndarray,
313 curvatures_growth: np.ndarray,
314 energies: np.ndarray,
315 energies_growth: np.ndarray,
316 ):
317 """Check that the dimensions match the expected number of cases.
319 Parameters
320 ----------
321 x_axes : np.ndarray
322 floe_params_all : np.ndarray
323 wave_params_all : list[tuple[np.ndarray, np.ndarray]]
324 displacements : np.ndarray
325 displacements_growth : np.ndarray
326 curvatures : np.ndarray
327 curvatures_growth : np.ndarray
328 energies : np.ndarray
330 """
331 for arr in (x_axes, floe_params_all, wave_params_all):
332 assert len(arr) == self.n_cases
333 for arr in (
334 energies,
335 energies_growth,
336 ):
337 assert arr.shape[1] == self.n_cases
338 for arr in (
339 displacements,
340 displacements_growth,
341 curvatures,
342 curvatures_growth,
343 ):
344 assert arr.shape[-2] == self.n_cases
345 assert arr.shape[-1] == x_axes.shape[-1]
347 @pytest.mark.parametrize("an_sol", (True, False))
348 @pytest.mark.parametrize(
349 "handler_type, target_name",
350 (
351 (ph.DisplacementHandler, "displacements"),
352 (ph.CurvatureHandler, "curvatures"),
353 ),
354 )
355 def test_local(
356 self,
357 request: pytest.FixtureRequest,
358 x_axes: np.ndarray,
359 floe_params_all: np.ndarray,
360 wave_params_all: list[tuple[np.ndarray, np.ndarray]],
361 handler_type: type[ph.DisplacementHandler] | type[ph.CurvatureHandler],
362 target_name: str,
363 an_sol: bool,
364 j: int,
365 benchmark: BenchmarkFixture,
366 ):
367 """Compare local quantities (displacement, curvature) to targets.
369 Parameters
370 ----------
371 request : pytest.FixtureRequest
372 x_axes : np.ndarray
373 floe_params_all : np.ndarray
374 wave_params_all : list[tuple[np.ndarray, np.ndarray]]
375 handler_type : type[ph.DisplacementHandler] | type[ph.CurvatureHandler]
376 The type of handler to use.
377 target_name : str
378 The name of the fixture providing the target.
379 an_sol : bool
380 Whether to use the analytical solution formulation.
381 j : int
382 Index of the test case.
383 benchmark : BenchmarkFixture
385 """
386 benchmark.group = (
387 f"{str(self.target_dir).split()[-1]}_{target_name}:case_{j:02d}"
388 )
389 # The first dimension of the target has size 2.
390 # The first entry (i := 0) corresponds to the analytical solution, the
391 # second entry (i := 1) to the numerical solution.
392 i = 0 if an_sol else 1
393 # Pytest magic, get fixture by name as fixtures cannot be used directly
394 # in parametrize.
395 target = request.getfixturevalue(target_name)
396 x = x_axes[j]
397 floe_params = floe_params_all[j]
398 wave_params = wave_params_all[j]
399 handler = handler_type(floe_params, wave_params)
401 computed = benchmark(handler.compute, x, an_sol=an_sol)
403 assert np.allclose(computed, target[i, j])
405 @pytest.mark.parametrize(
406 "handler_type, target_name",
407 (
408 (ph.DisplacementHandler, "displacements_growth"),
409 (ph.CurvatureHandler, "curvatures_growth"),
410 ),
411 )
412 def test_local_with_growth(
413 self,
414 request: pytest.FixtureRequest,
415 x_axes: np.ndarray,
416 floe_params_all: np.ndarray,
417 wave_params_all: list[tuple[np.ndarray, np.ndarray]],
418 growth_params_all: list[tuple[np.ndarray, float]],
419 handler_type: type[ph.DisplacementHandler] | type[ph.CurvatureHandler],
420 target_name: str,
421 j: int,
422 ):
423 """Compare local quantities (displacement, curvature) to targets.
425 Parameters
426 ----------
427 request : pytest.FixtureRequest
428 x_axes : np.ndarray
429 floe_params_all : np.ndarray
430 wave_params_all : list[tuple[np.ndarray, np.ndarray]]
431 growth_params_all: list[tuple[np.ndarray, float]],
432 handler_type : type[ph.DisplacementHandler] | type[ph.CurvatureHandler]
433 The type of handler to use.
434 target_name : str
435 The name of the fixture providing the target.
436 j : int
437 Index of the test case.
439 """
440 # Pytest magic, get fixture by name as fixtures cannot be used directly
441 # in parametrize.
442 target = request.getfixturevalue(target_name)
443 x = x_axes[j]
444 floe_params = floe_params_all[j]
445 wave_params = wave_params_all[j]
446 growth_params = growth_params_all[j]
447 handler = handler_type(floe_params, wave_params, growth_params)
449 computed = handler.compute(x)
451 assert np.allclose(computed, target[j])
453 @pytest.mark.parametrize("integration_method", (None, *INTEGRATION_METHODS))
454 @pytest.mark.filterwarnings("ignore::scipy.integrate.IntegrationWarning")
455 def test_energy(
456 self,
457 floe_params_all: np.ndarray,
458 wave_params_all: list[tuple[np.ndarray, np.ndarray]],
459 energies: np.ndarray,
460 integration_method: str | None,
461 j: int,
462 benchmark: BenchmarkFixture,
463 ):
464 """Compare energy to target.
466 Parameters
467 ----------
468 floe_params_all : np.ndarray
469 wave_params_all : list[tuple[np.ndarray, np.ndarray]]
470 energies : np.ndarray
471 integration_method : str | None
472 Which integration method to use. If none, compute the analytical
473 solution.
474 j : int
475 Index of the test case.
476 benchmark : BenchmarkFixture
478 Warns
479 -----
480 Four cases (j in {7, 15, 26, 29}) raise an IntegrationWarning when
481 using the quad method. This is expected and non consequantial; the
482 issue ("[...] The error may be underestimated.") happens when
483 generating the test cases, and the accuracy (when compared to other
484 methods) is correct. We thus filter the warnings when running the test
485 to avoid clutter.
487 """
488 benchmark.group = f"{str(self.target_dir).split()[-1]}_energy:case_{j:02d}"
489 floe_params = floe_params_all[j]
490 wave_params = wave_params_all[j]
491 handler = ph.EnergyHandler(floe_params, wave_params)
492 if integration_method is None:
493 an_sol = True
494 i = 0
495 else:
496 an_sol = False
497 i = 1 + INTEGRATION_METHODS.index(integration_method)
499 computed = benchmark(
500 handler.compute, an_sol=an_sol, integration_method=integration_method
501 )
503 assert np.allclose(computed, energies[i, j])
505 @pytest.mark.parametrize("integration_method", INTEGRATION_METHODS)
506 @pytest.mark.filterwarnings("ignore::scipy.integrate.IntegrationWarning")
507 def test_energy_with_growth(
508 self,
509 floe_params_all: np.ndarray,
510 wave_params_all: list[tuple[np.ndarray, np.ndarray]],
511 growth_params_all: list[tuple[np.ndarray, float]],
512 energies_growth: np.ndarray,
513 integration_method: str | None,
514 j: int,
515 benchmark: BenchmarkFixture,
516 ):
517 """Compare energy to target.
519 Parameters
520 ----------
521 floe_params_all : np.ndarray
522 wave_params_all : list[tuple[np.ndarray, np.ndarray]]
523 growth_params_all: list[tuple[np.ndarray, float]],
524 energies_growth : np.ndarray
525 integration_method : str
526 Which integration method to use.
527 j : int
528 Index of the test case.
529 benchmark : BenchmarkFixture
531 Warns
532 -----
533 Two cases (j in {7, 29}) raise an IntegrationWarning when
534 using the quad method, for the polychromatic case. This is expected and
535 non consequantial; the issue ("[...] The error may be underestimated.")
536 happens when generating the test cases, and the accuracy (when compared
537 to other methods) is correct. We thus filter the warnings when running
538 the test to avoid clutter.
540 """
541 benchmark.group = (
542 f"{str(self.target_dir).split()[-1]}_energy_with_growth:case_{j:02d}"
543 )
544 i_im = INTEGRATION_METHODS.index(integration_method)
545 floe_params = floe_params_all[j]
546 wave_params = wave_params_all[j]
547 growth_params = growth_params_all[j]
548 handler = ph.EnergyHandler(floe_params, wave_params, growth_params)
550 computed = benchmark(handler.compute, integration_method=integration_method)
552 assert np.allclose(computed, energies_growth[i_im, j])
555class TestPhysicsMono(_TestPhysics):
556 """Class dealing with monochromatic physics."""
558 target_dir = TARGET_DIR_MONO
559 n_cases = N_CASES_MONO
561 @pytest.fixture(scope="class")
562 def wave_params_all(self) -> list[tuple[np.ndarray, np.ndarray]]:
563 # Turn iterator into list as it will be used three times, by three
564 # handlers, and to allow indexing.
565 return list(
566 zip(
567 np.load(self.target_dir.joinpath("c_amplitudes.npy")),
568 np.load(self.target_dir.joinpath("c_wavenumbers.npy")),
569 )
570 )
572 @pytest.fixture(scope="class")
573 def growth_params_all(self) -> list[tuple[np.ndarray, float]]:
574 arr = self._load("growth_params.npy")
575 means = arr[:, 0]
576 stds = arr[:, 1]
577 return [(np.atleast_2d(mean), std) for mean, std in zip(means, stds)]
580class TestPhysicsPoly(_TestPhysics):
581 """Class dealing with polychromatic physics."""
583 target_dir = TARGET_DIR_POLY
584 n_cases = N_CASES_POLY
586 @pytest.fixture(scope="class")
587 def wave_params_all(self) -> list[tuple[np.ndarray, np.ndarray]]:
588 wave_params = np.load(self.target_dir.joinpath("wave_params.npz"))
589 # Check we do have the expected number of different numbers of frequencies.
590 assert len(wave_params) == N_N_FREQS
591 # Turn the dict-like Npz object into a list of tuples.
592 flat_list = [(v[:, 0], v[:, 1]) for vals in wave_params.values() for v in vals]
593 for i, nfreqs in enumerate(map(int, wave_params.keys())):
594 # Sanity check: we do have the expected number of frequencies
595 # (as provided by the keys of the Npz object).
596 for j in range(N_TRIES):
597 slice = flat_list[i * N_TRIES : (i + 1) * N_TRIES][j]
598 assert len(slice) == 2
599 assert len(slice[0]) == nfreqs # complex amplitudes
600 assert len(slice[1]) == nfreqs # complex wavenumbers
601 return flat_list
603 @pytest.fixture(scope="class")
604 def growth_params_all(self) -> list[tuple[np.ndarray, np.ndarray]]:
605 growth_params = np.load(self.target_dir.joinpath("growth_params.npz"))
606 assert len(growth_params) == N_N_FREQS
607 flat_list = [
608 (v[:, 0, None], v[:, 1, None])
609 for vals in growth_params.values()
610 for v in vals
611 ]
612 for i, nfreqs in enumerate(map(int, growth_params.keys())):
613 # Sanity check: we do have the expected number of frequencies
614 # (as provided by the keys of the Npz object).
615 for j in range(N_TRIES):
616 slice = flat_list[i * N_TRIES : (i + 1) * N_TRIES][j]
617 assert len(slice) == 2
618 for _s in slice:
619 assert _s.shape == (nfreqs, 1)
620 return flat_list