Coverage for tests/test_analytical_energy.py: 100%

63 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-08-30 14:00 +0200

1from collections.abc import Callable 

2import pathlib 

3 

4import numpy as np 

5import pytest 

6 

7import flexfrac1d.lib.physics as ph 

8 

9# TODO: fiture these tests instead of running loops in individual functions 

10 

11# Test configurations visually examined against solution from scipy.solve_bvp 

12PATH_DIS = pathlib.Path("tests/target/displacement") 

13PATH_CUR = pathlib.Path("tests/target/curvature") 

14PATH_EGY = pathlib.Path("tests/target/energy/") 

15PATH_PLY = pathlib.Path("tests/target/poly_analytical/") 

16 

17 

18def format_to_pack( 

19 red_num, length, wave_params_real 

20) -> tuple[tuple[float], tuple[np.ndarray]]: 

21 # format raw floats to fields of Handlers 

22 floe_params = red_num, length 

23 wave_params = tuple( 

24 map( 

25 np.atleast_1d, 

26 ( 

27 wave_params_real[0] * np.exp(1j * wave_params_real[3]), 

28 wave_params_real[1] + 1j * wave_params_real[2], 

29 ), 

30 ) 

31 ) 

32 return floe_params, wave_params 

33 

34 

35def read_header(handle: pathlib.Path): 

36 with open(handle, "r") as file: 

37 header = file.readline() 

38 # remove trailing '# ' and split 

39 red_num, length, *wave_params_real = map(float, header[2:-1].split(",")) 

40 return red_num, length, wave_params_real 

41 

42 

43def _test_analytical(root_dir: pathlib.Path, func: Callable): 

44 sentinel = 0 # make sure no error in path and at least one test was run 

45 for handle in root_dir.glob("*ssv"): 

46 sentinel += 1 

47 loaded = np.loadtxt(handle) 

48 # loaded[0]: along-floe space variable x 

49 # loaded[1]: reference values for func(x) 

50 floe_params, wave_params = format_to_pack(*read_header(handle)) 

51 handler = func(floe_params, wave_params) 

52 

53 # test func(x) against existing displacement 

54 assert np.allclose(loaded[1] - handler.compute(loaded[0]), 0) 

55 assert sentinel > 0 

56 

57 

58def test_displacement(): 

59 _test_analytical(PATH_DIS, ph.DisplacementHandler) 

60 

61 

62# @pytest.mark.parametrize 

63# def test_displacement_wuf(): 

64# # TODO: 

65# # * instantiate wui from parameters 

66# # * instantiate wuf from parameters and wui 

67# # * call `displacement` 

68# # * profit 

69# pass 

70 

71 

72def test_curvature(): 

73 _test_analytical(PATH_CUR, ph.CurvatureHandler) 

74 

75 

76def test_dce_poly(): 

77 sentinel = 0 

78 for handle in PATH_PLY.glob("*"): 

79 sentinel += 1 

80 loaded = np.loadtxt(handle.joinpath("values.ssv")) 

81 x, dis, cur = loaded 

82 egy = float(np.loadtxt(handle.joinpath("energy"))) 

83 floe_params = np.loadtxt(handle.joinpath("floe_params.ssv")) 

84 wave_params_real = np.loadtxt(handle.joinpath("wave_params.ssv")) 

85 assert len(wave_params_real.shape) == 2 

86 floe_params, wave_params = format_to_pack(*floe_params, wave_params_real) 

87 displacement = ph.DisplacementHandler(floe_params, wave_params).compute 

88 curvature = ph.CurvatureHandler(floe_params, wave_params).compute 

89 energy = ph.EnergyHandler(floe_params, wave_params).compute 

90 

91 _test_poly(dis, displacement, x, floe_params, wave_params) 

92 _test_poly(cur, curvature, x, floe_params, wave_params) 

93 _test_poly(egy, energy, floe_params, wave_params) 

94 assert sentinel > 1 

95 

96 

97def _test_poly(ref_val, function, *args): 

98 assert np.allclose(ref_val - function(*args), 0) 

99 

100 

101def test_energy(): 

102 loaded = np.loadtxt(PATH_EGY.joinpath("energy_mono.ssv")) 

103 for vars in loaded.T: 

104 red_num, length, *wave_params_real = vars[:-1] 

105 floe_params, wave_params = format_to_pack(red_num, length, wave_params_real) 

106 handler = ph.EnergyHandler(floe_params, wave_params) 

107 assert np.isclose(vars[-1] - handler.compute(), 0) 

108 

109 

110@pytest.mark.filterwarnings("error::RuntimeWarning") 

111def test_energy_no_attenuation(): 

112 floe_params = 0.34, 126.12 

113 wave_params = ( 

114 np.array([0.14808142 + 0.34891663j, 0.08581965 + 0.54191726j]), 

115 np.array([0.02674772 + 0.0j, 0.09422177 + 0.0j]), 

116 ) 

117 ph.EnergyHandler(floe_params, wave_params).compute()