Coverage for src/swiift/api/api.py: 0%

164 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-09-11 16:23 +0200

1from __future__ import annotations 

2 

3from collections import namedtuple 

4from collections.abc import Sequence 

5import functools 

6import logging 

7import operator 

8import pathlib 

9import pickle 

10import typing 

11from typing import Any 

12 

13import attrs 

14import numpy as np 

15 

16from .. import __about__ 

17from ..lib import att 

18from ..model import frac_handlers as fh, model as md 

19 

20# TODO: make into an attrs class for more flexibility (repr of subdomains) 

21Step = namedtuple("Step", ["subdomains", "growth_params"]) 

22 

23logger = logging.getLogger(__name__) 

24 

25 

26if typing.TYPE_CHECKING: 

27 

28 class _ProgressBarProtocol(typing.Protocol): 

29 def update(self, n): ... 

30 

31 def close(self): ... 

32 

33 class _VerboseProgressBarProtocol(_ProgressBarProtocol): 

34 @classmethod 

35 def write(cls, *args): ... 

36 

37 

38def _create_path(path: str | pathlib.Path) -> pathlib.Path: 

39 _path = pathlib.Path(path) 

40 if not _path.exists(): 

41 _path.mkdir(parents=True) 

42 return _path 

43 

44 

45def _load_pickle(fname: str | pathlib.Path) -> Experiment: 

46 with open(fname, "rb") as file: 

47 logger.info(f"Reading {fname}...") 

48 instance = pickle.load(file) 

49 if not isinstance(instance, Experiment): 

50 raise TypeError("The pickled object is not an instance of `Experiment`.") 

51 return instance 

52 

53 

54def _glob( 

55 pattern: str, dir_path: pathlib.Path, recursive: bool, **kwargs 

56) -> list[Experiment]: 

57 attribute = "rglob" if recursive else "glob" 

58 iterator = getattr(dir_path, attribute) 

59 return [_load_pickle(fname) for fname in iterator(pattern, **kwargs)] 

60 

61 

62def _dct_keys_to_array(dct, dtype=np.float64) -> np.ndarray: 

63 return np.fromiter(dct, dtype, count=len(dct)) 

64 

65 

66def _assemble_experiments(experiments: list[Experiment]) -> Experiment: 

67 latest_experiment = max(experiments, key=lambda _exp: _exp.time) 

68 common_history = functools.reduce( 

69 operator.or_, (_exp.history for _exp in experiments) 

70 ) 

71 sorted_keys = np.sort(_dct_keys_to_array(common_history)) 

72 common_history = {k: common_history[k] for k in sorted_keys} 

73 return Experiment( 

74 latest_experiment.time, 

75 latest_experiment.domain, 

76 common_history, 

77 latest_experiment.fracture_handler, 

78 ) 

79 

80 

81def _str_to_path(dir_path: str | pathlib.Path | None) -> pathlib.Path: 

82 match dir_path: 

83 case None: 

84 _dir_path = pathlib.Path.cwd() 

85 case _: 

86 _dir_path = pathlib.Path(dir_path) 

87 return _dir_path 

88 

89 

90def load_pickles( 

91 pattern: str, 

92 dir_path: str | pathlib.Path | None = None, 

93 recursive: bool = False, 

94 **kwargs, 

95) -> Experiment: 

96 """Load pickle objects and assemble them into a single `Experiment`. 

97 

98 This function relies on `pathlib.Path`'s `glob` and `rglob` methods. 

99 Files found matching the pattern are assembled into a single `Experiment` 

100 object: thas is, histories are concatenated. Duplicated keys (timestep 

101 entries) are thus lost. This function is therefore intended to be used 

102 on files which the user knows have no overlap between their time axes. 

103 

104 Parameters 

105 ---------- 

106 pattern : str 

107 A pattern to glob upon. 

108 root : str | pathlib.Path | None 

109 The directory in which files will be looked for. If `None`, search from 

110 the current working directory. 

111 recursive : bool 

112 Whether to search for the pattern recursively. 

113 **kwargs 

114 Arguments passed to `pathlib.Path.[r]glob`. 

115 

116 Returns 

117 ------- 

118 Experiment 

119 

120 Raises 

121 ------ 

122 FileNotFoundError 

123 If no file matches the pattern. 

124 ValueError 

125 If a found file does not correspond to an instance of `Experiment`. 

126 

127 """ 

128 _dir_path = _str_to_path(dir_path) 

129 

130 experiments = _glob(pattern, _dir_path, recursive, **kwargs) 

131 if len(experiments) == 0: 

132 raise FileNotFoundError(f"No file matching {pattern} was found.") 

133 return _assemble_experiments(experiments) 

134 

135 

136def load_pickle( 

137 fname: str | pathlib.Path, 

138) -> Experiment: 

139 """Read and return an `Experiment` object stored in a pickle file. 

140 

141 Parameters 

142 ---------- 

143 fname : str | pathlib.Path 

144 A file name or path object. 

145 

146 Returns 

147 ------- 

148 Experiment 

149 

150 Raises 

151 ------ 

152 FileNotFoundError 

153 If files matching `fname` cannot be found. 

154 

155 """ 

156 return _load_pickle(fname) 

157 

158 

159@attrs.define 

160class Experiment: 

161 time: float 

162 domain: md.Domain 

163 history: dict[float, Step] = attrs.field(factory=dict, repr=False) 

164 fracture_handler: fh._FractureHandler = attrs.field(factory=fh.BinaryFracture) 

165 

166 @classmethod 

167 def from_discrete( 

168 cls, 

169 gravity: float, 

170 spectrum: md.DiscreteSpectrum, 

171 ocean: md.Ocean, 

172 growth_params: tuple | None = None, 

173 fracture_handler: fh._FractureHandler | None = None, 

174 attenuation_spec: att.Attenuation | None = None, 

175 ): 

176 if attenuation_spec is None: 

177 attenuation_spec = att.AttenuationParameterisation(1) 

178 domain = md.Domain.from_discrete( 

179 gravity, spectrum, ocean, attenuation_spec, growth_params 

180 ) 

181 

182 if fracture_handler is None: 

183 return cls(0, domain) 

184 return cls(0, domain, fracture_handler=fracture_handler) 

185 

186 @property 

187 def timesteps(self) -> np.ndarray: 

188 """The experiment timesteps in s. 

189 

190 These can be used to index `self.history`. 

191 

192 Returns 

193 ------- 

194 1D array 

195 The existing timesteps. 

196 

197 """ 

198 return np.array(list(self.history.keys())) 

199 

200 def add_floes(self, floes: md.Floe | Sequence[md.Floe]): 

201 self.domain.add_floes(floes) 

202 self._save_step() 

203 

204 def _find_fracture_indices(self) -> np.ndarray[tuple[Any, ...], np.dtype[np.int_]]: 

205 """Find the indices of states immediately before fracture. 

206 

207 Returns 

208 ------- 

209 1D array of int 

210 The indices of the current timesteps corresponding to the states 

211 that broke on the next iteration. 

212 

213 """ 

214 _t = [len(step.subdomains) for step in self.history.values()] 

215 return np.nonzero(np.ediff1d(_t))[0] 

216 

217 def get_pre_fracture_times(self) -> np.ndarray: 

218 """Return the times corresponding to states immediately after fracture. 

219 

220 These can be used to index `self.history`. 

221 

222 Returns 

223 ------- 

224 1D array 

225 Output times. 

226 

227 """ 

228 return self.timesteps[self._find_fracture_indices()] 

229 

230 def get_post_fracture_times(self) -> np.ndarray: 

231 """Return the times corresponding to states immediately after fracture. 

232 

233 These can be used to index `self.history`. 

234 Note: in these states, the waves have been advected, compared to the 

235 corresponding pre-fracture states. 

236 

237 Returns 

238 ------- 

239 1D array 

240 Output times. 

241 

242 """ 

243 return self.timesteps[self._find_fracture_indices() + 1] 

244 

245 def get_final_state(self) -> Step: 

246 """Return the final state of the experiment. 

247 

248 Returns 

249 ------- 

250 Step 

251 The `Step` corresponding to the last timestep. 

252 

253 """ 

254 return self.history[next(reversed(self.history))] 

255 

256 def _save_step(self): 

257 self.history[self.time] = Step( 

258 tuple(wuf.make_copy() for wuf in self.domain.subdomains), 

259 ( 

260 (self.domain.growth_params[0].copy(), self.domain.growth_params[1]) 

261 if self.domain.growth_params is not None 

262 else None 

263 ), 

264 ) 

265 

266 def step( 

267 self, 

268 delta_time: float, 

269 an_sol: bool | None = None, 

270 num_params: dict | None = None, 

271 ): 

272 """Move the experiment forward in time. 

273 

274 On step is a succession of events. First, the current floes are scanned 

275 for fractures. The domain is eventually updated with the newly formed 

276 fragments replacing the fractured floes. Then, the actual time 

277 progression happens, by updating the wave phases at the edge of every 

278 individual floe. Finally, this new state is saved to the history, at 

279 the index corresponding to the updated time. 

280 

281 Parameters 

282 ---------- 

283 delta_time : float 

284 The time increment in second. 

285 an_sol : bool, optional 

286 Whether to force the use of a numerical or analytical solution for 

287 the deflection of the floes. 

288 num_params : dict, optional 

289 Optional parameters to pass to the numerical solver, if applicable. 

290 

291 """ 

292 self.domain.breakup(self.fracture_handler, an_sol, num_params) 

293 self.domain.iterate(delta_time) 

294 self.time += delta_time 

295 self._save_step() 

296 

297 def get_states(self, times: np.ndarray | float) -> dict[float, Step]: 

298 """Return a subset of the history matching the given times. 

299 

300 Parameters 

301 ---------- 

302 times : 1D array_like, float 

303 Time, or sequence of times. 

304 

305 Returns 

306 ------- 

307 dict[float, Step] 

308 A dictionary containing the `Step`s closest to the input `times`. 

309 

310 """ 

311 times = np.ravel(times) # ensure we have exactly a 1D array 

312 timestep_keys = _dct_keys_to_array(self.history) 

313 indexes = (np.abs(times - timestep_keys[:, None])).argmin(axis=0) 

314 return {k: self.history[k] for k in timestep_keys[indexes]} 

315 

316 def get_states_strict(self, times: np.ndarray | float) -> dict[float, Step]: 

317 """Return a subset of the history matching the given times. 

318 

319 Parameters 

320 ---------- 

321 times : np.ndarray | float 

322 Time, or sequence of times. 

323 

324 Returns 

325 ------- 

326 dict[float, Step] 

327 A dictionary containing the `Step`s matching exactly the input. 

328 

329 """ 

330 times, sort_idx = np.unique(np.ravel(times), return_index=True) 

331 timestep_keys = _dct_keys_to_array(self.history) 

332 # We `unsort' the output of np.unique with the index `sort_idx`, 

333 # so that values are returned in the order they were passed. 

334 filtered_times = times[ 

335 np.isin(times[np.argsort(sort_idx)], timestep_keys, assume_unique=True) 

336 ] 

337 return {_time: self.history[_time] for _time in filtered_times} 

338 

339 def _time_interval_str(self): 

340 first_time = next(iter(self.history)) 

341 return f"{first_time:.3f}--{self.time:.3f}" 

342 

343 def _generate_name(self, prefix: str | None) -> str: 

344 if prefix is None: 

345 prefix = f"{id(self):x}" 

346 return prefix + f"_v{__about__.__version__}_" + self._time_interval_str() 

347 

348 def _dump(self, prefix: str | None, dir_path: pathlib.Path): 

349 fname = f"{self._generate_name(prefix)}.pickle" 

350 dir_path = _create_path(dir_path) 

351 full_path = dir_path.joinpath(fname) 

352 with open(full_path, "bw") as file: 

353 pickle.dump(self, file) 

354 

355 def _clean_history(self): 

356 current_state = self.get_final_state() 

357 self.history.clear() 

358 self.history[self.time] = current_state 

359 

360 def dump_history( 

361 self, 

362 prefix: str | None = None, 

363 dir_path: str | pathlib.Path | None = None, 

364 ): 

365 """Write the results to disk and clear the history. 

366 

367 The whole object is pickled, before emptying the current history from 

368 memory. The filename is constructed with the `prefix` passed as 

369 argument, the package version number, and the time interval covered by 

370 the history. 

371 

372 Parameters 

373 ---------- 

374 prefix : str | None 

375 Prefix for the file name. If none is provided, defaults to the `id` 

376 of the `Experiment` object. 

377 

378 """ 

379 _dir_path = _str_to_path(dir_path) 

380 self._dump(prefix, _dir_path) 

381 self._clean_history() 

382 

383 def _should_terminate( 

384 self, 

385 initial_number_of_fragments: int, 

386 number_of_fragments: int, 

387 time_since_fracture: float, 

388 break_time: float | None, 

389 ): 

390 return ( 

391 break_time is not None 

392 and number_of_fragments > initial_number_of_fragments 

393 and time_since_fracture > break_time 

394 ) 

395 

396 @typing.overload 

397 def run( 

398 self, 

399 time: float, 

400 delta_time: float, 

401 break_time: float | None = ..., 

402 chunk_size: int | None = ..., 

403 verbose: None = ..., 

404 pbar: _ProgressBarProtocol | None = ..., 

405 path: str | pathlib.Path | None = ..., 

406 dump_final: bool = ..., 

407 dump_prefix: str | None = ..., 

408 ): ... 

409 

410 @typing.overload 

411 def run( 

412 self, 

413 time: float, 

414 delta_time: float, 

415 break_time: float | None = ..., 

416 chunk_size: int | None = ..., 

417 verbose: int = ..., 

418 pbar: _VerboseProgressBarProtocol | None = ..., 

419 path: str | pathlib.Path | None = ..., 

420 dump_final: bool = ..., 

421 dump_prefix: str | None = ..., 

422 ): ... 

423 

424 def run( 

425 self, 

426 time: float, 

427 delta_time: float, 

428 break_time: float | None = None, 

429 chunk_size: int | None = None, 

430 verbose: int | None = None, 

431 pbar: _ProgressBarProtocol | None = None, 

432 path: str | pathlib.Path | None = None, 

433 dump_final: bool = True, 

434 dump_prefix: str | None = None, 

435 ): 

436 """Run the experiment for a specified duration. 

437 

438 The experiment is run from its current time for a duration 

439 corresponding to `time`, with states regularly spaced with step 

440 `delta_time`. If `time` is not an integer multiple of `delta_time`, the 

441 number of steps will be rounded up. The experiment can optionally be 

442 stopped before `time`, if no fracture happens for `break_time`, and at 

443 least one fracture has occured. 

444 

445 The current object can be saved at regularly spaced step intervals, as 

446 specified by `chunk_size`. 

447 

448 Optional messages can be printed to stdout, with a verbosity level 

449 controlled by `verbose`. 

450 

451 A progress bar can be passed as an optional parameter to monitor the 

452 experiment. The implementation expect an objects that behaves as a 

453 `tqdm` bar; in particular, it needs to expose `update` and close 

454 `method`. If used conjonctly with `verbose`, it also needs to expose a 

455 `write` method. 

456 

457 Parameters 

458 ---------- 

459 time : float 

460 Duration to run the experiment for, in seconds. 

461 delta_time : float 

462 Time step between iterations, in seconds. 

463 break_time : float | None 

464 Time before stopping the experiment if no fracture occurs, in seconds. 

465 chunk_size : int | None 

466 Number of steps before writing the results to a file. 

467 verbose : int | None 

468 Verbosity level. If 1, outputs for disk writes. If 2, additional 

469 outputs for fractures. 

470 pbar : progress bar | None 

471 Progress bar monitoring the experiment. 

472 path : str | pathlib.Path | None 

473 Directory where files will be saved. If none is provided, files 

474 will be saved in the current directory. 

475 dump_final : bool 

476 Whether the results should be saved to disk at the end of the run 

477 by calling `dump_history`, thus clearing the history from memory. 

478 dump_prefix : str | None 

479 Prefix for the file names used in the dumps. If none is provided, 

480 defaults to the `id` of the `Experiment` object. 

481 

482 """ 

483 

484 def pbar_print(msg: str, pbar: _VerboseProgressBarProtocol | None): 

485 if pbar is not None: 

486 pbar.write(msg) 

487 else: 

488 logger.info(msg) 

489 

490 @typing.overload 

491 def dump_and_print( 

492 dump_prefix: str | None, 

493 path: str | pathlib.Path | None, 

494 verbose: None, 

495 pbar: _ProgressBarProtocol | None, 

496 ): ... 

497 

498 @typing.overload 

499 def dump_and_print( 

500 dump_prefix: str | None, 

501 path: str | pathlib.Path | None, 

502 verbose: int, 

503 pbar: _VerboseProgressBarProtocol | None, 

504 ): ... 

505 

506 def dump_and_print( 

507 dump_prefix, 

508 path, 

509 verbose, 

510 pbar, 

511 ): 

512 self.dump_history(dump_prefix, path) 

513 if verbose is not None and verbose >= 1: 

514 msg = f"t = {self.time:.3f} s; history dumped" 

515 pbar_print(msg, pbar) 

516 

517 initial_number_of_fragments = len(self.domain.subdomains) 

518 number_of_fragments = initial_number_of_fragments 

519 number_of_steps = np.ceil(time / delta_time).astype(int) 

520 time_since_fracture = 0.0 

521 if chunk_size is not None: 

522 modulo_target = chunk_size - 1 

523 

524 for i in range(number_of_steps): 

525 self.step(delta_time) 

526 new_nof = len(self.domain.subdomains) 

527 if new_nof > number_of_fragments: 

528 time_since_fracture = 0 

529 number_of_fragments = new_nof 

530 if verbose is not None and verbose >= 2: 

531 msg = f"t = {self.time:.3f} s; N_f = {number_of_fragments}" 

532 pbar_print(msg, pbar) 

533 else: 

534 time_since_fracture += delta_time 

535 

536 if chunk_size is not None: 

537 if i % chunk_size == modulo_target: 

538 dump_and_print(dump_prefix, path, verbose, pbar) 

539 

540 if self._should_terminate( 

541 initial_number_of_fragments, 

542 number_of_fragments, 

543 time_since_fracture, 

544 break_time, 

545 ): 

546 msg = f"No fracture in {break_time:.3f} s, stopping" 

547 pbar_print(msg, pbar) 

548 break 

549 

550 if pbar is not None: 

551 pbar.update(1) 

552 

553 if pbar is not None: 

554 pbar.close() 

555 

556 # If single item in history, either we just dump, either we did not 

557 # step, and there is no need to dump again. 

558 if dump_final and len(self.history) > 1: 

559 # No `pbar` passed as it should have been closed 

560 dump_and_print(dump_prefix, path, verbose, None)