Coverage for src/swiift/model/model.py: 43%

282 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-09-11 16:23 +0200

1from __future__ import annotations 

2 

3from collections.abc import Sequence 

4import functools 

5import itertools 

6from numbers import Real 

7import operator 

8import typing 

9from typing import Self 

10 

11import attrs 

12import numpy as np 

13 

14from ..lib import att, dr, physics as ph 

15from ..lib.constants import PI_2, SQR2 

16from ..lib.graphics import plot_displacement 

17 

18if typing.TYPE_CHECKING: 

19 # Guard against circular imports 

20 from . import frac_handlers as fh 

21 

22 

23@attrs.define(frozen=True) 

24class Ocean: 

25 """The fluid bearing ice floes. 

26 

27 This class encapsulates the properties of an incompressible ocean of 

28 constant depth and given density. 

29 

30 Parameters 

31 ---------- 

32 depth : float 

33 Ocean depth in m 

34 density : float 

35 Ocean density in kg m^-3 

36 

37 """ 

38 

39 depth: float = np.inf 

40 density: float = 1025 

41 

42 

43@attrs.frozen 

44class _Subdomain: 

45 """A segment localised in space. 

46 

47 Parameters 

48 ---------- 

49 left_edge : float 

50 Coordinate of the left edge of the domain in m 

51 length : float 

52 Length of the domain in m 

53 

54 Attributes 

55 ---------- 

56 right_edge : float 

57 Coordinate of the right edge of the domain in m 

58 

59 """ 

60 

61 left_edge: float 

62 length: float 

63 

64 @functools.cached_property 

65 def right_edge(self): 

66 return self.left_edge + self.length 

67 

68 

69@attrs.define(frozen=True) 

70class Ice: 

71 """A container for ice mechanical properties. 

72 

73 Ice is modelled as an elastic thin plate, with prescribed density, 

74 thickness, Poisson's ratio and Young's modulus. Its fracture under bending 

75 is considered either through the lens of Griffith's fracture mechanics, or 

76 through the framework of strain failure commonly used in the sea ice 

77 modelling community. The fracture toughness is relevant to the former, 

78 while the strain threshold is relevant to the latter. Ice is considered 

79 translationally invariant in one horizontal direction, so that its 

80 quadratic moment of area is given per unit length in that direction. 

81 

82 Parameters 

83 ---------- 

84 density : float 

85 Density in kg m^-3 

86 frac_toughness : float 

87 Fracture toughness in Pa m^-1/2 

88 poissons_ratio : float 

89 Poisson's ratio 

90 strain_threshold : float 

91 Critical flexural strain in m m^-1 

92 thickness : float 

93 Ice thickness in m 

94 youngs_modulus : float 

95 Scalar Young's modulus in Pa 

96 

97 Attributes 

98 ---------- 

99 quad_moment : float 

100 Quadratic moment of area in m^3 

101 flex_rigidity : float 

102 Flexural rigidity in 

103 frac_energy_rate : float 

104 Fracture energy release rate in J m^-2 

105 

106 """ 

107 

108 density: float = 922.5 

109 frac_toughness: float = 1e5 

110 poissons_ratio: float = 0.3 

111 strain_threshold: float = 3e-5 

112 thickness: float = 1.0 

113 youngs_modulus: float = 6e9 

114 

115 @functools.cached_property 

116 def quad_moment(self) -> float: 

117 return self.thickness**3 / (12 * (1 - self.poissons_ratio**2)) 

118 

119 @functools.cached_property 

120 def flex_rigidity(self) -> float: 

121 return self.quad_moment * self.youngs_modulus 

122 

123 @functools.cached_property 

124 def frac_energy_rate(self) -> float: 

125 return ( 

126 (1 - self.poissons_ratio**2) * self.frac_toughness**2 / self.youngs_modulus 

127 ) 

128 

129 

130@attrs.define(kw_only=True, frozen=True) 

131class FloatingIce(Ice): 

132 """An extension of `Ice` to represent properties due to buyoancy. 

133 

134 Parameters 

135 ---------- 

136 draft : float 

137 Immersed ice thickness at rest in m 

138 dud : float 

139 Height of the water column underneath the ice at rest in m 

140 elastic_length_pow4 : float 

141 Characteristic elastic length scale, raised to the 4th power, in m^4 

142 

143 Attributes 

144 ---------- 

145 elastic_length : float 

146 Characteristic elastic length scale in m 

147 freeboard : float 

148 Emerged ice thickness at rest in m 

149 

150 """ 

151 

152 draft: float 

153 dud: float 

154 elastic_length_pow4: float 

155 

156 @classmethod 

157 def from_ice_ocean(cls, ice: Ice, ocean: Ocean, gravity: float) -> FloatingIce: 

158 """Build an instance by combining properties of existing objects. 

159 

160 Parameters 

161 ---------- 

162 ice : Ice 

163 ocean : Ocean 

164 gravity : float 

165 Strengh of the local gravitational field in m s^-2 

166 

167 Returns 

168 ------- 

169 FloatingIce 

170 

171 """ 

172 draft = ice.density / ocean.density * ice.thickness 

173 dud = ocean.depth - draft 

174 # NOTE: as the 4th power of the elastic length scale arises naturally, 

175 # we prefer using it to instantiate the class and computing the length 

176 # scale when needed, over using the length scale for instantiation and 

177 # recomputing the fourth power from it, as the latter approach can lead 

178 # to substantial numerical imprecision. 

179 el_lgth_pow4 = ice.flex_rigidity / (ocean.density * gravity) 

180 return cls( 

181 density=ice.density, 

182 frac_toughness=ice.frac_toughness, 

183 poissons_ratio=ice.poissons_ratio, 

184 strain_threshold=ice.strain_threshold, 

185 thickness=ice.thickness, 

186 youngs_modulus=ice.youngs_modulus, 

187 draft=draft, 

188 dud=dud, 

189 elastic_length_pow4=el_lgth_pow4, 

190 ) 

191 

192 @functools.cached_property 

193 def elastic_length(self): 

194 return self.elastic_length_pow4**0.25 

195 

196 @functools.cached_property 

197 def freeboard(self): 

198 return self.thickness - self.draft 

199 

200 @functools.cached_property 

201 def _elastic_number(self) -> float: 

202 """Reciprocal of the Characteristic elastic length scale. 

203 

204 Returns 

205 ------- 

206 float 

207 Elastic number in m^-1 

208 

209 """ 

210 return 1 / self.elastic_length 

211 

212 @functools.cached_property 

213 def _red_elastic_number(self) -> float: 

214 """Characteristic elastic number scaled by 1/sqrt(2). 

215 

216 Returns 

217 ------- 

218 float 

219 Reduced elastic number in m^-1 

220 

221 """ 

222 return 1 / (SQR2 * self.elastic_length) 

223 

224 

225@attrs.define(frozen=True) 

226class WavesUnderElasticPlate: 

227 """A non-localised zone characterised by wave action. 

228 

229 The spatial behaviour of waves (wavelength) is linked to their temporal 

230 behaviour (period) through a dispersion relation. In the case of waves 

231 propagating underneath floating ice, considered as an elastic plate, this 

232 dispersion relation depends on the properties of the ice as encapsulated by 

233 the `FloatingIce` class. 

234 

235 Parameters 

236 ---------- 

237 ice : FloatingIce 

238 wavenumbers : 1d array_like of float 

239 Propagating wavenumbers, in rad m^-1 

240 

241 """ 

242 

243 ice: FloatingIce 

244 wavenumbers: np.ndarray = attrs.field(repr=False) 

245 

246 @classmethod 

247 def from_floating( 

248 cls, 

249 ice: FloatingIce, 

250 spectrum: DiscreteSpectrum, 

251 gravity: float, 

252 ) -> typing.Self: 

253 """Build an instance by combining properties of existing objects. 

254 

255 Parameters 

256 ---------- 

257 ice : FloatingIce 

258 spectrum : DiscreteSpectrum 

259 gravity : float 

260 Strengh of the local gravitational field in m s^-2 

261 

262 Returns 

263 ------- 

264 WavesUnderElasticPlate 

265 

266 """ 

267 solver = dr.ElasticMassLoadingSolver.from_floating(ice, spectrum, gravity) 

268 wavenumbers = solver.compute_wavenumbers() 

269 return cls(ice, wavenumbers) 

270 

271 @classmethod 

272 def from_ocean( 

273 cls, 

274 ice: Ice, 

275 ocean: Ocean, 

276 spectrum: DiscreteSpectrum, 

277 gravity: float, 

278 ) -> typing.Self: 

279 """Build an instance by combining properties of existing objects. 

280 

281 Parameters 

282 ---------- 

283 ice : Ice 

284 ocean : Ocean 

285 spectrum : DiscreteSpectrum 

286 gravity : float 

287 Strengh of the local gravitational field in m s^-2 

288 

289 Returns 

290 ------- 

291 WavesUnderElasticPlate 

292 

293 """ 

294 floating_ice = FloatingIce.from_ice_ocean(ice, ocean, gravity) 

295 return cls.from_floating(floating_ice, spectrum, gravity) 

296 

297 

298# TODO: check docstring 

299@attrs.define(frozen=True) 

300class WavesUnderIce: 

301 """A non-localised zone characetrised by wave action under floating ice. 

302 

303 This class extends the behaviour of `WavesUnderElasticPlate` by adding an 

304 `attenuations` attribute, that parametrises the observed exponential decay 

305 of waves underneath floating ice. 

306 

307 Parameters 

308 ---------- 

309 ice : FloatingIce 

310 wavenumbers : 1d array_like of float 

311 Propagating wavenumbers, in rad m^-1 

312 attenuations : 1d array_like of float 

313 Parametrised wave amplitude attenuation rate, in m^-1 

314 

315 """ 

316 

317 ice: FloatingIce 

318 wavenumbers: np.ndarray = attrs.field(repr=False) 

319 attenuations: np.ndarray | float = attrs.field(repr=False) 

320 

321 @classmethod 

322 def without_attenuation(cls, waves_under_ep: WavesUnderElasticPlate) -> typing.Self: 

323 """Build an instance by combining properties of existing objects. 

324 

325 Parameters 

326 ---------- 

327 waves_under_ep : WavesUnderElasticPlate 

328 An object instance 

329 

330 Returns 

331 ------- 

332 WavesUnderIce 

333 

334 See Also 

335 ----- 

336 lib.att.no_attenuation 

337 

338 """ 

339 return cls( 

340 waves_under_ep.ice, 

341 waves_under_ep.wavenumbers, 

342 att.no_attenuation(), 

343 ) 

344 

345 @classmethod 

346 def with_attenuation_01(cls, waves_under_ep: WavesUnderElasticPlate) -> typing.Self: 

347 """Build an instance by combining properties of existing objects. 

348 

349 Parameters 

350 ---------- 

351 waves_under_ep : WavesUnderElasticPlate 

352 An object instance 

353 

354 Returns 

355 ------- 

356 WavesUnderIce 

357 

358 See Also 

359 ----- 

360 lib.att.parameterisation_01 

361 

362 """ 

363 return cls( 

364 waves_under_ep.ice, 

365 waves_under_ep.wavenumbers, 

366 att.parameterisation_01( 

367 waves_under_ep.ice.thickness, waves_under_ep.wavenumbers 

368 ), 

369 ) 

370 

371 @classmethod 

372 def with_attenuation_yu2022( 

373 cls, 

374 waves_under_ep: WavesUnderElasticPlate, 

375 gravity: float, 

376 angular_frequencies: np.ndarray, 

377 ) -> typing.Self: 

378 return cls( 

379 waves_under_ep.ice, 

380 waves_under_ep.wavenumbers, 

381 att.parameterisation_yu2022( 

382 waves_under_ep.ice.thickness, gravity, angular_frequencies 

383 ), 

384 ) 

385 

386 @classmethod 

387 def with_generic_attenuation( 

388 cls, 

389 waves_under_ep: WavesUnderElasticPlate, 

390 parameterisation: typing.Callable, 

391 args: str | None = None, 

392 **kwargs, 

393 ) -> typing.Self: 

394 """Instantiate a `WavesUnderFloe` with custom attenuation. 

395 

396 Parameters 

397 ---------- 

398 waves_under_ep : WavesUnderElasticPlate 

399 An object instance. 

400 parameterisation : typing.Callable 

401 Function defining attenuation. 

402 Must return a type broadcastable to `waves_under_ep.wavenumbers`. 

403 args : str | None 

404 A string of attributes of `waves_under_ep`, separated by 

405 whitespace, to be passed as parameters to `parameterisation`. 

406 All parameters will be passed as a mapping between the stem of the 

407 attribute and its value. 

408 **kwargs : dict 

409 Additional parameters to pass to `parameterisation`. 

410 

411 Returns 

412 ------- 

413 WavesUnderFloe 

414 

415 Examples 

416 -------- 

417 Assuming an existing `wue` instance of `WavesUnderElasticPlate`, the 

418 three following objects are identical, setting the attenuation egal to 

419 the ice density for all wave modes. 

420 

421 >>> WavesUnderIce.with_generic_attenuation_param( 

422 wue, 

423 lambda density: density, 

424 "ice.density" 

425 ) 

426 >>> WavesUnderIce.with_generic_attenuation_param( 

427 wue, 

428 lambda density: density, 

429 {"density": wue.ice.density}, 

430 ) 

431 >>> WavesUnderIce(wue.ice, wue.wavenumbers, wue.ice.density) 

432 

433 """ 

434 if args is not None: 

435 kwargs |= { 

436 arg.split(".")[-1]: operator.attrgetter(arg)(waves_under_ep) 

437 for arg in args.split() 

438 } 

439 return cls( 

440 waves_under_ep.ice, waves_under_ep.wavenumbers, parameterisation(**kwargs) 

441 ) 

442 

443 @functools.cached_property 

444 def _c_wavenumbers(self) -> np.ndarray: 

445 """Complex wavenumbers. 

446 

447 Their real part correspond to the propagating wavenumber, while their 

448 imaginary part correspond to the attenuation rate. 

449 

450 Returns 

451 ------- 

452 1d np.ndarray of complex 

453 The complex wavenumbers in m^-1 

454 

455 """ 

456 return self.wavenumbers + 1j * self.attenuations 

457 

458 

459@attrs.define(frozen=True) 

460class FreeSurfaceWaves: 

461 """The wave state in the absence of ice. 

462 

463 The spatial behaviour of waves (wavelength) is linked to their temporal 

464 behaviour (period) through a dispersion relation. In the case of free 

465 surface waves, propagating underneath floating ice, this dispersion 

466 relation depends on the properties of the ocean as encapsulated in the 

467 `Ocean` class. 

468 

469 Parameters 

470 ---------- 

471 ocean : Ocean 

472 wavenumbers : array_like 

473 Propagating wavenumbers in rad m^-1 

474 

475 Attributes 

476 ---------- 

477 wavelengths : 1d np.ndarray of float 

478 Propagating wavelengths in m 

479 

480 """ 

481 

482 ocean: Ocean 

483 wavenumbers: np.ndarray 

484 

485 @classmethod 

486 def from_ocean(cls, ocean: Ocean, spectrum: DiscreteSpectrum, gravity: float): 

487 """Build an instance by combining properties of existing objects.""" 

488 solver = dr.FreeSurfaceSolver.from_ocean(ocean, spectrum, gravity) 

489 wavenumbers = solver.compute_wavenumbers() 

490 return cls(ocean, wavenumbers) 

491 

492 @functools.cached_property 

493 def wavelengths(self) -> np.ndarray: 

494 return PI_2 / self.wavenumbers 

495 

496 

497# TODO: docstring inheritance 

498@attrs.define(kw_only=True) 

499class Floe(_Subdomain): 

500 """An ice floe localised in space. 

501 

502 Parameters 

503 ---------- 

504 ice : Ice 

505 The mechanical properties of the floe 

506 

507 """ 

508 

509 ice: Ice 

510 

511 

512@attrs.define(kw_only=True) 

513class WavesUnderFloe(_Subdomain): 

514 """A localised zone characetrised by wave action under floating ice. 

515 

516 Parameters 

517 ---------- 

518 wui : WavesUnderIce 

519 edge_amplitudes : 1d np.ndarray of complex 

520 The wave complex amplitude at the left edge of the floe in m 

521 generation : int 

522 The number of fractures that led to the existence of this floe 

523 

524 """ 

525 

526 wui: WavesUnderIce 

527 edge_amplitudes: np.ndarray 

528 generation: int = 0 

529 

530 @functools.cached_property 

531 def _adim(self) -> float: 

532 """A non-dimentional number characetrising the floe. 

533 

534 Returns 

535 ------- 

536 float 

537 

538 """ 

539 return self.length * self.wui.ice._red_elastic_number 

540 

541 # TODO: typing.Self? 

542 def make_copy(self) -> WavesUnderFloe: 

543 return WavesUnderFloe( 

544 left_edge=self.left_edge, 

545 length=self.length, 

546 wui=self.wui, 

547 edge_amplitudes=self.edge_amplitudes.copy(), 

548 generation=self.generation, 

549 ) 

550 

551 @typing.overload 

552 def shift_waves( 

553 self, phase_shifts: np.ndarray, inplace: typing.Literal[True] = ... 

554 ): ... 

555 

556 @typing.overload 

557 def shift_waves( 

558 self, phase_shifts: np.ndarray, inplace: typing.Literal[False] = ... 

559 ) -> Self: ... 

560 

561 # TODO: docstring 

562 def shift_waves(self, phase_shifts: np.ndarray, inplace: bool = True): 

563 shifted_amplitudes = self.edge_amplitudes * np.exp(-1j * phase_shifts) 

564 if not inplace: 

565 return WavesUnderFloe( 

566 left_edge=self.left_edge, 

567 length=self.length, 

568 wui=self.wui, 

569 edge_amplitudes=shifted_amplitudes, 

570 generation=self.generation, 

571 ) 

572 # HACK: instantiating an new object is cheap, resorting a whole list of 

573 # subdomains is not, so we mutate the amplitude/phase instead 

574 object.__setattr__(self, "edge_amplitudes", shifted_amplitudes) 

575 

576 def displacement( 

577 self, 

578 x: np.ndarray, 

579 growth_params=None, 

580 an_sol: bool | None = None, 

581 num_params=None, 

582 ): 

583 return ph.DisplacementHandler.from_wuf(self, growth_params).compute( 

584 x, an_sol, num_params 

585 ) 

586 

587 def curvature( 

588 self, 

589 x: np.ndarray, 

590 growth_params=None, 

591 an_sol: bool | None = None, 

592 num_params=None, 

593 is_linear: bool | None = None, 

594 ): 

595 return ph.CurvatureHandler.from_wuf(self, growth_params).compute( 

596 x, an_sol, num_params, is_linear 

597 ) 

598 

599 def energy( 

600 self, 

601 growth_params=None, 

602 an_sol: bool | None = None, 

603 num_params=None, 

604 linear_curvature: bool | None = None, 

605 ): 

606 return ph.EnergyHandler.from_wuf(self, growth_params).compute( 

607 an_sol, num_params, linear_curvature 

608 ) 

609 

610 # TODO: method to return the local wave forcing 

611 

612 

613@attrs.frozen(init=False) 

614class DiscreteSpectrum: 

615 amplitudes: np.ndarray 

616 frequencies: np.ndarray 

617 phases: np.ndarray 

618 

619 def __init__( 

620 self, 

621 amplitudes: Sequence[float] | float, 

622 frequencies: Sequence[float] | float, 

623 phases: Sequence[float] | float = 0, 

624 ): 

625 

626 # np.ravel to force precisely 1D-arrays 

627 # Promote the map to list so the iterator can be used several times. 

628 # Eventual phases are modulo'd to 2pi rad. 

629 args = list(map(np.ravel, (amplitudes, frequencies, phases))) 

630 (size,) = np.broadcast_shapes(*(arr.shape for arr in args)) 

631 

632 # If size is one, all the arguments are scalar and the "spectrum" is 

633 # monochromatic. Otherwise, there is at least one argument with 

634 # different components. The eventual other arguments with a single 

635 # component are repeated so that the three arrays have the same size. 

636 if size != 1: 

637 for i, arr in enumerate(args): 

638 if arr.size == 0: 

639 raise ValueError(f"The spectral argument {i} is empty.") 

640 if arr.size == 1: 

641 args[i] = args[i] * np.ones(size) 

642 

643 # Remove entries corresponding to nan in any of the array 

644 nan_mask = functools.reduce(operator.or_, (np.isnan(arr) for arr in args)) 

645 args = [arr[~nan_mask] for arr in args] 

646 # Sort arrays by the frequency values 

647 sk = np.argsort(args[1]) 

648 _amplitudes, _frequencies, _phases = (arr[sk] for arr in args) 

649 _phases = _phases % PI_2 

650 

651 self.__attrs_init__(_amplitudes, _frequencies, _phases) 

652 

653 @classmethod 

654 def from_periods( 

655 cls, 

656 amplitudes: Sequence[float] | float, 

657 periods: Sequence[float] | float, 

658 phases: Sequence[float] | float = 0, 

659 ): 

660 return cls(amplitudes, 1 / np.asarray(periods), phases) 

661 

662 @functools.cached_property 

663 def periods(self): 

664 return 1 / self.frequencies 

665 

666 @functools.cached_property 

667 def angular_frequencies(self): 

668 return self.frequencies * PI_2 

669 

670 @functools.cached_property 

671 def _ang_freqs_pow2(self): 

672 return self.angular_frequencies**2 

673 

674 @functools.cached_property 

675 def nf(self): 

676 return len(self.frequencies) 

677 

678 @functools.cached_property 

679 def energy(self): 

680 return np.sum(self.amplitudes**2) / 2 

681 

682 

683# TODO: docstrings 

684@attrs.define 

685class Domain: 

686 """A spatial domain forced by waves. 

687 

688 This represents the state of a MIZ at a given time. 

689 

690 

691 Attributes 

692 ---------- 

693 gravity : float 

694 spectrum : DiscreteSpectrum 

695 fsw : FreeSurfaceWaves 

696 attenuation: flexrac1d.lib.att.Attenuation 

697 growth_params : list 

698 subdomains : list of WavesUnderFloe 

699 cached_wuis : 

700 cached_phases : 

701 

702 """ 

703 

704 gravity: float 

705 spectrum: DiscreteSpectrum 

706 fsw: FreeSurfaceWaves 

707 attenuation: att.Attenuation = attrs.field(repr=False) 

708 growth_params: tuple[np.ndarray, float] | None = None 

709 subdomains: list[WavesUnderFloe] = attrs.field(repr=False, init=False, factory=list) 

710 cached_wuis: dict[Ice, WavesUnderIce] = attrs.field( 

711 repr=False, init=False, factory=dict 

712 ) 

713 cached_phases: dict[float, np.ndarray] = attrs.field( 

714 repr=False, init=False, factory=dict 

715 ) 

716 

717 @classmethod 

718 def from_discrete( 

719 cls, 

720 gravity, 

721 spectrum, 

722 ocean, 

723 attenuation: att.Attenuation | None = None, 

724 growth_params: tuple | None = None, 

725 ): 

726 fsw = FreeSurfaceWaves.from_ocean(ocean, spectrum, gravity) 

727 if attenuation is None: 

728 attenuation = att.AttenuationParameterisation(1) 

729 return cls(gravity, spectrum, fsw, attenuation, growth_params) 

730 

731 @classmethod 

732 def with_growth_means( 

733 cls, 

734 gravity: float, 

735 spectrum: DiscreteSpectrum, 

736 ocean: Ocean, 

737 growth_means: np.ndarray | Sequence[Real] | Real, 

738 attenuation: att.Attenuation | None = None, 

739 ) -> typing.Self: 

740 return cls.from_discrete( 

741 gravity, spectrum, ocean, attenuation, (growth_means, None) 

742 ) 

743 

744 @classmethod 

745 def with_growth_std( 

746 cls, 

747 gravity: float, 

748 spectrum: DiscreteSpectrum, 

749 ocean: Ocean, 

750 growth_std: Real, 

751 attenuation: att.Attenuation | None = None, 

752 ) -> typing.Self: 

753 return cls.from_discrete(gravity, spectrum, ocean, attenuation, (0, growth_std)) 

754 

755 def __attrs_post_init__(self): 

756 if self.growth_params is not None: 

757 if len(self.growth_params) != 2: 

758 raise ValueError 

759 growth_means, growth_std = ( 

760 np.asarray(self.growth_params[0]), 

761 self.growth_params[1], 

762 ) 

763 # TODO: simplify all this. Ideally, do not test for anything or 

764 # babysit the user. Why was upping growth_mean to a column 

765 # necessary in case its of size 1? 

766 if growth_means.size == 1: 

767 # As `broadcast_to` returns a view, 

768 # copying is necessary to obtain a mutable array. It is easier 

769 # than dealing with 0-length and 1-length arrays seperately. 

770 growth_means = np.broadcast_to( 

771 growth_means, (self.spectrum.nf, 1) 

772 ).copy() 

773 else: 

774 if growth_means.size != self.spectrum.nf: 

775 raise ValueError( 

776 f"Means (size {growth_means.size}) could not be" 

777 "broadcast with the shape of the spectrum" 

778 f"({self.spectrum.nf} components)" 

779 ) 

780 if growth_std is None: 

781 growth_std = self.fsw.wavelengths[self.spectrum.amplitudes.argmax()] 

782 self.growth_params = [growth_means, growth_std] 

783 

784 def _compute_phase_shifts(self, delta_time: float) -> np.ndarray: 

785 if delta_time not in self.cached_phases: 

786 self.cached_phases[delta_time] = ( 

787 delta_time * self.spectrum.angular_frequencies 

788 ) 

789 return self.cached_phases[delta_time] 

790 

791 def _compute_wui(self, ice: Ice): 

792 if ice not in self.cached_wuis: 

793 wup = WavesUnderElasticPlate.from_ocean( 

794 ice, self.fsw.ocean, self.spectrum, self.gravity 

795 ) 

796 if isinstance(self.attenuation, att.AttenuationParameterisation): 

797 if self.attenuation == att.AttenuationParameterisation.NO: 

798 wui = WavesUnderIce.without_attenuation(wup) 

799 elif self.attenuation == att.AttenuationParameterisation.PARAM_01: 

800 wui = WavesUnderIce.with_attenuation_01(wup) 

801 elif self.attenuation == att.AttenuationParameterisation.PARAM_YU_2022: 

802 wui = WavesUnderIce.with_attenuation_yu2022( 

803 wup, self.gravity, self.spectrum.angular_frequencies 

804 ) 

805 else: 

806 wui = WavesUnderIce.with_generic_attenuation( 

807 wup, 

808 self.attenuation.function, 

809 self.attenuation.args, 

810 **self.attenuation.kwargs, 

811 ) 

812 self.cached_wuis[ice] = wui 

813 return self.cached_wuis[ice] 

814 

815 def _shift_phases(self, phases: np.ndarray): 

816 # NOTE: doesn't seem to be called 

817 for i in range(len(self.subdomains)): 

818 self.subdomains[i].phases -= phases 

819 

820 def _shift_growth_means(self, phases: np.ndarray): 

821 # TODO: refine to take into account subdomain transitions 

822 # and floes with variying properties 

823 mask = self.growth_params[0] < self.subdomains[0].left_edge 

824 if mask.any(): 

825 self.growth_params[0][mask] += ( 

826 phases[mask[:, 0]] / self.fsw.wavenumbers[mask[:, 0]] 

827 ) 

828 if not mask.all(): 

829 self.growth_params[0][~mask] += ( 

830 phases[~mask[:, 0]] / self.subdomains[0].wui.wavenumbers[~mask[:, 0]] 

831 ) 

832 

833 def add_floes(self, floes: Floe | Sequence[Floe]): 

834 self.subdomains = self._init_subdomains(floes) 

835 

836 @staticmethod 

837 def _promote_floe(floes: Floe | Sequence[Floe]) -> Sequence[Floe]: 

838 match floes: 

839 case Floe(): 

840 return (floes,) 

841 case Sequence(): 

842 return floes 

843 case _: 

844 raise ValueError( 

845 "`floes` should be a `Floe` object or a sequence of such objects" 

846 ) 

847 

848 def _check_overlap(self, floes: Sequence[Floe]): 

849 l_edges, r_edges = map( 

850 np.array, zip(*((floe.left_edge, floe.right_edge) for floe in floes)) 

851 ) 

852 if not (r_edges[:-1] <= l_edges[1:]).all(): 

853 raise ValueError("Floe overlap") # TODO: dedicated exception 

854 

855 def _init_phases(self, floes: Sequence[Floe]) -> np.ndarray: 

856 phases = np.full((len(floes), self.spectrum.nf), np.nan) 

857 phases[0] = self.spectrum.phases + floes[0].left_edge * self.fsw.wavenumbers 

858 for i, floe in enumerate(floes[1:], 1): 

859 wui = self._compute_wui(floe.ice) 

860 prev = floes[i - 1] 

861 phases[i:,] = ( 

862 phases[i - 1] 

863 + floe.length * wui.wavenumbers 

864 + (prev.right_edge - floe.left_edge) * self.fsw.wavenumbers 

865 ) 

866 return phases % PI_2 

867 

868 def _init_amplitudes(self, floes: Sequence[Floe]) -> np.ndarray: 

869 amplitudes = np.full((len(floes), self.spectrum.nf), np.nan) 

870 amplitudes[0, :] = self.spectrum.amplitudes 

871 for i, floe in enumerate(floes[1:], 1): 

872 amplitudes[i, :] = amplitudes[i - 1, :] * np.exp( 

873 -self._compute_wui(floe.ice).attenuations * floe.length 

874 ) 

875 return amplitudes 

876 

877 def _init_subdomains(self, floes: Floe | Sequence[Floe]) -> list[WavesUnderFloe]: 

878 # TODO: look for already existing floes. In the present state, only 

879 # valid for starting from scratch, not for adding floes to a domain 

880 # that already has some. 

881 floes = self.__class__._promote_floe(floes) 

882 self._check_overlap(floes) 

883 complex_amplitudes = self._init_amplitudes(floes) * np.exp( 

884 1j * self._init_phases(floes) 

885 ) 

886 

887 return [ 

888 WavesUnderFloe( 

889 left_edge=floe.left_edge, 

890 length=floe.length, 

891 wui=self._compute_wui(floe.ice), 

892 edge_amplitudes=edge_amplitudes, 

893 ) 

894 for floe, edge_amplitudes in zip(floes, complex_amplitudes) 

895 ] 

896 

897 def iterate(self, delta_time: float): 

898 phase_shifts = self._compute_phase_shifts(delta_time) 

899 # TODO: can be optimised by iterating a first time to extract the 

900 # edges, coerce them to a np.array, apply the product with 

901 # complex_shifts, and then iterate a second time to build the objects. 

902 # See Propagation_tests.ipynb/DNE06-26 

903 for i in range(len(self.subdomains)): 

904 self.subdomains[i].shift_waves(phase_shifts) 

905 if self.growth_params is not None: 

906 # Phases are only modulo'd in the setter 

907 self._shift_growth_means(phase_shifts) 

908 

909 def breakup( 

910 self, 

911 fracture_handler: fh._FractureHandler, 

912 an_sol=None, 

913 num_params=None, 

914 ): 

915 def get_broken_wufs(wuf: WavesUnderFloe) -> list[WavesUnderFloe]: 

916 xf = fracture_handler.search(wuf, self.growth_params, an_sol, num_params) 

917 if xf is None: 

918 return [wuf] 

919 else: 

920 return fracture_handler.split(wuf, xf) 

921 

922 self.subdomains = list( 

923 itertools.chain(*(get_broken_wufs(wuf) for wuf in self.subdomains)) 

924 ) 

925 

926 def plot( 

927 self, 

928 resolution: float, 

929 left_bound: float, 

930 ax=None, 

931 an_sol=None, 

932 add_surface=True, 

933 base=0, 

934 kw_dis=None, 

935 kw_sur=None, 

936 ): 

937 plot_displacement( 

938 resolution, self, left_bound, ax, an_sol, add_surface, base, kw_dis, kw_sur 

939 )