Coverage for src/swiift/api/api.py: 0%
164 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-09-11 16:23 +0200
« prev ^ index » next coverage.py v7.9.1, created at 2025-09-11 16:23 +0200
1from __future__ import annotations
3from collections import namedtuple
4from collections.abc import Sequence
5import functools
6import logging
7import operator
8import pathlib
9import pickle
10import typing
11from typing import Any
13import attrs
14import numpy as np
16from .. import __about__
17from ..lib import att
18from ..model import frac_handlers as fh, model as md
20# TODO: make into an attrs class for more flexibility (repr of subdomains)
21Step = namedtuple("Step", ["subdomains", "growth_params"])
23logger = logging.getLogger(__name__)
26if typing.TYPE_CHECKING:
28 class _ProgressBarProtocol(typing.Protocol):
29 def update(self, n): ...
31 def close(self): ...
33 class _VerboseProgressBarProtocol(_ProgressBarProtocol):
34 @classmethod
35 def write(cls, *args): ...
38def _create_path(path: str | pathlib.Path) -> pathlib.Path:
39 _path = pathlib.Path(path)
40 if not _path.exists():
41 _path.mkdir(parents=True)
42 return _path
45def _load_pickle(fname: str | pathlib.Path) -> Experiment:
46 with open(fname, "rb") as file:
47 logger.info(f"Reading {fname}...")
48 instance = pickle.load(file)
49 if not isinstance(instance, Experiment):
50 raise TypeError("The pickled object is not an instance of `Experiment`.")
51 return instance
54def _glob(
55 pattern: str, dir_path: pathlib.Path, recursive: bool, **kwargs
56) -> list[Experiment]:
57 attribute = "rglob" if recursive else "glob"
58 iterator = getattr(dir_path, attribute)
59 return [_load_pickle(fname) for fname in iterator(pattern, **kwargs)]
62def _dct_keys_to_array(dct, dtype=np.float64) -> np.ndarray:
63 return np.fromiter(dct, dtype, count=len(dct))
66def _assemble_experiments(experiments: list[Experiment]) -> Experiment:
67 latest_experiment = max(experiments, key=lambda _exp: _exp.time)
68 common_history = functools.reduce(
69 operator.or_, (_exp.history for _exp in experiments)
70 )
71 sorted_keys = np.sort(_dct_keys_to_array(common_history))
72 common_history = {k: common_history[k] for k in sorted_keys}
73 return Experiment(
74 latest_experiment.time,
75 latest_experiment.domain,
76 common_history,
77 latest_experiment.fracture_handler,
78 )
81def _str_to_path(dir_path: str | pathlib.Path | None) -> pathlib.Path:
82 match dir_path:
83 case None:
84 _dir_path = pathlib.Path.cwd()
85 case _:
86 _dir_path = pathlib.Path(dir_path)
87 return _dir_path
90def load_pickles(
91 pattern: str,
92 dir_path: str | pathlib.Path | None = None,
93 recursive: bool = False,
94 **kwargs,
95) -> Experiment:
96 """Load pickle objects and assemble them into a single `Experiment`.
98 This function relies on `pathlib.Path`'s `glob` and `rglob` methods.
99 Files found matching the pattern are assembled into a single `Experiment`
100 object: thas is, histories are concatenated. Duplicated keys (timestep
101 entries) are thus lost. This function is therefore intended to be used
102 on files which the user knows have no overlap between their time axes.
104 Parameters
105 ----------
106 pattern : str
107 A pattern to glob upon.
108 root : str | pathlib.Path | None
109 The directory in which files will be looked for. If `None`, search from
110 the current working directory.
111 recursive : bool
112 Whether to search for the pattern recursively.
113 **kwargs
114 Arguments passed to `pathlib.Path.[r]glob`.
116 Returns
117 -------
118 Experiment
120 Raises
121 ------
122 FileNotFoundError
123 If no file matches the pattern.
124 ValueError
125 If a found file does not correspond to an instance of `Experiment`.
127 """
128 _dir_path = _str_to_path(dir_path)
130 experiments = _glob(pattern, _dir_path, recursive, **kwargs)
131 if len(experiments) == 0:
132 raise FileNotFoundError(f"No file matching {pattern} was found.")
133 return _assemble_experiments(experiments)
136def load_pickle(
137 fname: str | pathlib.Path,
138) -> Experiment:
139 """Read and return an `Experiment` object stored in a pickle file.
141 Parameters
142 ----------
143 fname : str | pathlib.Path
144 A file name or path object.
146 Returns
147 -------
148 Experiment
150 Raises
151 ------
152 FileNotFoundError
153 If files matching `fname` cannot be found.
155 """
156 return _load_pickle(fname)
159@attrs.define
160class Experiment:
161 time: float
162 domain: md.Domain
163 history: dict[float, Step] = attrs.field(factory=dict, repr=False)
164 fracture_handler: fh._FractureHandler = attrs.field(factory=fh.BinaryFracture)
166 @classmethod
167 def from_discrete(
168 cls,
169 gravity: float,
170 spectrum: md.DiscreteSpectrum,
171 ocean: md.Ocean,
172 growth_params: tuple | None = None,
173 fracture_handler: fh._FractureHandler | None = None,
174 attenuation_spec: att.Attenuation | None = None,
175 ):
176 if attenuation_spec is None:
177 attenuation_spec = att.AttenuationParameterisation(1)
178 domain = md.Domain.from_discrete(
179 gravity, spectrum, ocean, attenuation_spec, growth_params
180 )
182 if fracture_handler is None:
183 return cls(0, domain)
184 return cls(0, domain, fracture_handler=fracture_handler)
186 @property
187 def timesteps(self) -> np.ndarray:
188 """The experiment timesteps in s.
190 These can be used to index `self.history`.
192 Returns
193 -------
194 1D array
195 The existing timesteps.
197 """
198 return np.array(list(self.history.keys()))
200 def add_floes(self, floes: md.Floe | Sequence[md.Floe]):
201 self.domain.add_floes(floes)
202 self._save_step()
204 def _find_fracture_indices(self) -> np.ndarray[tuple[Any, ...], np.dtype[np.int_]]:
205 """Find the indices of states immediately before fracture.
207 Returns
208 -------
209 1D array of int
210 The indices of the current timesteps corresponding to the states
211 that broke on the next iteration.
213 """
214 _t = [len(step.subdomains) for step in self.history.values()]
215 return np.nonzero(np.ediff1d(_t))[0]
217 def get_pre_fracture_times(self) -> np.ndarray:
218 """Return the times corresponding to states immediately after fracture.
220 These can be used to index `self.history`.
222 Returns
223 -------
224 1D array
225 Output times.
227 """
228 return self.timesteps[self._find_fracture_indices()]
230 def get_post_fracture_times(self) -> np.ndarray:
231 """Return the times corresponding to states immediately after fracture.
233 These can be used to index `self.history`.
234 Note: in these states, the waves have been advected, compared to the
235 corresponding pre-fracture states.
237 Returns
238 -------
239 1D array
240 Output times.
242 """
243 return self.timesteps[self._find_fracture_indices() + 1]
245 def get_final_state(self) -> Step:
246 """Return the final state of the experiment.
248 Returns
249 -------
250 Step
251 The `Step` corresponding to the last timestep.
253 """
254 return self.history[next(reversed(self.history))]
256 def _save_step(self):
257 self.history[self.time] = Step(
258 tuple(wuf.make_copy() for wuf in self.domain.subdomains),
259 (
260 (self.domain.growth_params[0].copy(), self.domain.growth_params[1])
261 if self.domain.growth_params is not None
262 else None
263 ),
264 )
266 def step(
267 self,
268 delta_time: float,
269 an_sol: bool | None = None,
270 num_params: dict | None = None,
271 ):
272 """Move the experiment forward in time.
274 On step is a succession of events. First, the current floes are scanned
275 for fractures. The domain is eventually updated with the newly formed
276 fragments replacing the fractured floes. Then, the actual time
277 progression happens, by updating the wave phases at the edge of every
278 individual floe. Finally, this new state is saved to the history, at
279 the index corresponding to the updated time.
281 Parameters
282 ----------
283 delta_time : float
284 The time increment in second.
285 an_sol : bool, optional
286 Whether to force the use of a numerical or analytical solution for
287 the deflection of the floes.
288 num_params : dict, optional
289 Optional parameters to pass to the numerical solver, if applicable.
291 """
292 self.domain.breakup(self.fracture_handler, an_sol, num_params)
293 self.domain.iterate(delta_time)
294 self.time += delta_time
295 self._save_step()
297 def get_states(self, times: np.ndarray | float) -> dict[float, Step]:
298 """Return a subset of the history matching the given times.
300 Parameters
301 ----------
302 times : 1D array_like, float
303 Time, or sequence of times.
305 Returns
306 -------
307 dict[float, Step]
308 A dictionary containing the `Step`s closest to the input `times`.
310 """
311 times = np.ravel(times) # ensure we have exactly a 1D array
312 timestep_keys = _dct_keys_to_array(self.history)
313 indexes = (np.abs(times - timestep_keys[:, None])).argmin(axis=0)
314 return {k: self.history[k] for k in timestep_keys[indexes]}
316 def get_states_strict(self, times: np.ndarray | float) -> dict[float, Step]:
317 """Return a subset of the history matching the given times.
319 Parameters
320 ----------
321 times : np.ndarray | float
322 Time, or sequence of times.
324 Returns
325 -------
326 dict[float, Step]
327 A dictionary containing the `Step`s matching exactly the input.
329 """
330 times, sort_idx = np.unique(np.ravel(times), return_index=True)
331 timestep_keys = _dct_keys_to_array(self.history)
332 # We `unsort' the output of np.unique with the index `sort_idx`,
333 # so that values are returned in the order they were passed.
334 filtered_times = times[
335 np.isin(times[np.argsort(sort_idx)], timestep_keys, assume_unique=True)
336 ]
337 return {_time: self.history[_time] for _time in filtered_times}
339 def _time_interval_str(self):
340 first_time = next(iter(self.history))
341 return f"{first_time:.3f}--{self.time:.3f}"
343 def _generate_name(self, prefix: str | None) -> str:
344 if prefix is None:
345 prefix = f"{id(self):x}"
346 return prefix + f"_v{__about__.__version__}_" + self._time_interval_str()
348 def _dump(self, prefix: str | None, dir_path: pathlib.Path):
349 fname = f"{self._generate_name(prefix)}.pickle"
350 dir_path = _create_path(dir_path)
351 full_path = dir_path.joinpath(fname)
352 with open(full_path, "bw") as file:
353 pickle.dump(self, file)
355 def _clean_history(self):
356 current_state = self.get_final_state()
357 self.history.clear()
358 self.history[self.time] = current_state
360 def dump_history(
361 self,
362 prefix: str | None = None,
363 dir_path: str | pathlib.Path | None = None,
364 ):
365 """Write the results to disk and clear the history.
367 The whole object is pickled, before emptying the current history from
368 memory. The filename is constructed with the `prefix` passed as
369 argument, the package version number, and the time interval covered by
370 the history.
372 Parameters
373 ----------
374 prefix : str | None
375 Prefix for the file name. If none is provided, defaults to the `id`
376 of the `Experiment` object.
378 """
379 _dir_path = _str_to_path(dir_path)
380 self._dump(prefix, _dir_path)
381 self._clean_history()
383 def _should_terminate(
384 self,
385 initial_number_of_fragments: int,
386 number_of_fragments: int,
387 time_since_fracture: float,
388 break_time: float | None,
389 ):
390 return (
391 break_time is not None
392 and number_of_fragments > initial_number_of_fragments
393 and time_since_fracture > break_time
394 )
396 @typing.overload
397 def run(
398 self,
399 time: float,
400 delta_time: float,
401 break_time: float | None = ...,
402 chunk_size: int | None = ...,
403 verbose: None = ...,
404 pbar: _ProgressBarProtocol | None = ...,
405 path: str | pathlib.Path | None = ...,
406 dump_final: bool = ...,
407 dump_prefix: str | None = ...,
408 ): ...
410 @typing.overload
411 def run(
412 self,
413 time: float,
414 delta_time: float,
415 break_time: float | None = ...,
416 chunk_size: int | None = ...,
417 verbose: int = ...,
418 pbar: _VerboseProgressBarProtocol | None = ...,
419 path: str | pathlib.Path | None = ...,
420 dump_final: bool = ...,
421 dump_prefix: str | None = ...,
422 ): ...
424 def run(
425 self,
426 time: float,
427 delta_time: float,
428 break_time: float | None = None,
429 chunk_size: int | None = None,
430 verbose: int | None = None,
431 pbar: _ProgressBarProtocol | None = None,
432 path: str | pathlib.Path | None = None,
433 dump_final: bool = True,
434 dump_prefix: str | None = None,
435 ):
436 """Run the experiment for a specified duration.
438 The experiment is run from its current time for a duration
439 corresponding to `time`, with states regularly spaced with step
440 `delta_time`. If `time` is not an integer multiple of `delta_time`, the
441 number of steps will be rounded up. The experiment can optionally be
442 stopped before `time`, if no fracture happens for `break_time`, and at
443 least one fracture has occured.
445 The current object can be saved at regularly spaced step intervals, as
446 specified by `chunk_size`.
448 Optional messages can be printed to stdout, with a verbosity level
449 controlled by `verbose`.
451 A progress bar can be passed as an optional parameter to monitor the
452 experiment. The implementation expect an objects that behaves as a
453 `tqdm` bar; in particular, it needs to expose `update` and close
454 `method`. If used conjonctly with `verbose`, it also needs to expose a
455 `write` method.
457 Parameters
458 ----------
459 time : float
460 Duration to run the experiment for, in seconds.
461 delta_time : float
462 Time step between iterations, in seconds.
463 break_time : float | None
464 Time before stopping the experiment if no fracture occurs, in seconds.
465 chunk_size : int | None
466 Number of steps before writing the results to a file.
467 verbose : int | None
468 Verbosity level. If 1, outputs for disk writes. If 2, additional
469 outputs for fractures.
470 pbar : progress bar | None
471 Progress bar monitoring the experiment.
472 path : str | pathlib.Path | None
473 Directory where files will be saved. If none is provided, files
474 will be saved in the current directory.
475 dump_final : bool
476 Whether the results should be saved to disk at the end of the run
477 by calling `dump_history`, thus clearing the history from memory.
478 dump_prefix : str | None
479 Prefix for the file names used in the dumps. If none is provided,
480 defaults to the `id` of the `Experiment` object.
482 """
484 def pbar_print(msg: str, pbar: _VerboseProgressBarProtocol | None):
485 if pbar is not None:
486 pbar.write(msg)
487 else:
488 logger.info(msg)
490 @typing.overload
491 def dump_and_print(
492 dump_prefix: str | None,
493 path: str | pathlib.Path | None,
494 verbose: None,
495 pbar: _ProgressBarProtocol | None,
496 ): ...
498 @typing.overload
499 def dump_and_print(
500 dump_prefix: str | None,
501 path: str | pathlib.Path | None,
502 verbose: int,
503 pbar: _VerboseProgressBarProtocol | None,
504 ): ...
506 def dump_and_print(
507 dump_prefix,
508 path,
509 verbose,
510 pbar,
511 ):
512 self.dump_history(dump_prefix, path)
513 if verbose is not None and verbose >= 1:
514 msg = f"t = {self.time:.3f} s; history dumped"
515 pbar_print(msg, pbar)
517 initial_number_of_fragments = len(self.domain.subdomains)
518 number_of_fragments = initial_number_of_fragments
519 number_of_steps = np.ceil(time / delta_time).astype(int)
520 time_since_fracture = 0.0
521 if chunk_size is not None:
522 modulo_target = chunk_size - 1
524 for i in range(number_of_steps):
525 self.step(delta_time)
526 new_nof = len(self.domain.subdomains)
527 if new_nof > number_of_fragments:
528 time_since_fracture = 0
529 number_of_fragments = new_nof
530 if verbose is not None and verbose >= 2:
531 msg = f"t = {self.time:.3f} s; N_f = {number_of_fragments}"
532 pbar_print(msg, pbar)
533 else:
534 time_since_fracture += delta_time
536 if chunk_size is not None:
537 if i % chunk_size == modulo_target:
538 dump_and_print(dump_prefix, path, verbose, pbar)
540 if self._should_terminate(
541 initial_number_of_fragments,
542 number_of_fragments,
543 time_since_fracture,
544 break_time,
545 ):
546 msg = f"No fracture in {break_time:.3f} s, stopping"
547 pbar_print(msg, pbar)
548 break
550 if pbar is not None:
551 pbar.update(1)
553 if pbar is not None:
554 pbar.close()
556 # If single item in history, either we just dump, either we did not
557 # step, and there is no need to dump again.
558 if dump_final and len(self.history) > 1:
559 # No `pbar` passed as it should have been closed
560 dump_and_print(dump_prefix, path, verbose, None)