grib2io.xarray_backend
grib2io Backend Engine for Xarray
grib2io provides a Xarray backend entrypoint for decoding many GRIB2 messages from a single file or many files and represented as Xarray DataArray objects and collected along common coordinates as Datasets and DataTrees.
The grib2io.xarray_backend engine API is experimental.
Its interface and behavior may change in future releases,
which could affect backward compatibility.
Users are encouraged to treat this backend as subject to change
and to pin their grib2io version if depending on its current
implementation details.
1""" 2grib2io Backend Engine for Xarray 3================================= 4grib2io provides a Xarray backend entrypoint for decoding many GRIB2 messages 5from a single file or many files and represented as Xarray DataArray objects and 6collected along common coordinates as Datasets and DataTrees. 7 8.. warning:: 9 10 The ``grib2io.xarray_backend`` engine API is **experimental**. 11 Its interface and behavior may change in future releases, 12 which could affect backward compatibility. 13 14 Users are encouraged to treat this backend as subject to change 15 and to pin their ``grib2io`` version if depending on its current 16 implementation details. 17""" 18from grib2io._grib2io import _data 19from grib2io import Grib2Message, Grib2GridDef, msgs_from_index 20import grib2io 21from xarray.backends.locks import SerializableLock 22from xarray.core import indexing 23from xarray.backends import ( 24 BackendArray, 25 BackendEntrypoint, 26) 27from copy import copy 28from collections import defaultdict 29from dataclasses import dataclass, field, astuple 30import importlib.metadata 31import itertools 32import logging 33import typing 34import warnings 35 36from . import tables 37 38import numpy as np 39import pandas as pd 40import xarray as xr 41import re 42from pyproj import CRS 43import datetime 44 45# Check if xarray version supports DataTree 46_HAS_DATATREE = False 47try: 48 # Try importing DataTree to check if it's available 49 xarray_version = importlib.metadata.version('xarray') 50 xarray_parts = [int(x) if x.isdigit() else x for x in xarray_version.split('.')] 51 min_version_parts = [2024, 10, 0] 52 _HAS_DATATREE = xarray_parts >= min_version_parts 53except (ImportError, ValueError): 54 _HAS_DATATREE = False 55 56_logger = logging.getLogger(__name__) 57 58_LOCK = SerializableLock() 59 60_LEVEL_NAME_MAPPING = grib2io.tables.get_table('4.5.grib2io.level.name') 61 62_TREE_HIERARCHY_LEVELS = [ 63 "typeOfFirstFixedSurface", 64 "valueOfFirstFixedSurface", 65 "productDefinitionTemplateNumber", 66 "perturbationNumber", 67 "leadTime", 68 "duration", 69 "percentileValue", 70 "typeOfProbability", 71 "thresholdLowerLimit", 72 "thresholdUpperLimit" 73] 74 75AVAILABLE_NON_GEO_COORDS = [ 76 "duration", 77 "leadTime", 78 "percentileValue", 79 "perturbationNumber", 80 "refDate", 81 "thresholdLowerLimit", 82 "thresholdUpperLimit", 83 "valueOfFirstFixedSurface", 84 "valueOfSecondFixedSurface", 85 "aerosolType", 86 "scaledValueOfFirstWavelength", 87 "scaledValueOfSecondWavelength", 88 "scaledValueOfCentralWaveNumber", 89 "scaledValueOfFirstSize", 90 "scaledValueOfSecondSize" 91] 92"""Available non-geographic coordinate names.""" 93 94AVAILABLE_NON_GEO_DIMS = [ 95 "duration", 96 "leadTime", 97 "percentileValue", 98 "perturbationNumber", 99 "refDate", 100 "threshold", 101 "level", 102] 103"""Available non-geographic dimension names.""" 104 105# Lookup table to define surface types that should be parsed as vertical coordinates 106VERTICAL_COORDINATE_SURFACES = [ 107 "Ground or Water Surface", 108 "Isothermal Level", 109 "Specified radius from the centre of the Sun", 110 "Isobaric Surface", 111 "Mean Sea Level", 112 "Specific Altitude Above Mean Sea Level", 113 "Specified Height Level Above Ground", 114 "Sigma Level", 115 "Hybrid Level", 116 "Depth Below Land Surface", 117 "Isentropic (theta) Level", 118 "Level at Specified Pressure Difference from Ground to Level", 119 "Potential Vorticity Surface", 120 "Eta Level", 121 "Logarithmic Hybrid Level", 122 "Sigma height level", 123 "Hybrid Height Level", 124 "Hybrid Pressure Level", 125 "Soil level", 126 "Sea-ice level", 127 "Depth Below Sea Level", 128 "Depth Below Water Surface", 129 "Ocean Model Level", 130 "Ocean level defined by water density (sigma-theta) difference from near-surface to level", 131 "Ocean level defined by water potential temperature difference from near-surface to level", 132 "Ocean level defined by vertical eddy diffusivity difference from near-surface to level", 133 "Ocean level defined by water density (rho) difference from near-surface to level" 134] 135""" 136Lookup table to define surface types that should be parsed as vertical coordinates 137when `data_model="nws-viz"`. 138""" 139 140def parse_data_model(ds, data_model): 141 """ 142 Normalize a GRIB2-derived Dataset to a target data model (currently ``"nws-viz"``). 143 144 When ``data_model == "nws-viz"``, this function converts coordinate and 145 variable names to snake_case, derives CF-like metadata, promotes select 146 GRIB-derived quantities to coordinates, optionally swaps dimensions, and 147 standardizes units/attributes. If ``data_model`` is anything else, the 148 input dataset is returned unchanged. 149 150 Parameters 151 ---------- 152 ds : xarray.Dataset 153 GRIB2-derived dataset whose variables and attributes follow the 154 conventions emitted by ``grib2io``. Expected to contain GRIB-related 155 attributes such as ``typeOfFirstFixedSurface``, 156 ``typeOfSecondFixedSurface``, and (for probabilistic variables) 157 ``typeOfProbability``. 158 data_model : str 159 Target data model name. Only the value ``"nws-viz"`` triggers 160 transformations. 161 162 Returns 163 ------- 164 xarray.Dataset 165 A new dataset with: 166 * Selected coordinates renamed: 167 ``refDate -> forecast_reference_time``, 168 ``leadTime -> lead_time``, 169 ``validDate -> time``, 170 ``percentileValue -> percentile``, 171 ``thresholdLowerLimit -> threshold_lower_limit``, 172 ``thresholdUpperLimit -> threshold_upper_limit``. 173 * Vertical coordinates derived from 174 ``valueOfFirstFixedSurface`` / ``valueOfSecondFixedSurface`` and their 175 corresponding ``typeOf*FixedSurface`` definitions. New coordinate 176 names are generated from the surface definition (lowercased, spaces 177 to underscores, punctuation removed). If the name already exists, a 178 ``"_2"`` suffix is appended. 179 * Possible dimension swaps: 180 ``level -> <derived_vertical_coord>`` when present; and for 181 probabilistic variables, ``threshold -> threshold_lower_limit`` or 182 ``threshold -> threshold_upper_limit`` when 183 ``typeOfProbability`` indicates the appropriate semantics. 184 * Variable names lowercased; dataset- and variable-level attributes 185 converted to snake_case (except GRIB section attributes which are 186 normalized to ``grib...``). 187 * CF-adjacent metadata populated: ``standard_name`` and 188 ``cell_methods`` are set via the shortname→CF lookup table. 189 * Percent units normalized from ``"%"`` to ``"percent"`` on coordinates. 190 * For precipitation type (``PTYPE``) thresholds, numeric codes are 191 decoded to strings (GRIB2 Table 4.201) in relevant attrs/coords. 192 193 Notes 194 ----- 195 - Precipitation type decoding uses GRIB2 Table 4.201 via 196 ``tables.get_value_from_table(code, "4.201")`` and returns a NumPy 197 array with ``np.dtypes.StringDType``. 198 - CF-related lookups are performed using 199 ``tables.get_table("shortname_to_cf")``. 200 - Vertical coordinate surface names are validated against 201 ``VERTICAL_COORDINATE_SURFACES`` before promotion to coordinates. 202 203 Warnings 204 -------- 205 This function assumes the presence of certain GRIB-derived attributes on the 206 first data variable (e.g., ``typeOfFirstFixedSurface``, 207 ``typeOfSecondFixedSurface``, and possibly ``typeOfProbability``). 208 If these are absent or malformed, errors (e.g., ``KeyError``) may occur. 209 210 Examples 211 -------- 212 >>> ds2 = parse_data_model(ds, "nws-viz") 213 >>> list(ds2.coords) 214 ['forecast_reference_time', 'lead_time', 'time', 'percentile', ...] 215 """ 216 217 def _decode_ptype(values): 218 """ 219 Decode precipitation type values into human-readable strings. 220 221 Uses GRIB2 Table 4.201 to map numeric codes to precipitation type descriptions. 222 223 Parameters 224 ---------- 225 values : array_like 226 Array of numeric precipitation type codes (e.g., integers or floats). 227 Each value corresponds to a GRIB2 Table 4.201 precipitation type code. 228 229 Returns 230 ------- 231 numpy.ndarray 232 Array of decoded precipitation type strings with 233 NumPy’s flexible string data type (`np.dtypes.StringDType`). 234 """ 235 results = [] 236 for val in values: 237 # Convert each numeric code to string and look up in Table 4.201 238 results.append(str(tables.get_value_from_table(str(int(val)), '4.201'))) 239 240 # Return array of strings using numpy's string data type 241 return np.array(results, dtype=np.dtypes.StringDType) 242 243 # convert coordinates and attributes to CF if requested 244 if data_model == 'nws-viz': 245 246 # define regex to convert to snake case 247 pattern = re.compile(r'(?<!^)(?=[A-Z])') 248 249 # check for coordinates and rename 250 for coord in ds.coords: 251 if coord == 'refDate': 252 ds = ds.rename({'refDate': 'forecast_reference_time'}) 253 254 elif coord == 'leadTime': 255 ds = ds.rename({'leadTime': 'lead_time'}) 256 257 elif coord == 'validDate': 258 ds = ds.rename({'validDate': 'time'}) 259 260 elif coord == 'percentileValue': 261 ds = ds.rename({'percentileValue': 'percentile'}) 262 263 elif coord == 'thresholdLowerLimit': 264 ds = ds.rename({'thresholdLowerLimit': 'threshold_lower_limit'}) 265 ds['threshold_lower_limit'].attrs['long_name'] = 'Threshold Lower Limit' 266 ds['threshold_lower_limit'].attrs['units'] = ds[list(ds.data_vars.keys())[0]].attrs['units'] 267 268 if 'PTYPE' in ds.data_vars: 269 ds['threshold_lower_limit'] = xr.apply_ufunc(_decode_ptype, ds['threshold_lower_limit']) 270 271 # check if thresholdLowerLimit should be a dimension coordinate 272 if 'threshold' in ds.dims: 273 var_key = list(ds.data_vars.keys())[0] 274 prob_types = [ 275 'Probability of event below lower limit', 276 'Probability of event above lower limit', 277 'Probability of event equal to lower limit', 278 'Probability of event between upper and lower limits (the range includes lower limit but not the upper limit)' 279 ] 280 if ds[var_key].attrs['typeOfProbability'] in prob_types: 281 ds = ds.swap_dims({'threshold': 'threshold_lower_limit'}) 282 283 elif coord == 'thresholdUpperLimit': 284 ds = ds.rename({'thresholdUpperLimit': 'threshold_upper_limit'}) 285 ds['threshold_upper_limit'].attrs['long_name'] = 'Threshold Upper Limit' 286 ds['threshold_upper_limit'].attrs['units'] = ds[list(ds.data_vars.keys())[0]].attrs['units'] 287 288 if 'PTYPE' in ds.data_vars: 289 ds['threshold_upper_limit'] = xr.apply_ufunc(_decode_ptype, ds['threshold_upper_limit']) 290 291 if 'threshold' in ds.dims: 292 var_key = list(ds.data_vars.keys())[0] 293 prob_types = [ 294 'Probability of event below upper limit', 295 'Probability of event above upper limit' 296 ] 297 if ds[var_key].attrs['typeOfProbability'] in prob_types: 298 ds = ds.swap_dims({'threshold': 'threshold_upper_limit'}) 299 300 # If the dataset has valueOfFirstFixedSurface as a coordinate 301 elif coord == 'valueOfFirstFixedSurface': 302 # Get the valueOfFirstFixedSurface coordinate 303 da = ds.valueOfFirstFixedSurface 304 305 # Get the definition and units from typeOfFirstFixedSurface 306 var_key = list(ds.data_vars.keys())[0] 307 definition, units = ds[var_key].attrs['typeOfFirstFixedSurface'] 308 309 if definition in VERTICAL_COORDINATE_SURFACES: 310 # Convert definition to lowercase and replace spaces with underscores 311 key = definition.lower().replace(' ', '_') 312 313 # remove special characters 314 key = re.sub(r'[^a-z0-9_]', '', key) 315 316 # Add units and grib_name attributes 317 da.attrs['units'] = units 318 da.attrs['grib_name'] = ['valueOfFirstFixedSurface', 'typeOfFirstFixedSurface'] 319 320 # Assign the coordinate with the new key name 321 ds = ds.assign_coords({key: da}) 322 323 # If valueOfFirstFixedSurface is a dimension, swap it with the new key 324 if 'level' in ds.dims: 325 ds = ds.swap_dims({"level": key}) 326 327 # Remove the original coordinates 328 del ds['valueOfFirstFixedSurface'] 329 330 # If the dataset has valueOfSecondFixedSurface as a coordinate 331 elif coord == 'valueOfSecondFixedSurface': 332 # Get the valueOfSecondFixedSurface coordinate 333 da = ds.valueOfSecondFixedSurface 334 335 # Get the definition and units from typeOfSecondFixedSurface 336 var_key = list(ds.data_vars.keys())[0] 337 definition, units = ds[var_key].attrs['typeOfSecondFixedSurface'] 338 339 if definition in VERTICAL_COORDINATE_SURFACES: 340 # Convert definition to lowercase and replace spaces with underscores 341 key = definition.lower().replace(' ', '_') 342 343 # remove special characters 344 key = re.sub(r'[^a-z0-9_]', '', key) 345 346 # check if key is already in coords 347 if key in ds.coords: 348 key = key + '_2' 349 350 # Add units and grib_name attributes 351 da.attrs['units'] = units 352 da.attrs['grib_name'] = ['valueOfSecondFixedSurface', 'typeOfSecondFixedSurface'] 353 354 # Assign the coordinate with the new key name 355 ds = ds.assign_coords({key: da}) 356 357 # Remove the original coordinates 358 del ds['valueOfSecondFixedSurface'] 359 else: 360 # change coord name to snake case 361 new_coord_name = pattern.sub('_', coord).lower() 362 ds = ds.rename({coord: new_coord_name}) 363 364 # convert all attributes and variable names to snake case 365 for var in ds.data_vars: 366 da = ds[var] 367 record = tables.get_table('shortname_to_cf').get(da.name) 368 da.attrs['standard_name'] = 'unknown' if record is None else record['cf_standard_name'] 369 da.attrs['cell_methods'] = 'unknown' if record is None else record['cf_cell_methods'] 370 371 ds[var] = da 372 373 # rename variable 374 new_var_name = var.lower() 375 ds = ds.rename({var: new_var_name}) 376 377 # remove attr for typeOfFirstFixedSurface (applied as coordinate above) 378 if 'typeOfFirstFixedSurface' in ds[new_var_name].attrs: 379 definition, units = ds[new_var_name].attrs['typeOfFirstFixedSurface'] 380 ds[new_var_name].attrs['typeOfFirstFixedSurface'] = f'{definition} ({units})' 381 382 if 'typeOfSecondFixedSurface' in ds[new_var_name].attrs: 383 definition, units = ds[new_var_name].attrs['typeOfSecondFixedSurface'] 384 ds[new_var_name].attrs['typeOfSecondFixedSurface'] = f'{definition} ({units})' 385 386 ds[new_var_name].attrs.pop('percentileValue', None) 387 388 if 'threshold_lower_limit' in ds.coords: 389 ds[new_var_name].attrs.pop('thresholdLowerLimit', None) 390 391 if 'threshold_upper_limit' in ds.coords: 392 ds[new_var_name].attrs.pop('thresholdUpperLimit', None) 393 394 for attr in list(ds[new_var_name].attrs.keys()): 395 # skip grib section attrs 396 if 'GRIB2IO_section' in attr: 397 # replace GRIB2IO with grib in attr 398 new_attr_name = attr.replace('GRIB2IO', 'grib') 399 else: 400 # change attr name to snake case 401 new_attr_name = pattern.sub('_', attr).lower() 402 403 # update new attr name for specific CF names 404 if new_attr_name == 'full_name': 405 new_attr_name = 'long_name' 406 407 # change % to percent 408 if attr == 'units' and ds[new_var_name].attrs[attr] == '%': 409 ds[new_var_name].attrs[attr] = 'percent' 410 411 if new_var_name == 'ptype' and 'threshold' in new_attr_name: 412 value = ds[new_var_name].attrs.pop(attr) 413 ds[new_var_name].attrs[attr] = _decode_ptype(value) 414 else: 415 # change attr name in attrs 416 ds[new_var_name].attrs[new_attr_name] = ds[new_var_name].attrs.pop(attr) 417 418 419 # change dataset attrs to snake case 420 for attr in list(ds.attrs.keys()): 421 # change attr name to snake case 422 new_attr_name = pattern.sub('_', attr).lower() 423 424 # change attr name in attrs 425 ds.attrs[new_attr_name] = ds.attrs.pop(attr) 426 427 # change % to percent 428 for coord in ds.coords: 429 if 'units' in ds[coord].attrs and ds[coord].attrs['units'] == '%': 430 ds[coord].attrs['units'] = 'percent' 431 432 return ds 433 434 435class GribBackendEntrypoint(BackendEntrypoint): 436 """ 437 xarray backend engine entrypoint for opening and decoding grib2 files. 438 439 .. warning:: 440 441 This backend is experimental and the API/behavior may change without 442 backward compatibility. 443 """ 444 445 def open_dataset( 446 self, 447 filename, 448 *, 449 drop_variables=None, 450 save_index=True, 451 filters: typing.Mapping[str, typing.Any] = dict(), 452 data_model=None 453 ): 454 """ 455 Read and parse metadata from grib file. 456 457 Parameters 458 ---------- 459 filename 460 GRIB2 file to be opened. 461 filters 462 Filter GRIB2 messages to single hypercube. Dict keys can be any 463 GRIB2 metadata attribute name. 464 data_model 465 Parse GRIB metadata following a defined data model comvention. 466 467 Returns 468 ------- 469 open_dataset 470 Xarray dataset of grib2 messages. 471 """ 472 with grib2io.open(filename, save_index=save_index, _xarray_backend=True) as f: 473 file_index = pd.DataFrame(f._index) 474 file_index = file_index.assign(msg=msgs_from_index(f._index)) 475 476 # parse grib2io _index to dataframe and acquire non-geo possible dims 477 # (scalar coord when not dim due to squeeze) parse_grib_index applies 478 # filters to index and expands metadata based on product definition 479 # template number 480 file_index, dim_coords, attrs, coord_attrs = parse_grib_index(file_index, filters) 481 482 # Divide up records by variable 483 frames, cube, extra_geo = make_variables(file_index, filename, dim_coords) # have this return var_attrs 484 485 # return empty dataset if no data 486 if frames is None: 487 return xr.Dataset() 488 489 # create dataframe and add datarrays without any coords 490 ds = xr.Dataset() 491 for var_df in frames: 492 da = build_da_without_coords(var_df, cube, filename, attrs) 493 ds[da.name] = da 494 495 # add coords and dataset meta 496 ds = assign_xr_meta(ds, frames, cube, dim_coords, extra_geo, coord_attrs) 497 498 if data_model is not None: 499 ds = parse_data_model(ds, data_model) 500 501 # assign attributes 502 ds.attrs['engine'] = 'grib2io' 503 504 return ds 505 506 def open_datatree( 507 self, 508 filename, 509 *, 510 drop_variables=None, 511 save_index=True, 512 filters: typing.Mapping[str, typing.Any] = None, 513 stack_vertical: bool = False, 514 ): 515 """ 516 Open a GRIB2 file as an xarray DataTree. 517 518 Parameters 519 ---------- 520 filename : str 521 Path to the GRIB2 file. 522 drop_variables : list, optional 523 List of variables to exclude. 524 filters : dict, optional 525 Filter criteria for GRIB2 messages. 526 stack_vertical : bool, optional 527 If True, organize the tree with vertical layers stacked in a single dataset. 528 529 Returns 530 ------- 531 xarray.DataTree 532 A hierarchical DataTree representation of the GRIB2 data. 533 """ 534 if not _HAS_DATATREE: 535 raise ImportError("xarray version does not support DataTree functionality.") 536 537 if filters is None: 538 filters = {} 539 540 # Open the file without any filters first to get all messages 541 with grib2io.open(filename, save_index=save_index, _xarray_backend=True) as f: 542 file_index = pd.DataFrame(f._index) 543 file_index = file_index.assign(msg=msgs_from_index(f._index)) 544 545 # Build tree structure from GRIB messages with specified options 546 tree = build_datatree_from_grib(filename, file_index, filters, stack_vertical=stack_vertical) 547 548 # Put warning here so it is the last message from likely other Xarray warnings. 549 warnings.warn( 550 "grib2io’s xarray backend DataTree support is experimental. " 551 "The DataTree structure or attributes may change in future releases.", 552 UserWarning, 553 stacklevel=2, 554 ) 555 556 return tree 557 558 559class GribBackendArray(BackendArray): 560 561 def __init__(self, array, lock): 562 self.array = array 563 self.shape = array.shape 564 self.dtype = np.dtype(array.dtype) 565 self.lock = lock 566 567 def __getitem__(self, key: xr.core.indexing.ExplicitIndexer) -> np.typing.ArrayLike: 568 return xr.core.indexing.explicit_indexing_adapter( 569 key, 570 self.shape, 571 indexing.IndexingSupport.BASIC, 572 self._raw_getitem, 573 ) 574 575 def _raw_getitem(self, key: tuple): 576 """Implement thread safe access to data on disk.""" 577 with self.lock: 578 return self.array[key] 579 580 581def exclusive_slice_to_inclusive(item: slice): 582 """ 583 Convert a slice with exclusive stop to an inclusive slice. 584 585 If the slice has a step, the stop is reduced by the step, so that both 586 interpretations would yield the same result. 587 588 The means that [start, stop) is converted to [start, stop - step]. 589 590 Parameters 591 ---------- 592 item 593 The slice to convert. 594 595 Returns 596 ------- 597 slice 598 The converted slice. 599 """ 600 # return the None slice 601 if item.start is None and item.stop is None and item.step is None: 602 return item 603 if not isinstance(item, slice): 604 raise ValueError(f'item must be a slice; it was of type {type(item)}') 605 # if step is None, it's one 606 step = 1 if item.step is None else item.step 607 if item.stop < item.start or step < 1: 608 raise ValueError(f'slice {item} not accounted for') 609 # handle case where slice has one item 610 if abs(item.stop - item.start) == step: 611 return [item.start] 612 # other cases require reducing the stop by the step 613 s = slice(item.start, item.stop - step, step) 614 return s 615 616 617class Validator: 618 def __set_name__(self, owner, name): 619 self.private_name = f'_{name}' 620 self.name = name 621 622 def __get__(self, obj, objtype=None): 623 try: 624 value = getattr(obj, self.private_name) 625 except AttributeError: 626 value = None 627 return value 628 629 630class PdIndex(Validator): 631 632 def __set__(self, obj, value): 633 try: 634 value = pd.Index(value) 635 except TypeError: 636 value = pd.Index([value]) 637 setattr(obj, self.private_name, value) 638 639 640def _asarray_tuplesafe(values): 641 """ 642 Convert values to a numpy array of at most 1-dimension and preserve tuples. 643 644 Adapted from pandas.core.common._asarray_tuplesafe 645 """ 646 if isinstance(values, tuple): 647 result = np.empty(1, dtype=object) 648 result[0] = values 649 else: 650 result = np.asarray(values) 651 if result.ndim == 2: 652 result = np.empty(len(values), dtype=object) 653 result[:] = values 654 655 return result 656 657 658def array_safe_eq(a, b) -> bool: 659 """Check if a and b are equal, even if they are numpy arrays.""" 660 if a is b: 661 return True 662 if hasattr(a, 'equals'): 663 return a.equals(b) 664 if hasattr(a, 'all') and hasattr(b, 'all'): 665 return a.shape == b.shape and (a == b).all() 666 if hasattr(a, 'all') or hasattr(b, 'all'): 667 return False 668 try: 669 return a == b 670 except TypeError: 671 return NotImplementedError 672 673 674def dc_eq(dc1, dc2) -> bool: 675 """Check if two dataclasses which hold numpy arrays are equal.""" 676 if dc1 is dc2: 677 return True 678 if dc1.__class__ is not dc2.__class__: 679 return NotImplementedError 680 t1 = astuple(dc1) 681 t2 = astuple(dc2) 682 return all(array_safe_eq(a1, a2) for a1, a2 in zip(t1, t2)) 683 684 685def coords_from_cube(cube) -> typing.Dict[str, xr.Variable]: 686 keys = list(cube.keys()) 687 keys.remove('x') 688 keys.remove('y') 689 coords = dict() 690 for k in keys: 691 if k is not None: 692 if len(cube[k]) > 1: 693 coords[k] = xr.Variable(dims=k, data=cube[k], attrs=dict(grib_name=k)) 694 elif len(cube[k]) == 1: 695 coords[k] = xr.Variable(dims=tuple(), data=cube[k][0], attrs=dict(grib_name=k)) 696 return coords 697 698 699@dataclass 700class OnDiskArray: 701 file_name: str 702 index: pd.DataFrame = field(repr=False) 703 cube: dict = field(repr=False) 704 shape: typing.Tuple[int, ...] = field(init=False) 705 ndim: int = field(init=False) 706 geo_ndim: int = field(init=False) 707 dtype = 'float32' 708 709 def __post_init__(self): 710 # multiple grids not allowed so can just use first 711 geo_shape = (self.index.iloc[0].ny, self.index.iloc[0].nx) 712 713 self.geo_shape = geo_shape 714 self.geo_ndim = len(geo_shape) 715 716 if len(self.index) == 1: 717 self.shape = geo_shape 718 else: 719 if self.index.index.nlevels == 1: 720 self.shape = tuple([len(self.index.index)]) + geo_shape 721 else: 722 self.shape = tuple([len(i) for i in self.index.index.levels]) + geo_shape 723 self.ndim = len(self.shape) 724 725 cols = ['msg', 'sectionOffset'] 726 self.index = self.index[cols] 727 728 def __getitem__(self, item) -> np.array: 729 # dimensions not in index are internal to tdlpack records; 2 dims for 730 # grids; 1 dim for stations 731 732 index_slicer = item[:-self.geo_ndim] 733 # maintain all multindex levels 734 index_slicer = tuple([[i] if isinstance(i, int) else i for i in index_slicer]) 735 736 # pandas loc slicing is inclusive, therefore convert slices into 737 # explicit lists 738 index_slicer_inclusive = tuple([exclusive_slice_to_inclusive( 739 i) if isinstance(i, slice) else i for i in index_slicer]) 740 741 # get records selected by item in new index dataframe 742 if len(index_slicer_inclusive) == 1: 743 index = self.index.loc[index_slicer_inclusive] 744 elif len(index_slicer_inclusive) > 1: 745 index = self.index.loc[index_slicer_inclusive, :] 746 else: 747 index = self.index 748 index = index.set_index(index.index) 749 750 # set miloc to new relative locations in sub array 751 index['miloc'] = list( 752 zip(*[index.index.unique(level=dim).get_indexer(index.index.get_level_values(dim)) for dim in index.index.names])) 753 754 if len(index_slicer_inclusive) == 1: 755 array_field_shape = tuple([len(index.index)]) + self.geo_shape 756 elif len(index_slicer_inclusive) > 1: 757 array_field_shape = index.index.levshape + self.geo_shape 758 else: 759 array_field_shape = self.geo_shape 760 761 array_field = np.full(array_field_shape, fill_value=np.nan, dtype="float32") 762 763 with open(self.file_name, mode='rb') as filehandle: 764 for key, row in index.iterrows(): 765 766 bitmap_offset = None if pd.isna(row['sectionOffset'][6]) else int(row['sectionOffset'][6]) 767 values = _data(filehandle, row.msg, bitmap_offset, row['sectionOffset'][7]) 768 769 if len(index_slicer_inclusive) >= 1: 770 array_field[row.miloc] = values 771 else: 772 array_field = values 773 774 # handle geo dim slicing 775 array_field = array_field[(Ellipsis,) + item[-self.geo_ndim:]] 776 777 # squeeze array dimensions expressed as integer 778 for i, it in reversed(list(enumerate(item[: -self.geo_ndim]))): 779 if isinstance(it, int): 780 array_field = array_field[(slice(None, None, None),) * i + (0,)] 781 782 return array_field 783 784 785def dims_to_shape(d) -> tuple: 786 if 'nx' in d: 787 t = (d['ny'], d['nx']) 788 else: 789 t = (d['nsta'],) 790 return t 791 792 793def filter_index(index, k, v): 794 if isinstance(v, slice): 795 index = index.set_index(k) 796 index = index.loc[v] 797 index = index.reset_index() 798 else: 799 label = ( 800 v 801 if getattr(v, "ndim", 1) > 1 # vectorized-indexing 802 else _asarray_tuplesafe(v) 803 ) 804 if label.ndim == 0: 805 # see https://github.com/pydata/xarray/pull/4292 for details 806 label_value = label[()] if label.dtype.kind in "mM" else label.item() 807 try: 808 indexer = pd.Index(index[k]).get_loc(label_value) 809 if isinstance(indexer, int): 810 index = index.iloc[[indexer]] 811 else: 812 index = index.iloc[indexer] 813 except KeyError: 814 index = index.iloc[[]] 815 else: 816 indexer = pd.Index(index[k]).get_indexer_for(np.ravel(v)) 817 index = index.iloc[indexer[indexer >= 0]] 818 819 return index 820 821 822def parse_grib_index( 823 index: pd.DataFrame, 824 filters: typing.Mapping[str, typing.Any] = dict(), 825): 826 """ 827 Apply filters. 828 829 Evaluate remaining dimensions based on pdtn and parse each out. 830 831 Parameters 832 ---------- 833 index 834 Pandas DataFrame containing the GRIB2 message index. 835 filters 836 Filter GRIB2 messages to single hypercube. Dict keys can be any 837 GRIB2 metadata attribute name. 838 839 Returns 840 ------- 841 index 842 Modified Pandas DataFrame with added GRIB2 metadata columns. 843 dim_coords 844 List of GRIB2 attributes that will be used for coordinates and/or dimensions. 845 attrs 846 Dict of metadata attributes (non-coordinates, non-geo) 847 """ 848 849 # make a copy of filters, remove filters as they are applied 850 filters = copy(filters) 851 852 for k, v in filters.items(): 853 if k not in index.columns: 854 kwarg = {k: index.msg.apply(lambda msg: getattr(msg, k))} 855 index = index.assign(**kwarg) 856 # adopt parts of xarray's sel logic so that filters behave similarly 857 # allowed to filter to nothing to make empty dataset 858 index = filter_index(index, k, v) 859 860 if len(index) == 0: 861 return index, list() 862 863 dim_coords = dict() # key=name of dim, value=list of coord names 864 attrs = dict() 865 coord_attrs = dict() 866 867 # expand index 868 index = index.assign(shortName=index.msg.apply(lambda msg: msg.shortName)) 869 index = index.assign(nx=index.msg.apply(lambda msg: msg.nx)) 870 index = index.assign(ny=index.msg.apply(lambda msg: msg.ny)) 871 index = index.astype({'ny': 'int', 'nx': 'int'}) 872 873 # apply common filters(to all definition templates) to reduce dataset to 874 # single cube 875 # ensure only one of each of the below exists after filters applied 876 required_uniques = [ 877 "productDefinitionTemplateNumber", 878 "typeOfGeneratingProcess", 879 "typeOfFirstFixedSurface", 880 "typeOfSecondFixedSurface", 881 ] 882 883 def meta_check(index, attrs, meta): 884 """ 885 add meta to the datframe index 886 check that there is a single type 887 add the type to attrs 888 889 returns index, attrs 890 """ 891 index = index.assign(**{meta: index.msg.apply(lambda msg: getattr(msg, meta))}) 892 893 unique = index[meta].unique() 894 if len(index[meta].unique()) > 1: 895 raise ValueError(f'filter to a single {meta}; found: {[str(i) for i in unique]}') 896 value = unique.item() 897 if type(value) == grib2io.templates.Grib2Metadata: 898 value = value.definition 899 900 # None is returned if no value found, 901 # check and change to string None 902 if value is None: 903 value = 'None' 904 905 attrs[meta] = value 906 return index, attrs 907 908 for meta in required_uniques: 909 index, attrs = meta_check(index, attrs, meta) 910 911 pdtn = index.productDefinitionTemplateNumber.iloc[0].value 912 913 # determine which non geo dimensions can be created from data by this point 914 # the index is filtered down to a single type for all required_uniques 915 916 # Dim Name # matching dim_name for using this data as index coordinate 917 dim_coords["refDate"] = ["refDate"] 918 coord_attrs["refDate"] = dict(standard_name="forecast_reference_time") 919# dim_coords["refDate"] = ["refDate", "hour"] # non dim name matching items in list are used as non-index coordinates 920 921 dim_coords["leadTime"] = ["leadTime"] 922 coord_attrs["leadTime"] = dict(standard_name="forecast_period") 923 924 if 'valueOfFirstFixedSurface' not in index.columns: 925 index = index.assign(valueOfFirstFixedSurface=index.msg.apply(lambda msg: msg.valueOfFirstFixedSurface)) 926 if 'valueOfsecondFixedSurface' not in index.columns: 927 index = index.assign(valueOfSecondFixedSurface=index.msg.apply(lambda msg: msg.valueOfSecondFixedSurface)) 928 929 # dim name api change, user could run ds = ds.swap_dims(fixedSurface="valueOfFirstFixedSurface") 930 index = index.assign(level=list(zip(index['valueOfFirstFixedSurface'], index['valueOfSecondFixedSurface']))) 931# index = index.assign(level=index.msg.apply(lambda msg: msg.level)) 932 # lack of "level" indeicates don't create extra index coordinate "level" 933 dim_coords["level"] = ["valueOfFirstFixedSurface", "valueOfSecondFixedSurface"] 934 935 # logic for parsing possible dims from specific product definition section 936 937 if pdtn in {5, 9}: 938 939 # Probability forecasts at a horizontal level or in a horizontal layer 940 # in a continuous or non-continuous time interval. (see Template 941 # 4.9) 942 # AVAILABLE_THRESHOLD = { 943 # 0: {'has_lower': True, 'has_upper': False}, 944 # 1: {'has_lower': False, 'has_upper': True}, 945 # 2: {'has_lower': True, 'has_upper': True}, 946 # 3: {'has_lower': True, 'has_upper': False}, 947 # 4: {'has_lower': False, 'has_upper': True}, 948 # 5: {'has_lower': True, 'has_upper': False}, 949 # } 950 951 index, attrs = meta_check(index, attrs, "typeOfProbability") 952 if 'thresholdLowerLimit' not in index.columns: 953 index = index.assign(thresholdLowerLimit=index.msg.apply(lambda msg: msg.thresholdLowerLimit)) 954 if 'thresholdUpperLimit' not in index.columns: 955 index = index.assign(thresholdUpperLimit=index.msg.apply(lambda msg: msg.thresholdUpperLimit)) 956 if 'threshold' not in index.columns: 957 # using composite of lower and upper, but could use threshold string from grib2io as long as that is unique and based on lower and upper 958 index = index.assign(threshold=list(zip(index['thresholdLowerLimit'], index['thresholdUpperLimit']))) 959# index = index.assign(threshold = index.msg.apply(lambda msg: msg.threshold)) 960 961 # ommiting threshold results in no index being assigned for this possible dim 962 dim_coords["threshold"] = ["thresholdLowerLimit", "thresholdUpperLimit"] 963 964 if pdtn in {6, 10}: 965 966 # Percentile forecasts at a horizontal level or in a horizontal layer 967 # in a continuous or non-continuous time interval. (see Template 968 # 4.10) 969 dim_coords["percentileValue"] = ["percentileValue"] 970 coord_attrs["percentileValue"] = dict(long_name='percentile', units='percent') 971 972 if pdtn in {8, 9, 10, 11, 12, 13, 14, 42, 43, 45, 46, 47, 61, 62, 63, 67, 68, 72, 73, 78, 79, 82, 83, 84, 85, 87, 91}: 973 dim_coords["duration"] = ["duration"] 974 975 if pdtn in {1, 11, 33, 34, 41, 43, 45, 47, 49, 54, 56, 58, 59, 63, 68, 77, 79, 81, 83, 84, 85, 92}: 976 dim_coords["perturbationNumber"] = ["perturbationNumber"] 977 978 if pdtn in {2,3,4,12,13,14}: 979 index, attrs = meta_check(index, attrs, 'typeOfDerivedForecast') 980 981 if pdtn in {8,15,42,46,62,67,72,78,82,1001,1002,1100,1101}: 982 index, attrs = meta_check(index, attrs, 'statisticalProcess') 983 984 # Finish logic by pdtn 985 986 for k, v in dim_coords.items(): 987 for meta in v: 988 if meta not in index.columns: 989 index = index.assign(**{meta: index.msg.apply(lambda msg: getattr(msg, meta))}) 990 991 return index, dim_coords, attrs, coord_attrs 992 993 994# Custom open_datatree function to open grib files as DataTree 995def open_datatree(filename, *, filters: typing.Mapping[str, typing.Any] = None, engine="grib2io"): 996 """ 997 Open a GRIB2 file as an xarray DataTree. 998 999 Parameters 1000 ---------- 1001 filename : str 1002 Path to the GRIB2 file. 1003 filters : dict, optional 1004 Filter criteria for GRIB2 messages. 1005 engine : str, optional 1006 Engine to use for opening the file, defaults to "grib2io". 1007 1008 Returns 1009 ------- 1010 xarray.DataTree 1011 A hierarchical DataTree representation of the GRIB2 data. 1012 """ 1013 if not _HAS_DATATREE: 1014 raise ImportError("xarray version does not support DataTree functionality.") 1015 1016 if filters is None: 1017 filters = {} 1018 1019 # Open the file without any filters first to get all messages 1020 with grib2io.open(filename, _xarray_backend=True) as f: 1021 file_index = pd.DataFrame(f._index) 1022 1023 # Create a DataTree root 1024 tree = xr.DataTree() 1025 1026 # Build tree structure from GRIB messages 1027 return build_datatree_from_grib(filename, file_index, filters) 1028 1029 1030def build_da_without_coords(index, cube, filename, attrs) -> xr.DataArray: 1031 """ 1032 Build a DataArray without coordinates from a cube of grib2 messages. 1033 1034 Parameters 1035 ---------- 1036 index 1037 Index of cube. 1038 cube 1039 Cube of grib2 messages. 1040 filename 1041 Filename of grib2 file 1042 add_grib_section_attrs 1043 Include grib section arrays as dataArray attributes 1044 1045 Returns 1046 ------- 1047 DataArray 1048 DataArray without coordinates 1049 """ 1050 1051 dim_names = [k for k in cube.keys() if cube[k] is not None and len(cube[k]) > 1] 1052 constant_meta_names = [k for k in cube.keys() if cube[k] is None] 1053 dims = {k: len(cube[k]) for k in dim_names} 1054 1055 # guard against bad datarrays being formed 1056 dims_total = 1 1057 dims_to_filter = [] 1058 for dim_name, dim_len, in dims.items(): 1059 if dim_name not in {'x', 'y', 'station'}: 1060 dims_total *= dim_len 1061 dims_to_filter.append(dim_name) 1062 1063 # Check number of GRIB2 message indexed compared to non-X/Y 1064 # dimensions. 1065 if dims_total != len(index): 1066 raise ValueError( 1067 f"DataArray dimensions are not compatible with number of GRIB2 messages; DataArray has {dims_total} " 1068 f"and GRIB2 index has {len(index)}. Consider applying a filter for dimensions: {dims_to_filter}" 1069 ) 1070 1071 data = OnDiskArray(filename, index, cube) 1072 lock = _LOCK 1073 data = GribBackendArray(data, lock) 1074 data = indexing.LazilyIndexedArray(data) 1075 if len(dim_names) != len(data.shape): 1076 raise ValueError( 1077 "different number of dimensions on data " 1078 f"and dims: {len(data.shape)} vs {len(dim_names)}\n" 1079 "Grib2 messages could not be formed into a data cube; " 1080 "It's possible extra messages exist along a non-accounted for dimension based on PDTN\n" 1081 "It might be possible to get around this by applying a filter on the non-accounted for dimension" 1082 ) 1083 da = xr.DataArray(data, dims=dim_names) 1084 1085 da.encoding['original_shape'] = data.shape 1086 1087 da.encoding['preferred_chunks'] = {'y': -1, 'x': -1} 1088 msg1 = index.msg.iloc[0] 1089 1090 # plain language metadata is minimized 1091 # add grib section metadata 1092 da.attrs['GRIB2IO_section0'] = msg1.section0 1093 da.attrs['GRIB2IO_section1'] = msg1.section1 1094 da.attrs['GRIB2IO_section2'] = msg1.section2 if msg1.section2 else [] 1095 da.attrs['GRIB2IO_section3'] = msg1.section3 1096 da.attrs['GRIB2IO_section4'] = msg1.section4 1097 da.attrs['GRIB2IO_section5'] = msg1.section5 1098 da.attrs['fullName'] = str(msg1.fullName) 1099 da.attrs['shortName'] = str(msg1.shortName) 1100 da.attrs['units'] = str(msg1.units) 1101 da.attrs['originatingCenter'] = str(msg1.originatingCenter.definition) 1102 da.attrs['originatingSubCenter'] = str(msg1.originatingSubCenter.definition) 1103 1104 # add master table 1105 da.attrs['masterTableInfo'] = str(msg1.masterTableInfo.definition) 1106 1107 da.name = index.shortName.iloc[0] 1108 for meta_name in constant_meta_names: 1109 if meta_name in index.columns: 1110 da.attrs[meta_name] = index[meta_name].iloc[0] 1111 1112 da.attrs.update(attrs) 1113 1114 return da 1115 1116 1117def assign_xr_meta(ds, frames, cube, non_geo_dims, extra_geo, coord_attrs): 1118 1119 # assign coords from the cube; the cube prevents datarrays with 1120 # different shapes 1121 ds = ds.assign_coords(coords_from_cube(cube)) 1122 # assign extra index associated coords 1123 df = frames[0] # use first variable as they all have same shape and index metadata 1124 for dim_name, coord_names in non_geo_dims.items(): 1125 retain_index_coord = False 1126 for name in coord_names: 1127 if name == dim_name: 1128 retain_index_coord = True 1129 else: 1130 if ds[dim_name].size == 1: 1131 # for assigning scalar coords 1132 coord_data = [df[name].unique().item()] 1133 ds = ds.assign_coords({name: coord_data}).squeeze() 1134 else: 1135 # "ValueError: can only convert an array of size 1 to a Python scalar" indicates the coord is not compatible with the index 1136 coord_data = [df[df.index.get_level_values(f'{dim_name}_ix') == val][name].unique( 1137 ).item() for val in range(ds[dim_name].size)] 1138 coord = pd.Index(coord_data, name=dim_name) 1139 ds = ds.assign_coords({name: coord}) 1140 if not retain_index_coord: 1141 ds = ds.drop_vars(dim_name) 1142 1143 # assign extra geo coords 1144 ds = ds.assign_coords(extra_geo) 1145 # add crs data from first grib message to each data variable and the dataset 1146 geo_attrs = { 1147 'crs_wkt': CRS.from_dict(df.msg.iloc[0].projParameters).to_wkt(), 1148 'gridlengthXDirection': df.msg.iloc[0].gridlengthXDirection, 1149 'gridlengthYDirection': df.msg.iloc[0].gridlengthYDirection, 1150 'latitudeFirstGridpoint': df.msg.iloc[0].latitudeFirstGridpoint, 1151 'longitudeFirstGridpoint': df.msg.iloc[0].longitudeFirstGridpoint, 1152 } 1153 for data_var in ds.data_vars: 1154 ds[data_var].attrs.update(geo_attrs) 1155 ds.attrs.update(geo_attrs) 1156 1157 # add coordinate specific attributes 1158 for coord, attrs in coord_attrs.items(): 1159 ds[coord].attrs.update(attrs) 1160 1161 # assign valid date coords 1162 try: 1163 ds = ds.assign_coords(dict(validDate=ds.coords['refDate']+ds.coords['leadTime'])) 1164 ds.validDate.attrs['standard_name'] = 'time' 1165 ds.validDate.attrs['long_name'] = 'time' 1166 except Exception as e: 1167 warnings.warn(f'could not parse validTime: {e}') 1168 1169 # assign attributes 1170 ds.attrs['engine'] = 'grib2io' 1171 1172 return ds 1173 1174 1175def make_variables(index, f, non_geo_dims, allow_uneven_dims=False): 1176 """ 1177 Create an individual dataframe index and cube for each variable. 1178 1179 Parameters 1180 ---------- 1181 index 1182 Index of cube. 1183 f 1184 ? 1185 non_geo_dims 1186 Dimensions not associated with the x,y grid 1187 allow_uneven_dims 1188 If True, allows uneven dimensions (used for DataTree creation) 1189 1190 Returns 1191 ------- 1192 ordered_frames 1193 List of dataframes, one for each variable. 1194 cube 1195 Cube of grib2 messages. 1196 extra_geo 1197 Extra geographic coordinates. 1198 """ 1199 # let shortName determine the variables 1200 1201 # set the index to the name 1202 index = index.set_index('shortName').sort_index() 1203 # return nothing if no data 1204 if index.empty: 1205 return None, None, None 1206 1207 # define the DimCube 1208 dims = copy(non_geo_dims) 1209 1210 ordered_meta = list(non_geo_dims.keys()) 1211 cube = None 1212 ordered_frames = list() 1213 for key in index.index.unique(): 1214 frame = index.loc[[key]] 1215 frame = frame.reset_index() 1216 # frame is a dataframe with all records for one variable 1217 c = dict() 1218 # for colname in frame.columns: 1219 for colname in ordered_meta: 1220 uniques = pd.Index(frame[colname]).unique() 1221 if len(uniques) > 1: 1222 c[colname] = uniques.sort_values() 1223 else: 1224 c[colname] = [uniques[0]] 1225 1226 dims = [k for k in ordered_meta if len(c[k]) > 1] 1227 1228 for dim in dims: 1229 if frame[dim].value_counts().nunique() > 1 and not allow_uneven_dims: 1230 raise ValueError( 1231 f'uneven number of grib msgs associated with dimension: {dim}\n unique values for {dim}: {frame[dim].unique()} ') 1232 1233 if len(dims) >= 1: # dims may be empty if no extra dims on top of x,y 1234 frame = frame.sort_values(dims) 1235 frame = frame.set_index(dims) 1236 1237 if cube: 1238 if cube != c and not allow_uneven_dims: 1239 raise ValueError(f'{cube},\n {c};\n cubes are not the same; filter to a single cube') 1240 else: 1241 cube = c 1242 1243 # miloc is multi-index integer location of msg in nd DataArray 1244 miloc = list(zip(*[frame.index.unique(level=dim).get_indexer(frame.index.get_level_values(dim)) 1245 for dim in dims])) 1246 1247 # set frame multi index 1248 if len(miloc) >= 1: # miloc will be empty when no extra dims, thus no multiindex 1249 dim_ix = tuple([n+'_ix' for n in dims]) 1250 frame = frame.set_index(pd.MultiIndex.from_tuples(miloc, names=dim_ix)) 1251 1252 ordered_frames.append(frame) 1253 1254 # no variables 1255 if cube is None: 1256 cube = dict() 1257 1258 # check geography of data and assign to cube 1259 if len(index.ny.unique()) > 1 or len(index.nx.unique()) > 1: 1260 raise ValueError('multiple grids not accommodated') 1261 cube["y"] = range(int(index.ny.iloc[0])) 1262 cube["x"] = range(int(index.nx.iloc[0])) 1263 1264 extra_geo = None 1265 msg = index.msg.iloc[0] 1266 1267 # we want the lat lons; make them via accessing a record; we are assuming 1268 # all records are the same grid because they have the same shape; 1269 # may want a unique grid identifier from grib2io to avoid assuming this 1270 latitude, longitude = msg.latlons() 1271 latitude = xr.DataArray(latitude, dims=['y', 'x']) 1272 latitude.attrs['standard_name'] = 'latitude' 1273 latitude.attrs['units'] = 'degrees_north' 1274 longitude = xr.DataArray(longitude, dims=['y', 'x']) 1275 longitude.attrs['standard_name'] = 'longitude' 1276 longitude.attrs['units'] = 'degrees_east' 1277 extra_geo = dict(latitude=latitude, longitude=longitude) 1278 1279 return ordered_frames, cube, extra_geo 1280 1281 1282def interp_nd(a, *, method, grid_def_in, grid_def_out, method_options=None, num_threads=1): 1283 front_shape = a.shape[:-2] 1284 a = a.reshape(-1, a.shape[-2], a.shape[-1]) 1285 a = grib2io.interpolate(a, method, grid_def_in, grid_def_out, method_options=method_options, 1286 num_threads=num_threads) 1287 a = a.reshape(front_shape + (a.shape[-2], a.shape[-1])) 1288 return a 1289 1290 1291def interp_nd_stations(a, *, method, grid_def_in, lats, lons, method_options=None, num_threads=1): 1292 front_shape = a.shape[:-2] 1293 a = a.reshape(-1, a.shape[-2], a.shape[-1]) 1294 a = grib2io.interpolate_to_stations(a, method, grid_def_in, lats, lons, method_options=method_options, 1295 num_threads=num_threads) 1296 a = a.reshape(front_shape + (len(lats),)) 1297 return a 1298 1299 1300@xr.register_dataset_accessor("grib2io") 1301class Grib2ioDataSet: 1302 1303 def __init__(self, xarray_obj): 1304 self._obj = xarray_obj 1305 1306 def griddef(self): 1307 return Grib2GridDef.from_section3(self._obj[list(self._obj.data_vars)[0]].attrs['GRIB2IO_section3']) 1308 1309 def interp(self, method, grid_def_out, method_options=None, num_threads=1) -> xr.Dataset: 1310 # see interp method of class Grib2ioDataArray 1311 da = self._obj.to_array() 1312 da.attrs['GRIB2IO_section3'] = self._obj[list(self._obj.data_vars)[0]].attrs['GRIB2IO_section3'] 1313 da = da.grib2io.interp(method, grid_def_out, method_options=method_options, 1314 num_threads=num_threads) 1315 ds = da.to_dataset(dim='variable') 1316 return ds 1317 1318 def interp_to_stations(self, method, calls, lats, lons, method_options=None, num_threads=1) -> xr.Dataset: 1319 # see interp_to_stations method of class Grib2ioDataArray 1320 da = self._obj.to_array() 1321 da.attrs['GRIB2IO_section3'] = self._obj[list(self._obj.data_vars)[0]].attrs['GRIB2IO_section3'] 1322 da = da.grib2io.interp_to_stations(method, calls, lats, lons, method_options=method_options, 1323 num_threads=num_threads) 1324 ds = da.to_dataset(dim='variable') 1325 return ds 1326 1327 def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"): 1328 """ 1329 Write a DataSet to a grib2 file. 1330 1331 Parameters 1332 ---------- 1333 filename 1334 Name of the grib2 file to write to. 1335 mode: {"x", "w", "a"}, optional, default="x" 1336 Persistence mode 1337 1338 | mode | Description | 1339 | :---:| :---: | 1340 | 'x' | create (fail if exists) | 1341 | 'w' | create (overwrite if exists) | 1342 | 'a' | append (create if does not exist) | 1343 1344 """ 1345 ds = self._obj 1346 1347 for shortName in sorted(ds): 1348 # make a DataArray from the "Data Variables" in the DataSet 1349 da = ds[shortName] 1350 1351 da.grib2io.to_grib2(filename, mode=mode) 1352 mode = "a" 1353 1354 def update_attrs(self, **kwargs): 1355 """ 1356 Raises an error because Datasets don't have a .attrs attribute. 1357 1358 Parameters 1359 ---------- 1360 attrs 1361 Attributes to update. 1362 """ 1363 raise ValueError( 1364 f"Datasets do not have a .attrs attribute; use .grib2io.update_attrs({kwargs}) on a DataArray instead." 1365 ) 1366 1367 def subset(self, lats, lons) -> xr.Dataset: 1368 """ 1369 Subset the DataSet to a region defined by latitudes and longitudes. 1370 1371 Parameters 1372 ---------- 1373 lats 1374 Latitude bounds of the region. 1375 lons 1376 Longitude bounds of the region. 1377 1378 Returns 1379 ------- 1380 subset 1381 DataSet subset to the region. 1382 """ 1383 ds = self._obj 1384 1385 newds = xr.Dataset() 1386 for shortName in ds: 1387 newds[shortName] = ds[shortName].grib2io.subset(lats, lons).copy() 1388 1389 return newds 1390 1391 1392@xr.register_dataarray_accessor("grib2io") 1393class Grib2ioDataArray: 1394 1395 def __init__(self, xarray_obj): 1396 self._obj = xarray_obj 1397 1398 def griddef(self): 1399 return Grib2GridDef.from_section3(self._obj.attrs['GRIB2IO_section3']) 1400 1401 def interp(self, method, grid_def_out, method_options=None, num_threads=1) -> xr.DataArray: 1402 """ 1403 Perform grid spatial interpolation. 1404 1405 Uses the [NCEPLIBS-ip library](https://github.com/NOAA-EMC/NCEPLIBS-ip). 1406 1407 Parameters 1408 ---------- 1409 method 1410 Interpolate method to use. This can either be an integer or string 1411 using the following mapping: 1412 1413 | Interpolate Scheme | Integer Value | 1414 | :---: | :---: | 1415 | 'bilinear' | 0 | 1416 | 'bicubic' | 1 | 1417 | 'neighbor' | 2 | 1418 | 'budget' | 3 | 1419 | 'spectral' | 4 | 1420 | 'neighbor-budget' | 6 | 1421 grid_def_out 1422 Grib2GridDef object of the output grid. 1423 method_options : list of ints, optional 1424 Interpolation options. See the NCEPLIBS-ip documentation for 1425 more information on how these are used. 1426 num_threads : int, optional 1427 Number of OpenMP threads to use for interpolation. The default 1428 value is 1. If grib2io_interp was not built with OpenMP, then 1429 this keyword argument and value will have no impact. 1430 1431 Returns 1432 ------- 1433 interp 1434 DataSet interpolated to new grid definition. The attribute 1435 GRIB2IO_section3 is replaced with the section3 array from the new 1436 grid definition. 1437 """ 1438 da = self._obj 1439 # ensure that y, x are rightmost dims; they should be if opening with 1440 # grib2io engine 1441 1442 # gdtn and gdt is not the entirety of the new s3 1443 npoints = grid_def_out.npoints 1444 s3_new = np.array([0, npoints, 0, 0, grid_def_out.gdtn] + list(grid_def_out.gdt)) 1445 1446 # make new lat lons 1447 lats, lons = Grib2Message(section3=s3_new, pdtn=0, drtn=0).grid() 1448 latitude = xr.DataArray(lats, dims=['y', 'x']) 1449 longitude = xr.DataArray(lons, dims=['y', 'x']) 1450 1451 # create new coords 1452 new_coords = dict(da.coords) 1453 del new_coords['latitude'] 1454 del new_coords['longitude'] 1455 new_coords['longitude'] = longitude 1456 new_coords['latitude'] = latitude 1457 1458 # make grid def in from section3 on da.attrs 1459 grid_def_in = self.griddef() 1460 1461 if da.chunks is None: 1462 data = interp_nd(da.data, method=method, grid_def_in=grid_def_in, 1463 grid_def_out=grid_def_out, 1464 method_options=method_options, num_threads=num_threads) 1465 else: 1466 import dask 1467 front_shape = da.shape[:-2] 1468 data = da.data.map_blocks(interp_nd, method=method, grid_def_in=grid_def_in, 1469 grid_def_out=grid_def_out, method_options=method_options, 1470 chunks=da.chunks[:-2]+latitude.shape, dtype=da.dtype) 1471 1472 new_da = xr.DataArray(data, dims=da.dims, coords=new_coords, attrs=da.attrs) 1473 1474 new_da.attrs['GRIB2IO_section3'] = s3_new 1475 new_da.name = da.name 1476 return new_da 1477 1478 def interp_to_stations(self, method, calls, lats, lons, method_options=None, num_threads=1) -> xr.DataArray: 1479 """ 1480 Perform spatial interpolation to station points. 1481 1482 Parameters 1483 ---------- 1484 method 1485 Interpolate method to use. This can either be an integer or string 1486 using the following mapping: 1487 1488 | Interpolate Scheme | Integer Value | 1489 | :---: | :---: | 1490 | 'bilinear' | 0 | 1491 | 'bicubic' | 1 | 1492 | 'neighbor' | 2 | 1493 | 'budget' | 3 | 1494 | 'spectral' | 4 | 1495 | 'neighbor-budget' | 6 | 1496 1497 calls 1498 Station calls used for labeling new station index coordinate 1499 lats 1500 Latitudes of the station points. 1501 lons 1502 Longitudes of the station points. 1503 1504 Returns 1505 ------- 1506 interp_to_stations 1507 DataArray interpolated to lat and lon locations and labeled with 1508 dimension and coordinate 'station'. (..., y, x) -> (..., station) 1509 """ 1510 da = self._obj 1511 # TODO ensure that y, x are rightmost dims; they should be if opening 1512 # with grib2io engine 1513 1514 calls = np.asarray(calls) 1515 lats = np.asarray(lats) 1516 lons = np.asarray(lons) 1517 latitude = xr.DataArray(lats, dims=['station']) 1518 longitude = xr.DataArray(lons, dims=['station']) 1519 1520 # create new coords 1521 new_coords = dict(da.coords) 1522 del new_coords['latitude'] 1523 del new_coords['longitude'] 1524 new_coords['longitude'] = longitude 1525 new_coords['latitude'] = latitude 1526 new_coords['station'] = calls 1527 1528 new_dims = da.dims[:-2] + ('station',) 1529 1530 # make grid def in from section3 on da attrs 1531 grid_def_in = self.griddef() 1532 1533 if da.chunks is None: 1534 data = interp_nd_stations(da.data, method=method, grid_def_in=grid_def_in, lats=lats, 1535 lons=lons, method_options=method_options, num_threads=num_threads) 1536 else: 1537 import dask 1538 front_shape = da.shape[:-1] 1539 data = da.data.map_blocks(interp_nd_stations, method=method, grid_def_in=grid_def_in, 1540 lats=lats, lons=lons, method_options=method_options, 1541 drop_axis=-1, chunks=da.chunks[:-2]+latitude.shape, 1542 dtype=da.dtype) 1543 1544 new_da = xr.DataArray(data, dims=new_dims, coords=new_coords, attrs=da.attrs) 1545 1546 new_da.name = da.name 1547 return new_da 1548 1549 def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"): 1550 """ 1551 Write a DataArray to a grib2 file. 1552 1553 Parameters 1554 ---------- 1555 filename 1556 Name of the grib2 file to write to. 1557 mode: {"x", "w", "a"}, optional, default="x" 1558 Persistence mode 1559 1560 +------+-----------------------------------+ 1561 | mode | Description | 1562 +======+===================================+ 1563 | x | create (fail if exists) | 1564 +------+-----------------------------------+ 1565 | w | create (overwrite if exists) | 1566 +------+-----------------------------------+ 1567 | a | append (create if does not exist) | 1568 +------+-----------------------------------+ 1569 1570 """ 1571 da = self._obj.copy(deep=True) 1572 1573 coords_keys = sorted(da.coords.keys()) 1574 coords_keys = [k for k in coords_keys if k in AVAILABLE_NON_GEO_COORDS] 1575 1576 # If there are dimension coordinates, the DataArray is a hypercube of 1577 # grib2 messages. 1578 1579 # Create `indexes` which is a list of lists of dictionaries for all 1580 # dimension coordinates. Each dictionary key is the dimension 1581 # coordinate name and the value is a list of the dimension coordinate 1582 # values. This allows for easy iteration over all possible grib2 1583 # messages in the DataArray by using itertools.product. 1584 # 1585 # For example: 1586 # indexes = [ 1587 # [ 1588 # {"leadTime": 9}, 1589 # {"leadTime": 12}, 1590 # ], 1591 # [ 1592 # {"valueOfFirstFixedSurface": 900}, 1593 # {"valueOfFirstFixedSurface": 925}, 1594 # {"valueOfFirstFixedSurface": 950}, 1595 # ], 1596 # ] 1597 1598 # assign loc indexes to dimensions without indexes for uniform selection by name 1599 loc_indexes = list() 1600 for dim in da.dims: 1601 if dim not in da.indexes: 1602 da = da.assign_coords({dim: range(da[dim].size)}) 1603 loc_indexes.append(dim) 1604 1605 indexes = [] 1606 for index in [i for i in AVAILABLE_NON_GEO_DIMS if i in da.dims]: 1607 values = da.coords[index].values 1608 if len(values) != len(set(values)): 1609 raise ValueError( 1610 f"Dimension coordinate '{index}' has duplicate values, but to_grib2 requires unique values to find each GRIB2 message in the DataArray." 1611 ) 1612 listeach = [{index: value} for value in sorted(values)] 1613 indexes.append(listeach) 1614 1615 # If `dim_coords` is [], then the DataArray is a single grib2 message and 1616 # itertools.product(*dim_coords) will run once with `selectors = ()`. 1617 for selectors in itertools.product(*indexes): 1618 # Need to find the correct data in the DataArray based on the 1619 # dimension coordinates. 1620 filters = {k: v for d in selectors for k, v in d.items()} 1621 1622 # If `filters` is {}, then the DataArray is a single grib2 message 1623 # and da.sel(indexers={}) returns the DataArray. 1624 selected = da.sel(indexers=filters) 1625 1626 newmsg = Grib2Message( 1627 selected.attrs["GRIB2IO_section0"], 1628 selected.attrs["GRIB2IO_section1"], 1629 selected.attrs["GRIB2IO_section2"], 1630 selected.attrs["GRIB2IO_section3"], 1631 selected.attrs["GRIB2IO_section4"], 1632 selected.attrs["GRIB2IO_section5"], 1633 ) 1634 newmsg.data = np.array(selected.data) 1635 1636 # For dimension coordinates, set the grib2 message metadata to the 1637 # dimension coordinate value. 1638 for index, value in filters.items(): 1639 if index not in loc_indexes: 1640 setattr(newmsg, index, value) 1641 1642 # For non-dimension coordinates, set the grib2 message metadata to 1643 # the DataArray coordinate value. 1644 for index in [i for i in coords_keys if i not in da.dims]: 1645 setattr(newmsg, index, selected.coords[index].values) 1646 1647 # Set section 5 attributes to the da.encoding dictionary. 1648 for key, value in selected.encoding.items(): 1649 if key in ["dtype", "chunks", "original_shape"]: 1650 continue 1651 setattr(newmsg, key, value) 1652 1653 # write the message to file 1654 with grib2io.open(filename, mode=mode) as f: 1655 f.write(newmsg) 1656 mode = "a" 1657 1658 def update_attrs(self, **kwargs): 1659 """ 1660 Update many of the attributes of the DataArray. 1661 1662 Parameters 1663 ---------- 1664 **kwargs 1665 Attributes to update. This can include many of the GRIB2IO message 1666 attributes that you can find when you print a GRIB2IO message. For 1667 conflicting updates, the last keyword will be used. 1668 1669 +-----------------------+------------------------------------------+ 1670 | kwargs | Description | 1671 +=======================+==========================================+ 1672 | shortName="VTMP" | Set shortName to "VTMP", along with | 1673 | | appropriate discipline, | 1674 | | parameterCategory, parameterNumber, | 1675 | | fullName and units. | 1676 +-----------------------+------------------------------------------+ 1677 | discipline=0, | Set shortName, discipline, | 1678 | parameterCategory=0, | parameterCategory, parameterNumber, | 1679 | parameterNumber=1 | fullName and units appropriate for | 1680 | | "Virtual Temperature". | 1681 +-----------------------+------------------------------------------+ 1682 | discipline=0, | Conflicting keywords but | 1683 | parameterCategory=0, | 'shortName="TMP"' wins. Set shortName, | 1684 | parameterNumber=1, | discipline, parameterCategory, | 1685 | shortName="TMP" | parameterNumber, fullName and units | 1686 | | appropriate for "Temperature". | 1687 +-----------------------+------------------------------------------+ 1688 1689 Returns 1690 ------- 1691 DataArray 1692 DataArray with updated attributes. 1693 """ 1694 da = self._obj.copy(deep=True) 1695 1696 newmsg = Grib2Message( 1697 da.attrs["GRIB2IO_section0"], 1698 da.attrs["GRIB2IO_section1"], 1699 da.attrs["GRIB2IO_section2"], 1700 da.attrs["GRIB2IO_section3"], 1701 da.attrs["GRIB2IO_section4"], 1702 da.attrs["GRIB2IO_section5"], 1703 ) 1704 1705 coords_keys = [ 1706 k 1707 for k in da.coords.keys() 1708 if k in AVAILABLE_NON_GEO_COORDS 1709 ] 1710 1711 for grib2_name, value in kwargs.items(): 1712 if grib2_name == "gridDefinitionTemplateNumber": 1713 raise ValueError( 1714 "The gridDefinitionTemplateNumber attribute cannot be updated. The best way to change to a different grid is to interpolate the data to a new grid using the grib2io interpolate functions." 1715 ) 1716 if grib2_name == "productDefinitionTemplateNumber": 1717 raise ValueError( 1718 "The productDefinitionTemplateNumber attribute cannot be updated." 1719 ) 1720 if grib2_name == "dataRepresentationTemplateNumber": 1721 raise ValueError( 1722 "The dataRepresentationTemplateNumber attribute cannot be updated." 1723 ) 1724 if grib2_name in coords_keys: 1725 warnings.warn( 1726 f"Skipping attribute '{grib2_name}' because it is a coordinate. Use da.assign_coords() to change coordinate values." 1727 ) 1728 continue 1729 if hasattr(newmsg, grib2_name): 1730 setattr(newmsg, grib2_name, value) 1731 else: 1732 warnings.warn( 1733 f"Skipping attribute '{grib2_name}' because it is not a valid GRIB2 attribute for this message and cannot be updated." 1734 ) 1735 continue 1736 1737 da.attrs["GRIB2IO_section0"] = newmsg.section0 1738 da.attrs["GRIB2IO_section1"] = newmsg.section1 1739 da.attrs["GRIB2IO_section2"] = newmsg.section2 or [] 1740 da.attrs["GRIB2IO_section3"] = newmsg.section3 1741 da.attrs["GRIB2IO_section4"] = newmsg.section4 1742 da.attrs["GRIB2IO_section5"] = newmsg.section5 1743 da.attrs["fullName"] = newmsg.fullName 1744 da.attrs["shortName"] = newmsg.shortName 1745 da.attrs["units"] = newmsg.units 1746 1747 return da 1748 1749 def subset(self, lats, lons) -> xr.DataArray: 1750 """ 1751 Subset the DataArray to a region defined by latitudes and longitudes. 1752 1753 Parameters 1754 ---------- 1755 lats 1756 Latitude bounds of the region. 1757 lons 1758 Longitude bounds of the region. 1759 1760 Returns 1761 ------- 1762 subset 1763 DataArray subset to the region. 1764 """ 1765 da = self._obj.copy(deep=True) 1766 1767 newmsg = Grib2Message( 1768 da.attrs["GRIB2IO_section0"], 1769 da.attrs["GRIB2IO_section1"], 1770 da.attrs["GRIB2IO_section2"], 1771 da.attrs["GRIB2IO_section3"], 1772 da.attrs["GRIB2IO_section4"], 1773 da.attrs["GRIB2IO_section5"], 1774 ) 1775 1776 newmsg.data = np.zeros((newmsg.ny, newmsg.nx), dtype=np.float32) 1777 1778 newmsg = newmsg.subset(lats, lons) 1779 1780 da.attrs["GRIB2IO_section3"] = newmsg.section3 1781 1782 mask_lat = (da.latitude >= newmsg.latitudeLastGridpoint) & ( 1783 da.latitude <= newmsg.latitudeFirstGridpoint 1784 ) 1785 mask_lon = (da.longitude >= newmsg.longitudeFirstGridpoint) & ( 1786 da.longitude <= newmsg.longitudeLastGridpoint 1787 ) 1788 1789 del newmsg 1790 1791 return da.where((mask_lon & mask_lat).compute(), drop=True) 1792 1793 1794def build_datatree_from_grib(filename, file_index, filters=None, stack_vertical=False): 1795 """ 1796 Build a DataTree from GRIB2 messages. 1797 1798 Parameters 1799 ---------- 1800 filename : str 1801 Path to the GRIB2 file. 1802 file_index : pandas.DataFrame 1803 DataFrame of GRIB2 message index. 1804 filters : dict, optional 1805 Filter criteria for GRIB2 messages. 1806 stack_vertical : bool, optional 1807 If True, vertical levels will be stacked in a single dataset 1808 instead of being organized in separate tree nodes. 1809 1810 Returns 1811 ------- 1812 xarray.DataTree 1813 A hierarchical DataTree representation of the GRIB2 data. 1814 """ 1815 if filters is None: 1816 filters = {} 1817 1818 # Apply any filters from user 1819 for k, v in filters.items(): 1820 if k not in file_index.columns: 1821 file_index = file_index.copy() 1822 file_index[k] = file_index.msg.apply(lambda msg: getattr(msg, k, None)) 1823 file_index = filter_index(file_index, k, v) 1824 1825 # Make a copy to avoid the SettingWithCopyWarning 1826 file_index = file_index.copy() 1827 1828 # Extract metadata needed for tree organization 1829 # Use a safer approach to handle missing attributes 1830 def safe_getattr(obj, name): 1831 try: 1832 attr = getattr(obj, name) 1833 # Need to test if the attribute is Grib2Metadata. If so, 1834 # then get the value attribute. 1835 if isinstance(attr, grib2io.templates.Grib2Metadata): 1836 attr = attr.value 1837 return attr 1838 except (AttributeError, KeyError): 1839 return None 1840 1841 for attr in _TREE_HIERARCHY_LEVELS: 1842 if (attr not in file_index.columns) and (attr != 'valueOfFirstFixedSurface'): 1843 file_index[attr] = file_index.msg.apply(lambda msg: safe_getattr(msg, attr)) 1844 1845 # Also extract shortName for variable naming 1846 if 'shortName' not in file_index.columns: 1847 file_index = file_index.assign(shortName=file_index.msg.apply(lambda msg: getattr(msg, 'shortName', None))) 1848 file_index = file_index.assign(nx=file_index.msg.apply(lambda msg: getattr(msg, 'nx', None))) 1849 file_index = file_index.assign(ny=file_index.msg.apply(lambda msg: getattr(msg, 'ny', None))) 1850 1851 # Create root DataTree 1852 root = xr.DataTree() 1853 1854 # Adjust hierarchy levels if we're stacking vertical levels 1855 hierarchy_levels = list(_TREE_HIERARCHY_LEVELS) # This makes a copy 1856 if stack_vertical and "valueOfFirstFixedSurface" in hierarchy_levels: 1857 hierarchy_levels.remove("valueOfFirstFixedSurface") 1858 1859 # First group by level type 1860 level_groups = {} 1861 1862 # Create a dictionary to group data by level type 1863 for level_type in file_index['typeOfFirstFixedSurface'].unique(): 1864 if pd.notna(level_type): # Skip None/NaN values 1865 level_info = _LEVEL_NAME_MAPPING.get(level_type, f"level_{level_type}") 1866 level_name = level_info[0] 1867 level_source = level_info[1] 1868 # Get all rows for this level type 1869 level_data = file_index[file_index['typeOfFirstFixedSurface'] == level_type] 1870 level_groups[level_type] = {'name': level_name, 'data': level_data} 1871 1872 # Process each level group 1873 for level_type, group_info in level_groups.items(): 1874 level_name = group_info['name'] 1875 level_df = group_info['data'] 1876 1877 # Create a branch for this level type 1878 level_tree = xr.DataTree() 1879 1880 # Process this branch based on PDTN, perturbation number, etc. 1881 process_level_branch(level_tree, level_df, filename) 1882 1883 # Add this branch to the main tree 1884 root[level_name] = level_tree 1885 1886 return root 1887 1888 1889def process_level_branch(level_tree, df, filename): 1890 """ 1891 Process a level type branch of the data tree, organizing by PDTN and other attributes. 1892 1893 Parameters 1894 ---------- 1895 level_tree : xarray.DataTree 1896 The DataTree node for this level type 1897 df : pandas.DataFrame 1898 DataFrame of messages for this level type 1899 filename : str 1900 Path to the GRIB2 file 1901 """ 1902 # Group by PDTN 1903 pdtn_groups = {} 1904 1905 # Group data by PDTN first 1906 for pdtn_value in df['productDefinitionTemplateNumber'].unique(): 1907 if pd.notna(pdtn_value): 1908 pdtn_df = df[df['productDefinitionTemplateNumber'] == pdtn_value] 1909 pdtn_groups[pdtn_value] = pdtn_df 1910 1911 # If there's only one PDTN value, skip creating PDTN branch level 1912 if len(pdtn_groups) == 1: 1913 pdtn, pdtn_df = next(iter(pdtn_groups.items())) 1914 1915 pdtn_name = f"pdtn_{int(pdtn)}" 1916 1917 # Check if we need to further subdivide by perturbation number 1918 has_perturbations = ('perturbationNumber' in pdtn_df.columns and 1919 len(pdtn_df['perturbationNumber'].dropna().unique()) > 1) 1920 1921 # Check if we need to further subdivide by probabilities unique for each variable. 1922 has_probabilities = ('typeOfProbability' in pdtn_df.columns and 1923 len(pdtn_df['typeOfProbability'].dropna().unique()) > 1) 1924 1925 if has_perturbations: 1926 # Process perturbations directly on the level tree 1927 process_perturbation_groups(level_tree, pdtn_df, filename) 1928 elif has_probabilities: 1929 # Process probability groups 1930 process_probability_groups(level_tree, pdtn_df, filename) 1931 else: 1932 # Try to create dataset directly on level 1933 try: 1934 dss = create_datasets_from_df(pdtn_df, filename) 1935 if dss is not None: 1936 dt = xr.DataTree() 1937 if len(dss) == 1: 1938 dt.ds = dss[0] 1939 else: 1940 for ds in dss: 1941 varname = list(ds.data_vars)[0] 1942 dt[f"var_{varname}"] = ds 1943 level_tree[pdtn_name] = dt 1944 except Exception as e: 1945 print(f"Error creating dataset for level with pdtn {int(pdtn)}: {e}") 1946 1947 # Try to separate by variable name as a fallback 1948 try_process_by_variables(level_tree, pdtn_df, filename) 1949 else: 1950 # Multiple PDTN values, process each group with PDTN branch nodes 1951 for pdtn, pdtn_df in pdtn_groups.items(): 1952 # Use a simple node name that's easy to use in code 1953 pdtn_name = f"pdtn_{int(pdtn)}" 1954 1955 # Check if we need to further subdivide by perturbation number 1956 has_perturbations = ('perturbationNumber' in pdtn_df.columns and 1957 len(pdtn_df['perturbationNumber'].dropna().unique()) > 1) 1958 1959 # Check if we need to further subdivide by probabilities unique for each variable. 1960 has_probabilities = ('typeOfProbability' in pdtn_df.columns and 1961 len(pdtn_df['typeOfProbability'].dropna().unique()) > 1) 1962 1963 if has_perturbations: 1964 # Create a branch for this PDTN 1965 pdtn_tree = xr.DataTree() 1966 1967 # Process perturbation groups 1968 process_perturbation_groups(pdtn_tree, pdtn_df, filename) 1969 1970 # Only add the PDTN branch if it has children 1971 if len(pdtn_tree.children) > 0 or pdtn_tree.ds is not None: 1972 level_tree[pdtn_name] = pdtn_tree 1973 elif has_probabilities: 1974 # Create a branch for this PDTN 1975 pdtn_tree = xr.DataTree() 1976 1977 # Process probability groups 1978 process_probability_groups(pdtn_tree, pdtn_df, filename) 1979 1980 # Only add the PDTN branch if it has children 1981 if len(pdtn_tree.children) > 0 or pdtn_tree.ds is not None: 1982 level_tree[pdtn_name] = pdtn_tree 1983 else: 1984 # Create a subtree for this PDTN 1985 pdtn_tree = xr.DataTree() 1986 1987 # Try to create dataset directly on level 1988 try: 1989 dss = create_datasets_from_df(pdtn_df, filename) 1990 if dss is not None: 1991 if len(dss) == 1: 1992 pdtn_tree.ds = dss[0] 1993 else: 1994 for ds in dss: 1995 varname = list(ds.data_vars)[0] 1996 pdtn_tree[f"var_{varname}"] = ds 1997 level_tree[pdtn_name] = pdtn_tree 1998 except Exception as e: 1999 print(f"Error creating dataset for level with pdtn {int(pdtn)}: {e}") 2000 2001 # Try to separate by variable name as a fallback 2002 try_process_by_variables(pdtn_tree, pdtn_df, filename) 2003 level_tree[pdtn_name] = pdtn_tree 2004 2005 2006def process_probability_groups(target_tree, pdtn_df, filename): 2007 """ 2008 """ 2009 success = False 2010 # Group by type of probability 2011 prob_groups = {} 2012 for prob_value in pdtn_df['typeOfProbability'].unique(): 2013 if pd.notna(prob_value): 2014 prob_df = pdtn_df[pdtn_df['typeOfProbability'] == prob_value] 2015 prob_groups[prob_value] = prob_df 2016 2017 # Process each probability group 2018 prob_dict = {} 2019 for prob_num, prob_df in prob_groups.items(): 2020 prob_name = f"prob_{int(prob_num)}" 2021 2022 # Try to create dataset for this probability group 2023 try: 2024 dss = create_datasets_from_df(prob_df, filename) 2025 dt = xr.DataTree() 2026 if len(dss) == 1: 2027 dt.ds = dss[0] 2028 target_tree[prob_name] = dt 2029 elif len(dss) > 1: 2030 for ds in dss: 2031 dt[f"var_{ds.data_vars[0]}"] = ds 2032 target_tree[prob_name] = dt 2033 except Exception as e: 2034 # Log error but continue processing other groups 2035 print(f"Error creating dataset for type of probability {prob_name}: {e}") 2036 2037 return success 2038 2039 2040def process_perturbation_groups(target_tree, pdtn_df, filename): 2041 """ 2042 Process perturbation groups and add them to the target tree. 2043 2044 Parameters 2045 ---------- 2046 target_tree : xarray.DataTree 2047 The tree node to add perturbation groups to 2048 pdtn_df : pandas.DataFrame 2049 DataFrame of messages for a specific PDTN 2050 filename : str 2051 Path to the GRIB2 file 2052 2053 Returns 2054 ------- 2055 bool 2056 True if at least one perturbation was successfully processed 2057 """ 2058 success = False 2059 # Group by perturbation number 2060 pert_groups = {} 2061 for pert_value in pdtn_df['perturbationNumber'].unique(): 2062 if pd.notna(pert_value): 2063 pert_df = pdtn_df[pdtn_df['perturbationNumber'] == pert_value] 2064 pert_groups[pert_value] = pert_df 2065 2066 # Process each perturbation group 2067 for pert_num, pert_df in pert_groups.items(): 2068 pert_name = f"pert_{int(pert_num)}" 2069 2070 ## Try to create dataset for this perturbation group 2071 #try: 2072 # dss = create_datasets_from_df(pert_df, filename) 2073 # if dss is not None: 2074 # if len(dss) == 1: 2075 # target_tree.ds = dss[0] 2076 # else: 2077 # dss_dict = {f"ds_{i}": ds for i, ds in enumerate(dss)} 2078 # atree = xr.DataTree(dss_dict) 2079 # target_tree[prob_name] = atree 2080 # success = True 2081 #except Exception as e: 2082 # # Log error but continue processing other groups 2083 # print(f"Error creating dataset for perturbation {pert_name}: {e}") 2084 2085 # Try to create dataset for this perturbation group 2086 try: 2087 dss = create_datasets_from_df(pert_df, filename) 2088 dt = xr.DataTree() 2089 if len(dss) == 1: 2090 dt.ds = dss[0] 2091 target_tree[pert_name] = dt 2092 elif len(dss) > 1: 2093 for ds in dss: 2094 dt[f"pert{ds.data_vars[0]}"] = ds 2095 target_tree[pert_name] = dt 2096 except Exception as e: 2097 # Log error but continue processing other groups 2098 print(f"Error creating dataset for perturbation {pert_name}: {e}") 2099 2100 return success 2101 2102 2103def try_process_by_variables(target_tree, df, filename): 2104 """ 2105 Try to separate data by variable names and create datasets. 2106 2107 Parameters 2108 ---------- 2109 target_tree : xarray.DataTree 2110 The tree node to add variable datasets to 2111 df : pandas.DataFrame 2112 DataFrame of messages 2113 filename : str 2114 Path to the GRIB2 file 2115 2116 Returns 2117 ------- 2118 bool 2119 True if at least one variable was successfully processed 2120 """ 2121 success = False 2122 2123 try: 2124 for var_name in df['shortName'].unique(): 2125 if pd.notna(var_name): 2126 var_df = df[df['shortName'] == var_name] 2127 try: 2128 var_ds = create_datasets_from_df(var_df, filename) 2129 if var_ds is not None: 2130 target_tree[f"var_{var_name}"] = var_ds[0] 2131 success = True 2132 except Exception as var_e: 2133 print(f"Error creating dataset for variable {var_name}: {var_e}") 2134 except Exception as nested_e: 2135 print(f"Failed to process variables: {nested_e}") 2136 2137 return success 2138 2139 2140def create_datasets_from_df( 2141 df, 2142 filename, 2143 verbose=False 2144) -> typing.Optional[typing.List[xr.Dataset]]: 2145 """ 2146 Create a list of xarray Datasets from a DataFrame of messages. 2147 2148 Parameters 2149 ---------- 2150 df : pandas.DataFrame 2151 DataFrame of GRIB messages 2152 filename : str 2153 Path to the GRIB2 file 2154 verbose : bool, optional 2155 If True, prints detailed debugging information 2156 2157 Returns 2158 ------- 2159 dss 2160 List of Datasets, or None if creation failed 2161 """ 2162 try: 2163 if verbose: 2164 print(f"\n==== VERBOSE DEBUG INFO ====") 2165 print(f"Creating dataset from DataFrame with {len(df)} messages") 2166 print(f"DataFrame columns: {df.columns.tolist()}") 2167 2168 if 'shortName' in df.columns: 2169 print(f"Variables in group: {df['shortName'].unique().tolist()}") 2170 2171 if 'valueOfFirstFixedSurface' in df.columns: 2172 print(f"Vertical levels: {df['valueOfFirstFixedSurface'].unique().tolist()}") 2173 2174 # Process by variables 2175 datasets = {} 2176 2177 # Process each variable separately, regardless of whether there are vertical levels 2178 for var_name, var_df in df.groupby('shortName'): 2179 if verbose: 2180 print( 2181 f"\n Processing variable: {var_name} with {len(var_df)} messages, with pdtn(s) = {var_df['productDefinitionTemplateNumber'].unique()}") 2182 2183 # Process vertical levels if present 2184 if 'valueOfFirstFixedSurface' in var_df.columns and len(var_df['valueOfFirstFixedSurface'].unique()) > 1: 2185 if verbose: 2186 print(f" Variable {var_name} has multiple vertical levels") 2187 # Process each level separately 2188 level_das = [] 2189 2190 for level, level_df in var_df.groupby('valueOfFirstFixedSurface'): 2191 if verbose: 2192 print(f" Processing level {level} with {len(level_df)} messages") 2193 try: 2194 # Parse the index and get dimensions for this level 2195 file_index, non_geo_dims, attrs, coord_attrs = parse_grib_index(level_df, {}) 2196 # Remove valueOfFirstFixedSurface from dimensions since we're handling it separately 2197 non_geo_dims = [d for d in non_geo_dims if d.__name__ != "ValueOfFirstFixedSurfaceDim"] 2198 2199 frames, cube, extra_geo = make_variables( 2200 file_index, filename, non_geo_dims, allow_uneven_dims=True) 2201 2202 if frames is not None and len(frames) == 1: 2203 level_da = build_da_without_coords(frames[0], cube, filename, attrs) 2204 # Add this level to the list with its level value as coord 2205 level_da = level_da.assign_coords(valueOfFirstFixedSurface=level) 2206 level_das.append(level_da) 2207 except Exception as e: 2208 if verbose: 2209 print(f" Error processing level {level} for {var_name}: {e}") 2210 2211 if level_das: 2212 # Combine all levels into a single DataArray along the valueOfFirstFixedSurface dimension 2213 if verbose: 2214 print(f" Combining {len(level_das)} levels for {var_name}") 2215 try: 2216 combined_da = xr.concat(level_das, dim='valueOfFirstFixedSurface') 2217 # Create a simple dataset with just this variable 2218 var_ds = xr.Dataset({var_name: combined_da}) 2219 # Assign the coords from the first level's cube 2220 var_ds = assign_xr_meta(var_ds, frames, cube, non_geo_dims, extra_geo, coord_attrs) 2221 # TODO: is the below code all now in assign_xr_meta? was there instances where refDate and leadTime were not coords? 2222 # var_ds = var_ds.assign_coords(coords_from_cube(cube)) 2223 # Add extra geo coords 2224 # if extra_geo: 2225 # var_ds = var_ds.assign_coords(extra_geo) 2226 # Add valid date coords if available 2227 # if 'refDate' in var_ds.coords and 'leadTime' in var_ds.coords: 2228 # var_ds = var_ds.assign_coords(dict(validDate=var_ds.coords['refDate']+var_ds.coords['leadTime'])) 2229 2230 # Store this variable's dataset 2231 datasets[var_name] = var_ds 2232 if verbose: 2233 print(f" Created dataset for {var_name} with levels") 2234 except Exception as e: 2235 if verbose: 2236 print(f" Error combining levels for {var_name}: {e}") 2237 else: 2238 # Single level or no vertical levels 2239 if verbose: 2240 print(f" Variable {var_name} is a single level or has no vertical dimension") 2241 try: 2242 # Parse the index and get dimensions 2243 file_index, non_geo_dims, attrs, coord_attrs = parse_grib_index(var_df, {}) 2244 frames, cube, extra_geo = make_variables(file_index, filename, non_geo_dims, allow_uneven_dims=True) 2245 2246 if frames is not None and len(frames) == 1: 2247 # Create dataset with this variable 2248 var_ds = xr.Dataset() 2249 da = build_da_without_coords(frames[0], cube, filename, attrs) 2250 var_ds[da.name] = da 2251 2252 # Assign coords 2253 var_ds = assign_xr_meta(var_ds, frames, cube, non_geo_dims, extra_geo, coord_attrs) 2254 # TODO: is the below code all now in assign_xr_meta? was there instances where refDate and leadTime were not coords? 2255 # var_ds = var_ds.assign_coords(coords_from_cube(cube)) 2256 # if extra_geo: 2257 # var_ds = var_ds.assign_coords(extra_geo) 2258 # if 'refDate' in var_ds.coords and 'leadTime' in var_ds.coords: 2259 # var_ds = var_ds.assign_coords(dict(validDate=var_ds.coords['refDate']+var_ds.coords['leadTime'])) 2260 2261 # Store this variable's dataset 2262 datasets[var_name] = var_ds 2263 if verbose: 2264 print(f" Created dataset for {var_name}") 2265 elif frames is not None and len(frames) > 1: 2266 if verbose: 2267 print(f" Variable {var_name} has multiple frames, possibly different parameters") 2268 # Just use the first frame for now (simplified approach) 2269 var_ds = xr.Dataset() 2270 da = build_da_without_coords(frames[0], cube, filename, attrs) 2271 var_ds[da.name] = da 2272 2273 # Assign coords 2274 var_ds = assign_xr_meta(var_ds, frames, cube, non_geo_dims, extra_geo, coord_attrs) 2275 # TODO: is the below code all now in assign_xr_meta? was there instances where refDate and leadTime were not coords? 2276 # var_ds = var_ds.assign_coords(coords_from_cube(cube)) 2277 # if extra_geo: 2278 # var_ds = var_ds.assign_coords(extra_geo) 2279 # if 'refDate' in var_ds.coords and 'leadTime' in var_ds.coords: 2280 # var_ds = var_ds.assign_coords(dict(validDate=var_ds.coords['refDate']+var_ds.coords['leadTime'])) 2281 2282 datasets[var_name] = var_ds 2283 if verbose: 2284 print(f" Created dataset with first frame for {var_name}") 2285 except Exception as e: 2286 if verbose: 2287 print(f" Error processing variable {var_name}: {e}") 2288 2289 # Attempt to merge all the variable datasets 2290 if datasets: 2291 try: 2292 if verbose: 2293 print(f"\nMerging {len(datasets)} datasets...") 2294 # Get the list of datasets to merge 2295 ds_list = list(datasets.values()) 2296 2297 # Try merging them all at once 2298 try: 2299 combined_ds = xr.merge(ds_list) 2300 if verbose: 2301 print(f"Successfully merged all datasets into one.") 2302 print(f"Final dataset has variables: {list(combined_ds.data_vars)}") 2303 print(f"==== END VERBOSE DEBUG INFO ====\n") 2304 return [combined_ds] 2305 except Exception as merge_error: 2306 if verbose: 2307 print(f"Error merging all datasets: {merge_error}") 2308 return ds_list 2309 except Exception as e: 2310 if verbose: 2311 print(f"Error in final merge process: {e}") 2312 print(f"==== END VERBOSE DEBUG INFO ====\n") 2313 return None 2314 else: 2315 if verbose: 2316 print(f"No datasets were created for any variables") 2317 print(f"==== END VERBOSE DEBUG INFO ====\n") 2318 return None 2319 2320 except Exception as e: 2321 # If there's an error, log it and return None 2322 if verbose: 2323 print(f"Error creating dataset: {e}") 2324 import traceback 2325 traceback.print_exc() 2326 print(f"==== END VERBOSE DEBUG INFO ====\n") 2327 return None 2328 2329 2330# Only register the DataTree accessor if DataTree is supported 2331if _HAS_DATATREE: 2332 @xr.register_datatree_accessor("grib2io") 2333 class Grib2ioDataTree: 2334 """ 2335 DataTree accessor for GRIB2 files. 2336 2337 This accessor provides methods for working with GRIB2 data organized 2338 in a hierarchical tree structure. 2339 """ 2340 2341 def __init__(self, datatree_obj): 2342 self._obj = datatree_obj 2343 2344 def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"): 2345 """ 2346 Write all datasets in the DataTree to a GRIB2 file. 2347 2348 Parameters 2349 ---------- 2350 filename : str 2351 Name of the GRIB2 file to write to. 2352 mode : {"x", "w", "a"}, optional 2353 Persistence mode, default is "x" (create, fail if exists) 2354 """ 2355 # Start with the specified mode 2356 current_mode = mode 2357 2358 # Function to recursively process the tree 2359 def process_tree(node): 2360 nonlocal current_mode 2361 2362 # If this is a Dataset node with data variables 2363 if node.ds is not None and node.ds.data_vars: 2364 # Write dataset to GRIB2 file 2365 node.ds.grib2io.to_grib2(filename, mode=current_mode) 2366 # Switch to append mode after first write 2367 current_mode = "a" 2368 2369 # Process children 2370 for child_name, child_node in node.children.items(): 2371 process_tree(child_node) 2372 2373 # Start processing from the root 2374 process_tree(self._obj) 2375 2376 def griddef(self): 2377 """ 2378 Get the grid definition from the first dataset in the tree that has one. 2379 2380 Returns 2381 ------- 2382 grib2io.Grib2GridDef 2383 Grid definition object 2384 """ 2385 # Function to find first dataset with GRIB2IO_section3 2386 def find_griddef(node): 2387 if node.ds is not None and node.ds.data_vars: 2388 for var_name in node.ds.data_vars: 2389 if 'GRIB2IO_section3' in node.ds[var_name].attrs: 2390 return Grib2GridDef.from_section3(node.ds[var_name].attrs['GRIB2IO_section3']) 2391 2392 # Check children 2393 for child_name, child_node in node.children.items(): 2394 griddef = find_griddef(child_node) 2395 if griddef is not None: 2396 return griddef 2397 2398 return None 2399 2400 return find_griddef(self._obj) 2401 2402 def interp(self, method, grid_def_out, method_options=None, num_threads=1): 2403 """ 2404 Interpolate all datasets in the tree to a new grid. 2405 2406 Parameters 2407 ---------- 2408 method : str or int 2409 Interpolation method to use 2410 grid_def_out : grib2io.Grib2GridDef 2411 Target grid definition 2412 method_options : list, optional 2413 Options for interpolation method 2414 num_threads : int, optional 2415 Number of threads to use for interpolation 2416 2417 Returns 2418 ------- 2419 xarray.DataTree 2420 New DataTree with interpolated data 2421 """ 2422 new_tree = xr.DataTree() 2423 2424 # Function to recursively process the tree 2425 def process_tree(node, new_parent): 2426 # If this is a Dataset node with data variables 2427 if node.ds is not None and node.ds.data_vars: 2428 # Interpolate dataset 2429 interp_ds = node.ds.grib2io.interp(method, grid_def_out, 2430 method_options=method_options, 2431 num_threads=num_threads) 2432 2433 # Add to new tree at the same path 2434 if node == self._obj: # Root node 2435 new_parent.ds = interp_ds 2436 else: 2437 new_parent.ds = interp_ds 2438 2439 # Process children 2440 for child_name, child_node in node.children.items(): 2441 # Create same child in new tree 2442 new_child = xr.DataTree() 2443 new_parent[child_name] = new_child 2444 process_tree(child_node, new_child) 2445 2446 # Start processing from the root 2447 process_tree(self._obj, new_tree) 2448 2449 return new_tree 2450 2451 def subset(self, lats, lons): 2452 """ 2453 Subset all datasets in the tree to a region. 2454 2455 Parameters 2456 ---------- 2457 lats : list or tuple 2458 Latitude bounds [min_lat, max_lat] 2459 lons : list or tuple 2460 Longitude bounds [min_lon, max_lon] 2461 2462 Returns 2463 ------- 2464 xarray.DataTree 2465 New DataTree with subset data 2466 """ 2467 new_tree = xr.DataTree() 2468 2469 # Function to recursively process the tree 2470 def process_tree(node, new_parent): 2471 # If this is a Dataset node with data variables 2472 if node.ds is not None and node.ds.data_vars: 2473 # Subset dataset 2474 subset_ds = node.ds.grib2io.subset(lats, lons) 2475 2476 # Add to new tree at the same path 2477 if node == self._obj: # Root node 2478 new_parent.ds = subset_ds 2479 else: 2480 new_parent.ds = subset_ds 2481 2482 # Process children 2483 for child_name, child_node in node.children.items(): 2484 # Create same child in new tree 2485 new_child = xr.DataTree() 2486 new_parent[child_name] = new_child 2487 process_tree(child_node, new_child) 2488 2489 # Start processing from the root 2490 process_tree(self._obj, new_tree) 2491 2492 return new_tree
Available non-geographic coordinate names.
Available non-geographic dimension names.
Lookup table to define surface types that should be parsed as vertical coordinates
when data_model="nws-viz".
141def parse_data_model(ds, data_model): 142 """ 143 Normalize a GRIB2-derived Dataset to a target data model (currently ``"nws-viz"``). 144 145 When ``data_model == "nws-viz"``, this function converts coordinate and 146 variable names to snake_case, derives CF-like metadata, promotes select 147 GRIB-derived quantities to coordinates, optionally swaps dimensions, and 148 standardizes units/attributes. If ``data_model`` is anything else, the 149 input dataset is returned unchanged. 150 151 Parameters 152 ---------- 153 ds : xarray.Dataset 154 GRIB2-derived dataset whose variables and attributes follow the 155 conventions emitted by ``grib2io``. Expected to contain GRIB-related 156 attributes such as ``typeOfFirstFixedSurface``, 157 ``typeOfSecondFixedSurface``, and (for probabilistic variables) 158 ``typeOfProbability``. 159 data_model : str 160 Target data model name. Only the value ``"nws-viz"`` triggers 161 transformations. 162 163 Returns 164 ------- 165 xarray.Dataset 166 A new dataset with: 167 * Selected coordinates renamed: 168 ``refDate -> forecast_reference_time``, 169 ``leadTime -> lead_time``, 170 ``validDate -> time``, 171 ``percentileValue -> percentile``, 172 ``thresholdLowerLimit -> threshold_lower_limit``, 173 ``thresholdUpperLimit -> threshold_upper_limit``. 174 * Vertical coordinates derived from 175 ``valueOfFirstFixedSurface`` / ``valueOfSecondFixedSurface`` and their 176 corresponding ``typeOf*FixedSurface`` definitions. New coordinate 177 names are generated from the surface definition (lowercased, spaces 178 to underscores, punctuation removed). If the name already exists, a 179 ``"_2"`` suffix is appended. 180 * Possible dimension swaps: 181 ``level -> <derived_vertical_coord>`` when present; and for 182 probabilistic variables, ``threshold -> threshold_lower_limit`` or 183 ``threshold -> threshold_upper_limit`` when 184 ``typeOfProbability`` indicates the appropriate semantics. 185 * Variable names lowercased; dataset- and variable-level attributes 186 converted to snake_case (except GRIB section attributes which are 187 normalized to ``grib...``). 188 * CF-adjacent metadata populated: ``standard_name`` and 189 ``cell_methods`` are set via the shortname→CF lookup table. 190 * Percent units normalized from ``"%"`` to ``"percent"`` on coordinates. 191 * For precipitation type (``PTYPE``) thresholds, numeric codes are 192 decoded to strings (GRIB2 Table 4.201) in relevant attrs/coords. 193 194 Notes 195 ----- 196 - Precipitation type decoding uses GRIB2 Table 4.201 via 197 ``tables.get_value_from_table(code, "4.201")`` and returns a NumPy 198 array with ``np.dtypes.StringDType``. 199 - CF-related lookups are performed using 200 ``tables.get_table("shortname_to_cf")``. 201 - Vertical coordinate surface names are validated against 202 ``VERTICAL_COORDINATE_SURFACES`` before promotion to coordinates. 203 204 Warnings 205 -------- 206 This function assumes the presence of certain GRIB-derived attributes on the 207 first data variable (e.g., ``typeOfFirstFixedSurface``, 208 ``typeOfSecondFixedSurface``, and possibly ``typeOfProbability``). 209 If these are absent or malformed, errors (e.g., ``KeyError``) may occur. 210 211 Examples 212 -------- 213 >>> ds2 = parse_data_model(ds, "nws-viz") 214 >>> list(ds2.coords) 215 ['forecast_reference_time', 'lead_time', 'time', 'percentile', ...] 216 """ 217 218 def _decode_ptype(values): 219 """ 220 Decode precipitation type values into human-readable strings. 221 222 Uses GRIB2 Table 4.201 to map numeric codes to precipitation type descriptions. 223 224 Parameters 225 ---------- 226 values : array_like 227 Array of numeric precipitation type codes (e.g., integers or floats). 228 Each value corresponds to a GRIB2 Table 4.201 precipitation type code. 229 230 Returns 231 ------- 232 numpy.ndarray 233 Array of decoded precipitation type strings with 234 NumPy’s flexible string data type (`np.dtypes.StringDType`). 235 """ 236 results = [] 237 for val in values: 238 # Convert each numeric code to string and look up in Table 4.201 239 results.append(str(tables.get_value_from_table(str(int(val)), '4.201'))) 240 241 # Return array of strings using numpy's string data type 242 return np.array(results, dtype=np.dtypes.StringDType) 243 244 # convert coordinates and attributes to CF if requested 245 if data_model == 'nws-viz': 246 247 # define regex to convert to snake case 248 pattern = re.compile(r'(?<!^)(?=[A-Z])') 249 250 # check for coordinates and rename 251 for coord in ds.coords: 252 if coord == 'refDate': 253 ds = ds.rename({'refDate': 'forecast_reference_time'}) 254 255 elif coord == 'leadTime': 256 ds = ds.rename({'leadTime': 'lead_time'}) 257 258 elif coord == 'validDate': 259 ds = ds.rename({'validDate': 'time'}) 260 261 elif coord == 'percentileValue': 262 ds = ds.rename({'percentileValue': 'percentile'}) 263 264 elif coord == 'thresholdLowerLimit': 265 ds = ds.rename({'thresholdLowerLimit': 'threshold_lower_limit'}) 266 ds['threshold_lower_limit'].attrs['long_name'] = 'Threshold Lower Limit' 267 ds['threshold_lower_limit'].attrs['units'] = ds[list(ds.data_vars.keys())[0]].attrs['units'] 268 269 if 'PTYPE' in ds.data_vars: 270 ds['threshold_lower_limit'] = xr.apply_ufunc(_decode_ptype, ds['threshold_lower_limit']) 271 272 # check if thresholdLowerLimit should be a dimension coordinate 273 if 'threshold' in ds.dims: 274 var_key = list(ds.data_vars.keys())[0] 275 prob_types = [ 276 'Probability of event below lower limit', 277 'Probability of event above lower limit', 278 'Probability of event equal to lower limit', 279 'Probability of event between upper and lower limits (the range includes lower limit but not the upper limit)' 280 ] 281 if ds[var_key].attrs['typeOfProbability'] in prob_types: 282 ds = ds.swap_dims({'threshold': 'threshold_lower_limit'}) 283 284 elif coord == 'thresholdUpperLimit': 285 ds = ds.rename({'thresholdUpperLimit': 'threshold_upper_limit'}) 286 ds['threshold_upper_limit'].attrs['long_name'] = 'Threshold Upper Limit' 287 ds['threshold_upper_limit'].attrs['units'] = ds[list(ds.data_vars.keys())[0]].attrs['units'] 288 289 if 'PTYPE' in ds.data_vars: 290 ds['threshold_upper_limit'] = xr.apply_ufunc(_decode_ptype, ds['threshold_upper_limit']) 291 292 if 'threshold' in ds.dims: 293 var_key = list(ds.data_vars.keys())[0] 294 prob_types = [ 295 'Probability of event below upper limit', 296 'Probability of event above upper limit' 297 ] 298 if ds[var_key].attrs['typeOfProbability'] in prob_types: 299 ds = ds.swap_dims({'threshold': 'threshold_upper_limit'}) 300 301 # If the dataset has valueOfFirstFixedSurface as a coordinate 302 elif coord == 'valueOfFirstFixedSurface': 303 # Get the valueOfFirstFixedSurface coordinate 304 da = ds.valueOfFirstFixedSurface 305 306 # Get the definition and units from typeOfFirstFixedSurface 307 var_key = list(ds.data_vars.keys())[0] 308 definition, units = ds[var_key].attrs['typeOfFirstFixedSurface'] 309 310 if definition in VERTICAL_COORDINATE_SURFACES: 311 # Convert definition to lowercase and replace spaces with underscores 312 key = definition.lower().replace(' ', '_') 313 314 # remove special characters 315 key = re.sub(r'[^a-z0-9_]', '', key) 316 317 # Add units and grib_name attributes 318 da.attrs['units'] = units 319 da.attrs['grib_name'] = ['valueOfFirstFixedSurface', 'typeOfFirstFixedSurface'] 320 321 # Assign the coordinate with the new key name 322 ds = ds.assign_coords({key: da}) 323 324 # If valueOfFirstFixedSurface is a dimension, swap it with the new key 325 if 'level' in ds.dims: 326 ds = ds.swap_dims({"level": key}) 327 328 # Remove the original coordinates 329 del ds['valueOfFirstFixedSurface'] 330 331 # If the dataset has valueOfSecondFixedSurface as a coordinate 332 elif coord == 'valueOfSecondFixedSurface': 333 # Get the valueOfSecondFixedSurface coordinate 334 da = ds.valueOfSecondFixedSurface 335 336 # Get the definition and units from typeOfSecondFixedSurface 337 var_key = list(ds.data_vars.keys())[0] 338 definition, units = ds[var_key].attrs['typeOfSecondFixedSurface'] 339 340 if definition in VERTICAL_COORDINATE_SURFACES: 341 # Convert definition to lowercase and replace spaces with underscores 342 key = definition.lower().replace(' ', '_') 343 344 # remove special characters 345 key = re.sub(r'[^a-z0-9_]', '', key) 346 347 # check if key is already in coords 348 if key in ds.coords: 349 key = key + '_2' 350 351 # Add units and grib_name attributes 352 da.attrs['units'] = units 353 da.attrs['grib_name'] = ['valueOfSecondFixedSurface', 'typeOfSecondFixedSurface'] 354 355 # Assign the coordinate with the new key name 356 ds = ds.assign_coords({key: da}) 357 358 # Remove the original coordinates 359 del ds['valueOfSecondFixedSurface'] 360 else: 361 # change coord name to snake case 362 new_coord_name = pattern.sub('_', coord).lower() 363 ds = ds.rename({coord: new_coord_name}) 364 365 # convert all attributes and variable names to snake case 366 for var in ds.data_vars: 367 da = ds[var] 368 record = tables.get_table('shortname_to_cf').get(da.name) 369 da.attrs['standard_name'] = 'unknown' if record is None else record['cf_standard_name'] 370 da.attrs['cell_methods'] = 'unknown' if record is None else record['cf_cell_methods'] 371 372 ds[var] = da 373 374 # rename variable 375 new_var_name = var.lower() 376 ds = ds.rename({var: new_var_name}) 377 378 # remove attr for typeOfFirstFixedSurface (applied as coordinate above) 379 if 'typeOfFirstFixedSurface' in ds[new_var_name].attrs: 380 definition, units = ds[new_var_name].attrs['typeOfFirstFixedSurface'] 381 ds[new_var_name].attrs['typeOfFirstFixedSurface'] = f'{definition} ({units})' 382 383 if 'typeOfSecondFixedSurface' in ds[new_var_name].attrs: 384 definition, units = ds[new_var_name].attrs['typeOfSecondFixedSurface'] 385 ds[new_var_name].attrs['typeOfSecondFixedSurface'] = f'{definition} ({units})' 386 387 ds[new_var_name].attrs.pop('percentileValue', None) 388 389 if 'threshold_lower_limit' in ds.coords: 390 ds[new_var_name].attrs.pop('thresholdLowerLimit', None) 391 392 if 'threshold_upper_limit' in ds.coords: 393 ds[new_var_name].attrs.pop('thresholdUpperLimit', None) 394 395 for attr in list(ds[new_var_name].attrs.keys()): 396 # skip grib section attrs 397 if 'GRIB2IO_section' in attr: 398 # replace GRIB2IO with grib in attr 399 new_attr_name = attr.replace('GRIB2IO', 'grib') 400 else: 401 # change attr name to snake case 402 new_attr_name = pattern.sub('_', attr).lower() 403 404 # update new attr name for specific CF names 405 if new_attr_name == 'full_name': 406 new_attr_name = 'long_name' 407 408 # change % to percent 409 if attr == 'units' and ds[new_var_name].attrs[attr] == '%': 410 ds[new_var_name].attrs[attr] = 'percent' 411 412 if new_var_name == 'ptype' and 'threshold' in new_attr_name: 413 value = ds[new_var_name].attrs.pop(attr) 414 ds[new_var_name].attrs[attr] = _decode_ptype(value) 415 else: 416 # change attr name in attrs 417 ds[new_var_name].attrs[new_attr_name] = ds[new_var_name].attrs.pop(attr) 418 419 420 # change dataset attrs to snake case 421 for attr in list(ds.attrs.keys()): 422 # change attr name to snake case 423 new_attr_name = pattern.sub('_', attr).lower() 424 425 # change attr name in attrs 426 ds.attrs[new_attr_name] = ds.attrs.pop(attr) 427 428 # change % to percent 429 for coord in ds.coords: 430 if 'units' in ds[coord].attrs and ds[coord].attrs['units'] == '%': 431 ds[coord].attrs['units'] = 'percent' 432 433 return ds
Normalize a GRIB2-derived Dataset to a target data model (currently "nws-viz").
When data_model == "nws-viz", this function converts coordinate and
variable names to snake_case, derives CF-like metadata, promotes select
GRIB-derived quantities to coordinates, optionally swaps dimensions, and
standardizes units/attributes. If data_model is anything else, the
input dataset is returned unchanged.
Parameters
- ds (xarray.Dataset):
GRIB2-derived dataset whose variables and attributes follow the
conventions emitted by
grib2io. Expected to contain GRIB-related attributes such astypeOfFirstFixedSurface,typeOfSecondFixedSurface, and (for probabilistic variables)typeOfProbability. - data_model (str):
Target data model name. Only the value
"nws-viz"triggers transformations.
Returns
- xarray.Dataset: A new dataset with:
- Selected coordinates renamed:
refDate -> forecast_reference_time,leadTime -> lead_time,validDate -> time,percentileValue -> percentile,thresholdLowerLimit -> threshold_lower_limit,thresholdUpperLimit -> threshold_upper_limit. - Vertical coordinates derived from
valueOfFirstFixedSurface/valueOfSecondFixedSurfaceand their correspondingtypeOf*FixedSurfacedefinitions. New coordinate names are generated from the surface definition (lowercased, spaces to underscores, punctuation removed). If the name already exists, a"_2"suffix is appended. - Possible dimension swaps:
level -> <derived_vertical_coord>when present; and for probabilistic variables,threshold -> threshold_lower_limitorthreshold -> threshold_upper_limitwhentypeOfProbabilityindicates the appropriate semantics. - Variable names lowercased; dataset- and variable-level attributes
converted to snake_case (except GRIB section attributes which are
normalized to
grib...). - CF-adjacent metadata populated:
standard_nameandcell_methodsare set via the shortname→CF lookup table. - Percent units normalized from
"%"to"percent"on coordinates. - For precipitation type (
PTYPE) thresholds, numeric codes are decoded to strings (GRIB2 Table 4.201) in relevant attrs/coords.
- Selected coordinates renamed:
Notes
- Precipitation type decoding uses GRIB2 Table 4.201 via
tables.get_value_from_table(code, "4.201")and returns a NumPy array withnp.dtypes.StringDType. - CF-related lookups are performed using
tables.get_table("shortname_to_cf"). - Vertical coordinate surface names are validated against
VERTICAL_COORDINATE_SURFACESbefore promotion to coordinates.
Warnings
This function assumes the presence of certain GRIB-derived attributes on the
first data variable (e.g., typeOfFirstFixedSurface,
typeOfSecondFixedSurface, and possibly typeOfProbability).
If these are absent or malformed, errors (e.g., KeyError) may occur.
Examples
>>> ds2 = parse_data_model(ds, "nws-viz")
>>> list(ds2.coords)
['forecast_reference_time', 'lead_time', 'time', 'percentile', ...]
436class GribBackendEntrypoint(BackendEntrypoint): 437 """ 438 xarray backend engine entrypoint for opening and decoding grib2 files. 439 440 .. warning:: 441 442 This backend is experimental and the API/behavior may change without 443 backward compatibility. 444 """ 445 446 def open_dataset( 447 self, 448 filename, 449 *, 450 drop_variables=None, 451 save_index=True, 452 filters: typing.Mapping[str, typing.Any] = dict(), 453 data_model=None 454 ): 455 """ 456 Read and parse metadata from grib file. 457 458 Parameters 459 ---------- 460 filename 461 GRIB2 file to be opened. 462 filters 463 Filter GRIB2 messages to single hypercube. Dict keys can be any 464 GRIB2 metadata attribute name. 465 data_model 466 Parse GRIB metadata following a defined data model comvention. 467 468 Returns 469 ------- 470 open_dataset 471 Xarray dataset of grib2 messages. 472 """ 473 with grib2io.open(filename, save_index=save_index, _xarray_backend=True) as f: 474 file_index = pd.DataFrame(f._index) 475 file_index = file_index.assign(msg=msgs_from_index(f._index)) 476 477 # parse grib2io _index to dataframe and acquire non-geo possible dims 478 # (scalar coord when not dim due to squeeze) parse_grib_index applies 479 # filters to index and expands metadata based on product definition 480 # template number 481 file_index, dim_coords, attrs, coord_attrs = parse_grib_index(file_index, filters) 482 483 # Divide up records by variable 484 frames, cube, extra_geo = make_variables(file_index, filename, dim_coords) # have this return var_attrs 485 486 # return empty dataset if no data 487 if frames is None: 488 return xr.Dataset() 489 490 # create dataframe and add datarrays without any coords 491 ds = xr.Dataset() 492 for var_df in frames: 493 da = build_da_without_coords(var_df, cube, filename, attrs) 494 ds[da.name] = da 495 496 # add coords and dataset meta 497 ds = assign_xr_meta(ds, frames, cube, dim_coords, extra_geo, coord_attrs) 498 499 if data_model is not None: 500 ds = parse_data_model(ds, data_model) 501 502 # assign attributes 503 ds.attrs['engine'] = 'grib2io' 504 505 return ds 506 507 def open_datatree( 508 self, 509 filename, 510 *, 511 drop_variables=None, 512 save_index=True, 513 filters: typing.Mapping[str, typing.Any] = None, 514 stack_vertical: bool = False, 515 ): 516 """ 517 Open a GRIB2 file as an xarray DataTree. 518 519 Parameters 520 ---------- 521 filename : str 522 Path to the GRIB2 file. 523 drop_variables : list, optional 524 List of variables to exclude. 525 filters : dict, optional 526 Filter criteria for GRIB2 messages. 527 stack_vertical : bool, optional 528 If True, organize the tree with vertical layers stacked in a single dataset. 529 530 Returns 531 ------- 532 xarray.DataTree 533 A hierarchical DataTree representation of the GRIB2 data. 534 """ 535 if not _HAS_DATATREE: 536 raise ImportError("xarray version does not support DataTree functionality.") 537 538 if filters is None: 539 filters = {} 540 541 # Open the file without any filters first to get all messages 542 with grib2io.open(filename, save_index=save_index, _xarray_backend=True) as f: 543 file_index = pd.DataFrame(f._index) 544 file_index = file_index.assign(msg=msgs_from_index(f._index)) 545 546 # Build tree structure from GRIB messages with specified options 547 tree = build_datatree_from_grib(filename, file_index, filters, stack_vertical=stack_vertical) 548 549 # Put warning here so it is the last message from likely other Xarray warnings. 550 warnings.warn( 551 "grib2io’s xarray backend DataTree support is experimental. " 552 "The DataTree structure or attributes may change in future releases.", 553 UserWarning, 554 stacklevel=2, 555 ) 556 557 return tree
xarray backend engine entrypoint for opening and decoding grib2 files.
This backend is experimental and the API/behavior may change without backward compatibility.
446 def open_dataset( 447 self, 448 filename, 449 *, 450 drop_variables=None, 451 save_index=True, 452 filters: typing.Mapping[str, typing.Any] = dict(), 453 data_model=None 454 ): 455 """ 456 Read and parse metadata from grib file. 457 458 Parameters 459 ---------- 460 filename 461 GRIB2 file to be opened. 462 filters 463 Filter GRIB2 messages to single hypercube. Dict keys can be any 464 GRIB2 metadata attribute name. 465 data_model 466 Parse GRIB metadata following a defined data model comvention. 467 468 Returns 469 ------- 470 open_dataset 471 Xarray dataset of grib2 messages. 472 """ 473 with grib2io.open(filename, save_index=save_index, _xarray_backend=True) as f: 474 file_index = pd.DataFrame(f._index) 475 file_index = file_index.assign(msg=msgs_from_index(f._index)) 476 477 # parse grib2io _index to dataframe and acquire non-geo possible dims 478 # (scalar coord when not dim due to squeeze) parse_grib_index applies 479 # filters to index and expands metadata based on product definition 480 # template number 481 file_index, dim_coords, attrs, coord_attrs = parse_grib_index(file_index, filters) 482 483 # Divide up records by variable 484 frames, cube, extra_geo = make_variables(file_index, filename, dim_coords) # have this return var_attrs 485 486 # return empty dataset if no data 487 if frames is None: 488 return xr.Dataset() 489 490 # create dataframe and add datarrays without any coords 491 ds = xr.Dataset() 492 for var_df in frames: 493 da = build_da_without_coords(var_df, cube, filename, attrs) 494 ds[da.name] = da 495 496 # add coords and dataset meta 497 ds = assign_xr_meta(ds, frames, cube, dim_coords, extra_geo, coord_attrs) 498 499 if data_model is not None: 500 ds = parse_data_model(ds, data_model) 501 502 # assign attributes 503 ds.attrs['engine'] = 'grib2io' 504 505 return ds
Read and parse metadata from grib file.
Parameters
- filename: GRIB2 file to be opened.
- filters: Filter GRIB2 messages to single hypercube. Dict keys can be any GRIB2 metadata attribute name.
- data_model: Parse GRIB metadata following a defined data model comvention.
Returns
- open_dataset: Xarray dataset of grib2 messages.
507 def open_datatree( 508 self, 509 filename, 510 *, 511 drop_variables=None, 512 save_index=True, 513 filters: typing.Mapping[str, typing.Any] = None, 514 stack_vertical: bool = False, 515 ): 516 """ 517 Open a GRIB2 file as an xarray DataTree. 518 519 Parameters 520 ---------- 521 filename : str 522 Path to the GRIB2 file. 523 drop_variables : list, optional 524 List of variables to exclude. 525 filters : dict, optional 526 Filter criteria for GRIB2 messages. 527 stack_vertical : bool, optional 528 If True, organize the tree with vertical layers stacked in a single dataset. 529 530 Returns 531 ------- 532 xarray.DataTree 533 A hierarchical DataTree representation of the GRIB2 data. 534 """ 535 if not _HAS_DATATREE: 536 raise ImportError("xarray version does not support DataTree functionality.") 537 538 if filters is None: 539 filters = {} 540 541 # Open the file without any filters first to get all messages 542 with grib2io.open(filename, save_index=save_index, _xarray_backend=True) as f: 543 file_index = pd.DataFrame(f._index) 544 file_index = file_index.assign(msg=msgs_from_index(f._index)) 545 546 # Build tree structure from GRIB messages with specified options 547 tree = build_datatree_from_grib(filename, file_index, filters, stack_vertical=stack_vertical) 548 549 # Put warning here so it is the last message from likely other Xarray warnings. 550 warnings.warn( 551 "grib2io’s xarray backend DataTree support is experimental. " 552 "The DataTree structure or attributes may change in future releases.", 553 UserWarning, 554 stacklevel=2, 555 ) 556 557 return tree
Open a GRIB2 file as an xarray DataTree.
Parameters
- filename (str): Path to the GRIB2 file.
- drop_variables (list, optional): List of variables to exclude.
- filters (dict, optional): Filter criteria for GRIB2 messages.
- stack_vertical (bool, optional): If True, organize the tree with vertical layers stacked in a single dataset.
Returns
- xarray.DataTree: A hierarchical DataTree representation of the GRIB2 data.
560class GribBackendArray(BackendArray): 561 562 def __init__(self, array, lock): 563 self.array = array 564 self.shape = array.shape 565 self.dtype = np.dtype(array.dtype) 566 self.lock = lock 567 568 def __getitem__(self, key: xr.core.indexing.ExplicitIndexer) -> np.typing.ArrayLike: 569 return xr.core.indexing.explicit_indexing_adapter( 570 key, 571 self.shape, 572 indexing.IndexingSupport.BASIC, 573 self._raw_getitem, 574 ) 575 576 def _raw_getitem(self, key: tuple): 577 """Implement thread safe access to data on disk.""" 578 with self.lock: 579 return self.array[key]
Mixin class that extends a class that defines a shape property to
one that also defines ndim, size and __len__.
582def exclusive_slice_to_inclusive(item: slice): 583 """ 584 Convert a slice with exclusive stop to an inclusive slice. 585 586 If the slice has a step, the stop is reduced by the step, so that both 587 interpretations would yield the same result. 588 589 The means that [start, stop) is converted to [start, stop - step]. 590 591 Parameters 592 ---------- 593 item 594 The slice to convert. 595 596 Returns 597 ------- 598 slice 599 The converted slice. 600 """ 601 # return the None slice 602 if item.start is None and item.stop is None and item.step is None: 603 return item 604 if not isinstance(item, slice): 605 raise ValueError(f'item must be a slice; it was of type {type(item)}') 606 # if step is None, it's one 607 step = 1 if item.step is None else item.step 608 if item.stop < item.start or step < 1: 609 raise ValueError(f'slice {item} not accounted for') 610 # handle case where slice has one item 611 if abs(item.stop - item.start) == step: 612 return [item.start] 613 # other cases require reducing the stop by the step 614 s = slice(item.start, item.stop - step, step) 615 return s
Convert a slice with exclusive stop to an inclusive slice.
If the slice has a step, the stop is reduced by the step, so that both interpretations would yield the same result.
The means that [start, stop) is converted to [start, stop - step].
Parameters
- item: The slice to convert.
Returns
- slice: The converted slice.
659def array_safe_eq(a, b) -> bool: 660 """Check if a and b are equal, even if they are numpy arrays.""" 661 if a is b: 662 return True 663 if hasattr(a, 'equals'): 664 return a.equals(b) 665 if hasattr(a, 'all') and hasattr(b, 'all'): 666 return a.shape == b.shape and (a == b).all() 667 if hasattr(a, 'all') or hasattr(b, 'all'): 668 return False 669 try: 670 return a == b 671 except TypeError: 672 return NotImplementedError
Check if a and b are equal, even if they are numpy arrays.
675def dc_eq(dc1, dc2) -> bool: 676 """Check if two dataclasses which hold numpy arrays are equal.""" 677 if dc1 is dc2: 678 return True 679 if dc1.__class__ is not dc2.__class__: 680 return NotImplementedError 681 t1 = astuple(dc1) 682 t2 = astuple(dc2) 683 return all(array_safe_eq(a1, a2) for a1, a2 in zip(t1, t2))
Check if two dataclasses which hold numpy arrays are equal.
686def coords_from_cube(cube) -> typing.Dict[str, xr.Variable]: 687 keys = list(cube.keys()) 688 keys.remove('x') 689 keys.remove('y') 690 coords = dict() 691 for k in keys: 692 if k is not None: 693 if len(cube[k]) > 1: 694 coords[k] = xr.Variable(dims=k, data=cube[k], attrs=dict(grib_name=k)) 695 elif len(cube[k]) == 1: 696 coords[k] = xr.Variable(dims=tuple(), data=cube[k][0], attrs=dict(grib_name=k)) 697 return coords
700@dataclass 701class OnDiskArray: 702 file_name: str 703 index: pd.DataFrame = field(repr=False) 704 cube: dict = field(repr=False) 705 shape: typing.Tuple[int, ...] = field(init=False) 706 ndim: int = field(init=False) 707 geo_ndim: int = field(init=False) 708 dtype = 'float32' 709 710 def __post_init__(self): 711 # multiple grids not allowed so can just use first 712 geo_shape = (self.index.iloc[0].ny, self.index.iloc[0].nx) 713 714 self.geo_shape = geo_shape 715 self.geo_ndim = len(geo_shape) 716 717 if len(self.index) == 1: 718 self.shape = geo_shape 719 else: 720 if self.index.index.nlevels == 1: 721 self.shape = tuple([len(self.index.index)]) + geo_shape 722 else: 723 self.shape = tuple([len(i) for i in self.index.index.levels]) + geo_shape 724 self.ndim = len(self.shape) 725 726 cols = ['msg', 'sectionOffset'] 727 self.index = self.index[cols] 728 729 def __getitem__(self, item) -> np.array: 730 # dimensions not in index are internal to tdlpack records; 2 dims for 731 # grids; 1 dim for stations 732 733 index_slicer = item[:-self.geo_ndim] 734 # maintain all multindex levels 735 index_slicer = tuple([[i] if isinstance(i, int) else i for i in index_slicer]) 736 737 # pandas loc slicing is inclusive, therefore convert slices into 738 # explicit lists 739 index_slicer_inclusive = tuple([exclusive_slice_to_inclusive( 740 i) if isinstance(i, slice) else i for i in index_slicer]) 741 742 # get records selected by item in new index dataframe 743 if len(index_slicer_inclusive) == 1: 744 index = self.index.loc[index_slicer_inclusive] 745 elif len(index_slicer_inclusive) > 1: 746 index = self.index.loc[index_slicer_inclusive, :] 747 else: 748 index = self.index 749 index = index.set_index(index.index) 750 751 # set miloc to new relative locations in sub array 752 index['miloc'] = list( 753 zip(*[index.index.unique(level=dim).get_indexer(index.index.get_level_values(dim)) for dim in index.index.names])) 754 755 if len(index_slicer_inclusive) == 1: 756 array_field_shape = tuple([len(index.index)]) + self.geo_shape 757 elif len(index_slicer_inclusive) > 1: 758 array_field_shape = index.index.levshape + self.geo_shape 759 else: 760 array_field_shape = self.geo_shape 761 762 array_field = np.full(array_field_shape, fill_value=np.nan, dtype="float32") 763 764 with open(self.file_name, mode='rb') as filehandle: 765 for key, row in index.iterrows(): 766 767 bitmap_offset = None if pd.isna(row['sectionOffset'][6]) else int(row['sectionOffset'][6]) 768 values = _data(filehandle, row.msg, bitmap_offset, row['sectionOffset'][7]) 769 770 if len(index_slicer_inclusive) >= 1: 771 array_field[row.miloc] = values 772 else: 773 array_field = values 774 775 # handle geo dim slicing 776 array_field = array_field[(Ellipsis,) + item[-self.geo_ndim:]] 777 778 # squeeze array dimensions expressed as integer 779 for i, it in reversed(list(enumerate(item[: -self.geo_ndim]))): 780 if isinstance(it, int): 781 array_field = array_field[(slice(None, None, None),) * i + (0,)] 782 783 return array_field
794def filter_index(index, k, v): 795 if isinstance(v, slice): 796 index = index.set_index(k) 797 index = index.loc[v] 798 index = index.reset_index() 799 else: 800 label = ( 801 v 802 if getattr(v, "ndim", 1) > 1 # vectorized-indexing 803 else _asarray_tuplesafe(v) 804 ) 805 if label.ndim == 0: 806 # see https://github.com/pydata/xarray/pull/4292 for details 807 label_value = label[()] if label.dtype.kind in "mM" else label.item() 808 try: 809 indexer = pd.Index(index[k]).get_loc(label_value) 810 if isinstance(indexer, int): 811 index = index.iloc[[indexer]] 812 else: 813 index = index.iloc[indexer] 814 except KeyError: 815 index = index.iloc[[]] 816 else: 817 indexer = pd.Index(index[k]).get_indexer_for(np.ravel(v)) 818 index = index.iloc[indexer[indexer >= 0]] 819 820 return index
823def parse_grib_index( 824 index: pd.DataFrame, 825 filters: typing.Mapping[str, typing.Any] = dict(), 826): 827 """ 828 Apply filters. 829 830 Evaluate remaining dimensions based on pdtn and parse each out. 831 832 Parameters 833 ---------- 834 index 835 Pandas DataFrame containing the GRIB2 message index. 836 filters 837 Filter GRIB2 messages to single hypercube. Dict keys can be any 838 GRIB2 metadata attribute name. 839 840 Returns 841 ------- 842 index 843 Modified Pandas DataFrame with added GRIB2 metadata columns. 844 dim_coords 845 List of GRIB2 attributes that will be used for coordinates and/or dimensions. 846 attrs 847 Dict of metadata attributes (non-coordinates, non-geo) 848 """ 849 850 # make a copy of filters, remove filters as they are applied 851 filters = copy(filters) 852 853 for k, v in filters.items(): 854 if k not in index.columns: 855 kwarg = {k: index.msg.apply(lambda msg: getattr(msg, k))} 856 index = index.assign(**kwarg) 857 # adopt parts of xarray's sel logic so that filters behave similarly 858 # allowed to filter to nothing to make empty dataset 859 index = filter_index(index, k, v) 860 861 if len(index) == 0: 862 return index, list() 863 864 dim_coords = dict() # key=name of dim, value=list of coord names 865 attrs = dict() 866 coord_attrs = dict() 867 868 # expand index 869 index = index.assign(shortName=index.msg.apply(lambda msg: msg.shortName)) 870 index = index.assign(nx=index.msg.apply(lambda msg: msg.nx)) 871 index = index.assign(ny=index.msg.apply(lambda msg: msg.ny)) 872 index = index.astype({'ny': 'int', 'nx': 'int'}) 873 874 # apply common filters(to all definition templates) to reduce dataset to 875 # single cube 876 # ensure only one of each of the below exists after filters applied 877 required_uniques = [ 878 "productDefinitionTemplateNumber", 879 "typeOfGeneratingProcess", 880 "typeOfFirstFixedSurface", 881 "typeOfSecondFixedSurface", 882 ] 883 884 def meta_check(index, attrs, meta): 885 """ 886 add meta to the datframe index 887 check that there is a single type 888 add the type to attrs 889 890 returns index, attrs 891 """ 892 index = index.assign(**{meta: index.msg.apply(lambda msg: getattr(msg, meta))}) 893 894 unique = index[meta].unique() 895 if len(index[meta].unique()) > 1: 896 raise ValueError(f'filter to a single {meta}; found: {[str(i) for i in unique]}') 897 value = unique.item() 898 if type(value) == grib2io.templates.Grib2Metadata: 899 value = value.definition 900 901 # None is returned if no value found, 902 # check and change to string None 903 if value is None: 904 value = 'None' 905 906 attrs[meta] = value 907 return index, attrs 908 909 for meta in required_uniques: 910 index, attrs = meta_check(index, attrs, meta) 911 912 pdtn = index.productDefinitionTemplateNumber.iloc[0].value 913 914 # determine which non geo dimensions can be created from data by this point 915 # the index is filtered down to a single type for all required_uniques 916 917 # Dim Name # matching dim_name for using this data as index coordinate 918 dim_coords["refDate"] = ["refDate"] 919 coord_attrs["refDate"] = dict(standard_name="forecast_reference_time") 920# dim_coords["refDate"] = ["refDate", "hour"] # non dim name matching items in list are used as non-index coordinates 921 922 dim_coords["leadTime"] = ["leadTime"] 923 coord_attrs["leadTime"] = dict(standard_name="forecast_period") 924 925 if 'valueOfFirstFixedSurface' not in index.columns: 926 index = index.assign(valueOfFirstFixedSurface=index.msg.apply(lambda msg: msg.valueOfFirstFixedSurface)) 927 if 'valueOfsecondFixedSurface' not in index.columns: 928 index = index.assign(valueOfSecondFixedSurface=index.msg.apply(lambda msg: msg.valueOfSecondFixedSurface)) 929 930 # dim name api change, user could run ds = ds.swap_dims(fixedSurface="valueOfFirstFixedSurface") 931 index = index.assign(level=list(zip(index['valueOfFirstFixedSurface'], index['valueOfSecondFixedSurface']))) 932# index = index.assign(level=index.msg.apply(lambda msg: msg.level)) 933 # lack of "level" indeicates don't create extra index coordinate "level" 934 dim_coords["level"] = ["valueOfFirstFixedSurface", "valueOfSecondFixedSurface"] 935 936 # logic for parsing possible dims from specific product definition section 937 938 if pdtn in {5, 9}: 939 940 # Probability forecasts at a horizontal level or in a horizontal layer 941 # in a continuous or non-continuous time interval. (see Template 942 # 4.9) 943 # AVAILABLE_THRESHOLD = { 944 # 0: {'has_lower': True, 'has_upper': False}, 945 # 1: {'has_lower': False, 'has_upper': True}, 946 # 2: {'has_lower': True, 'has_upper': True}, 947 # 3: {'has_lower': True, 'has_upper': False}, 948 # 4: {'has_lower': False, 'has_upper': True}, 949 # 5: {'has_lower': True, 'has_upper': False}, 950 # } 951 952 index, attrs = meta_check(index, attrs, "typeOfProbability") 953 if 'thresholdLowerLimit' not in index.columns: 954 index = index.assign(thresholdLowerLimit=index.msg.apply(lambda msg: msg.thresholdLowerLimit)) 955 if 'thresholdUpperLimit' not in index.columns: 956 index = index.assign(thresholdUpperLimit=index.msg.apply(lambda msg: msg.thresholdUpperLimit)) 957 if 'threshold' not in index.columns: 958 # using composite of lower and upper, but could use threshold string from grib2io as long as that is unique and based on lower and upper 959 index = index.assign(threshold=list(zip(index['thresholdLowerLimit'], index['thresholdUpperLimit']))) 960# index = index.assign(threshold = index.msg.apply(lambda msg: msg.threshold)) 961 962 # ommiting threshold results in no index being assigned for this possible dim 963 dim_coords["threshold"] = ["thresholdLowerLimit", "thresholdUpperLimit"] 964 965 if pdtn in {6, 10}: 966 967 # Percentile forecasts at a horizontal level or in a horizontal layer 968 # in a continuous or non-continuous time interval. (see Template 969 # 4.10) 970 dim_coords["percentileValue"] = ["percentileValue"] 971 coord_attrs["percentileValue"] = dict(long_name='percentile', units='percent') 972 973 if pdtn in {8, 9, 10, 11, 12, 13, 14, 42, 43, 45, 46, 47, 61, 62, 63, 67, 68, 72, 73, 78, 79, 82, 83, 84, 85, 87, 91}: 974 dim_coords["duration"] = ["duration"] 975 976 if pdtn in {1, 11, 33, 34, 41, 43, 45, 47, 49, 54, 56, 58, 59, 63, 68, 77, 79, 81, 83, 84, 85, 92}: 977 dim_coords["perturbationNumber"] = ["perturbationNumber"] 978 979 if pdtn in {2,3,4,12,13,14}: 980 index, attrs = meta_check(index, attrs, 'typeOfDerivedForecast') 981 982 if pdtn in {8,15,42,46,62,67,72,78,82,1001,1002,1100,1101}: 983 index, attrs = meta_check(index, attrs, 'statisticalProcess') 984 985 # Finish logic by pdtn 986 987 for k, v in dim_coords.items(): 988 for meta in v: 989 if meta not in index.columns: 990 index = index.assign(**{meta: index.msg.apply(lambda msg: getattr(msg, meta))}) 991 992 return index, dim_coords, attrs, coord_attrs
Apply filters.
Evaluate remaining dimensions based on pdtn and parse each out.
Parameters
- index: Pandas DataFrame containing the GRIB2 message index.
- filters: Filter GRIB2 messages to single hypercube. Dict keys can be any GRIB2 metadata attribute name.
Returns
- index: Modified Pandas DataFrame with added GRIB2 metadata columns.
- dim_coords: List of GRIB2 attributes that will be used for coordinates and/or dimensions.
- attrs: Dict of metadata attributes (non-coordinates, non-geo)
996def open_datatree(filename, *, filters: typing.Mapping[str, typing.Any] = None, engine="grib2io"): 997 """ 998 Open a GRIB2 file as an xarray DataTree. 999 1000 Parameters 1001 ---------- 1002 filename : str 1003 Path to the GRIB2 file. 1004 filters : dict, optional 1005 Filter criteria for GRIB2 messages. 1006 engine : str, optional 1007 Engine to use for opening the file, defaults to "grib2io". 1008 1009 Returns 1010 ------- 1011 xarray.DataTree 1012 A hierarchical DataTree representation of the GRIB2 data. 1013 """ 1014 if not _HAS_DATATREE: 1015 raise ImportError("xarray version does not support DataTree functionality.") 1016 1017 if filters is None: 1018 filters = {} 1019 1020 # Open the file without any filters first to get all messages 1021 with grib2io.open(filename, _xarray_backend=True) as f: 1022 file_index = pd.DataFrame(f._index) 1023 1024 # Create a DataTree root 1025 tree = xr.DataTree() 1026 1027 # Build tree structure from GRIB messages 1028 return build_datatree_from_grib(filename, file_index, filters)
Open a GRIB2 file as an xarray DataTree.
Parameters
- filename (str): Path to the GRIB2 file.
- filters (dict, optional): Filter criteria for GRIB2 messages.
- engine (str, optional): Engine to use for opening the file, defaults to "grib2io".
Returns
- xarray.DataTree: A hierarchical DataTree representation of the GRIB2 data.
1031def build_da_without_coords(index, cube, filename, attrs) -> xr.DataArray: 1032 """ 1033 Build a DataArray without coordinates from a cube of grib2 messages. 1034 1035 Parameters 1036 ---------- 1037 index 1038 Index of cube. 1039 cube 1040 Cube of grib2 messages. 1041 filename 1042 Filename of grib2 file 1043 add_grib_section_attrs 1044 Include grib section arrays as dataArray attributes 1045 1046 Returns 1047 ------- 1048 DataArray 1049 DataArray without coordinates 1050 """ 1051 1052 dim_names = [k for k in cube.keys() if cube[k] is not None and len(cube[k]) > 1] 1053 constant_meta_names = [k for k in cube.keys() if cube[k] is None] 1054 dims = {k: len(cube[k]) for k in dim_names} 1055 1056 # guard against bad datarrays being formed 1057 dims_total = 1 1058 dims_to_filter = [] 1059 for dim_name, dim_len, in dims.items(): 1060 if dim_name not in {'x', 'y', 'station'}: 1061 dims_total *= dim_len 1062 dims_to_filter.append(dim_name) 1063 1064 # Check number of GRIB2 message indexed compared to non-X/Y 1065 # dimensions. 1066 if dims_total != len(index): 1067 raise ValueError( 1068 f"DataArray dimensions are not compatible with number of GRIB2 messages; DataArray has {dims_total} " 1069 f"and GRIB2 index has {len(index)}. Consider applying a filter for dimensions: {dims_to_filter}" 1070 ) 1071 1072 data = OnDiskArray(filename, index, cube) 1073 lock = _LOCK 1074 data = GribBackendArray(data, lock) 1075 data = indexing.LazilyIndexedArray(data) 1076 if len(dim_names) != len(data.shape): 1077 raise ValueError( 1078 "different number of dimensions on data " 1079 f"and dims: {len(data.shape)} vs {len(dim_names)}\n" 1080 "Grib2 messages could not be formed into a data cube; " 1081 "It's possible extra messages exist along a non-accounted for dimension based on PDTN\n" 1082 "It might be possible to get around this by applying a filter on the non-accounted for dimension" 1083 ) 1084 da = xr.DataArray(data, dims=dim_names) 1085 1086 da.encoding['original_shape'] = data.shape 1087 1088 da.encoding['preferred_chunks'] = {'y': -1, 'x': -1} 1089 msg1 = index.msg.iloc[0] 1090 1091 # plain language metadata is minimized 1092 # add grib section metadata 1093 da.attrs['GRIB2IO_section0'] = msg1.section0 1094 da.attrs['GRIB2IO_section1'] = msg1.section1 1095 da.attrs['GRIB2IO_section2'] = msg1.section2 if msg1.section2 else [] 1096 da.attrs['GRIB2IO_section3'] = msg1.section3 1097 da.attrs['GRIB2IO_section4'] = msg1.section4 1098 da.attrs['GRIB2IO_section5'] = msg1.section5 1099 da.attrs['fullName'] = str(msg1.fullName) 1100 da.attrs['shortName'] = str(msg1.shortName) 1101 da.attrs['units'] = str(msg1.units) 1102 da.attrs['originatingCenter'] = str(msg1.originatingCenter.definition) 1103 da.attrs['originatingSubCenter'] = str(msg1.originatingSubCenter.definition) 1104 1105 # add master table 1106 da.attrs['masterTableInfo'] = str(msg1.masterTableInfo.definition) 1107 1108 da.name = index.shortName.iloc[0] 1109 for meta_name in constant_meta_names: 1110 if meta_name in index.columns: 1111 da.attrs[meta_name] = index[meta_name].iloc[0] 1112 1113 da.attrs.update(attrs) 1114 1115 return da
Build a DataArray without coordinates from a cube of grib2 messages.
Parameters
- index: Index of cube.
- cube: Cube of grib2 messages.
- filename: Filename of grib2 file
- add_grib_section_attrs: Include grib section arrays as dataArray attributes
Returns
- DataArray: DataArray without coordinates
1118def assign_xr_meta(ds, frames, cube, non_geo_dims, extra_geo, coord_attrs): 1119 1120 # assign coords from the cube; the cube prevents datarrays with 1121 # different shapes 1122 ds = ds.assign_coords(coords_from_cube(cube)) 1123 # assign extra index associated coords 1124 df = frames[0] # use first variable as they all have same shape and index metadata 1125 for dim_name, coord_names in non_geo_dims.items(): 1126 retain_index_coord = False 1127 for name in coord_names: 1128 if name == dim_name: 1129 retain_index_coord = True 1130 else: 1131 if ds[dim_name].size == 1: 1132 # for assigning scalar coords 1133 coord_data = [df[name].unique().item()] 1134 ds = ds.assign_coords({name: coord_data}).squeeze() 1135 else: 1136 # "ValueError: can only convert an array of size 1 to a Python scalar" indicates the coord is not compatible with the index 1137 coord_data = [df[df.index.get_level_values(f'{dim_name}_ix') == val][name].unique( 1138 ).item() for val in range(ds[dim_name].size)] 1139 coord = pd.Index(coord_data, name=dim_name) 1140 ds = ds.assign_coords({name: coord}) 1141 if not retain_index_coord: 1142 ds = ds.drop_vars(dim_name) 1143 1144 # assign extra geo coords 1145 ds = ds.assign_coords(extra_geo) 1146 # add crs data from first grib message to each data variable and the dataset 1147 geo_attrs = { 1148 'crs_wkt': CRS.from_dict(df.msg.iloc[0].projParameters).to_wkt(), 1149 'gridlengthXDirection': df.msg.iloc[0].gridlengthXDirection, 1150 'gridlengthYDirection': df.msg.iloc[0].gridlengthYDirection, 1151 'latitudeFirstGridpoint': df.msg.iloc[0].latitudeFirstGridpoint, 1152 'longitudeFirstGridpoint': df.msg.iloc[0].longitudeFirstGridpoint, 1153 } 1154 for data_var in ds.data_vars: 1155 ds[data_var].attrs.update(geo_attrs) 1156 ds.attrs.update(geo_attrs) 1157 1158 # add coordinate specific attributes 1159 for coord, attrs in coord_attrs.items(): 1160 ds[coord].attrs.update(attrs) 1161 1162 # assign valid date coords 1163 try: 1164 ds = ds.assign_coords(dict(validDate=ds.coords['refDate']+ds.coords['leadTime'])) 1165 ds.validDate.attrs['standard_name'] = 'time' 1166 ds.validDate.attrs['long_name'] = 'time' 1167 except Exception as e: 1168 warnings.warn(f'could not parse validTime: {e}') 1169 1170 # assign attributes 1171 ds.attrs['engine'] = 'grib2io' 1172 1173 return ds
1176def make_variables(index, f, non_geo_dims, allow_uneven_dims=False): 1177 """ 1178 Create an individual dataframe index and cube for each variable. 1179 1180 Parameters 1181 ---------- 1182 index 1183 Index of cube. 1184 f 1185 ? 1186 non_geo_dims 1187 Dimensions not associated with the x,y grid 1188 allow_uneven_dims 1189 If True, allows uneven dimensions (used for DataTree creation) 1190 1191 Returns 1192 ------- 1193 ordered_frames 1194 List of dataframes, one for each variable. 1195 cube 1196 Cube of grib2 messages. 1197 extra_geo 1198 Extra geographic coordinates. 1199 """ 1200 # let shortName determine the variables 1201 1202 # set the index to the name 1203 index = index.set_index('shortName').sort_index() 1204 # return nothing if no data 1205 if index.empty: 1206 return None, None, None 1207 1208 # define the DimCube 1209 dims = copy(non_geo_dims) 1210 1211 ordered_meta = list(non_geo_dims.keys()) 1212 cube = None 1213 ordered_frames = list() 1214 for key in index.index.unique(): 1215 frame = index.loc[[key]] 1216 frame = frame.reset_index() 1217 # frame is a dataframe with all records for one variable 1218 c = dict() 1219 # for colname in frame.columns: 1220 for colname in ordered_meta: 1221 uniques = pd.Index(frame[colname]).unique() 1222 if len(uniques) > 1: 1223 c[colname] = uniques.sort_values() 1224 else: 1225 c[colname] = [uniques[0]] 1226 1227 dims = [k for k in ordered_meta if len(c[k]) > 1] 1228 1229 for dim in dims: 1230 if frame[dim].value_counts().nunique() > 1 and not allow_uneven_dims: 1231 raise ValueError( 1232 f'uneven number of grib msgs associated with dimension: {dim}\n unique values for {dim}: {frame[dim].unique()} ') 1233 1234 if len(dims) >= 1: # dims may be empty if no extra dims on top of x,y 1235 frame = frame.sort_values(dims) 1236 frame = frame.set_index(dims) 1237 1238 if cube: 1239 if cube != c and not allow_uneven_dims: 1240 raise ValueError(f'{cube},\n {c};\n cubes are not the same; filter to a single cube') 1241 else: 1242 cube = c 1243 1244 # miloc is multi-index integer location of msg in nd DataArray 1245 miloc = list(zip(*[frame.index.unique(level=dim).get_indexer(frame.index.get_level_values(dim)) 1246 for dim in dims])) 1247 1248 # set frame multi index 1249 if len(miloc) >= 1: # miloc will be empty when no extra dims, thus no multiindex 1250 dim_ix = tuple([n+'_ix' for n in dims]) 1251 frame = frame.set_index(pd.MultiIndex.from_tuples(miloc, names=dim_ix)) 1252 1253 ordered_frames.append(frame) 1254 1255 # no variables 1256 if cube is None: 1257 cube = dict() 1258 1259 # check geography of data and assign to cube 1260 if len(index.ny.unique()) > 1 or len(index.nx.unique()) > 1: 1261 raise ValueError('multiple grids not accommodated') 1262 cube["y"] = range(int(index.ny.iloc[0])) 1263 cube["x"] = range(int(index.nx.iloc[0])) 1264 1265 extra_geo = None 1266 msg = index.msg.iloc[0] 1267 1268 # we want the lat lons; make them via accessing a record; we are assuming 1269 # all records are the same grid because they have the same shape; 1270 # may want a unique grid identifier from grib2io to avoid assuming this 1271 latitude, longitude = msg.latlons() 1272 latitude = xr.DataArray(latitude, dims=['y', 'x']) 1273 latitude.attrs['standard_name'] = 'latitude' 1274 latitude.attrs['units'] = 'degrees_north' 1275 longitude = xr.DataArray(longitude, dims=['y', 'x']) 1276 longitude.attrs['standard_name'] = 'longitude' 1277 longitude.attrs['units'] = 'degrees_east' 1278 extra_geo = dict(latitude=latitude, longitude=longitude) 1279 1280 return ordered_frames, cube, extra_geo
Create an individual dataframe index and cube for each variable.
Parameters
- index: Index of cube.
- f: ?
- non_geo_dims: Dimensions not associated with the x,y grid
- allow_uneven_dims: If True, allows uneven dimensions (used for DataTree creation)
Returns
- ordered_frames: List of dataframes, one for each variable.
- cube: Cube of grib2 messages.
- extra_geo: Extra geographic coordinates.
1283def interp_nd(a, *, method, grid_def_in, grid_def_out, method_options=None, num_threads=1): 1284 front_shape = a.shape[:-2] 1285 a = a.reshape(-1, a.shape[-2], a.shape[-1]) 1286 a = grib2io.interpolate(a, method, grid_def_in, grid_def_out, method_options=method_options, 1287 num_threads=num_threads) 1288 a = a.reshape(front_shape + (a.shape[-2], a.shape[-1])) 1289 return a
1292def interp_nd_stations(a, *, method, grid_def_in, lats, lons, method_options=None, num_threads=1): 1293 front_shape = a.shape[:-2] 1294 a = a.reshape(-1, a.shape[-2], a.shape[-1]) 1295 a = grib2io.interpolate_to_stations(a, method, grid_def_in, lats, lons, method_options=method_options, 1296 num_threads=num_threads) 1297 a = a.reshape(front_shape + (len(lats),)) 1298 return a
1301@xr.register_dataset_accessor("grib2io") 1302class Grib2ioDataSet: 1303 1304 def __init__(self, xarray_obj): 1305 self._obj = xarray_obj 1306 1307 def griddef(self): 1308 return Grib2GridDef.from_section3(self._obj[list(self._obj.data_vars)[0]].attrs['GRIB2IO_section3']) 1309 1310 def interp(self, method, grid_def_out, method_options=None, num_threads=1) -> xr.Dataset: 1311 # see interp method of class Grib2ioDataArray 1312 da = self._obj.to_array() 1313 da.attrs['GRIB2IO_section3'] = self._obj[list(self._obj.data_vars)[0]].attrs['GRIB2IO_section3'] 1314 da = da.grib2io.interp(method, grid_def_out, method_options=method_options, 1315 num_threads=num_threads) 1316 ds = da.to_dataset(dim='variable') 1317 return ds 1318 1319 def interp_to_stations(self, method, calls, lats, lons, method_options=None, num_threads=1) -> xr.Dataset: 1320 # see interp_to_stations method of class Grib2ioDataArray 1321 da = self._obj.to_array() 1322 da.attrs['GRIB2IO_section3'] = self._obj[list(self._obj.data_vars)[0]].attrs['GRIB2IO_section3'] 1323 da = da.grib2io.interp_to_stations(method, calls, lats, lons, method_options=method_options, 1324 num_threads=num_threads) 1325 ds = da.to_dataset(dim='variable') 1326 return ds 1327 1328 def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"): 1329 """ 1330 Write a DataSet to a grib2 file. 1331 1332 Parameters 1333 ---------- 1334 filename 1335 Name of the grib2 file to write to. 1336 mode: {"x", "w", "a"}, optional, default="x" 1337 Persistence mode 1338 1339 | mode | Description | 1340 | :---:| :---: | 1341 | 'x' | create (fail if exists) | 1342 | 'w' | create (overwrite if exists) | 1343 | 'a' | append (create if does not exist) | 1344 1345 """ 1346 ds = self._obj 1347 1348 for shortName in sorted(ds): 1349 # make a DataArray from the "Data Variables" in the DataSet 1350 da = ds[shortName] 1351 1352 da.grib2io.to_grib2(filename, mode=mode) 1353 mode = "a" 1354 1355 def update_attrs(self, **kwargs): 1356 """ 1357 Raises an error because Datasets don't have a .attrs attribute. 1358 1359 Parameters 1360 ---------- 1361 attrs 1362 Attributes to update. 1363 """ 1364 raise ValueError( 1365 f"Datasets do not have a .attrs attribute; use .grib2io.update_attrs({kwargs}) on a DataArray instead." 1366 ) 1367 1368 def subset(self, lats, lons) -> xr.Dataset: 1369 """ 1370 Subset the DataSet to a region defined by latitudes and longitudes. 1371 1372 Parameters 1373 ---------- 1374 lats 1375 Latitude bounds of the region. 1376 lons 1377 Longitude bounds of the region. 1378 1379 Returns 1380 ------- 1381 subset 1382 DataSet subset to the region. 1383 """ 1384 ds = self._obj 1385 1386 newds = xr.Dataset() 1387 for shortName in ds: 1388 newds[shortName] = ds[shortName].grib2io.subset(lats, lons).copy() 1389 1390 return newds
1310 def interp(self, method, grid_def_out, method_options=None, num_threads=1) -> xr.Dataset: 1311 # see interp method of class Grib2ioDataArray 1312 da = self._obj.to_array() 1313 da.attrs['GRIB2IO_section3'] = self._obj[list(self._obj.data_vars)[0]].attrs['GRIB2IO_section3'] 1314 da = da.grib2io.interp(method, grid_def_out, method_options=method_options, 1315 num_threads=num_threads) 1316 ds = da.to_dataset(dim='variable') 1317 return ds
1319 def interp_to_stations(self, method, calls, lats, lons, method_options=None, num_threads=1) -> xr.Dataset: 1320 # see interp_to_stations method of class Grib2ioDataArray 1321 da = self._obj.to_array() 1322 da.attrs['GRIB2IO_section3'] = self._obj[list(self._obj.data_vars)[0]].attrs['GRIB2IO_section3'] 1323 da = da.grib2io.interp_to_stations(method, calls, lats, lons, method_options=method_options, 1324 num_threads=num_threads) 1325 ds = da.to_dataset(dim='variable') 1326 return ds
1328 def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"): 1329 """ 1330 Write a DataSet to a grib2 file. 1331 1332 Parameters 1333 ---------- 1334 filename 1335 Name of the grib2 file to write to. 1336 mode: {"x", "w", "a"}, optional, default="x" 1337 Persistence mode 1338 1339 | mode | Description | 1340 | :---:| :---: | 1341 | 'x' | create (fail if exists) | 1342 | 'w' | create (overwrite if exists) | 1343 | 'a' | append (create if does not exist) | 1344 1345 """ 1346 ds = self._obj 1347 1348 for shortName in sorted(ds): 1349 # make a DataArray from the "Data Variables" in the DataSet 1350 da = ds[shortName] 1351 1352 da.grib2io.to_grib2(filename, mode=mode) 1353 mode = "a"
Write a DataSet to a grib2 file.
Parameters
- filename: Name of the grib2 file to write to.
mode ({"x", "w", "a"}, optional, default="x"): Persistence mode
mode Description 'x' create (fail if exists) 'w' create (overwrite if exists) 'a' append (create if does not exist)
1355 def update_attrs(self, **kwargs): 1356 """ 1357 Raises an error because Datasets don't have a .attrs attribute. 1358 1359 Parameters 1360 ---------- 1361 attrs 1362 Attributes to update. 1363 """ 1364 raise ValueError( 1365 f"Datasets do not have a .attrs attribute; use .grib2io.update_attrs({kwargs}) on a DataArray instead." 1366 )
Raises an error because Datasets don't have a .attrs attribute.
Parameters
- attrs: Attributes to update.
1368 def subset(self, lats, lons) -> xr.Dataset: 1369 """ 1370 Subset the DataSet to a region defined by latitudes and longitudes. 1371 1372 Parameters 1373 ---------- 1374 lats 1375 Latitude bounds of the region. 1376 lons 1377 Longitude bounds of the region. 1378 1379 Returns 1380 ------- 1381 subset 1382 DataSet subset to the region. 1383 """ 1384 ds = self._obj 1385 1386 newds = xr.Dataset() 1387 for shortName in ds: 1388 newds[shortName] = ds[shortName].grib2io.subset(lats, lons).copy() 1389 1390 return newds
Subset the DataSet to a region defined by latitudes and longitudes.
Parameters
- lats: Latitude bounds of the region.
- lons: Longitude bounds of the region.
Returns
- subset: DataSet subset to the region.
1393@xr.register_dataarray_accessor("grib2io") 1394class Grib2ioDataArray: 1395 1396 def __init__(self, xarray_obj): 1397 self._obj = xarray_obj 1398 1399 def griddef(self): 1400 return Grib2GridDef.from_section3(self._obj.attrs['GRIB2IO_section3']) 1401 1402 def interp(self, method, grid_def_out, method_options=None, num_threads=1) -> xr.DataArray: 1403 """ 1404 Perform grid spatial interpolation. 1405 1406 Uses the [NCEPLIBS-ip library](https://github.com/NOAA-EMC/NCEPLIBS-ip). 1407 1408 Parameters 1409 ---------- 1410 method 1411 Interpolate method to use. This can either be an integer or string 1412 using the following mapping: 1413 1414 | Interpolate Scheme | Integer Value | 1415 | :---: | :---: | 1416 | 'bilinear' | 0 | 1417 | 'bicubic' | 1 | 1418 | 'neighbor' | 2 | 1419 | 'budget' | 3 | 1420 | 'spectral' | 4 | 1421 | 'neighbor-budget' | 6 | 1422 grid_def_out 1423 Grib2GridDef object of the output grid. 1424 method_options : list of ints, optional 1425 Interpolation options. See the NCEPLIBS-ip documentation for 1426 more information on how these are used. 1427 num_threads : int, optional 1428 Number of OpenMP threads to use for interpolation. The default 1429 value is 1. If grib2io_interp was not built with OpenMP, then 1430 this keyword argument and value will have no impact. 1431 1432 Returns 1433 ------- 1434 interp 1435 DataSet interpolated to new grid definition. The attribute 1436 GRIB2IO_section3 is replaced with the section3 array from the new 1437 grid definition. 1438 """ 1439 da = self._obj 1440 # ensure that y, x are rightmost dims; they should be if opening with 1441 # grib2io engine 1442 1443 # gdtn and gdt is not the entirety of the new s3 1444 npoints = grid_def_out.npoints 1445 s3_new = np.array([0, npoints, 0, 0, grid_def_out.gdtn] + list(grid_def_out.gdt)) 1446 1447 # make new lat lons 1448 lats, lons = Grib2Message(section3=s3_new, pdtn=0, drtn=0).grid() 1449 latitude = xr.DataArray(lats, dims=['y', 'x']) 1450 longitude = xr.DataArray(lons, dims=['y', 'x']) 1451 1452 # create new coords 1453 new_coords = dict(da.coords) 1454 del new_coords['latitude'] 1455 del new_coords['longitude'] 1456 new_coords['longitude'] = longitude 1457 new_coords['latitude'] = latitude 1458 1459 # make grid def in from section3 on da.attrs 1460 grid_def_in = self.griddef() 1461 1462 if da.chunks is None: 1463 data = interp_nd(da.data, method=method, grid_def_in=grid_def_in, 1464 grid_def_out=grid_def_out, 1465 method_options=method_options, num_threads=num_threads) 1466 else: 1467 import dask 1468 front_shape = da.shape[:-2] 1469 data = da.data.map_blocks(interp_nd, method=method, grid_def_in=grid_def_in, 1470 grid_def_out=grid_def_out, method_options=method_options, 1471 chunks=da.chunks[:-2]+latitude.shape, dtype=da.dtype) 1472 1473 new_da = xr.DataArray(data, dims=da.dims, coords=new_coords, attrs=da.attrs) 1474 1475 new_da.attrs['GRIB2IO_section3'] = s3_new 1476 new_da.name = da.name 1477 return new_da 1478 1479 def interp_to_stations(self, method, calls, lats, lons, method_options=None, num_threads=1) -> xr.DataArray: 1480 """ 1481 Perform spatial interpolation to station points. 1482 1483 Parameters 1484 ---------- 1485 method 1486 Interpolate method to use. This can either be an integer or string 1487 using the following mapping: 1488 1489 | Interpolate Scheme | Integer Value | 1490 | :---: | :---: | 1491 | 'bilinear' | 0 | 1492 | 'bicubic' | 1 | 1493 | 'neighbor' | 2 | 1494 | 'budget' | 3 | 1495 | 'spectral' | 4 | 1496 | 'neighbor-budget' | 6 | 1497 1498 calls 1499 Station calls used for labeling new station index coordinate 1500 lats 1501 Latitudes of the station points. 1502 lons 1503 Longitudes of the station points. 1504 1505 Returns 1506 ------- 1507 interp_to_stations 1508 DataArray interpolated to lat and lon locations and labeled with 1509 dimension and coordinate 'station'. (..., y, x) -> (..., station) 1510 """ 1511 da = self._obj 1512 # TODO ensure that y, x are rightmost dims; they should be if opening 1513 # with grib2io engine 1514 1515 calls = np.asarray(calls) 1516 lats = np.asarray(lats) 1517 lons = np.asarray(lons) 1518 latitude = xr.DataArray(lats, dims=['station']) 1519 longitude = xr.DataArray(lons, dims=['station']) 1520 1521 # create new coords 1522 new_coords = dict(da.coords) 1523 del new_coords['latitude'] 1524 del new_coords['longitude'] 1525 new_coords['longitude'] = longitude 1526 new_coords['latitude'] = latitude 1527 new_coords['station'] = calls 1528 1529 new_dims = da.dims[:-2] + ('station',) 1530 1531 # make grid def in from section3 on da attrs 1532 grid_def_in = self.griddef() 1533 1534 if da.chunks is None: 1535 data = interp_nd_stations(da.data, method=method, grid_def_in=grid_def_in, lats=lats, 1536 lons=lons, method_options=method_options, num_threads=num_threads) 1537 else: 1538 import dask 1539 front_shape = da.shape[:-1] 1540 data = da.data.map_blocks(interp_nd_stations, method=method, grid_def_in=grid_def_in, 1541 lats=lats, lons=lons, method_options=method_options, 1542 drop_axis=-1, chunks=da.chunks[:-2]+latitude.shape, 1543 dtype=da.dtype) 1544 1545 new_da = xr.DataArray(data, dims=new_dims, coords=new_coords, attrs=da.attrs) 1546 1547 new_da.name = da.name 1548 return new_da 1549 1550 def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"): 1551 """ 1552 Write a DataArray to a grib2 file. 1553 1554 Parameters 1555 ---------- 1556 filename 1557 Name of the grib2 file to write to. 1558 mode: {"x", "w", "a"}, optional, default="x" 1559 Persistence mode 1560 1561 +------+-----------------------------------+ 1562 | mode | Description | 1563 +======+===================================+ 1564 | x | create (fail if exists) | 1565 +------+-----------------------------------+ 1566 | w | create (overwrite if exists) | 1567 +------+-----------------------------------+ 1568 | a | append (create if does not exist) | 1569 +------+-----------------------------------+ 1570 1571 """ 1572 da = self._obj.copy(deep=True) 1573 1574 coords_keys = sorted(da.coords.keys()) 1575 coords_keys = [k for k in coords_keys if k in AVAILABLE_NON_GEO_COORDS] 1576 1577 # If there are dimension coordinates, the DataArray is a hypercube of 1578 # grib2 messages. 1579 1580 # Create `indexes` which is a list of lists of dictionaries for all 1581 # dimension coordinates. Each dictionary key is the dimension 1582 # coordinate name and the value is a list of the dimension coordinate 1583 # values. This allows for easy iteration over all possible grib2 1584 # messages in the DataArray by using itertools.product. 1585 # 1586 # For example: 1587 # indexes = [ 1588 # [ 1589 # {"leadTime": 9}, 1590 # {"leadTime": 12}, 1591 # ], 1592 # [ 1593 # {"valueOfFirstFixedSurface": 900}, 1594 # {"valueOfFirstFixedSurface": 925}, 1595 # {"valueOfFirstFixedSurface": 950}, 1596 # ], 1597 # ] 1598 1599 # assign loc indexes to dimensions without indexes for uniform selection by name 1600 loc_indexes = list() 1601 for dim in da.dims: 1602 if dim not in da.indexes: 1603 da = da.assign_coords({dim: range(da[dim].size)}) 1604 loc_indexes.append(dim) 1605 1606 indexes = [] 1607 for index in [i for i in AVAILABLE_NON_GEO_DIMS if i in da.dims]: 1608 values = da.coords[index].values 1609 if len(values) != len(set(values)): 1610 raise ValueError( 1611 f"Dimension coordinate '{index}' has duplicate values, but to_grib2 requires unique values to find each GRIB2 message in the DataArray." 1612 ) 1613 listeach = [{index: value} for value in sorted(values)] 1614 indexes.append(listeach) 1615 1616 # If `dim_coords` is [], then the DataArray is a single grib2 message and 1617 # itertools.product(*dim_coords) will run once with `selectors = ()`. 1618 for selectors in itertools.product(*indexes): 1619 # Need to find the correct data in the DataArray based on the 1620 # dimension coordinates. 1621 filters = {k: v for d in selectors for k, v in d.items()} 1622 1623 # If `filters` is {}, then the DataArray is a single grib2 message 1624 # and da.sel(indexers={}) returns the DataArray. 1625 selected = da.sel(indexers=filters) 1626 1627 newmsg = Grib2Message( 1628 selected.attrs["GRIB2IO_section0"], 1629 selected.attrs["GRIB2IO_section1"], 1630 selected.attrs["GRIB2IO_section2"], 1631 selected.attrs["GRIB2IO_section3"], 1632 selected.attrs["GRIB2IO_section4"], 1633 selected.attrs["GRIB2IO_section5"], 1634 ) 1635 newmsg.data = np.array(selected.data) 1636 1637 # For dimension coordinates, set the grib2 message metadata to the 1638 # dimension coordinate value. 1639 for index, value in filters.items(): 1640 if index not in loc_indexes: 1641 setattr(newmsg, index, value) 1642 1643 # For non-dimension coordinates, set the grib2 message metadata to 1644 # the DataArray coordinate value. 1645 for index in [i for i in coords_keys if i not in da.dims]: 1646 setattr(newmsg, index, selected.coords[index].values) 1647 1648 # Set section 5 attributes to the da.encoding dictionary. 1649 for key, value in selected.encoding.items(): 1650 if key in ["dtype", "chunks", "original_shape"]: 1651 continue 1652 setattr(newmsg, key, value) 1653 1654 # write the message to file 1655 with grib2io.open(filename, mode=mode) as f: 1656 f.write(newmsg) 1657 mode = "a" 1658 1659 def update_attrs(self, **kwargs): 1660 """ 1661 Update many of the attributes of the DataArray. 1662 1663 Parameters 1664 ---------- 1665 **kwargs 1666 Attributes to update. This can include many of the GRIB2IO message 1667 attributes that you can find when you print a GRIB2IO message. For 1668 conflicting updates, the last keyword will be used. 1669 1670 +-----------------------+------------------------------------------+ 1671 | kwargs | Description | 1672 +=======================+==========================================+ 1673 | shortName="VTMP" | Set shortName to "VTMP", along with | 1674 | | appropriate discipline, | 1675 | | parameterCategory, parameterNumber, | 1676 | | fullName and units. | 1677 +-----------------------+------------------------------------------+ 1678 | discipline=0, | Set shortName, discipline, | 1679 | parameterCategory=0, | parameterCategory, parameterNumber, | 1680 | parameterNumber=1 | fullName and units appropriate for | 1681 | | "Virtual Temperature". | 1682 +-----------------------+------------------------------------------+ 1683 | discipline=0, | Conflicting keywords but | 1684 | parameterCategory=0, | 'shortName="TMP"' wins. Set shortName, | 1685 | parameterNumber=1, | discipline, parameterCategory, | 1686 | shortName="TMP" | parameterNumber, fullName and units | 1687 | | appropriate for "Temperature". | 1688 +-----------------------+------------------------------------------+ 1689 1690 Returns 1691 ------- 1692 DataArray 1693 DataArray with updated attributes. 1694 """ 1695 da = self._obj.copy(deep=True) 1696 1697 newmsg = Grib2Message( 1698 da.attrs["GRIB2IO_section0"], 1699 da.attrs["GRIB2IO_section1"], 1700 da.attrs["GRIB2IO_section2"], 1701 da.attrs["GRIB2IO_section3"], 1702 da.attrs["GRIB2IO_section4"], 1703 da.attrs["GRIB2IO_section5"], 1704 ) 1705 1706 coords_keys = [ 1707 k 1708 for k in da.coords.keys() 1709 if k in AVAILABLE_NON_GEO_COORDS 1710 ] 1711 1712 for grib2_name, value in kwargs.items(): 1713 if grib2_name == "gridDefinitionTemplateNumber": 1714 raise ValueError( 1715 "The gridDefinitionTemplateNumber attribute cannot be updated. The best way to change to a different grid is to interpolate the data to a new grid using the grib2io interpolate functions." 1716 ) 1717 if grib2_name == "productDefinitionTemplateNumber": 1718 raise ValueError( 1719 "The productDefinitionTemplateNumber attribute cannot be updated." 1720 ) 1721 if grib2_name == "dataRepresentationTemplateNumber": 1722 raise ValueError( 1723 "The dataRepresentationTemplateNumber attribute cannot be updated." 1724 ) 1725 if grib2_name in coords_keys: 1726 warnings.warn( 1727 f"Skipping attribute '{grib2_name}' because it is a coordinate. Use da.assign_coords() to change coordinate values." 1728 ) 1729 continue 1730 if hasattr(newmsg, grib2_name): 1731 setattr(newmsg, grib2_name, value) 1732 else: 1733 warnings.warn( 1734 f"Skipping attribute '{grib2_name}' because it is not a valid GRIB2 attribute for this message and cannot be updated." 1735 ) 1736 continue 1737 1738 da.attrs["GRIB2IO_section0"] = newmsg.section0 1739 da.attrs["GRIB2IO_section1"] = newmsg.section1 1740 da.attrs["GRIB2IO_section2"] = newmsg.section2 or [] 1741 da.attrs["GRIB2IO_section3"] = newmsg.section3 1742 da.attrs["GRIB2IO_section4"] = newmsg.section4 1743 da.attrs["GRIB2IO_section5"] = newmsg.section5 1744 da.attrs["fullName"] = newmsg.fullName 1745 da.attrs["shortName"] = newmsg.shortName 1746 da.attrs["units"] = newmsg.units 1747 1748 return da 1749 1750 def subset(self, lats, lons) -> xr.DataArray: 1751 """ 1752 Subset the DataArray to a region defined by latitudes and longitudes. 1753 1754 Parameters 1755 ---------- 1756 lats 1757 Latitude bounds of the region. 1758 lons 1759 Longitude bounds of the region. 1760 1761 Returns 1762 ------- 1763 subset 1764 DataArray subset to the region. 1765 """ 1766 da = self._obj.copy(deep=True) 1767 1768 newmsg = Grib2Message( 1769 da.attrs["GRIB2IO_section0"], 1770 da.attrs["GRIB2IO_section1"], 1771 da.attrs["GRIB2IO_section2"], 1772 da.attrs["GRIB2IO_section3"], 1773 da.attrs["GRIB2IO_section4"], 1774 da.attrs["GRIB2IO_section5"], 1775 ) 1776 1777 newmsg.data = np.zeros((newmsg.ny, newmsg.nx), dtype=np.float32) 1778 1779 newmsg = newmsg.subset(lats, lons) 1780 1781 da.attrs["GRIB2IO_section3"] = newmsg.section3 1782 1783 mask_lat = (da.latitude >= newmsg.latitudeLastGridpoint) & ( 1784 da.latitude <= newmsg.latitudeFirstGridpoint 1785 ) 1786 mask_lon = (da.longitude >= newmsg.longitudeFirstGridpoint) & ( 1787 da.longitude <= newmsg.longitudeLastGridpoint 1788 ) 1789 1790 del newmsg 1791 1792 return da.where((mask_lon & mask_lat).compute(), drop=True)
1402 def interp(self, method, grid_def_out, method_options=None, num_threads=1) -> xr.DataArray: 1403 """ 1404 Perform grid spatial interpolation. 1405 1406 Uses the [NCEPLIBS-ip library](https://github.com/NOAA-EMC/NCEPLIBS-ip). 1407 1408 Parameters 1409 ---------- 1410 method 1411 Interpolate method to use. This can either be an integer or string 1412 using the following mapping: 1413 1414 | Interpolate Scheme | Integer Value | 1415 | :---: | :---: | 1416 | 'bilinear' | 0 | 1417 | 'bicubic' | 1 | 1418 | 'neighbor' | 2 | 1419 | 'budget' | 3 | 1420 | 'spectral' | 4 | 1421 | 'neighbor-budget' | 6 | 1422 grid_def_out 1423 Grib2GridDef object of the output grid. 1424 method_options : list of ints, optional 1425 Interpolation options. See the NCEPLIBS-ip documentation for 1426 more information on how these are used. 1427 num_threads : int, optional 1428 Number of OpenMP threads to use for interpolation. The default 1429 value is 1. If grib2io_interp was not built with OpenMP, then 1430 this keyword argument and value will have no impact. 1431 1432 Returns 1433 ------- 1434 interp 1435 DataSet interpolated to new grid definition. The attribute 1436 GRIB2IO_section3 is replaced with the section3 array from the new 1437 grid definition. 1438 """ 1439 da = self._obj 1440 # ensure that y, x are rightmost dims; they should be if opening with 1441 # grib2io engine 1442 1443 # gdtn and gdt is not the entirety of the new s3 1444 npoints = grid_def_out.npoints 1445 s3_new = np.array([0, npoints, 0, 0, grid_def_out.gdtn] + list(grid_def_out.gdt)) 1446 1447 # make new lat lons 1448 lats, lons = Grib2Message(section3=s3_new, pdtn=0, drtn=0).grid() 1449 latitude = xr.DataArray(lats, dims=['y', 'x']) 1450 longitude = xr.DataArray(lons, dims=['y', 'x']) 1451 1452 # create new coords 1453 new_coords = dict(da.coords) 1454 del new_coords['latitude'] 1455 del new_coords['longitude'] 1456 new_coords['longitude'] = longitude 1457 new_coords['latitude'] = latitude 1458 1459 # make grid def in from section3 on da.attrs 1460 grid_def_in = self.griddef() 1461 1462 if da.chunks is None: 1463 data = interp_nd(da.data, method=method, grid_def_in=grid_def_in, 1464 grid_def_out=grid_def_out, 1465 method_options=method_options, num_threads=num_threads) 1466 else: 1467 import dask 1468 front_shape = da.shape[:-2] 1469 data = da.data.map_blocks(interp_nd, method=method, grid_def_in=grid_def_in, 1470 grid_def_out=grid_def_out, method_options=method_options, 1471 chunks=da.chunks[:-2]+latitude.shape, dtype=da.dtype) 1472 1473 new_da = xr.DataArray(data, dims=da.dims, coords=new_coords, attrs=da.attrs) 1474 1475 new_da.attrs['GRIB2IO_section3'] = s3_new 1476 new_da.name = da.name 1477 return new_da
Perform grid spatial interpolation.
Uses the NCEPLIBS-ip library.
Parameters
- method: Interpolate method to use. This can either be an integer or string using the following mapping:
| Interpolate Scheme | Integer Value |
|---|---|
| 'bilinear' | 0 |
| 'bicubic' | 1 |
| 'neighbor' | 2 |
| 'budget' | 3 |
| 'spectral' | 4 |
| 'neighbor-budget' | 6 |
- grid_def_out: Grib2GridDef object of the output grid.
- method_options (list of ints, optional): Interpolation options. See the NCEPLIBS-ip documentation for more information on how these are used.
- num_threads (int, optional): Number of OpenMP threads to use for interpolation. The default value is 1. If grib2io_interp was not built with OpenMP, then this keyword argument and value will have no impact.
Returns
- interp: DataSet interpolated to new grid definition. The attribute GRIB2IO_section3 is replaced with the section3 array from the new grid definition.
1479 def interp_to_stations(self, method, calls, lats, lons, method_options=None, num_threads=1) -> xr.DataArray: 1480 """ 1481 Perform spatial interpolation to station points. 1482 1483 Parameters 1484 ---------- 1485 method 1486 Interpolate method to use. This can either be an integer or string 1487 using the following mapping: 1488 1489 | Interpolate Scheme | Integer Value | 1490 | :---: | :---: | 1491 | 'bilinear' | 0 | 1492 | 'bicubic' | 1 | 1493 | 'neighbor' | 2 | 1494 | 'budget' | 3 | 1495 | 'spectral' | 4 | 1496 | 'neighbor-budget' | 6 | 1497 1498 calls 1499 Station calls used for labeling new station index coordinate 1500 lats 1501 Latitudes of the station points. 1502 lons 1503 Longitudes of the station points. 1504 1505 Returns 1506 ------- 1507 interp_to_stations 1508 DataArray interpolated to lat and lon locations and labeled with 1509 dimension and coordinate 'station'. (..., y, x) -> (..., station) 1510 """ 1511 da = self._obj 1512 # TODO ensure that y, x are rightmost dims; they should be if opening 1513 # with grib2io engine 1514 1515 calls = np.asarray(calls) 1516 lats = np.asarray(lats) 1517 lons = np.asarray(lons) 1518 latitude = xr.DataArray(lats, dims=['station']) 1519 longitude = xr.DataArray(lons, dims=['station']) 1520 1521 # create new coords 1522 new_coords = dict(da.coords) 1523 del new_coords['latitude'] 1524 del new_coords['longitude'] 1525 new_coords['longitude'] = longitude 1526 new_coords['latitude'] = latitude 1527 new_coords['station'] = calls 1528 1529 new_dims = da.dims[:-2] + ('station',) 1530 1531 # make grid def in from section3 on da attrs 1532 grid_def_in = self.griddef() 1533 1534 if da.chunks is None: 1535 data = interp_nd_stations(da.data, method=method, grid_def_in=grid_def_in, lats=lats, 1536 lons=lons, method_options=method_options, num_threads=num_threads) 1537 else: 1538 import dask 1539 front_shape = da.shape[:-1] 1540 data = da.data.map_blocks(interp_nd_stations, method=method, grid_def_in=grid_def_in, 1541 lats=lats, lons=lons, method_options=method_options, 1542 drop_axis=-1, chunks=da.chunks[:-2]+latitude.shape, 1543 dtype=da.dtype) 1544 1545 new_da = xr.DataArray(data, dims=new_dims, coords=new_coords, attrs=da.attrs) 1546 1547 new_da.name = da.name 1548 return new_da
Perform spatial interpolation to station points.
Parameters
- method: Interpolate method to use. This can either be an integer or string using the following mapping:
| Interpolate Scheme | Integer Value |
|---|---|
| 'bilinear' | 0 |
| 'bicubic' | 1 |
| 'neighbor' | 2 |
| 'budget' | 3 |
| 'spectral' | 4 |
| 'neighbor-budget' | 6 |
- calls: Station calls used for labeling new station index coordinate
- lats: Latitudes of the station points.
- lons: Longitudes of the station points.
Returns
- interp_to_stations: DataArray interpolated to lat and lon locations and labeled with dimension and coordinate 'station'. (..., y, x) -> (..., station)
1550 def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"): 1551 """ 1552 Write a DataArray to a grib2 file. 1553 1554 Parameters 1555 ---------- 1556 filename 1557 Name of the grib2 file to write to. 1558 mode: {"x", "w", "a"}, optional, default="x" 1559 Persistence mode 1560 1561 +------+-----------------------------------+ 1562 | mode | Description | 1563 +======+===================================+ 1564 | x | create (fail if exists) | 1565 +------+-----------------------------------+ 1566 | w | create (overwrite if exists) | 1567 +------+-----------------------------------+ 1568 | a | append (create if does not exist) | 1569 +------+-----------------------------------+ 1570 1571 """ 1572 da = self._obj.copy(deep=True) 1573 1574 coords_keys = sorted(da.coords.keys()) 1575 coords_keys = [k for k in coords_keys if k in AVAILABLE_NON_GEO_COORDS] 1576 1577 # If there are dimension coordinates, the DataArray is a hypercube of 1578 # grib2 messages. 1579 1580 # Create `indexes` which is a list of lists of dictionaries for all 1581 # dimension coordinates. Each dictionary key is the dimension 1582 # coordinate name and the value is a list of the dimension coordinate 1583 # values. This allows for easy iteration over all possible grib2 1584 # messages in the DataArray by using itertools.product. 1585 # 1586 # For example: 1587 # indexes = [ 1588 # [ 1589 # {"leadTime": 9}, 1590 # {"leadTime": 12}, 1591 # ], 1592 # [ 1593 # {"valueOfFirstFixedSurface": 900}, 1594 # {"valueOfFirstFixedSurface": 925}, 1595 # {"valueOfFirstFixedSurface": 950}, 1596 # ], 1597 # ] 1598 1599 # assign loc indexes to dimensions without indexes for uniform selection by name 1600 loc_indexes = list() 1601 for dim in da.dims: 1602 if dim not in da.indexes: 1603 da = da.assign_coords({dim: range(da[dim].size)}) 1604 loc_indexes.append(dim) 1605 1606 indexes = [] 1607 for index in [i for i in AVAILABLE_NON_GEO_DIMS if i in da.dims]: 1608 values = da.coords[index].values 1609 if len(values) != len(set(values)): 1610 raise ValueError( 1611 f"Dimension coordinate '{index}' has duplicate values, but to_grib2 requires unique values to find each GRIB2 message in the DataArray." 1612 ) 1613 listeach = [{index: value} for value in sorted(values)] 1614 indexes.append(listeach) 1615 1616 # If `dim_coords` is [], then the DataArray is a single grib2 message and 1617 # itertools.product(*dim_coords) will run once with `selectors = ()`. 1618 for selectors in itertools.product(*indexes): 1619 # Need to find the correct data in the DataArray based on the 1620 # dimension coordinates. 1621 filters = {k: v for d in selectors for k, v in d.items()} 1622 1623 # If `filters` is {}, then the DataArray is a single grib2 message 1624 # and da.sel(indexers={}) returns the DataArray. 1625 selected = da.sel(indexers=filters) 1626 1627 newmsg = Grib2Message( 1628 selected.attrs["GRIB2IO_section0"], 1629 selected.attrs["GRIB2IO_section1"], 1630 selected.attrs["GRIB2IO_section2"], 1631 selected.attrs["GRIB2IO_section3"], 1632 selected.attrs["GRIB2IO_section4"], 1633 selected.attrs["GRIB2IO_section5"], 1634 ) 1635 newmsg.data = np.array(selected.data) 1636 1637 # For dimension coordinates, set the grib2 message metadata to the 1638 # dimension coordinate value. 1639 for index, value in filters.items(): 1640 if index not in loc_indexes: 1641 setattr(newmsg, index, value) 1642 1643 # For non-dimension coordinates, set the grib2 message metadata to 1644 # the DataArray coordinate value. 1645 for index in [i for i in coords_keys if i not in da.dims]: 1646 setattr(newmsg, index, selected.coords[index].values) 1647 1648 # Set section 5 attributes to the da.encoding dictionary. 1649 for key, value in selected.encoding.items(): 1650 if key in ["dtype", "chunks", "original_shape"]: 1651 continue 1652 setattr(newmsg, key, value) 1653 1654 # write the message to file 1655 with grib2io.open(filename, mode=mode) as f: 1656 f.write(newmsg) 1657 mode = "a"
Write a DataArray to a grib2 file.
Parameters
- filename: Name of the grib2 file to write to.
mode ({"x", "w", "a"}, optional, default="x"): Persistence mode
+------+-----------------------------------+ | mode | Description | +======+===================================+ | x | create (fail if exists) | +------+-----------------------------------+ | w | create (overwrite if exists) | +------+-----------------------------------+ | a | append (create if does not exist) | +------+-----------------------------------+
1659 def update_attrs(self, **kwargs): 1660 """ 1661 Update many of the attributes of the DataArray. 1662 1663 Parameters 1664 ---------- 1665 **kwargs 1666 Attributes to update. This can include many of the GRIB2IO message 1667 attributes that you can find when you print a GRIB2IO message. For 1668 conflicting updates, the last keyword will be used. 1669 1670 +-----------------------+------------------------------------------+ 1671 | kwargs | Description | 1672 +=======================+==========================================+ 1673 | shortName="VTMP" | Set shortName to "VTMP", along with | 1674 | | appropriate discipline, | 1675 | | parameterCategory, parameterNumber, | 1676 | | fullName and units. | 1677 +-----------------------+------------------------------------------+ 1678 | discipline=0, | Set shortName, discipline, | 1679 | parameterCategory=0, | parameterCategory, parameterNumber, | 1680 | parameterNumber=1 | fullName and units appropriate for | 1681 | | "Virtual Temperature". | 1682 +-----------------------+------------------------------------------+ 1683 | discipline=0, | Conflicting keywords but | 1684 | parameterCategory=0, | 'shortName="TMP"' wins. Set shortName, | 1685 | parameterNumber=1, | discipline, parameterCategory, | 1686 | shortName="TMP" | parameterNumber, fullName and units | 1687 | | appropriate for "Temperature". | 1688 +-----------------------+------------------------------------------+ 1689 1690 Returns 1691 ------- 1692 DataArray 1693 DataArray with updated attributes. 1694 """ 1695 da = self._obj.copy(deep=True) 1696 1697 newmsg = Grib2Message( 1698 da.attrs["GRIB2IO_section0"], 1699 da.attrs["GRIB2IO_section1"], 1700 da.attrs["GRIB2IO_section2"], 1701 da.attrs["GRIB2IO_section3"], 1702 da.attrs["GRIB2IO_section4"], 1703 da.attrs["GRIB2IO_section5"], 1704 ) 1705 1706 coords_keys = [ 1707 k 1708 for k in da.coords.keys() 1709 if k in AVAILABLE_NON_GEO_COORDS 1710 ] 1711 1712 for grib2_name, value in kwargs.items(): 1713 if grib2_name == "gridDefinitionTemplateNumber": 1714 raise ValueError( 1715 "The gridDefinitionTemplateNumber attribute cannot be updated. The best way to change to a different grid is to interpolate the data to a new grid using the grib2io interpolate functions." 1716 ) 1717 if grib2_name == "productDefinitionTemplateNumber": 1718 raise ValueError( 1719 "The productDefinitionTemplateNumber attribute cannot be updated." 1720 ) 1721 if grib2_name == "dataRepresentationTemplateNumber": 1722 raise ValueError( 1723 "The dataRepresentationTemplateNumber attribute cannot be updated." 1724 ) 1725 if grib2_name in coords_keys: 1726 warnings.warn( 1727 f"Skipping attribute '{grib2_name}' because it is a coordinate. Use da.assign_coords() to change coordinate values." 1728 ) 1729 continue 1730 if hasattr(newmsg, grib2_name): 1731 setattr(newmsg, grib2_name, value) 1732 else: 1733 warnings.warn( 1734 f"Skipping attribute '{grib2_name}' because it is not a valid GRIB2 attribute for this message and cannot be updated." 1735 ) 1736 continue 1737 1738 da.attrs["GRIB2IO_section0"] = newmsg.section0 1739 da.attrs["GRIB2IO_section1"] = newmsg.section1 1740 da.attrs["GRIB2IO_section2"] = newmsg.section2 or [] 1741 da.attrs["GRIB2IO_section3"] = newmsg.section3 1742 da.attrs["GRIB2IO_section4"] = newmsg.section4 1743 da.attrs["GRIB2IO_section5"] = newmsg.section5 1744 da.attrs["fullName"] = newmsg.fullName 1745 da.attrs["shortName"] = newmsg.shortName 1746 da.attrs["units"] = newmsg.units 1747 1748 return da
Update many of the attributes of the DataArray.
Parameters
- **kwargs: Attributes to update. This can include many of the GRIB2IO message attributes that you can find when you print a GRIB2IO message. For conflicting updates, the last keyword will be used.
+-----------------------+------------------------------------------+ | kwargs | Description | +=======================+==========================================+ | shortName="VTMP" | Set shortName to "VTMP", along with | | | appropriate discipline, | | | parameterCategory, parameterNumber, | | | fullName and units. | +-----------------------+------------------------------------------+ | discipline=0, | Set shortName, discipline, | | parameterCategory=0, | parameterCategory, parameterNumber, | | parameterNumber=1 | fullName and units appropriate for | | | "Virtual Temperature". | +-----------------------+------------------------------------------+ | discipline=0, | Conflicting keywords but | | parameterCategory=0, | 'shortName="TMP"' wins. Set shortName, | | parameterNumber=1, | discipline, parameterCategory, | | shortName="TMP" | parameterNumber, fullName and units | | | appropriate for "Temperature". | +-----------------------+------------------------------------------+
Returns
- DataArray: DataArray with updated attributes.
1750 def subset(self, lats, lons) -> xr.DataArray: 1751 """ 1752 Subset the DataArray to a region defined by latitudes and longitudes. 1753 1754 Parameters 1755 ---------- 1756 lats 1757 Latitude bounds of the region. 1758 lons 1759 Longitude bounds of the region. 1760 1761 Returns 1762 ------- 1763 subset 1764 DataArray subset to the region. 1765 """ 1766 da = self._obj.copy(deep=True) 1767 1768 newmsg = Grib2Message( 1769 da.attrs["GRIB2IO_section0"], 1770 da.attrs["GRIB2IO_section1"], 1771 da.attrs["GRIB2IO_section2"], 1772 da.attrs["GRIB2IO_section3"], 1773 da.attrs["GRIB2IO_section4"], 1774 da.attrs["GRIB2IO_section5"], 1775 ) 1776 1777 newmsg.data = np.zeros((newmsg.ny, newmsg.nx), dtype=np.float32) 1778 1779 newmsg = newmsg.subset(lats, lons) 1780 1781 da.attrs["GRIB2IO_section3"] = newmsg.section3 1782 1783 mask_lat = (da.latitude >= newmsg.latitudeLastGridpoint) & ( 1784 da.latitude <= newmsg.latitudeFirstGridpoint 1785 ) 1786 mask_lon = (da.longitude >= newmsg.longitudeFirstGridpoint) & ( 1787 da.longitude <= newmsg.longitudeLastGridpoint 1788 ) 1789 1790 del newmsg 1791 1792 return da.where((mask_lon & mask_lat).compute(), drop=True)
Subset the DataArray to a region defined by latitudes and longitudes.
Parameters
- lats: Latitude bounds of the region.
- lons: Longitude bounds of the region.
Returns
- subset: DataArray subset to the region.
1795def build_datatree_from_grib(filename, file_index, filters=None, stack_vertical=False): 1796 """ 1797 Build a DataTree from GRIB2 messages. 1798 1799 Parameters 1800 ---------- 1801 filename : str 1802 Path to the GRIB2 file. 1803 file_index : pandas.DataFrame 1804 DataFrame of GRIB2 message index. 1805 filters : dict, optional 1806 Filter criteria for GRIB2 messages. 1807 stack_vertical : bool, optional 1808 If True, vertical levels will be stacked in a single dataset 1809 instead of being organized in separate tree nodes. 1810 1811 Returns 1812 ------- 1813 xarray.DataTree 1814 A hierarchical DataTree representation of the GRIB2 data. 1815 """ 1816 if filters is None: 1817 filters = {} 1818 1819 # Apply any filters from user 1820 for k, v in filters.items(): 1821 if k not in file_index.columns: 1822 file_index = file_index.copy() 1823 file_index[k] = file_index.msg.apply(lambda msg: getattr(msg, k, None)) 1824 file_index = filter_index(file_index, k, v) 1825 1826 # Make a copy to avoid the SettingWithCopyWarning 1827 file_index = file_index.copy() 1828 1829 # Extract metadata needed for tree organization 1830 # Use a safer approach to handle missing attributes 1831 def safe_getattr(obj, name): 1832 try: 1833 attr = getattr(obj, name) 1834 # Need to test if the attribute is Grib2Metadata. If so, 1835 # then get the value attribute. 1836 if isinstance(attr, grib2io.templates.Grib2Metadata): 1837 attr = attr.value 1838 return attr 1839 except (AttributeError, KeyError): 1840 return None 1841 1842 for attr in _TREE_HIERARCHY_LEVELS: 1843 if (attr not in file_index.columns) and (attr != 'valueOfFirstFixedSurface'): 1844 file_index[attr] = file_index.msg.apply(lambda msg: safe_getattr(msg, attr)) 1845 1846 # Also extract shortName for variable naming 1847 if 'shortName' not in file_index.columns: 1848 file_index = file_index.assign(shortName=file_index.msg.apply(lambda msg: getattr(msg, 'shortName', None))) 1849 file_index = file_index.assign(nx=file_index.msg.apply(lambda msg: getattr(msg, 'nx', None))) 1850 file_index = file_index.assign(ny=file_index.msg.apply(lambda msg: getattr(msg, 'ny', None))) 1851 1852 # Create root DataTree 1853 root = xr.DataTree() 1854 1855 # Adjust hierarchy levels if we're stacking vertical levels 1856 hierarchy_levels = list(_TREE_HIERARCHY_LEVELS) # This makes a copy 1857 if stack_vertical and "valueOfFirstFixedSurface" in hierarchy_levels: 1858 hierarchy_levels.remove("valueOfFirstFixedSurface") 1859 1860 # First group by level type 1861 level_groups = {} 1862 1863 # Create a dictionary to group data by level type 1864 for level_type in file_index['typeOfFirstFixedSurface'].unique(): 1865 if pd.notna(level_type): # Skip None/NaN values 1866 level_info = _LEVEL_NAME_MAPPING.get(level_type, f"level_{level_type}") 1867 level_name = level_info[0] 1868 level_source = level_info[1] 1869 # Get all rows for this level type 1870 level_data = file_index[file_index['typeOfFirstFixedSurface'] == level_type] 1871 level_groups[level_type] = {'name': level_name, 'data': level_data} 1872 1873 # Process each level group 1874 for level_type, group_info in level_groups.items(): 1875 level_name = group_info['name'] 1876 level_df = group_info['data'] 1877 1878 # Create a branch for this level type 1879 level_tree = xr.DataTree() 1880 1881 # Process this branch based on PDTN, perturbation number, etc. 1882 process_level_branch(level_tree, level_df, filename) 1883 1884 # Add this branch to the main tree 1885 root[level_name] = level_tree 1886 1887 return root
Build a DataTree from GRIB2 messages.
Parameters
- filename (str): Path to the GRIB2 file.
- file_index (pandas.DataFrame): DataFrame of GRIB2 message index.
- filters (dict, optional): Filter criteria for GRIB2 messages.
- stack_vertical (bool, optional): If True, vertical levels will be stacked in a single dataset instead of being organized in separate tree nodes.
Returns
- xarray.DataTree: A hierarchical DataTree representation of the GRIB2 data.
1890def process_level_branch(level_tree, df, filename): 1891 """ 1892 Process a level type branch of the data tree, organizing by PDTN and other attributes. 1893 1894 Parameters 1895 ---------- 1896 level_tree : xarray.DataTree 1897 The DataTree node for this level type 1898 df : pandas.DataFrame 1899 DataFrame of messages for this level type 1900 filename : str 1901 Path to the GRIB2 file 1902 """ 1903 # Group by PDTN 1904 pdtn_groups = {} 1905 1906 # Group data by PDTN first 1907 for pdtn_value in df['productDefinitionTemplateNumber'].unique(): 1908 if pd.notna(pdtn_value): 1909 pdtn_df = df[df['productDefinitionTemplateNumber'] == pdtn_value] 1910 pdtn_groups[pdtn_value] = pdtn_df 1911 1912 # If there's only one PDTN value, skip creating PDTN branch level 1913 if len(pdtn_groups) == 1: 1914 pdtn, pdtn_df = next(iter(pdtn_groups.items())) 1915 1916 pdtn_name = f"pdtn_{int(pdtn)}" 1917 1918 # Check if we need to further subdivide by perturbation number 1919 has_perturbations = ('perturbationNumber' in pdtn_df.columns and 1920 len(pdtn_df['perturbationNumber'].dropna().unique()) > 1) 1921 1922 # Check if we need to further subdivide by probabilities unique for each variable. 1923 has_probabilities = ('typeOfProbability' in pdtn_df.columns and 1924 len(pdtn_df['typeOfProbability'].dropna().unique()) > 1) 1925 1926 if has_perturbations: 1927 # Process perturbations directly on the level tree 1928 process_perturbation_groups(level_tree, pdtn_df, filename) 1929 elif has_probabilities: 1930 # Process probability groups 1931 process_probability_groups(level_tree, pdtn_df, filename) 1932 else: 1933 # Try to create dataset directly on level 1934 try: 1935 dss = create_datasets_from_df(pdtn_df, filename) 1936 if dss is not None: 1937 dt = xr.DataTree() 1938 if len(dss) == 1: 1939 dt.ds = dss[0] 1940 else: 1941 for ds in dss: 1942 varname = list(ds.data_vars)[0] 1943 dt[f"var_{varname}"] = ds 1944 level_tree[pdtn_name] = dt 1945 except Exception as e: 1946 print(f"Error creating dataset for level with pdtn {int(pdtn)}: {e}") 1947 1948 # Try to separate by variable name as a fallback 1949 try_process_by_variables(level_tree, pdtn_df, filename) 1950 else: 1951 # Multiple PDTN values, process each group with PDTN branch nodes 1952 for pdtn, pdtn_df in pdtn_groups.items(): 1953 # Use a simple node name that's easy to use in code 1954 pdtn_name = f"pdtn_{int(pdtn)}" 1955 1956 # Check if we need to further subdivide by perturbation number 1957 has_perturbations = ('perturbationNumber' in pdtn_df.columns and 1958 len(pdtn_df['perturbationNumber'].dropna().unique()) > 1) 1959 1960 # Check if we need to further subdivide by probabilities unique for each variable. 1961 has_probabilities = ('typeOfProbability' in pdtn_df.columns and 1962 len(pdtn_df['typeOfProbability'].dropna().unique()) > 1) 1963 1964 if has_perturbations: 1965 # Create a branch for this PDTN 1966 pdtn_tree = xr.DataTree() 1967 1968 # Process perturbation groups 1969 process_perturbation_groups(pdtn_tree, pdtn_df, filename) 1970 1971 # Only add the PDTN branch if it has children 1972 if len(pdtn_tree.children) > 0 or pdtn_tree.ds is not None: 1973 level_tree[pdtn_name] = pdtn_tree 1974 elif has_probabilities: 1975 # Create a branch for this PDTN 1976 pdtn_tree = xr.DataTree() 1977 1978 # Process probability groups 1979 process_probability_groups(pdtn_tree, pdtn_df, filename) 1980 1981 # Only add the PDTN branch if it has children 1982 if len(pdtn_tree.children) > 0 or pdtn_tree.ds is not None: 1983 level_tree[pdtn_name] = pdtn_tree 1984 else: 1985 # Create a subtree for this PDTN 1986 pdtn_tree = xr.DataTree() 1987 1988 # Try to create dataset directly on level 1989 try: 1990 dss = create_datasets_from_df(pdtn_df, filename) 1991 if dss is not None: 1992 if len(dss) == 1: 1993 pdtn_tree.ds = dss[0] 1994 else: 1995 for ds in dss: 1996 varname = list(ds.data_vars)[0] 1997 pdtn_tree[f"var_{varname}"] = ds 1998 level_tree[pdtn_name] = pdtn_tree 1999 except Exception as e: 2000 print(f"Error creating dataset for level with pdtn {int(pdtn)}: {e}") 2001 2002 # Try to separate by variable name as a fallback 2003 try_process_by_variables(pdtn_tree, pdtn_df, filename) 2004 level_tree[pdtn_name] = pdtn_tree
Process a level type branch of the data tree, organizing by PDTN and other attributes.
Parameters
- level_tree (xarray.DataTree): The DataTree node for this level type
- df (pandas.DataFrame): DataFrame of messages for this level type
- filename (str): Path to the GRIB2 file
2007def process_probability_groups(target_tree, pdtn_df, filename): 2008 """ 2009 """ 2010 success = False 2011 # Group by type of probability 2012 prob_groups = {} 2013 for prob_value in pdtn_df['typeOfProbability'].unique(): 2014 if pd.notna(prob_value): 2015 prob_df = pdtn_df[pdtn_df['typeOfProbability'] == prob_value] 2016 prob_groups[prob_value] = prob_df 2017 2018 # Process each probability group 2019 prob_dict = {} 2020 for prob_num, prob_df in prob_groups.items(): 2021 prob_name = f"prob_{int(prob_num)}" 2022 2023 # Try to create dataset for this probability group 2024 try: 2025 dss = create_datasets_from_df(prob_df, filename) 2026 dt = xr.DataTree() 2027 if len(dss) == 1: 2028 dt.ds = dss[0] 2029 target_tree[prob_name] = dt 2030 elif len(dss) > 1: 2031 for ds in dss: 2032 dt[f"var_{ds.data_vars[0]}"] = ds 2033 target_tree[prob_name] = dt 2034 except Exception as e: 2035 # Log error but continue processing other groups 2036 print(f"Error creating dataset for type of probability {prob_name}: {e}") 2037 2038 return success
2041def process_perturbation_groups(target_tree, pdtn_df, filename): 2042 """ 2043 Process perturbation groups and add them to the target tree. 2044 2045 Parameters 2046 ---------- 2047 target_tree : xarray.DataTree 2048 The tree node to add perturbation groups to 2049 pdtn_df : pandas.DataFrame 2050 DataFrame of messages for a specific PDTN 2051 filename : str 2052 Path to the GRIB2 file 2053 2054 Returns 2055 ------- 2056 bool 2057 True if at least one perturbation was successfully processed 2058 """ 2059 success = False 2060 # Group by perturbation number 2061 pert_groups = {} 2062 for pert_value in pdtn_df['perturbationNumber'].unique(): 2063 if pd.notna(pert_value): 2064 pert_df = pdtn_df[pdtn_df['perturbationNumber'] == pert_value] 2065 pert_groups[pert_value] = pert_df 2066 2067 # Process each perturbation group 2068 for pert_num, pert_df in pert_groups.items(): 2069 pert_name = f"pert_{int(pert_num)}" 2070 2071 ## Try to create dataset for this perturbation group 2072 #try: 2073 # dss = create_datasets_from_df(pert_df, filename) 2074 # if dss is not None: 2075 # if len(dss) == 1: 2076 # target_tree.ds = dss[0] 2077 # else: 2078 # dss_dict = {f"ds_{i}": ds for i, ds in enumerate(dss)} 2079 # atree = xr.DataTree(dss_dict) 2080 # target_tree[prob_name] = atree 2081 # success = True 2082 #except Exception as e: 2083 # # Log error but continue processing other groups 2084 # print(f"Error creating dataset for perturbation {pert_name}: {e}") 2085 2086 # Try to create dataset for this perturbation group 2087 try: 2088 dss = create_datasets_from_df(pert_df, filename) 2089 dt = xr.DataTree() 2090 if len(dss) == 1: 2091 dt.ds = dss[0] 2092 target_tree[pert_name] = dt 2093 elif len(dss) > 1: 2094 for ds in dss: 2095 dt[f"pert{ds.data_vars[0]}"] = ds 2096 target_tree[pert_name] = dt 2097 except Exception as e: 2098 # Log error but continue processing other groups 2099 print(f"Error creating dataset for perturbation {pert_name}: {e}") 2100 2101 return success
Process perturbation groups and add them to the target tree.
Parameters
- target_tree (xarray.DataTree): The tree node to add perturbation groups to
- pdtn_df (pandas.DataFrame): DataFrame of messages for a specific PDTN
- filename (str): Path to the GRIB2 file
Returns
- bool: True if at least one perturbation was successfully processed
2104def try_process_by_variables(target_tree, df, filename): 2105 """ 2106 Try to separate data by variable names and create datasets. 2107 2108 Parameters 2109 ---------- 2110 target_tree : xarray.DataTree 2111 The tree node to add variable datasets to 2112 df : pandas.DataFrame 2113 DataFrame of messages 2114 filename : str 2115 Path to the GRIB2 file 2116 2117 Returns 2118 ------- 2119 bool 2120 True if at least one variable was successfully processed 2121 """ 2122 success = False 2123 2124 try: 2125 for var_name in df['shortName'].unique(): 2126 if pd.notna(var_name): 2127 var_df = df[df['shortName'] == var_name] 2128 try: 2129 var_ds = create_datasets_from_df(var_df, filename) 2130 if var_ds is not None: 2131 target_tree[f"var_{var_name}"] = var_ds[0] 2132 success = True 2133 except Exception as var_e: 2134 print(f"Error creating dataset for variable {var_name}: {var_e}") 2135 except Exception as nested_e: 2136 print(f"Failed to process variables: {nested_e}") 2137 2138 return success
Try to separate data by variable names and create datasets.
Parameters
- target_tree (xarray.DataTree): The tree node to add variable datasets to
- df (pandas.DataFrame): DataFrame of messages
- filename (str): Path to the GRIB2 file
Returns
- bool: True if at least one variable was successfully processed
2141def create_datasets_from_df( 2142 df, 2143 filename, 2144 verbose=False 2145) -> typing.Optional[typing.List[xr.Dataset]]: 2146 """ 2147 Create a list of xarray Datasets from a DataFrame of messages. 2148 2149 Parameters 2150 ---------- 2151 df : pandas.DataFrame 2152 DataFrame of GRIB messages 2153 filename : str 2154 Path to the GRIB2 file 2155 verbose : bool, optional 2156 If True, prints detailed debugging information 2157 2158 Returns 2159 ------- 2160 dss 2161 List of Datasets, or None if creation failed 2162 """ 2163 try: 2164 if verbose: 2165 print(f"\n==== VERBOSE DEBUG INFO ====") 2166 print(f"Creating dataset from DataFrame with {len(df)} messages") 2167 print(f"DataFrame columns: {df.columns.tolist()}") 2168 2169 if 'shortName' in df.columns: 2170 print(f"Variables in group: {df['shortName'].unique().tolist()}") 2171 2172 if 'valueOfFirstFixedSurface' in df.columns: 2173 print(f"Vertical levels: {df['valueOfFirstFixedSurface'].unique().tolist()}") 2174 2175 # Process by variables 2176 datasets = {} 2177 2178 # Process each variable separately, regardless of whether there are vertical levels 2179 for var_name, var_df in df.groupby('shortName'): 2180 if verbose: 2181 print( 2182 f"\n Processing variable: {var_name} with {len(var_df)} messages, with pdtn(s) = {var_df['productDefinitionTemplateNumber'].unique()}") 2183 2184 # Process vertical levels if present 2185 if 'valueOfFirstFixedSurface' in var_df.columns and len(var_df['valueOfFirstFixedSurface'].unique()) > 1: 2186 if verbose: 2187 print(f" Variable {var_name} has multiple vertical levels") 2188 # Process each level separately 2189 level_das = [] 2190 2191 for level, level_df in var_df.groupby('valueOfFirstFixedSurface'): 2192 if verbose: 2193 print(f" Processing level {level} with {len(level_df)} messages") 2194 try: 2195 # Parse the index and get dimensions for this level 2196 file_index, non_geo_dims, attrs, coord_attrs = parse_grib_index(level_df, {}) 2197 # Remove valueOfFirstFixedSurface from dimensions since we're handling it separately 2198 non_geo_dims = [d for d in non_geo_dims if d.__name__ != "ValueOfFirstFixedSurfaceDim"] 2199 2200 frames, cube, extra_geo = make_variables( 2201 file_index, filename, non_geo_dims, allow_uneven_dims=True) 2202 2203 if frames is not None and len(frames) == 1: 2204 level_da = build_da_without_coords(frames[0], cube, filename, attrs) 2205 # Add this level to the list with its level value as coord 2206 level_da = level_da.assign_coords(valueOfFirstFixedSurface=level) 2207 level_das.append(level_da) 2208 except Exception as e: 2209 if verbose: 2210 print(f" Error processing level {level} for {var_name}: {e}") 2211 2212 if level_das: 2213 # Combine all levels into a single DataArray along the valueOfFirstFixedSurface dimension 2214 if verbose: 2215 print(f" Combining {len(level_das)} levels for {var_name}") 2216 try: 2217 combined_da = xr.concat(level_das, dim='valueOfFirstFixedSurface') 2218 # Create a simple dataset with just this variable 2219 var_ds = xr.Dataset({var_name: combined_da}) 2220 # Assign the coords from the first level's cube 2221 var_ds = assign_xr_meta(var_ds, frames, cube, non_geo_dims, extra_geo, coord_attrs) 2222 # TODO: is the below code all now in assign_xr_meta? was there instances where refDate and leadTime were not coords? 2223 # var_ds = var_ds.assign_coords(coords_from_cube(cube)) 2224 # Add extra geo coords 2225 # if extra_geo: 2226 # var_ds = var_ds.assign_coords(extra_geo) 2227 # Add valid date coords if available 2228 # if 'refDate' in var_ds.coords and 'leadTime' in var_ds.coords: 2229 # var_ds = var_ds.assign_coords(dict(validDate=var_ds.coords['refDate']+var_ds.coords['leadTime'])) 2230 2231 # Store this variable's dataset 2232 datasets[var_name] = var_ds 2233 if verbose: 2234 print(f" Created dataset for {var_name} with levels") 2235 except Exception as e: 2236 if verbose: 2237 print(f" Error combining levels for {var_name}: {e}") 2238 else: 2239 # Single level or no vertical levels 2240 if verbose: 2241 print(f" Variable {var_name} is a single level or has no vertical dimension") 2242 try: 2243 # Parse the index and get dimensions 2244 file_index, non_geo_dims, attrs, coord_attrs = parse_grib_index(var_df, {}) 2245 frames, cube, extra_geo = make_variables(file_index, filename, non_geo_dims, allow_uneven_dims=True) 2246 2247 if frames is not None and len(frames) == 1: 2248 # Create dataset with this variable 2249 var_ds = xr.Dataset() 2250 da = build_da_without_coords(frames[0], cube, filename, attrs) 2251 var_ds[da.name] = da 2252 2253 # Assign coords 2254 var_ds = assign_xr_meta(var_ds, frames, cube, non_geo_dims, extra_geo, coord_attrs) 2255 # TODO: is the below code all now in assign_xr_meta? was there instances where refDate and leadTime were not coords? 2256 # var_ds = var_ds.assign_coords(coords_from_cube(cube)) 2257 # if extra_geo: 2258 # var_ds = var_ds.assign_coords(extra_geo) 2259 # if 'refDate' in var_ds.coords and 'leadTime' in var_ds.coords: 2260 # var_ds = var_ds.assign_coords(dict(validDate=var_ds.coords['refDate']+var_ds.coords['leadTime'])) 2261 2262 # Store this variable's dataset 2263 datasets[var_name] = var_ds 2264 if verbose: 2265 print(f" Created dataset for {var_name}") 2266 elif frames is not None and len(frames) > 1: 2267 if verbose: 2268 print(f" Variable {var_name} has multiple frames, possibly different parameters") 2269 # Just use the first frame for now (simplified approach) 2270 var_ds = xr.Dataset() 2271 da = build_da_without_coords(frames[0], cube, filename, attrs) 2272 var_ds[da.name] = da 2273 2274 # Assign coords 2275 var_ds = assign_xr_meta(var_ds, frames, cube, non_geo_dims, extra_geo, coord_attrs) 2276 # TODO: is the below code all now in assign_xr_meta? was there instances where refDate and leadTime were not coords? 2277 # var_ds = var_ds.assign_coords(coords_from_cube(cube)) 2278 # if extra_geo: 2279 # var_ds = var_ds.assign_coords(extra_geo) 2280 # if 'refDate' in var_ds.coords and 'leadTime' in var_ds.coords: 2281 # var_ds = var_ds.assign_coords(dict(validDate=var_ds.coords['refDate']+var_ds.coords['leadTime'])) 2282 2283 datasets[var_name] = var_ds 2284 if verbose: 2285 print(f" Created dataset with first frame for {var_name}") 2286 except Exception as e: 2287 if verbose: 2288 print(f" Error processing variable {var_name}: {e}") 2289 2290 # Attempt to merge all the variable datasets 2291 if datasets: 2292 try: 2293 if verbose: 2294 print(f"\nMerging {len(datasets)} datasets...") 2295 # Get the list of datasets to merge 2296 ds_list = list(datasets.values()) 2297 2298 # Try merging them all at once 2299 try: 2300 combined_ds = xr.merge(ds_list) 2301 if verbose: 2302 print(f"Successfully merged all datasets into one.") 2303 print(f"Final dataset has variables: {list(combined_ds.data_vars)}") 2304 print(f"==== END VERBOSE DEBUG INFO ====\n") 2305 return [combined_ds] 2306 except Exception as merge_error: 2307 if verbose: 2308 print(f"Error merging all datasets: {merge_error}") 2309 return ds_list 2310 except Exception as e: 2311 if verbose: 2312 print(f"Error in final merge process: {e}") 2313 print(f"==== END VERBOSE DEBUG INFO ====\n") 2314 return None 2315 else: 2316 if verbose: 2317 print(f"No datasets were created for any variables") 2318 print(f"==== END VERBOSE DEBUG INFO ====\n") 2319 return None 2320 2321 except Exception as e: 2322 # If there's an error, log it and return None 2323 if verbose: 2324 print(f"Error creating dataset: {e}") 2325 import traceback 2326 traceback.print_exc() 2327 print(f"==== END VERBOSE DEBUG INFO ====\n") 2328 return None
Create a list of xarray Datasets from a DataFrame of messages.
Parameters
- df (pandas.DataFrame): DataFrame of GRIB messages
- filename (str): Path to the GRIB2 file
- verbose (bool, optional): If True, prints detailed debugging information
Returns
- dss: List of Datasets, or None if creation failed
2333 @xr.register_datatree_accessor("grib2io") 2334 class Grib2ioDataTree: 2335 """ 2336 DataTree accessor for GRIB2 files. 2337 2338 This accessor provides methods for working with GRIB2 data organized 2339 in a hierarchical tree structure. 2340 """ 2341 2342 def __init__(self, datatree_obj): 2343 self._obj = datatree_obj 2344 2345 def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"): 2346 """ 2347 Write all datasets in the DataTree to a GRIB2 file. 2348 2349 Parameters 2350 ---------- 2351 filename : str 2352 Name of the GRIB2 file to write to. 2353 mode : {"x", "w", "a"}, optional 2354 Persistence mode, default is "x" (create, fail if exists) 2355 """ 2356 # Start with the specified mode 2357 current_mode = mode 2358 2359 # Function to recursively process the tree 2360 def process_tree(node): 2361 nonlocal current_mode 2362 2363 # If this is a Dataset node with data variables 2364 if node.ds is not None and node.ds.data_vars: 2365 # Write dataset to GRIB2 file 2366 node.ds.grib2io.to_grib2(filename, mode=current_mode) 2367 # Switch to append mode after first write 2368 current_mode = "a" 2369 2370 # Process children 2371 for child_name, child_node in node.children.items(): 2372 process_tree(child_node) 2373 2374 # Start processing from the root 2375 process_tree(self._obj) 2376 2377 def griddef(self): 2378 """ 2379 Get the grid definition from the first dataset in the tree that has one. 2380 2381 Returns 2382 ------- 2383 grib2io.Grib2GridDef 2384 Grid definition object 2385 """ 2386 # Function to find first dataset with GRIB2IO_section3 2387 def find_griddef(node): 2388 if node.ds is not None and node.ds.data_vars: 2389 for var_name in node.ds.data_vars: 2390 if 'GRIB2IO_section3' in node.ds[var_name].attrs: 2391 return Grib2GridDef.from_section3(node.ds[var_name].attrs['GRIB2IO_section3']) 2392 2393 # Check children 2394 for child_name, child_node in node.children.items(): 2395 griddef = find_griddef(child_node) 2396 if griddef is not None: 2397 return griddef 2398 2399 return None 2400 2401 return find_griddef(self._obj) 2402 2403 def interp(self, method, grid_def_out, method_options=None, num_threads=1): 2404 """ 2405 Interpolate all datasets in the tree to a new grid. 2406 2407 Parameters 2408 ---------- 2409 method : str or int 2410 Interpolation method to use 2411 grid_def_out : grib2io.Grib2GridDef 2412 Target grid definition 2413 method_options : list, optional 2414 Options for interpolation method 2415 num_threads : int, optional 2416 Number of threads to use for interpolation 2417 2418 Returns 2419 ------- 2420 xarray.DataTree 2421 New DataTree with interpolated data 2422 """ 2423 new_tree = xr.DataTree() 2424 2425 # Function to recursively process the tree 2426 def process_tree(node, new_parent): 2427 # If this is a Dataset node with data variables 2428 if node.ds is not None and node.ds.data_vars: 2429 # Interpolate dataset 2430 interp_ds = node.ds.grib2io.interp(method, grid_def_out, 2431 method_options=method_options, 2432 num_threads=num_threads) 2433 2434 # Add to new tree at the same path 2435 if node == self._obj: # Root node 2436 new_parent.ds = interp_ds 2437 else: 2438 new_parent.ds = interp_ds 2439 2440 # Process children 2441 for child_name, child_node in node.children.items(): 2442 # Create same child in new tree 2443 new_child = xr.DataTree() 2444 new_parent[child_name] = new_child 2445 process_tree(child_node, new_child) 2446 2447 # Start processing from the root 2448 process_tree(self._obj, new_tree) 2449 2450 return new_tree 2451 2452 def subset(self, lats, lons): 2453 """ 2454 Subset all datasets in the tree to a region. 2455 2456 Parameters 2457 ---------- 2458 lats : list or tuple 2459 Latitude bounds [min_lat, max_lat] 2460 lons : list or tuple 2461 Longitude bounds [min_lon, max_lon] 2462 2463 Returns 2464 ------- 2465 xarray.DataTree 2466 New DataTree with subset data 2467 """ 2468 new_tree = xr.DataTree() 2469 2470 # Function to recursively process the tree 2471 def process_tree(node, new_parent): 2472 # If this is a Dataset node with data variables 2473 if node.ds is not None and node.ds.data_vars: 2474 # Subset dataset 2475 subset_ds = node.ds.grib2io.subset(lats, lons) 2476 2477 # Add to new tree at the same path 2478 if node == self._obj: # Root node 2479 new_parent.ds = subset_ds 2480 else: 2481 new_parent.ds = subset_ds 2482 2483 # Process children 2484 for child_name, child_node in node.children.items(): 2485 # Create same child in new tree 2486 new_child = xr.DataTree() 2487 new_parent[child_name] = new_child 2488 process_tree(child_node, new_child) 2489 2490 # Start processing from the root 2491 process_tree(self._obj, new_tree) 2492 2493 return new_tree
DataTree accessor for GRIB2 files.
This accessor provides methods for working with GRIB2 data organized in a hierarchical tree structure.
2345 def to_grib2(self, filename, mode: typing.Literal["x", "w", "a"] = "x"): 2346 """ 2347 Write all datasets in the DataTree to a GRIB2 file. 2348 2349 Parameters 2350 ---------- 2351 filename : str 2352 Name of the GRIB2 file to write to. 2353 mode : {"x", "w", "a"}, optional 2354 Persistence mode, default is "x" (create, fail if exists) 2355 """ 2356 # Start with the specified mode 2357 current_mode = mode 2358 2359 # Function to recursively process the tree 2360 def process_tree(node): 2361 nonlocal current_mode 2362 2363 # If this is a Dataset node with data variables 2364 if node.ds is not None and node.ds.data_vars: 2365 # Write dataset to GRIB2 file 2366 node.ds.grib2io.to_grib2(filename, mode=current_mode) 2367 # Switch to append mode after first write 2368 current_mode = "a" 2369 2370 # Process children 2371 for child_name, child_node in node.children.items(): 2372 process_tree(child_node) 2373 2374 # Start processing from the root 2375 process_tree(self._obj)
Write all datasets in the DataTree to a GRIB2 file.
Parameters
- filename (str): Name of the GRIB2 file to write to.
- mode ({"x", "w", "a"}, optional): Persistence mode, default is "x" (create, fail if exists)
2377 def griddef(self): 2378 """ 2379 Get the grid definition from the first dataset in the tree that has one. 2380 2381 Returns 2382 ------- 2383 grib2io.Grib2GridDef 2384 Grid definition object 2385 """ 2386 # Function to find first dataset with GRIB2IO_section3 2387 def find_griddef(node): 2388 if node.ds is not None and node.ds.data_vars: 2389 for var_name in node.ds.data_vars: 2390 if 'GRIB2IO_section3' in node.ds[var_name].attrs: 2391 return Grib2GridDef.from_section3(node.ds[var_name].attrs['GRIB2IO_section3']) 2392 2393 # Check children 2394 for child_name, child_node in node.children.items(): 2395 griddef = find_griddef(child_node) 2396 if griddef is not None: 2397 return griddef 2398 2399 return None 2400 2401 return find_griddef(self._obj)
Get the grid definition from the first dataset in the tree that has one.
Returns
- grib2io.Grib2GridDef: Grid definition object
2403 def interp(self, method, grid_def_out, method_options=None, num_threads=1): 2404 """ 2405 Interpolate all datasets in the tree to a new grid. 2406 2407 Parameters 2408 ---------- 2409 method : str or int 2410 Interpolation method to use 2411 grid_def_out : grib2io.Grib2GridDef 2412 Target grid definition 2413 method_options : list, optional 2414 Options for interpolation method 2415 num_threads : int, optional 2416 Number of threads to use for interpolation 2417 2418 Returns 2419 ------- 2420 xarray.DataTree 2421 New DataTree with interpolated data 2422 """ 2423 new_tree = xr.DataTree() 2424 2425 # Function to recursively process the tree 2426 def process_tree(node, new_parent): 2427 # If this is a Dataset node with data variables 2428 if node.ds is not None and node.ds.data_vars: 2429 # Interpolate dataset 2430 interp_ds = node.ds.grib2io.interp(method, grid_def_out, 2431 method_options=method_options, 2432 num_threads=num_threads) 2433 2434 # Add to new tree at the same path 2435 if node == self._obj: # Root node 2436 new_parent.ds = interp_ds 2437 else: 2438 new_parent.ds = interp_ds 2439 2440 # Process children 2441 for child_name, child_node in node.children.items(): 2442 # Create same child in new tree 2443 new_child = xr.DataTree() 2444 new_parent[child_name] = new_child 2445 process_tree(child_node, new_child) 2446 2447 # Start processing from the root 2448 process_tree(self._obj, new_tree) 2449 2450 return new_tree
Interpolate all datasets in the tree to a new grid.
Parameters
- method (str or int): Interpolation method to use
- grid_def_out (grib2io.Grib2GridDef): Target grid definition
- method_options (list, optional): Options for interpolation method
- num_threads (int, optional): Number of threads to use for interpolation
Returns
- xarray.DataTree: New DataTree with interpolated data
2452 def subset(self, lats, lons): 2453 """ 2454 Subset all datasets in the tree to a region. 2455 2456 Parameters 2457 ---------- 2458 lats : list or tuple 2459 Latitude bounds [min_lat, max_lat] 2460 lons : list or tuple 2461 Longitude bounds [min_lon, max_lon] 2462 2463 Returns 2464 ------- 2465 xarray.DataTree 2466 New DataTree with subset data 2467 """ 2468 new_tree = xr.DataTree() 2469 2470 # Function to recursively process the tree 2471 def process_tree(node, new_parent): 2472 # If this is a Dataset node with data variables 2473 if node.ds is not None and node.ds.data_vars: 2474 # Subset dataset 2475 subset_ds = node.ds.grib2io.subset(lats, lons) 2476 2477 # Add to new tree at the same path 2478 if node == self._obj: # Root node 2479 new_parent.ds = subset_ds 2480 else: 2481 new_parent.ds = subset_ds 2482 2483 # Process children 2484 for child_name, child_node in node.children.items(): 2485 # Create same child in new tree 2486 new_child = xr.DataTree() 2487 new_parent[child_name] = new_child 2488 process_tree(child_node, new_child) 2489 2490 # Start processing from the root 2491 process_tree(self._obj, new_tree) 2492 2493 return new_tree
Subset all datasets in the tree to a region.
Parameters
- lats (list or tuple): Latitude bounds [min_lat, max_lat]
- lons (list or tuple): Longitude bounds [min_lon, max_lon]
Returns
- xarray.DataTree: New DataTree with subset data