Coverage for tests/test_analytical_energy.py: 100%
63 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-18 14:20 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-18 14:20 +0200
1from collections.abc import Callable
2import pathlib
4import numpy as np
5import pytest
7import swiift.lib.physics as ph
9# TODO: fiture these tests instead of running loops in individual functions
11# Test configurations visually examined against solution from scipy.solve_bvp
12PATH_DIS = pathlib.Path("tests/target/displacement")
13PATH_CUR = pathlib.Path("tests/target/curvature")
14PATH_EGY = pathlib.Path("tests/target/energy/")
15PATH_PLY = pathlib.Path("tests/target/poly_analytical/")
18def format_to_pack(
19 red_num, length, wave_params_real
20) -> tuple[tuple[float], tuple[np.ndarray]]:
21 # format raw floats to fields of Handlers
22 floe_params = red_num, length 1cade
23 wave_params = tuple( 1cade
24 map(
25 np.atleast_1d,
26 (
27 wave_params_real[0] * np.exp(1j * wave_params_real[3]),
28 wave_params_real[1] + 1j * wave_params_real[2],
29 ),
30 )
31 )
32 return floe_params, wave_params 1cade
35def read_header(handle: pathlib.Path):
36 with open(handle, "r") as file: 1cd
37 header = file.readline() 1cd
38 # remove trailing '# ' and split
39 red_num, length, *wave_params_real = map(float, header[2:-1].split(",")) 1cd
40 return red_num, length, wave_params_real 1cd
43def _test_analytical(root_dir: pathlib.Path, func: Callable):
44 sentinel = 0 # make sure no error in path and at least one test was run 1cd
45 for handle in root_dir.glob("*ssv"): 1cd
46 sentinel += 1 1cd
47 loaded = np.loadtxt(handle) 1cd
48 # loaded[0]: along-floe space variable x
49 # loaded[1]: reference values for func(x)
50 floe_params, wave_params = format_to_pack(*read_header(handle)) 1cd
51 handler = func(floe_params, wave_params) 1cd
53 # test func(x) against existing displacement
54 assert np.allclose(loaded[1] - handler.compute(loaded[0]), 0) 1cd
55 assert sentinel > 0 1cd
58def test_displacement():
59 _test_analytical(PATH_DIS, ph.DisplacementHandler) 1d
62# @pytest.mark.parametrize
63# def test_displacement_wuf():
64# # TODO:
65# # * instantiate wui from parameters
66# # * instantiate wuf from parameters and wui
67# # * call `displacement`
68# # * profit
69# pass
72def test_curvature():
73 _test_analytical(PATH_CUR, ph.CurvatureHandler) 1c
76def test_dce_poly():
77 sentinel = 0 1a
78 for handle in PATH_PLY.glob("*"): 1a
79 sentinel += 1 1a
80 loaded = np.loadtxt(handle.joinpath("values.ssv")) 1a
81 x, dis, cur = loaded 1a
82 egy = float(np.loadtxt(handle.joinpath("energy"))) 1a
83 floe_params = np.loadtxt(handle.joinpath("floe_params.ssv")) 1a
84 wave_params_real = np.loadtxt(handle.joinpath("wave_params.ssv")) 1a
85 assert len(wave_params_real.shape) == 2 1a
86 floe_params, wave_params = format_to_pack(*floe_params, wave_params_real) 1a
87 displacement = ph.DisplacementHandler(floe_params, wave_params).compute 1a
88 curvature = ph.CurvatureHandler(floe_params, wave_params).compute 1a
89 energy = ph.EnergyHandler(floe_params, wave_params).compute 1a
91 _test_poly(dis, displacement, x, floe_params, wave_params) 1a
92 _test_poly(cur, curvature, x, floe_params, wave_params) 1a
93 _test_poly(egy, energy, floe_params, wave_params) 1a
94 assert sentinel > 1 1a
97def _test_poly(ref_val, function, *args):
98 assert np.allclose(ref_val - function(*args), 0) 1a
101def test_energy():
102 loaded = np.loadtxt(PATH_EGY.joinpath("energy_mono.ssv")) 1e
103 for vars in loaded.T: 1e
104 red_num, length, *wave_params_real = vars[:-1] 1e
105 floe_params, wave_params = format_to_pack(red_num, length, wave_params_real) 1e
106 handler = ph.EnergyHandler(floe_params, wave_params) 1e
107 assert np.isclose(vars[-1] - handler.compute(), 0) 1e
110@pytest.mark.filterwarnings("error::RuntimeWarning")
111def test_energy_no_attenuation():
112 floe_params = 0.34, 126.12 1f
113 wave_params = ( 1f
114 np.array([0.14808142 + 0.34891663j, 0.08581965 + 0.54191726j]),
115 np.array([0.02674772 + 0.0j, 0.09422177 + 0.0j]),
116 )
117 ph.EnergyHandler(floe_params, wave_params).compute() 1f