Coverage for src/flexfrac1d/api/api.py: 100%

33 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-08-30 14:00 +0200

1from __future__ import annotations 

2 

3from collections import namedtuple 

4from collections.abc import Sequence 

5 

6import attrs 

7 

8from ..lib import att 

9from ..model import frac_handlers as fh, model as md 

10 

11# TODO: make into an attrs class for more flexibility (repr of subdomains) 

12Step = namedtuple("Step", ["subdomains", "growth_params"]) 

13 

14 

15@attrs.define 

16class Experiment: 

17 time: float 

18 domain: md.Domain 

19 history: dict[float, Step] = attrs.field(init=False, factory=dict, repr=False) 

20 fracture_handler: fh._FractureHandler = attrs.field(factory=fh.BinaryFracture) 

21 

22 @classmethod 

23 def from_discrete( 

24 cls, 

25 gravity: float, 

26 spectrum: md.DiscreteSpectrum, 

27 ocean: md.Ocean, 

28 growth_params: tuple | None = None, 

29 fracture_handler: fh._FractureHandler | None = None, 

30 attenuation_spec: att.Attenuation | None = None, 

31 ): 

32 if attenuation_spec is None: 

33 attenuation_spec = att.AttenuationParameterisation(1) 

34 domain = md.Domain.from_discrete( 

35 gravity, spectrum, ocean, attenuation_spec, growth_params 

36 ) 

37 

38 if fracture_handler is None: 

39 return cls(0, domain) 

40 return cls(0, domain, fracture_handler) 

41 

42 def add_floes(self, floes: md.Floe | Sequence[md.Floe]): 

43 self.domain.add_floes(floes) 

44 self.save_step() 

45 

46 def last_step(self): 

47 return self.history[next(reversed(self.history))] 

48 

49 def save_step(self): 

50 self.history[self.time] = Step( 

51 tuple(wuf.make_copy() for wuf in self.domain.subdomains), 

52 ( 

53 (self.domain.growth_params[0].copy(), self.domain.growth_params[1]) 

54 if self.domain.growth_params is not None 

55 else None 

56 ), 

57 ) 

58 

59 def step(self, delta_time: float, an_sol=None, num_params=None): 

60 self.domain.breakup(self.fracture_handler, an_sol, num_params) 

61 self.domain.iterate(delta_time) 

62 self.time += delta_time 

63 self.save_step()